Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] K-means: Save Silhouette Scores selection #4082

Merged
merged 2 commits into from
Oct 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion Orange/widgets/unsupervised/owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Warning(widget.OWWidget.Warning):
max_iterations = Setting(300)
n_init = Setting(10)
smart_init = Setting(0) # KMeans++
selection = Setting(None, schema_only=True) # type: Optional[int]
auto_commit = Setting(True)

settings_version = 2
Expand All @@ -158,6 +159,7 @@ def __init__(self):
super().__init__()

self.data = None # type: Optional[Table]
self.__pending_selection = self.selection # type: Optional[int]
self.clusterings = {}

self.__executor = ThreadExecutor(parent=self)
Expand Down Expand Up @@ -443,17 +445,25 @@ def update_results(self):
key=lambda x: 0 if isinstance(scores[x], str) else scores[x]
)
self.table_model.set_scores(scores, self.k_from)
self.table_view.selectRow(best_row)
self.apply_selection(best_row)
self.table_view.setFocus(Qt.OtherFocusReason)
self.table_view.resizeRowsToContents()

def apply_selection(self, best_row):
pending = best_row
if self.__pending_selection is not None:
pending = self.__pending_selection
self.__pending_selection = None
self.table_view.selectRow(pending)

def selected_row(self):
indices = self.table_view.selectedIndexes()
if not indices:
return None
return indices[0].row()

def select_row(self):
self.selection = self.selected_row()
self.send_data()

def preproces(self, data):
Expand Down Expand Up @@ -535,6 +545,7 @@ def send_data(self):
@check_sql_input
def set_data(self, data):
self.data, old_data = data, self.data
self.selection = None

# Do not needlessly recluster the data if X hasn't changed
if old_data and self.data and array_equal(self.data.X, old_data.X):
Expand Down
62 changes: 39 additions & 23 deletions Orange/widgets/unsupervised/tests/test_owkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def setUp(self):
self.widget = self.create_widget(
OWKMeans, stored_settings={"auto_commit": False, "version": 2}
) # type: OWKMeans
self.iris = Table("iris")
self.iris.X[0, 0] = np.nan
self.data = Table("heart_disease")

def tearDown(self):
self.widget.onDeleteWidget()
Expand All @@ -61,7 +60,7 @@ def test_migrate_version_1_settings(self):
def test_optimization_report_display(self):
"""Check visibility of the table after selecting number of clusters"""
self.widget.auto_commit = True
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
self.widget.optimize_k = True
radio_buttons = self.widget.controls.optimize_k.findChildren(QRadioButton)

Expand All @@ -80,7 +79,7 @@ def test_optimization_report_display(self):
def test_changing_k_changes_radio(self):
widget = self.widget
widget.auto_commit = True
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)

widget.optimize_k = True

Expand Down Expand Up @@ -118,7 +117,7 @@ def test_no_data_hides_main_area(self):

self.send_signal(self.widget.Inputs.data, None, wait=5000)
self.assertTrue(self.widget.mainArea.isHidden())
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
self.assertFalse(self.widget.mainArea.isHidden())
self.send_signal(self.widget.Inputs.data, None, wait=5000)
self.assertTrue(self.widget.mainArea.isHidden())
Expand All @@ -129,7 +128,7 @@ def test_data_limits(self):
widget = self.widget
widget.auto_commit = False

self.send_signal(self.widget.Inputs.data, self.iris[:5])
self.send_signal(self.widget.Inputs.data, self.data[:5])

widget.k = 10
self.commit_and_wait()
Expand Down Expand Up @@ -159,7 +158,7 @@ def test_use_cache(self):
"""Cache various clusterings for the dataset until data changes."""
widget = self.widget
widget.auto_commit = False
self.send_signal(self.widget.Inputs.data, self.iris)
self.send_signal(self.widget.Inputs.data, self.data)

with patch.object(widget, "_compute_clustering",
wraps=widget._compute_clustering) as compute:
Expand Down Expand Up @@ -191,7 +190,7 @@ def test_use_cache(self):
def test_data_on_output(self):
"""Check if data is on output after create widget and run"""
self.widget.auto_commit = True
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
self.widget.apply_button.button.click()
self.assertNotEqual(self.widget.data, None)
# Disconnect the data
Expand All @@ -203,19 +202,19 @@ def test_centroids_on_output(self):
widget = self.widget
widget.optimize_k = False
widget.k = 4
self.send_signal(widget.Inputs.data, self.iris)
self.send_signal(widget.Inputs.data, self.data)
self.commit_and_wait()
widget.clusterings[widget.k].labels = np.array([0] * 50 + [1] * 100).flatten()
widget.clusterings[widget.k].labels = np.array([0] * 100 + [1] * 203).flatten()

