Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] OWPredictions: Allow classification when data has no target column #2183

Merged
merged 1 commit into from
Apr 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions Orange/widgets/evaluate/owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ class Warning(OWWidget.Warning):
empty_data = Msg("Empty data set")

class Error(OWWidget.Error):
predictor_failed = Msg("One or more predictors failed (see more...)\n{}")
predictor_failed = \
Msg("One or more predictors failed (see more...)\n{}")
predictors_target_mismatch = \
Msg("Predictors do not have the same target.")
data_target_mismatch = \
Msg("Data does not have the same target as predictors.")

settingsHandler = settings.ClassValuesContextHandler()
#: Display the full input dataset or only the target variable columns (if
Expand Down Expand Up @@ -182,32 +187,18 @@ def set_data(self, data):

self.data = data
if data is None:
self.class_var = class_var = None
self.dataview.setModel(None)
self.predictionsview.setModel(None)
self.predictionsview.setItemDelegate(PredictionsItemDelegate())
else:
# force full reset of the view's HeaderView state
self.class_var = class_var = data.domain.class_var
self.dataview.setModel(None)
model = TableModel(data, parent=None)
modelproxy = TableSortProxyModel()
modelproxy.setSourceModel(model)
self.dataview.setModel(modelproxy)
self._update_column_visibility()

discrete_class = class_var is not None and class_var.is_discrete
self.classification_options.setVisible(discrete_class)

self.closeContext()
if discrete_class:
self.class_values = list(class_var.values)
self.selected_classes = list(range(len(self.class_values)))
self.openContext(self.class_var)
else:
self.class_values = []
self.selected_classes = []

self._invalidate_predictions()

def set_predictor(self, predictor=None, id=None):
Expand All @@ -221,7 +212,36 @@ def set_predictor(self, predictor=None, id=None):
self.predictors[id] = \
PredictorSlot(predictor, predictor.name, None)

def set_class_var(self):
pred_classes = set(pred.predictor.domain.class_var
for pred in self.predictors.values())
self.Error.predictors_target_mismatch.clear()
self.Error.data_target_mismatch.clear()
self.class_var = None
if len(pred_classes) > 1:
self.Error.predictors_target_mismatch()
if len(pred_classes) == 1:
self.class_var = pred_classes.pop()
if self.data is not None and \
self.data.domain.class_var is not None and \
self.class_var != self.data.domain.class_var:
self.Error.data_target_mismatch()
self.class_var = None

discrete_class = self.class_var is not None \
and self.class_var.is_discrete
self.classification_options.setVisible(discrete_class)
self.closeContext()
if discrete_class:
self.class_values = list(self.class_var.values)
self.selected_classes = list(range(len(self.class_values)))
self.openContext(self.class_var)
else:
self.class_values = []
self.selected_classes = []

def handleNewSignals(self):
self.set_class_var()
if self.data is not None:
self._call_predictors()
self._update_predictions_model()
Expand All @@ -232,14 +252,9 @@ def handleNewSignals(self):

def _call_predictors(self):
for inputid, pred in self.predictors.items():
if pred.results is None:
if pred.results is None or numpy.isnan(pred.results[0]).all():
try:
predictor_class = pred.predictor.domain.class_var
if predictor_class != self.class_var:
results = "{}: mismatching target ({})".format(
pred.predictor.name, predictor_class.name)
else:
results = self.predict(pred.predictor, self.data)
results = self.predict(pred.predictor, self.data)
except ValueError as err:
results = "{}: {}".format(pred.predictor.name, err)
self.predictors[inputid] = pred._replace(results=results)
Expand Down Expand Up @@ -285,12 +300,16 @@ def _invalidate_predictions(self):
self.predictors[inputid] = pred._replace(results=None)

def _valid_predictors(self):
return [p for p in self.predictors.values()
if p.results is not None and not isinstance(p.results, str)]
if self.class_var is not None and \
self.data is not None:
return [p for p in self.predictors.values()
if p.results is not None and not isinstance(p.results, str)]
else:
return []

def _update_predictions_model(self):
"""Update the prediction view model."""
if self.data is not None:
if self.data is not None and self.class_var is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the second part (class_var not None) also imply the first (data not None).

(The same if appears in a couple of places)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.class_var comes from predictors and can be not None even if no data is present. See https://github.com/janezd/orange3/blob/21ba4679bf19efde4498363a56c69bf3798ea0f9/Orange/widgets/evaluate/owpredictions.py#L224.

slots = self._valid_predictors()
results = []
class_var = self.class_var
Expand Down Expand Up @@ -323,7 +342,7 @@ def _update_predictions_model(self):

