Skip to content

Commit

Permalink
Merge pull request #4959 from janezd/distributions-sort
Browse files Browse the repository at this point in the history
[ENH] Distributions: Add sorting by category size
  • Loading branch information
ajdapretnar authored Sep 4, 2020
2 parents e812868 + 8a48cbd commit 6ef5f7d
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 18 deletions.
41 changes: 32 additions & 9 deletions Orange/widgets/visualize/owdistributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class Warning(OWWidget.Warning):
show_probs = settings.Setting(False)
stacked_columns = settings.Setting(False)
cumulative_distr = settings.Setting(False)
sort_by_freq = settings.Setting(False)
kde_smoothing = settings.Setting(10)

auto_apply = settings.Setting(True)
Expand Down Expand Up @@ -314,11 +315,14 @@ def __init__(self):
self.key_operation = None
self._user_var_bins = {}

gui.listView(
varview = gui.listView(
self.controlArea, self, "var", box="Variable",
model=DomainModel(valid_types=DomainModel.PRIMITIVE,
separators=False),
callback=self._on_var_changed)
gui.checkBox(
varview.box, self, "sort_by_freq", "Sort categories by frequency",
callback=self._on_sort_by_freq, stateWhenDisabled=False)

box = self.continuous_box = gui.vBox(self.controlArea, "Distribution")
slider = gui.hSlider(
Expand Down Expand Up @@ -466,6 +470,10 @@ def _on_show_cumulative(self):
self.replot()
self.apply()

def _on_sort_by_freq(self):
self.replot()
self.apply()

def _on_bins_changed(self):
self.reset_select()
self._set_bin_width_slider_label()
Expand Down Expand Up @@ -581,6 +589,7 @@ def _set_axis_names(self):

def _update_controls_state(self):
assert self.is_valid # called only from replot, so assumes data is OK
self.controls.sort_by_freq.setDisabled(self.var.is_continuous)
self.continuous_box.setDisabled(self.var.is_discrete)
self.controls.show_probs.setDisabled(self.cvar is None)
self.controls.stacked_columns.setDisabled(self.cvar is None)
Expand Down Expand Up @@ -610,11 +619,18 @@ def _add_bar(self, x, width, padding, freqs, colors, stacked, expanded,

def _disc_plot(self):
var = self.var
self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))])
colors = [QColor(0, 128, 255)]
dist = distribution.get_distribution(self.data, self.var)
for i, freq in enumerate(dist):
desc = var.values[i]
dist = np.array(dist) # Distribution misbehaves in further operations
if self.sort_by_freq:
order = np.argsort(dist)[::-1]
else:
order = np.arange(len(dist))

ordered_values = np.array(var.values)[order]
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])

colors = [QColor(0, 128, 255)]
for i, freq, desc in zip(count(), dist[order], ordered_values):
tooltip = \
"<p style='white-space:pre;'>" \
f"<b>{escape(desc)}</b>: {int(freq)} " \
Expand All @@ -625,13 +641,20 @@ def _disc_plot(self):

def _disc_split_plot(self):
var = self.var
self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))])
conts = contingency.get_contingency(self.data, self.cvar, self.var)
conts = np.array(conts) # Contingency misbehaves in further operations
if self.sort_by_freq:
order = np.argsort(conts.sum(axis=1))[::-1]
else:
order = np.arange(len(conts))

ordered_values = np.array(var.values)[order]
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])

