Skip to content

Commit

Permalink
Predictions: Replace list view with a combo
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Feb 12, 2022
1 parent 6c17080 commit 69c5f5c
Show file tree
Hide file tree
Showing 2 changed files with 312 additions and 84 deletions.
230 changes: 153 additions & 77 deletions Orange/widgets/evaluate/owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

import numpy
from AnyQt.QtWidgets import (
QTableView, QListWidget, QSplitter, QToolTip, QStyle, QApplication,
QSizePolicy
)
QTableView, QSplitter, QToolTip, QStyle, QApplication, QSizePolicy)
from AnyQt.QtGui import QPainter, QStandardItem, QPen, QColor
from AnyQt.QtCore import (
Qt, QSize, QRect, QRectF, QPoint, QLocale,
QModelIndex, pyqtSignal, QTimer,
QItemSelectionModel, QItemSelection)
from AnyQt.QtWidgets import QPushButton

from orangewidget.report import plural
from orangewidget.utils.itemmodels import AbstractSortTableModel
Expand Down Expand Up @@ -51,7 +50,7 @@ class OWPredictions(OWWidget):
description = "Display predictions of models for an input dataset."
keywords = []

buttons_area_orientation = None
want_control_area = False

class Inputs:
data = Input("Data", Orange.data.Table)
Expand All @@ -76,7 +75,17 @@ class Error(OWWidget.Error):
score_table = settings.SettingProvider(ScoreTable)

#: List of selected class value indices in the `class_values` list
selected_classes = settings.ContextSetting([])
PROB_OPTS = ["(None)",
"Classes in data", "Classes known to the model", "Classes in data and model"]
PROB_TOOLTIPS = ["Don't show probabilities",
"Show probabilities for classes in the data",
"Show probabilities for classes known to the model,\n"
"including those that don't appear in this data",
"Show probabilities for classes in data that are also\n"
"known to the model"
]
NO_PROBS, DATA_PROBS, MODEL_PROBS, BOTH_PROBS = range(4)
shown_probs = settings.ContextSetting(NO_PROBS)
selection = settings.Setting([], schema_only=True)
show_scores = settings.Setting(True)
TARGET_AVERAGE = "(Average over classes)"
Expand All @@ -94,31 +103,38 @@ def __init__(self):
self.selection_store = None
self.__pending_selection = self.selection

self.reset_button = gui.button(
self.controlArea, self, "Restore Original Order",
callback=self._reset_order,
tooltip="Show rows in the original order")
gui.separator(self.controlArea, 16)
self._prob_controls = []

gui.listBox(
self.controlArea, self, "selected_classes", "class_values",
box="Show probabilities",
callback=self._update_prediction_delegate,
selectionMode=QListWidget.ExtendedSelection,
sizePolicy=(QSizePolicy.Preferred, QSizePolicy.MinimumExpanding),
minimumHeight=100, maximumHeight=150)

gui.rubber(self.controlArea)

box = gui.vBox(self.controlArea, "Model Performance")
predopts = gui.hBox(
None, sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed))
self._prob_controls = [
predopts,
gui.widgetLabel(predopts, "Show probabilities for"),
gui.comboBox(
predopts, self, "shown_probs", contentsLength=30,
callback=self._update_prediction_delegate)
]
gui.rubber(predopts)
self.reset_button = button = QPushButton("Restore Original Order")
button.clicked.connect(self._reset_order)
button.setToolTip("Show rows in the original order")
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
predopts.layout().addWidget(self.reset_button)

scoreopts = gui.hBox(
None, sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed))
gui.checkBox(
box, self, "show_scores", "Show perfomance scores",
scoreopts, self, "show_scores", "Show perfomance scores",
callback=self._update_score_table_visibility
)
self.target_selection = gui.comboBox(
box, self, "target_class", items=[], label="Target class:",
sendSelectedValue=True, callback=self._on_target_changed
)
gui.separator(scoreopts, 32)
self._target_controls = [
gui.widgetLabel(scoreopts, "Target class:"),
gui.comboBox(
scoreopts, self, "target_class", items=[], contentsLength=30,
sendSelectedValue=True, callback=self._on_target_changed)
]
gui.rubber(scoreopts)

