From 716f25fcba25b275ef5ffaca83c2b3f9bfddda50 Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Mon, 7 Oct 2019 08:54:24 +0200 Subject: [PATCH 1/2] TestOWKMeans: Replace 'iris' dataset with 'heart_disease' dataset --- .../unsupervised/tests/test_owkmeans.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/Orange/widgets/unsupervised/tests/test_owkmeans.py b/Orange/widgets/unsupervised/tests/test_owkmeans.py index ced45791758..e09b1245d5a 100644 --- a/Orange/widgets/unsupervised/tests/test_owkmeans.py +++ b/Orange/widgets/unsupervised/tests/test_owkmeans.py @@ -44,8 +44,7 @@ def setUp(self): self.widget = self.create_widget( OWKMeans, stored_settings={"auto_commit": False, "version": 2} ) # type: OWKMeans - self.iris = Table("iris") - self.iris.X[0, 0] = np.nan + self.data = Table("heart_disease") def tearDown(self): self.widget.onDeleteWidget() @@ -61,7 +60,7 @@ def test_migrate_version_1_settings(self): def test_optimization_report_display(self): """Check visibility of the table after selecting number of clusters""" self.widget.auto_commit = True - self.send_signal(self.widget.Inputs.data, self.iris, wait=5000) + self.send_signal(self.widget.Inputs.data, self.data, wait=5000) self.widget.optimize_k = True radio_buttons = self.widget.controls.optimize_k.findChildren(QRadioButton) @@ -80,7 +79,7 @@ def test_optimization_report_display(self): def test_changing_k_changes_radio(self): widget = self.widget widget.auto_commit = True - self.send_signal(self.widget.Inputs.data, self.iris, wait=5000) + self.send_signal(self.widget.Inputs.data, self.data, wait=5000) widget.optimize_k = True @@ -118,7 +117,7 @@ def test_no_data_hides_main_area(self): self.send_signal(self.widget.Inputs.data, None, wait=5000) self.assertTrue(self.widget.mainArea.isHidden()) - self.send_signal(self.widget.Inputs.data, self.iris, wait=5000) + self.send_signal(self.widget.Inputs.data, self.data, wait=5000) self.assertFalse(self.widget.mainArea.isHidden()) self.send_signal(self.widget.Inputs.data, None, wait=5000) self.assertTrue(self.widget.mainArea.isHidden()) @@ -129,7 +128,7 @@ def test_data_limits(self): widget = self.widget widget.auto_commit = False - self.send_signal(self.widget.Inputs.data, self.iris[:5]) + self.send_signal(self.widget.Inputs.data, self.data[:5]) widget.k = 10 self.commit_and_wait() @@ -159,7 +158,7 @@ def test_use_cache(self): """Cache various clusterings for the dataset until data changes.""" widget = self.widget widget.auto_commit = False - self.send_signal(self.widget.Inputs.data, self.iris) + self.send_signal(self.widget.Inputs.data, self.data) with patch.object(widget, "_compute_clustering", wraps=widget._compute_clustering) as compute: @@ -191,7 +190,7 @@ def test_use_cache(self): def test_data_on_output(self): """Check if data is on output after create widget and run""" self.widget.auto_commit = True - self.send_signal(self.widget.Inputs.data, self.iris, wait=5000) + self.send_signal(self.widget.Inputs.data, self.data, wait=5000) self.widget.apply_button.button.click() self.assertNotEqual(self.widget.data, None) # Disconnect the data @@ -203,19 +202,19 @@ def test_centroids_on_output(self): widget = self.widget widget.optimize_k = False widget.k = 4 - self.send_signal(widget.Inputs.data, self.iris) + self.send_signal(widget.Inputs.data, self.data) self.commit_and_wait() - widget.clusterings[widget.k].labels = np.array([0] * 50 + [1] * 100).flatten() + widget.clusterings[widget.k].labels = np.array([0] * 100 + [1] * 203).flatten() widget.samples_scores = lambda x: np.arctan( - np.arange(150) / 150) / np.pi + 0.5 + np.arange(303) / 303) / np.pi + 0.5 widget.send_data() out = self.get_output(widget.Outputs.centroids) np.testing.assert_array_almost_equal( - np.array([[0, np.mean(np.arctan(np.arange(50) / 150)) / np.pi + 0.5], - [1, np.mean(np.arctan(np.arange(50, 150) / 150)) / np.pi + 0.5], + np.array([[0, np.mean(np.arctan(np.arange(100) / 303)) / np.pi + 0.5], + [1, np.mean(np.arctan(np.arange(100, 303) / 303)) / np.pi + 0.5], [2, 0], [3, 0]]), out.metas.astype(float)) - self.assertEqual(out.name, "iris centroids") + self.assertEqual(out.name, "heart_disease centroids") def test_centroids_domain_on_output(self): widget = self.widget @@ -262,13 +261,13 @@ def test_optimization_fails(self): with patch.object( model, "set_scores", wraps=model.set_scores) as set_scores: - self.send_signal(self.widget.Inputs.data, self.iris, wait=5000) + self.send_signal(self.widget.Inputs.data, self.data, wait=5000) scores, start_k = set_scores.call_args[0] - X = self.widget.preproces(self.iris).X + X = self.widget.preproces(self.data).X self.assertEqual( scores, [km if isinstance(km, str) else silhouette_score( - X, km(self.iris)) + X, km(self.data)) for km in (widget.clusterings[k] for k in range(3, 9))] ) self.assertEqual(start_k, 3) @@ -302,7 +301,7 @@ def test_run_fails(self): self.widget.auto_commit = True self.widget.optimize_k = False self.KMeansFail.fail_on = {3} - self.send_signal(self.widget.Inputs.data, self.iris, wait=5000) + self.send_signal(self.widget.Inputs.data, self.data, wait=5000) self.assertTrue(self.widget.Error.failed.is_shown()) self.assertIsNone(self.get_output(self.widget.Outputs.annotated_data)) @@ -362,7 +361,7 @@ def test_not_enough_rows(self): Widget should not crash when there is less rows than k_from. GH-2172 """ - table = self.iris[0:1, :] + table = self.data[0:1, :] self.widget.controls.k_from.setValue(2) self.widget.controls.k_to.setValue(9) self.send_signal(self.widget.Inputs.data, table) @@ -374,7 +373,7 @@ def test_from_to_table(self): """ k_from, k_to = 2, 9 self.widget.controls.k_from.setValue(k_from) - self.send_signal(self.widget.Inputs.data, self.iris, wait=5000) + self.send_signal(self.widget.Inputs.data, self.data, wait=5000) check = lambda x: 2 if x - k_from + 1 < 2 else x - k_from + 1 for i in range(k_from, k_to): self.widget.controls.k_to.setValue(i) @@ -415,7 +414,7 @@ def test_invalidate_clusterings_cancels_jobs(self): widget.auto_commit = False # Send the data without waiting - self.send_signal(widget.Inputs.data, self.iris) + self.send_signal(widget.Inputs.data, self.data) widget.unconditional_commit() # Now, invalidate by changing max_iter widget.max_iterations = widget.max_iterations + 1 @@ -460,7 +459,7 @@ def test_do_not_recluster_on_same_data(self): def test_correct_smart_init(self): # due to a bug where wrong init was passed to _compute_clustering - self.send_signal(self.widget.Inputs.data, self.iris[::10], wait=5000) + self.send_signal(self.widget.Inputs.data, self.data[::10], wait=5000) self.widget.smart_init = 0 self.widget.clusterings = {} with patch.object(self.widget, "_compute_clustering", @@ -476,7 +475,7 @@ def test_correct_smart_init(self): def test_always_same_cluster(self): """The same random state should always return the same clusters""" - self.send_signal(self.widget.Inputs.data, self.iris[::10], wait=5000) + self.send_signal(self.widget.Inputs.data, self.data[::10], wait=5000) def cluster(): self.widget.invalidate() # reset caches From a193f02c52605082a126284d0aa271082b4e6fc5 Mon Sep 17 00:00:00 2001 From: Vesna Tanko Date: Mon, 7 Oct 2019 10:49:27 +0200 Subject: [PATCH 2/2] KMeans: Save selection --- Orange/widgets/unsupervised/owkmeans.py | 13 ++++++++++++- .../widgets/unsupervised/tests/test_owkmeans.py | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/Orange/widgets/unsupervised/owkmeans.py b/Orange/widgets/unsupervised/owkmeans.py index 71ccc7f4e3a..691062b695c 100644 --- a/Orange/widgets/unsupervised/owkmeans.py +++ b/Orange/widgets/unsupervised/owkmeans.py @@ -142,6 +142,7 @@ class Warning(widget.OWWidget.Warning): max_iterations = Setting(300) n_init = Setting(10) smart_init = Setting(0) # KMeans++ + selection = Setting(None, schema_only=True) # type: Optional[int] auto_commit = Setting(True) settings_version = 2 @@ -158,6 +159,7 @@ def __init__(self): super().__init__() self.data = None # type: Optional[Table] + self.__pending_selection = self.selection # type: Optional[int] self.clusterings = {} self.__executor = ThreadExecutor(parent=self) @@ -443,10 +445,17 @@ def update_results(self): key=lambda x: 0 if isinstance(scores[x], str) else scores[x] ) self.table_model.set_scores(scores, self.k_from) - self.table_view.selectRow(best_row) + self.apply_selection(best_row) self.table_view.setFocus(Qt.OtherFocusReason) self.table_view.resizeRowsToContents() + def apply_selection(self, best_row): + pending = best_row + if self.__pending_selection is not None: + pending = self.__pending_selection + self.__pending_selection = None + self.table_view.selectRow(pending) + def selected_row(self): indices = self.table_view.selectedIndexes() if not indices: @@ -454,6 +463,7 @@ def selected_row(self): return indices[0].row() def select_row(self): + self.selection = self.selected_row() self.send_data() def preproces(self, data): @@ -535,6 +545,7 @@ def send_data(self): @check_sql_input def set_data(self, data): self.data, old_data = data, self.data + self.selection = None # Do not needlessly recluster the data if X hasn't changed if old_data and self.data and array_equal(self.data.X, old_data.X): diff --git a/Orange/widgets/unsupervised/tests/test_owkmeans.py b/Orange/widgets/unsupervised/tests/test_owkmeans.py index e09b1245d5a..9ddb5de7eb3 100644 --- a/Orange/widgets/unsupervised/tests/test_owkmeans.py +++ b/Orange/widgets/unsupervised/tests/test_owkmeans.py @@ -499,6 +499,23 @@ def test_error_no_attributes(self): self.send_signal(self.widget.Inputs.data, table) self.assertTrue(self.widget.Error.no_attributes.is_shown()) + def test_saved_selection(self): + self.widget.send_data = Mock() + self.widget.optimize_k = True + self.send_signal(self.widget.Inputs.data, self.data) + self.wait_until_stop_blocking() + self.widget.table_view.selectRow(2) + self.assertEqual(self.widget.selected_row(), 2) + self.assertEqual(self.widget.send_data.call_count, 3) + settings = self.widget.settingsHandler.pack_data(self.widget) + + w = self.create_widget(OWKMeans, stored_settings=settings) + w.send_data = Mock() + self.send_signal(w.Inputs.data, self.data, widget=w) + self.wait_until_stop_blocking(widget=w) + self.assertEqual(w.send_data.call_count, 2) + self.assertEqual(self.widget.selected_row(), w.selected_row()) + if __name__ == "__main__": unittest.main()