gcolors = [QColor(*col) for col in self.cvar.colors]
gvalues = self.cvar.values
conts = contingency.get_contingency(self.data, self.cvar, self.var)
total = len(self.data)
for i, freqs in enumerate(conts):
desc = var.values[i]
for i, freqs, desc in zip(count(), conts[order], ordered_values):
self._add_bar(
i - 0.5, 1, 0.1, freqs, gcolors,
stacked=self.stacked_columns, expanded=self.show_probs,
Expand Down
76 changes: 67 additions & 9 deletions Orange/widgets/visualize/tests/test_owdistributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,35 +376,40 @@ def test_controls_disabling(self):
cont = self.iris.domain[0]
disc = self.iris.domain.class_var
cont_box = widget.continuous_box
sort_by_freq = widget.controls.sort_by_freq
show_probs = widget.controls.show_probs
stacked = widget.controls.stacked_columns

self._set_var(cont)
self._set_cvar(disc)
self.assertFalse(sort_by_freq.isEnabled())
self.assertTrue(cont_box.isEnabled())
self.assertTrue(show_probs.isEnabled())
self.assertTrue(stacked.isEnabled())

self._set_var(cont)
self._set_cvar(None)
self.assertFalse(sort_by_freq.isEnabled())
self.assertTrue(cont_box.isEnabled())
self.assertFalse(show_probs.isEnabled())
self.assertFalse(stacked.isEnabled())

self._set_var(disc)
self._set_cvar(None)
self.assertTrue(sort_by_freq.isEnabled())
self.assertFalse(cont_box.isEnabled())
self.assertFalse(show_probs.isEnabled())
self.assertFalse(stacked.isEnabled())

self._set_var(disc)
self._set_cvar(disc)
self.assertTrue(sort_by_freq.isEnabled())
self.assertFalse(cont_box.isEnabled())
self.assertTrue(show_probs.isEnabled())
self.assertTrue(stacked.isEnabled())

if os.getenv("CI"):
# Testing all combinations takes 10-15 seconds; this should take < 2s
# Testing all combinations takes almost a minute; this should take < 2s
# Code for fitter, stacked_columns and show_probs is independent, so
# changing them simultaneously doesn't significantly degrade the tests
def test_plot_types_combinations(self):
Expand All @@ -424,6 +429,7 @@ def test_plot_types_combinations(self):
self._set_fitter(2 * b)
self._set_check(c.stacked_columns, b)
self._set_check(c.show_probs, b)
self._set_check(c.sort_by_freq, b)
qApp.processEvents()
else:
def test_plot_types_combinations(self):
Expand All @@ -433,6 +439,7 @@ def test_plot_types_combinations(self):

widget = self.widget
c = widget.controls
set_chk = self._set_check
self.send_signal(widget.Inputs.data, self.iris)
cont = self.iris.domain[0]
disc = self.iris.domain.class_var
Expand All @@ -442,14 +449,15 @@ def test_plot_types_combinations(self):
for cumulative in [False, True]:
for stack in [False, True]:
for show_probs in [False, True]:
self._set_var(var)
self._set_cvar(cvar)
self._set_fitter(fitter)
self._set_check(c.cumulative_distr,
cumulative)
self._set_check(c.stacked_columns, stack)
self._set_check(c.show_probs, show_probs)
qApp.processEvents()
for sort_by_freq in [False, True]:
self._set_var(var)
self._set_cvar(cvar)
self._set_fitter(fitter)
set_chk(c.cumulative_distr, cumulative)
set_chk(c.stacked_columns, stack)
set_chk(c.show_probs, show_probs)
set_chk(c.sort_by_freq, sort_by_freq)
qApp.processEvents()

def test_selection_grouping(self):
"""Widget groups consecutive selected bars"""
Expand Down Expand Up @@ -543,6 +551,56 @@ def test_summary(self):
self.assertEqual(info._StateInfo__output_summary.brief, "")
self.assertEqual(info._StateInfo__output_summary.details, no_output)

def test_sort_by_freq_no_split(self):
data = Table("heart_disease")
domain = data.domain
sort_by_freq = self.widget.controls.sort_by_freq

self.send_signal(self.widget.Inputs.data, data)
self._set_var(domain["gender"])
self._set_cvar(None)

self._set_check(sort_by_freq, False)
out = self.get_output(self.widget.Outputs.histogram_data)
self.assertEqual(out[0][0], "female")
self.assertEqual(out[0][1], 97)
self.assertEqual(out[1][0], "male")
self.assertEqual(out[1][1], 206)

self._set_check(sort_by_freq, True)
out = self.get_output(self.widget.Outputs.histogram_data)
self.assertEqual(out[0][0], "male")
self.assertEqual(out[0][1], 206)
self.assertEqual(out[1][0], "female")
self.assertEqual(out[1][1], 97)

def test_sort_by_freq_split(self):
data = Table("heart_disease")
domain = data.domain
sort_by_freq = self.widget.controls.sort_by_freq

self.send_signal(self.widget.Inputs.data, data)
self._set_var(domain["gender"])
self._set_cvar(domain["rest ECG"])

self._set_check(sort_by_freq, False)
out = self.get_output(self.widget.Outputs.histogram_data)
self.assertEqual(out[0][0], "female")
self.assertEqual(out[0][1], "normal")
self.assertEqual(out[0][2], 49)
self.assertEqual(out[4][0], "male")
self.assertEqual(out[4][1], "left vent hypertrophy")
self.assertEqual(out[4][2], 103)

self._set_check(sort_by_freq, True)
out = self.get_output(self.widget.Outputs.histogram_data)
self.assertEqual(out[0][0], "male")
self.assertEqual(out[0][1], "normal")
self.assertEqual(out[0][2], 102)
self.assertEqual(out[4][0], "female")
self.assertEqual(out[4][1], "left vent hypertrophy")
self.assertEqual(out[4][2], 45)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6ef5f7d

Please sign in to comment.