table_opts = dict(horizontalScrollBarPolicy=Qt.ScrollBarAlwaysOn,
horizontalScrollMode=QTableView.ScrollPerPixel,
Expand Down Expand Up @@ -153,9 +169,12 @@ def __init__(self):
self.splitter.addWidget(self.dataview)

self.score_table = ScoreTable(self)
self.vsplitter = gui.vBox(self.mainArea)
self.vsplitter.layout().addWidget(self.splitter)
self.vsplitter.layout().addWidget(self.score_table.view)
self.mainArea.layout().setSpacing(0)
self.mainArea.layout().setContentsMargins(4, 0, 4, 4)
self.mainArea.layout().addWidget(predopts)
self.mainArea.layout().addWidget(self.splitter)
self.mainArea.layout().addWidget(scoreopts)
self.mainArea.layout().addWidget(self.score_table.view)

def get_selection_store(self, model):
# Both models map the same, so it doesn't matter which one is used
Expand All @@ -168,6 +187,7 @@ def get_selection_store(self, model):
@check_sql_input
def set_data(self, data):
self.Warning.empty_data(shown=data is not None and not data)
self.closeContext()
self.data = data
self.selection_store = None
if not data:
Expand All @@ -194,14 +214,21 @@ def set_data(self, data):
self._update_data_sort_order, self.dataview,
self.predictionsview))

self._set_target_combos()
if self.is_discrete_class:
self.openContext(self.class_var.values)
self._invalidate_predictions()

def _store_selection(self):
self.selection = list(self.selection_store.rows)

@property
def class_var(self):
return self.data and self.data.domain.class_var
return self.data is not None and self.data.domain.class_var

@property
def is_discrete_class(self):
return bool(self.class_var) and self.class_var.is_discrete

@Inputs.predictors
def set_predictor(self, index, predictor: Model):
Expand All @@ -219,31 +246,43 @@ def insert_predictor(self, index, predictor: Model):
def remove_predictor(self, index):
self.predictors.pop(index)

def _set_target_combos(self):
prob_combo = self.controls.shown_probs
target_combo = self.controls.target_class
prob_combo.clear()
target_combo.clear()

self._update_control_visibility()

# Set these to prevent warnings when setting self.shown_probs
target_combo.addItem(self.TARGET_AVERAGE)
prob_combo.addItems(self.PROB_OPTS)

if self.is_discrete_class:
target_combo.addItems(self.class_var.values)
prob_combo.addItems(self.class_var.values)
for i, tip in enumerate(self.PROB_TOOLTIPS):
prob_combo.setItemData(i, tip, Qt.ToolTipRole)
self.shown_probs = self.DATA_PROBS
self.target_class = self.TARGET_AVERAGE
else:
self.shown_probs = self.NO_PROBS

def _update_control_visibility(self):
for widget in self._prob_controls:
widget.setVisible(self.is_discrete_class)

for widget in self._target_controls:
widget.setVisible(self.is_discrete_class and self.show_scores)

def _set_class_values(self):
class_values = []
self.class_values = []
for slot in self.predictors:
class_var = slot.predictor.domain.class_var
if class_var and class_var.is_discrete:
if class_var.is_discrete:
for value in class_var.values:
if value not in class_values:
class_values.append(value)

self.target_selection.clear()
self.target_selection.addItem(self.TARGET_AVERAGE)
if self.class_var and self.class_var.is_discrete:
values = self.class_var.values
self.target_selection.addItems(values)
self.target_selection.box.setVisible(True)
self.class_values = sorted(
class_values, key=lambda val: val not in values)
self.selected_classes = [
i for i, name in enumerate(class_values) if name in values]
self.controls.selected_classes.box.setVisible(True)
else:
self.class_values = class_values # This assignment updates listview
self.selected_classes = []
self.controls.selected_classes.box.setVisible(False)
self.target_selection.box.setVisible(False)
if value not in self.class_values:
self.class_values.append(value)

def handleNewSignals(self):
# Disconnect the model: the model and the delegate will be inconsistent
Expand Down Expand Up @@ -317,8 +356,7 @@ def _call_predictors(self):

def _update_scores(self):
model = self.score_table.model
if self.class_var and self.class_var.is_discrete \
and self.target_class != self.TARGET_AVERAGE:
if self.is_discrete_class and self.target_class != self.TARGET_AVERAGE:
target = self.class_var.values.index(self.target_class)
else:
target = None
Expand Down Expand Up @@ -368,6 +406,7 @@ def _update_scores(self):
self._update_score_table_visibility()

def _update_score_table_visibility(self):
self._update_control_visibility()
view = self.score_table.view
nmodels = self.score_table.model.rowCount()
if nmodels and self.show_scores:
Expand All @@ -385,7 +424,6 @@ def _update_score_table_visibility(self):
view.setVisible(False)
self.Error.scorer_failed.clear()


def _set_errors(self):
# Not all predictors are run every time, so errors can't be collected
# in _call_predictors
Expand Down Expand Up @@ -563,22 +601,49 @@ def _get_colors(self):
return colors

def _update_prediction_delegate(self):
def index(value):
if value in target.values:
return self.class_values.index(value)
else:
return None

