Skip to content

Commit

Permalink
Merge pull request #4082 from VesnaT/kmeans_selection
Browse files Browse the repository at this point in the history
[FIX] K-means: Save Silhouette Scores selection
  • Loading branch information
janezd authored Oct 11, 2019
2 parents e1f5e9f + a193f02 commit a283eb1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 24 deletions.
13 changes: 12 additions & 1 deletion Orange/widgets/unsupervised/owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Warning(widget.OWWidget.Warning):
max_iterations = Setting(300)
n_init = Setting(10)
smart_init = Setting(0) # KMeans++
selection = Setting(None, schema_only=True) # type: Optional[int]
auto_commit = Setting(True)

settings_version = 2
Expand All @@ -158,6 +159,7 @@ def __init__(self):
super().__init__()

self.data = None # type: Optional[Table]
self.__pending_selection = self.selection # type: Optional[int]
self.clusterings = {}

self.__executor = ThreadExecutor(parent=self)
Expand Down Expand Up @@ -443,17 +445,25 @@ def update_results(self):
key=lambda x: 0 if isinstance(scores[x], str) else scores[x]
)
self.table_model.set_scores(scores, self.k_from)
self.table_view.selectRow(best_row)
self.apply_selection(best_row)
self.table_view.setFocus(Qt.OtherFocusReason)
self.table_view.resizeRowsToContents()

def apply_selection(self, best_row):
pending = best_row
if self.__pending_selection is not None:
pending = self.__pending_selection
self.__pending_selection = None
self.table_view.selectRow(pending)

def selected_row(self):
indices = self.table_view.selectedIndexes()
if not indices:
return None
return indices[0].row()

def select_row(self):
self.selection = self.selected_row()
self.send_data()

def preproces(self, data):
Expand Down Expand Up @@ -535,6 +545,7 @@ def send_data(self):
@check_sql_input
def set_data(self, data):
self.data, old_data = data, self.data
self.selection = None

# Do not needlessly recluster the data if X hasn't changed
if old_data and self.data and array_equal(self.data.X, old_data.X):
Expand Down
62 changes: 39 additions & 23 deletions Orange/widgets/unsupervised/tests/test_owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def setUp(self):
self.widget = self.create_widget(
OWKMeans, stored_settings={"auto_commit": False, "version": 2}
) # type: OWKMeans
self.iris = Table("iris")
self.iris.X[0, 0] = np.nan
self.data = Table("heart_disease")

def tearDown(self):
self.widget.onDeleteWidget()
Expand All @@ -61,7 +60,7 @@ def test_migrate_version_1_settings(self):
def test_optimization_report_display(self):
"""Check visibility of the table after selecting number of clusters"""
self.widget.auto_commit = True
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
self.widget.optimize_k = True
radio_buttons = self.widget.controls.optimize_k.findChildren(QRadioButton)

Expand All @@ -80,7 +79,7 @@ def test_optimization_report_display(self):
def test_changing_k_changes_radio(self):
widget = self.widget
widget.auto_commit = True
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)

widget.optimize_k = True

Expand Down Expand Up @@ -118,7 +117,7 @@ def test_no_data_hides_main_area(self):

self.send_signal(self.widget.Inputs.data, None, wait=5000)
self.assertTrue(self.widget.mainArea.isHidden())
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
self.assertFalse(self.widget.mainArea.isHidden())
self.send_signal(self.widget.Inputs.data, None, wait=5000)
self.assertTrue(self.widget.mainArea.isHidden())
Expand All @@ -129,7 +128,7 @@ def test_data_limits(self):
widget = self.widget
widget.auto_commit = False

self.send_signal(self.widget.Inputs.data, self.iris[:5])
self.send_signal(self.widget.Inputs.data, self.data[:5])

widget.k = 10
self.commit_and_wait()
Expand Down Expand Up @@ -159,7 +158,7 @@ def test_use_cache(self):
"""Cache various clusterings for the dataset until data changes."""
widget = self.widget
widget.auto_commit = False
self.send_signal(self.widget.Inputs.data, self.iris)
self.send_signal(self.widget.Inputs.data, self.data)

with patch.object(widget, "_compute_clustering",
wraps=widget._compute_clustering) as compute:
Expand Down Expand Up @@ -191,7 +190,7 @@ def test_use_cache(self):
def test_data_on_output(self):
"""Check if data is on output after create widget and run"""
self.widget.auto_commit = True
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
self.widget.apply_button.button.click()
self.assertNotEqual(self.widget.data, None)
# Disconnect the data
Expand All @@ -203,19 +202,19 @@ def test_centroids_on_output(self):
widget = self.widget
widget.optimize_k = False
widget.k = 4
self.send_signal(widget.Inputs.data, self.iris)
self.send_signal(widget.Inputs.data, self.data)
self.commit_and_wait()
widget.clusterings[widget.k].labels = np.array([0] * 50 + [1] * 100).flatten()
widget.clusterings[widget.k].labels = np.array([0] * 100 + [1] * 203).flatten()