widget.samples_scores = lambda x: np.arctan(
np.arange(150) / 150) / np.pi + 0.5
np.arange(303) / 303) / np.pi + 0.5
widget.send_data()
out = self.get_output(widget.Outputs.centroids)
np.testing.assert_array_almost_equal(
np.array([[0, np.mean(np.arctan(np.arange(50) / 150)) / np.pi + 0.5],
[1, np.mean(np.arctan(np.arange(50, 150) / 150)) / np.pi + 0.5],
np.array([[0, np.mean(np.arctan(np.arange(100) / 303)) / np.pi + 0.5],
[1, np.mean(np.arctan(np.arange(100, 303) / 303)) / np.pi + 0.5],
[2, 0], [3, 0]]), out.metas.astype(float))
self.assertEqual(out.name, "iris centroids")
self.assertEqual(out.name, "heart_disease centroids")

def test_centroids_domain_on_output(self):
widget = self.widget
Expand Down Expand Up @@ -262,13 +261,13 @@ def test_optimization_fails(self):

with patch.object(
model, "set_scores", wraps=model.set_scores) as set_scores:
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
scores, start_k = set_scores.call_args[0]
X = self.widget.preproces(self.iris).X
X = self.widget.preproces(self.data).X
self.assertEqual(
scores,
[km if isinstance(km, str) else silhouette_score(
X, km(self.iris))
X, km(self.data))
for km in (widget.clusterings[k] for k in range(3, 9))]
)
self.assertEqual(start_k, 3)
Expand Down Expand Up @@ -302,7 +301,7 @@ def test_run_fails(self):
self.widget.auto_commit = True
self.widget.optimize_k = False
self.KMeansFail.fail_on = {3}
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
self.assertTrue(self.widget.Error.failed.is_shown())
self.assertIsNone(self.get_output(self.widget.Outputs.annotated_data))

Expand Down Expand Up @@ -362,7 +361,7 @@ def test_not_enough_rows(self):
Widget should not crash when there is less rows than k_from.
GH-2172
"""
table = self.iris[0:1, :]
table = self.data[0:1, :]
self.widget.controls.k_from.setValue(2)
self.widget.controls.k_to.setValue(9)
self.send_signal(self.widget.Inputs.data, table)
Expand All @@ -374,7 +373,7 @@ def test_from_to_table(self):
"""
k_from, k_to = 2, 9
self.widget.controls.k_from.setValue(k_from)
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.send_signal(self.widget.Inputs.data, self.data, wait=5000)
check = lambda x: 2 if x - k_from + 1 < 2 else x - k_from + 1
for i in range(k_from, k_to):
self.widget.controls.k_to.setValue(i)
Expand Down Expand Up @@ -415,7 +414,7 @@ def test_invalidate_clusterings_cancels_jobs(self):
widget.auto_commit = False

# Send the data without waiting
self.send_signal(widget.Inputs.data, self.iris)
self.send_signal(widget.Inputs.data, self.data)
widget.unconditional_commit()
# Now, invalidate by changing max_iter
widget.max_iterations = widget.max_iterations + 1
Expand Down Expand Up @@ -460,7 +459,7 @@ def test_do_not_recluster_on_same_data(self):

def test_correct_smart_init(self):
# due to a bug where wrong init was passed to _compute_clustering
self.send_signal(self.widget.Inputs.data, self.iris[::10], wait=5000)
self.send_signal(self.widget.Inputs.data, self.data[::10], wait=5000)
self.widget.smart_init = 0
self.widget.clusterings = {}
with patch.object(self.widget, "_compute_clustering",
Expand All @@ -476,7 +475,7 @@ def test_correct_smart_init(self):

def test_always_same_cluster(self):
"""The same random state should always return the same clusters"""
self.send_signal(self.widget.Inputs.data, self.iris[::10], wait=5000)
self.send_signal(self.widget.Inputs.data, self.data[::10], wait=5000)

def cluster():
self.widget.invalidate() # reset caches
Expand All @@ -500,6 +499,23 @@ def test_error_no_attributes(self):
self.send_signal(self.widget.Inputs.data, table)
self.assertTrue(self.widget.Error.no_attributes.is_shown())

def test_saved_selection(self):
self.widget.send_data = Mock()
self.widget.optimize_k = True
self.send_signal(self.widget.Inputs.data, self.data)
self.wait_until_stop_blocking()
self.widget.table_view.selectRow(2)
self.assertEqual(self.widget.selected_row(), 2)
self.assertEqual(self.widget.send_data.call_count, 3)
settings = self.widget.settingsHandler.pack_data(self.widget)

w = self.create_widget(OWKMeans, stored_settings=settings)
w.send_data = Mock()
self.send_signal(w.Inputs.data, self.data, widget=w)
self.wait_until_stop_blocking(widget=w)
self.assertEqual(w.send_data.call_count, 2)
self.assertEqual(self.widget.selected_row(), w.selected_row())


if __name__ == "__main__":
unittest.main()