def _update_column_visibility(self):
"""Update data column visibility."""
if self.data is not None:
if self.data is not None and self.class_var is not None:
domain = self.data.domain
first_attr = len(domain.class_vars) + len(domain.metas)

Expand Down Expand Up @@ -415,12 +434,12 @@ def commit(self):
self._commit_evaluation_results()

def _commit_evaluation_results(self):
class_var = self.class_var
slots = self._valid_predictors()
if not slots:
if not slots or self.data.domain.class_var is None:
self.send("Evaluation Results", None)
return

class_var = self.class_var
nanmask = numpy.isnan(self.data.get_column_view(class_var)[0])
data = self.data[~nanmask]
N = len(data)
Expand All @@ -442,15 +461,15 @@ def _commit_predictions(self):
self.send("Predictions", None)
return

class_var = self.class_var
if class_var and class_var.is_discrete:
if self.class_var and self.class_var.is_discrete:
newmetas, newcolumns = self._classification_output_columns()
else:
newmetas, newcolumns = self._regression_output_columns()

attrs = list(self.data.domain.attributes) if self.output_attrs else []
metas = list(self.data.domain.metas) + newmetas
domain = Orange.data.Domain(attrs, class_var, metas=metas)
domain = \
Orange.data.Domain(attrs, self.data.domain.class_var, metas=metas)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we assure at the beginning that self.data.domain.class_var must be equal to self.class_var, then the second (shorter/direct) version would look nicer here. And the line break would probably not be necessary :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I recalled: I did have self.class_var and it fit into a single line. But this fails exactly in the case that this PR is fixing: if self.data.domain.class_var is None, the output data should have no class, either. Changing this to self.class_var can add a column of unknown values if the original data has no class.

predictions = self.data.from_table(domain, self.data)
if newcolumns:
newcolumns = numpy.hstack(
Expand Down Expand Up @@ -506,7 +525,7 @@ def merge_data_with_predictions():
[data_model.data(data_model.index(i, j))
for j in iter_data_cols]

if self.data is not None:
if self.data is not None and self.class_var is not None:
text = self.infolabel.text().replace('\n', '<br>')
if self.show_probabilities and self.selected_classes:
text += '<br>Showing probabilities for: '
Expand Down
57 changes: 45 additions & 12 deletions Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.evaluate.owpredictions import OWPredictions

from Orange.data import Table
from Orange.data import Table, Domain
from Orange.classification import MajorityLearner
from Orange.evaluation import Results

Expand Down Expand Up @@ -47,42 +47,75 @@ def test_nan_target_input(self):
self.assertEqual(len(evres.data), 0)

def test_mismatching_targets(self):
error = self.widget.Error

titanic = Table("titanic")
majority_titanic = MajorityLearner()(titanic)
majority_iris = MajorityLearner()(self.iris)

self.send_signal("Data", self.iris)
self.send_signal("Predictors", majority_iris, 1)
self.send_signal("Predictors", majority_titanic, 2)
self.assertTrue(self.widget.Error.predictor_failed.is_shown())
output = self.get_output("Predictions")
self.assertEqual(len(output.domain.metas), 4)
self.assertTrue(error.predictors_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", None, 1)
self.assertTrue(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertTrue(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Data", None)
self.assertFalse(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", None, 2)
self.assertFalse(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", majority_titanic, 2)
self.assertFalse(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Data", self.iris)
self.assertTrue(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertTrue(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", majority_iris, 2)
self.assertFalse(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
output = self.get_output("Predictions")
self.assertEqual(len(output.domain.metas), 4)

self.send_signal("Predictors", majority_iris, 1)
self.send_signal("Predictors", majority_titanic, 3)
output = self.get_output("Predictions")
self.assertEqual(len(output.domain.metas), 8)
self.assertTrue(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

def test_no_class_on_test(self):
"""Allow test data with no class"""
error = self.widget.Error

titanic = Table("titanic")
majority_titanic = MajorityLearner()(titanic)
majority_iris = MajorityLearner()(self.iris)

no_class = Table(Domain(titanic.domain.attributes, None), titanic)
self.send_signal("Predictors", majority_titanic, 1)
self.send_signal("Data", no_class)
out = self.get_output("Predictions")
np.testing.assert_allclose(out.get_column_view("majority")[0], 0)

self.send_signal("Predictors", majority_iris, 2)
self.assertTrue(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", None, 2)
self.send_signal("Data", titanic)
out = self.get_output("Predictions")
np.testing.assert_allclose(out.get_column_view("majority")[0], 0)