widget.samples_scores = lambda x: np.arctan(
np.arange(150) / 150) / np.pi + 0.5
np.arange(303) / 303) / np.pi + 0.5
widget.send_data()
out = self.get_output(widget.Outputs.centroids)
np.testing.assert_array_almost_equal(
np.array([[0, np.mean(np.arctan(np.arange(50) / 150)) / np.pi + 0.5],
[1, np.mean(np.arctan(np.arange(50, 150) / 150)) / np.pi + 0.5],
np.array([[0, np.mean(np.arctan(np.arange(100) / 303)) / np.pi + 0.5],
[1, np.mean(np.arctan(np.arange(100, 303) / 303)) / np.pi + 0.5],
[2, 0], [3, 0]]), out.metas.astype(float))
self.assertEqual(out.name, "iris centroids")
self.assertEqual(out.name, "heart_disease centroids")

def test_centroids_domain_on_output(self):
widget = self.widget
Expand Down Expand Up @@ -262,13 +261,13 @@ def test_optimization_fails(self):

with patch.object(
model, "set_scores", wraps=model.set_scores) as set_scores:
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
scores, start_k = set_scores.call_args[0]
X = self.widget.preproces(self.iris).X
X = self.widget.preproces(self.data).X
self.assertEqual(
scores,
[km if isinstance(km, str) else silhouette_score(
X, km(self.iris))
X, km(self.data))
for km in (widget.clusterings[k] for k in range(3, 9))]
)
self.assertEqual(start_k, 3)
Expand Down Expand Up @@ -302,7 +301,7 @@ def test_run_fails(self):
self.widget.auto_commit = True
self.widget.optimize_k = False
self.KMeansFail.fail_on = {3}
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
self.assertTrue(self.widget.Error.failed.is_shown())
self.assertIsNone(self.get_output(self.widget.Outputs.annotated_data))

Expand Down Expand Up @@ -362,7 +361,7 @@ def test_not_enough_rows(self):
Widget should not crash when there is less rows than k_from.
GH-2172
"""
table = self.iris[0:1, :]
table = self.data[0:1, :]
self.widget.controls.k_from.setValue(2)
self.widget.controls.k_to.setValue(9)
self.send_signal(self.widget.Inputs.data, table)
Expand All @@ -374,7 +373,7 @@ def test_from_to_table(self):
"""
k_from, k_to = 2, 9
self.widget.controls.k_from.setValue(k_from)
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
check = lambda x: 2 if x - k_from + 1 < 2 else x - k_from + 1
for i in range(k_from, k_to):
self.widget.controls.k_to.setValue(i)
Expand Down Expand Up @@ -415,7 +414,7 @@ def test_invalidate_clusterings_cancels_jobs(self):
widget.auto_commit = False

# Send the data without waiting
self.send_signal(widget.Inputs.data, self.iris)
self.send_signal(widget.Inputs.data, self.data)
widget.unconditional_commit()
# Now, invalidate by changing max_iter
widget.max_iterations = widget.max_iterations + 1
Expand Down Expand Up @@ -460,7 +459,7 @@ def test_do_not_recluster_on_same_data(self):

def test_correct_smart_init(self):
# due to a bug where wrong init was passed to _compute_clustering
self.send_signal(self.widget.Inputs.data, self.iris[::10], wait=5000)
self.send_signal(self.widget.Inputs.data, self.data[::10], wait=5000)
self.widget.smart_init = 0
self.widget.clusterings = {}
with patch.object(self.widget, "_compute_clustering",
Expand All @@ -476,7 +475,7 @@ def test_correct_smart_init(self):

def test_always_same_cluster(self):
"""The same random state should always return the same clusters"""
self.send_signal(self.widget.Inputs.data, self.iris[::10], wait=5000)
self.send_signal(self.widget.Inputs.data, self.data[::10], wait=5000)

def cluster():
self.widget.invalidate() # reset caches
Expand All @@ -500,6 +499,23 @@ def test_error_no_attributes(self):
self.send_signal(self.widget.Inputs.data, table)
self.assertTrue(self.widget.Error.no_attributes.is_shown())

def test_saved_selection(self):
self.widget.send_data = Mock()
self.widget.optimize_k = True
self.send_signal(self.widget.Inputs.data, self.data)
self.wait_until_stop_blocking()
self.widget.table_view.selectRow(2)
self.assertEqual(self.widget.selected_row(), 2)
self.assertEqual(self.widget.send_data.call_count, 3)
settings = self.widget.settingsHandler.pack_data(self.widget)

w = self.create_widget(OWKMeans, stored_settings=settings)
w.send_data = Mock()
self.send_signal(w.Inputs.data, self.data, widget=w)
self.wait_until_stop_blocking(widget=w)
self.assertEqual(w.send_data.call_count, 2)
self.assertEqual(self.widget.selected_row(), w.selected_row())


if __name__ == "__main__":
unittest.main()

0 comments on commit a283eb1

Please sign in to comment.