Skip to content

Commit

Permalink
add multi_target_input tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JakaKokosar committed Mar 21, 2022
1 parent e1e7419 commit f820dea
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 6 deletions.
26 changes: 24 additions & 2 deletions Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def set_input(data, model):
(self.widget.Inputs.data, data),
(self.widget.Inputs.predictors, model)
])

iris = self.iris
learner = ConstantLearner()
heart_disease = Table("heart_disease")
Expand Down Expand Up @@ -253,6 +254,7 @@ def test_sort_predictions(self):
"""
Test whether sorting of probabilities by FilterSortProxy is correct.
"""

def get_items_order(model):
return model.mapToSourceRows(np.arange(model.rowCount()))

Expand Down Expand Up @@ -878,6 +880,27 @@ def test_change_target(self):
self.assertEqual(float(table.model.data(table.model.index(0, 3))),
idx)

def test_multi_target_input(self):
widget = self.widget

domain = Domain([ContinuousVariable('var1')],
class_vars=[
ContinuousVariable('c1'),
DiscreteVariable('c2', values=('no', 'yes'))
])
data = Table.from_list(domain, [[1, 5, 0], [2, 10, 1]])

mock_model = Mock(spec=Model, return_value=np.asarray([0.2, 0.1]))
mock_model.name = 'Mockery'
mock_model.domain = domain
mock_learner = Mock(return_value=mock_model)
model = mock_learner(data)

self.send_signal(widget.Inputs.data, data)
self.send_signal(widget.Inputs.predictors, model, 1)
pred = self.get_output(widget.Outputs.predictions)
self.assertIsInstance(pred, Table)

def test_report(self):
widget = self.widget

Expand Down Expand Up @@ -1022,7 +1045,6 @@ def assert_called(exp_selected, exp_deselected):
self.assertEqual(list(selected), exp_selected)
self.assertEqual(list(deselected), exp_deselected)


store.model.setSortIndices([4, 0, 1, 2, 3])
store.select_rows({3, 4}, QItemSelectionModel.Select)
assert_called([4, 0], [])
Expand Down Expand Up @@ -1132,7 +1154,7 @@ def setUpClass(cls) -> None:
cls.probs = [np.array([[80, 10, 10],
[30, 70, 0],
[15, 80, 5],
[0, 10, 90],
[0, 10, 90],
[55, 40, 5]]) / 100,
np.array([[80, 0, 20],
[90, 5, 5],
Expand Down
45 changes: 41 additions & 4 deletions Orange/widgets/evaluate/tests/test_owtestandscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from Orange.evaluation import Results, TestOnTestData, scoring
from Orange.evaluation.scoring import ClassificationScore, RegressionScore, \
Score
from Orange.base import Learner
from Orange.base import Learner, Model
from Orange.modelling import ConstantLearner
from Orange.regression import MeanLearner
from Orange.widgets.evaluate.owtestandscore import (
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_one_class_value(self):
table = Table.from_list(
Domain(
[ContinuousVariable("a"), ContinuousVariable("b")],
[DiscreteVariable("c", values=("y", ))]),
[DiscreteVariable("c", values=("y",))]),
list(zip(
[42.48, 16.84, 15.23, 23.8],
[1., 2., 3., 4.],
Expand All @@ -192,6 +192,7 @@ def test_one_class_value(self):

def test_data_errors(self):
""" Test all data_errors """

def assertErrorShown(data, is_shown, message):
self.send_signal("Data", data)
self.assertEqual(is_shown, self.widget.Error.train_data_error.is_shown())
Expand Down Expand Up @@ -378,7 +379,7 @@ def test_scores_log_reg_overfitted(self):
self.assertTupleEqual(self._test_scores(
table, table, LogisticRegressionLearner(),
OWTestAndScore.TestOnTest, None),
(1, 1, 1, 1, 1))
(1, 1, 1, 1, 1))

def test_scores_log_reg_bad(self):
table_train = Table.from_list(
Expand All @@ -393,7 +394,7 @@ def test_scores_log_reg_bad(self):
self.assertTupleEqual(self._test_scores(
table_train, table_test, LogisticRegressionLearner(),
OWTestAndScore.TestOnTest, None),
(0, 0, 0, 0, 0))
(0, 0, 0, 0, 0))

def test_scores_log_reg_bad2(self):
table_train = Table.from_list(
Expand Down Expand Up @@ -724,6 +725,42 @@ def test_copy_to_clipboard(self):
for i in (0, 3, 4, 5, 6, 7)]) + "\r\n"
self.assertEqual(clipboard_text, view_text)

def test_multi_target_input(self):
class NewScorer(Score):
class_types = (
ContinuousVariable,
DiscreteVariable,
)

@staticmethod
def is_compatible(domain: Domain) -> bool:
return True

def compute_score(self, results):
return [0.75]

domain = Domain([ContinuousVariable('var1')],
class_vars=[
ContinuousVariable('c1'),
DiscreteVariable('c2', values=('no', 'yes'))
])
data = Table.from_list(domain, [[1, 5, 0], [2, 10, 1], [2, 10, 1]])

mock_model = Mock(spec=Model, return_value=np.asarray([[0.2, 0.1, 0.2]]))
mock_model.name = 'Mockery'
mock_model.domain = domain
mock_learner = Mock(spec=Learner, return_value=mock_model)
mock_learner.name = 'Mockery'

self.widget.resampling = OWTestAndScore.TestOnTrain
self.send_signal(self.widget.Inputs.train_data, data)
self.send_signal(self.widget.Inputs.learner, MajorityLearner(), 0)
self.send_signal(self.widget.Inputs.learner, mock_learner, 1)
_ = self.get_output(self.widget.Outputs.evaluations_results, wait=5000)
self.assertTrue(len(self.widget.scorers) == 1)
self.assertTrue(NewScorer in self.widget.scorers)
self.assertTrue(len(self.widget._successful_slots()) == 1)


class TestHelpers(unittest.TestCase):
def test_results_one_vs_rest(self):
Expand Down

0 comments on commit f820dea

Please sign in to comment.