diff --git a/Orange/widgets/visualize/owdistributions.py b/Orange/widgets/visualize/owdistributions.py index e3906df4642..9c1f6ce36c2 100644 --- a/Orange/widgets/visualize/owdistributions.py +++ b/Orange/widgets/visualize/owdistributions.py @@ -280,6 +280,7 @@ class Warning(OWWidget.Warning): show_probs = settings.Setting(False) stacked_columns = settings.Setting(False) cumulative_distr = settings.Setting(False) + sort_by_freq = settings.Setting(False) kde_smoothing = settings.Setting(10) auto_apply = settings.Setting(True) @@ -314,11 +315,14 @@ def __init__(self): self.key_operation = None self._user_var_bins = {} - gui.listView( + varview = gui.listView( self.controlArea, self, "var", box="Variable", model=DomainModel(valid_types=DomainModel.PRIMITIVE, separators=False), callback=self._on_var_changed) + gui.checkBox( + varview.box, self, "sort_by_freq", "Sort categories by frequency", + callback=self._on_sort_by_freq, stateWhenDisabled=False) box = self.continuous_box = gui.vBox(self.controlArea, "Distribution") slider = gui.hSlider( @@ -466,6 +470,10 @@ def _on_show_cumulative(self): self.replot() self.apply() + def _on_sort_by_freq(self): + self.replot() + self.apply() + def _on_bins_changed(self): self.reset_select() self._set_bin_width_slider_label() @@ -581,6 +589,7 @@ def _set_axis_names(self): def _update_controls_state(self): assert self.is_valid # called only from replot, so assumes data is OK + self.controls.sort_by_freq.setDisabled(self.var.is_continuous) self.continuous_box.setDisabled(self.var.is_discrete) self.controls.show_probs.setDisabled(self.cvar is None) self.controls.stacked_columns.setDisabled(self.cvar is None) @@ -610,11 +619,18 @@ def _add_bar(self, x, width, padding, freqs, colors, stacked, expanded, def _disc_plot(self): var = self.var - self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))]) - colors = [QColor(0, 128, 255)] dist = distribution.get_distribution(self.data, self.var) - for i, freq in enumerate(dist): - desc = var.values[i] + dist = np.array(dist) # Distribution misbehaves in further operations + if self.sort_by_freq: + order = np.argsort(dist)[::-1] + else: + order = np.arange(len(dist)) + + ordered_values = np.array(var.values)[order] + self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))]) + + colors = [QColor(0, 128, 255)] + for i, freq, desc in zip(count(), dist[order], ordered_values): tooltip = \ "
" \ f"{escape(desc)}: {int(freq)} " \ @@ -625,13 +641,20 @@ def _disc_plot(self): def _disc_split_plot(self): var = self.var - self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))]) + conts = contingency.get_contingency(self.data, self.cvar, self.var) + conts = np.array(conts) # Contingency misbehaves in further operations + if self.sort_by_freq: + order = np.argsort(conts.sum(axis=1))[::-1] + else: + order = np.arange(len(conts)) + + ordered_values = np.array(var.values)[order] + self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))]) + gcolors = [QColor(*col) for col in self.cvar.colors] gvalues = self.cvar.values - conts = contingency.get_contingency(self.data, self.cvar, self.var) total = len(self.data) - for i, freqs in enumerate(conts): - desc = var.values[i] + for i, freqs, desc in zip(count(), conts[order], ordered_values): self._add_bar( i - 0.5, 1, 0.1, freqs, gcolors, stacked=self.stacked_columns, expanded=self.show_probs, diff --git a/Orange/widgets/visualize/tests/test_owdistributions.py b/Orange/widgets/visualize/tests/test_owdistributions.py index 3d4096732c1..1524d927814 100644 --- a/Orange/widgets/visualize/tests/test_owdistributions.py +++ b/Orange/widgets/visualize/tests/test_owdistributions.py @@ -376,35 +376,40 @@ def test_controls_disabling(self): cont = self.iris.domain[0] disc = self.iris.domain.class_var cont_box = widget.continuous_box + sort_by_freq = widget.controls.sort_by_freq show_probs = widget.controls.show_probs stacked = widget.controls.stacked_columns self._set_var(cont) self._set_cvar(disc) + self.assertFalse(sort_by_freq.isEnabled()) self.assertTrue(cont_box.isEnabled()) self.assertTrue(show_probs.isEnabled()) self.assertTrue(stacked.isEnabled()) self._set_var(cont) self._set_cvar(None) + self.assertFalse(sort_by_freq.isEnabled()) self.assertTrue(cont_box.isEnabled()) self.assertFalse(show_probs.isEnabled()) self.assertFalse(stacked.isEnabled()) self._set_var(disc) self._set_cvar(None) + self.assertTrue(sort_by_freq.isEnabled()) self.assertFalse(cont_box.isEnabled()) self.assertFalse(show_probs.isEnabled()) self.assertFalse(stacked.isEnabled()) self._set_var(disc) self._set_cvar(disc) + self.assertTrue(sort_by_freq.isEnabled()) self.assertFalse(cont_box.isEnabled()) self.assertTrue(show_probs.isEnabled()) self.assertTrue(stacked.isEnabled()) if os.getenv("CI"): - # Testing all combinations takes 10-15 seconds; this should take < 2s + # Testing all combinations takes almost a minute; this should take < 2s # Code for fitter, stacked_columns and show_probs is independent, so # changing them simultaneously doesn't significantly degrade the tests def test_plot_types_combinations(self): @@ -424,6 +429,7 @@ def test_plot_types_combinations(self): self._set_fitter(2 * b) self._set_check(c.stacked_columns, b) self._set_check(c.show_probs, b) + self._set_check(c.sort_by_freq, b) qApp.processEvents() else: def test_plot_types_combinations(self): @@ -433,6 +439,7 @@ def test_plot_types_combinations(self): widget = self.widget c = widget.controls + set_chk = self._set_check self.send_signal(widget.Inputs.data, self.iris) cont = self.iris.domain[0] disc = self.iris.domain.class_var @@ -442,14 +449,15 @@ def test_plot_types_combinations(self): for cumulative in [False, True]: for stack in [False, True]: for show_probs in [False, True]: - self._set_var(var) - self._set_cvar(cvar) - self._set_fitter(fitter) - self._set_check(c.cumulative_distr, - cumulative) - self._set_check(c.stacked_columns, stack) - self._set_check(c.show_probs, show_probs) - qApp.processEvents() + for sort_by_freq in [False, True]: + self._set_var(var) + self._set_cvar(cvar) + self._set_fitter(fitter) + set_chk(c.cumulative_distr, cumulative) + set_chk(c.stacked_columns, stack) + set_chk(c.show_probs, show_probs) + set_chk(c.sort_by_freq, sort_by_freq) + qApp.processEvents() def test_selection_grouping(self): """Widget groups consecutive selected bars""" @@ -543,6 +551,56 @@ def test_summary(self): self.assertEqual(info._StateInfo__output_summary.brief, "") self.assertEqual(info._StateInfo__output_summary.details, no_output) + def test_sort_by_freq_no_split(self): + data = Table("heart_disease") + domain = data.domain + sort_by_freq = self.widget.controls.sort_by_freq + + self.send_signal(self.widget.Inputs.data, data) + self._set_var(domain["gender"]) + self._set_cvar(None) + + self._set_check(sort_by_freq, False) + out = self.get_output(self.widget.Outputs.histogram_data) + self.assertEqual(out[0][0], "female") + self.assertEqual(out[0][1], 97) + self.assertEqual(out[1][0], "male") + self.assertEqual(out[1][1], 206) + + self._set_check(sort_by_freq, True) + out = self.get_output(self.widget.Outputs.histogram_data) + self.assertEqual(out[0][0], "male") + self.assertEqual(out[0][1], 206) + self.assertEqual(out[1][0], "female") + self.assertEqual(out[1][1], 97) + + def test_sort_by_freq_split(self): + data = Table("heart_disease") + domain = data.domain + sort_by_freq = self.widget.controls.sort_by_freq + + self.send_signal(self.widget.Inputs.data, data) + self._set_var(domain["gender"]) + self._set_cvar(domain["rest ECG"]) + + self._set_check(sort_by_freq, False) + out = self.get_output(self.widget.Outputs.histogram_data) + self.assertEqual(out[0][0], "female") + self.assertEqual(out[0][1], "normal") + self.assertEqual(out[0][2], 49) + self.assertEqual(out[4][0], "male") + self.assertEqual(out[4][1], "left vent hypertrophy") + self.assertEqual(out[4][2], 103) + + self._set_check(sort_by_freq, True) + out = self.get_output(self.widget.Outputs.histogram_data) + self.assertEqual(out[0][0], "male") + self.assertEqual(out[0][1], "normal") + self.assertEqual(out[0][2], 102) + self.assertEqual(out[4][0], "female") + self.assertEqual(out[4][1], "left vent hypertrophy") + self.assertEqual(out[4][2], 45) + if __name__ == "__main__": unittest.main()