self._delegates.clear()
colors = self._get_colors()
if self.shown_probs >= len(self.PROB_OPTS):
shown_class = self.class_values[self.shown_probs
- len(self.PROB_OPTS)]
else:
shown_class = "" # just to silence warnings about undefined var
sort_col_indices = []
for col, slot in enumerate(self.predictors):
target = slot.predictor.domain.class_var
shown_probs = (
() if target.is_continuous else
[val if self.class_values[val] in target.values else None
for val in self.selected_classes]
)
delegate = PredictionsItemDelegate(
None if target.is_continuous else self.class_values,
colors,
shown_probs,
target.format_str if target.is_continuous else None,
parent=self.predictionsview
)
if target.is_continuous:
delegate = PredictionsItemDelegate(
None, colors, (), target.format_str,
parent=self.predictionsview
)
sort_col_indices.append(None)
else:
if self.shown_probs == self.NO_PROBS:
shown_probs = []
elif self.shown_probs == self.DATA_PROBS:
shown_probs = [
index(value) for value in self.class_var.values]
elif self.shown_probs == self.MODEL_PROBS:
shown_probs = [
index(value) for value in target.values]
elif self.shown_probs == self.BOTH_PROBS:
shown_probs = [
index(value) for value in self.class_var.values
if value in target.values]
else:
shown_probs = [index(shown_class)]
delegate = PredictionsItemDelegate(
self.class_values, colors, shown_probs, None,
parent=self.predictionsview
)
sort_col_indices.append([col for col in shown_probs
if col is not None])
# QAbstractItemView does not take ownership of delegates, so we must
self._delegates.append(delegate)
self.predictionsview.setItemDelegateForColumn(col, delegate)
Expand All @@ -587,7 +652,7 @@ def _update_prediction_delegate(self):
self.predictionsview.resizeColumnsToContents()
self._recompute_splitter_sizes()
if self.predictionsview.model() is not None:
self.predictionsview.model().setProbInd(self.selected_classes)
self.predictionsview.model().setProbInd(sort_col_indices)

def _recompute_splitter_sizes(self):
if not self.data:
Expand Down Expand Up @@ -623,7 +688,7 @@ def _commit_evaluation_results(self):
results.actual = data.Y.ravel()
results.predicted = numpy.vstack(
tuple(p.results.predicted[0][~nanmask] for p in slots))
if self.class_var and self.class_var.is_discrete:
if self.is_discrete_class:
results.probabilities = numpy.array(
[p.results.probabilities[0][~nanmask] for p in slots])
results.learner_names = [p.name for p in slots]
Expand Down Expand Up @@ -730,10 +795,18 @@ def merge_data_with_predictions():

if self.data:
text = self._get_details().replace('\n', '<br>')
if self.selected_classes:
text += '<br>Showing probabilities for: '
text += ', '. join([self.class_values[i]
for i in self.selected_classes])
if self.is_discrete_class and self.shown_probs != self.NO_PROBS:
text += '<br>Showing probabilities for '
if self.shown_probs == self.MODEL_PROBS:
text += "all classes known to the model"
elif self.shown_probs == self.DATA_PROBS:
text += "all classes that appear in the data"
elif self.shown_probs == self.BOTH_PROBS:
text += "all classes that appear in the data " \
"and are known to the model"
else:
class_idx = self.shown_probs - len(self.PROB_OPTS)
text += f"'{self.class_var[class_idx]}'"
self.report_paragraph('Info', text)
self.report_table("Data & Predictions", merge_data_with_predictions(),
header_rows=1, header_columns=1)
Expand Down Expand Up @@ -952,17 +1025,20 @@ def headerData(self, section, orientation, role=Qt.DisplayRole):
return self._header[section]
return None

def setProbInd(self, indices):
self.__probInd = indices
def setProbInd(self, indicess):
self.__probInd = indicess
self.sort(self.sortColumn())

def sortColumnData(self, column):
values = self._values[column]
probs = self._probs[column]
# Let us assume that probs can be None, numpy array or list of arrays
# self.__probInd[column] can be None (numeric) or empty (no probs
# shown for particular model)
if probs is not None and len(probs) and len(probs[0]) \
and self.__probInd is not None and len(self.__probInd):
return probs[:, self.__probInd]
and self.__probInd is not None \
and self.__probInd[column]:
return probs[:, self.__probInd[column]]
else:
return values

Expand Down Expand Up @@ -1281,5 +1357,5 @@ def pred_error(data, *args, **kwargs):
predictors_ = [pred_error]

WidgetPreview(OWPredictions).run(
set_data=iris2,
set_data=iris,
insert_predictor=list(enumerate(predictors_)))
Loading

0 comments on commit 69c5f5c

Please sign in to comment.