Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Distributions: Add sorting by category size #4959

Merged
merged 1 commit into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions Orange/widgets/visualize/owdistributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class Warning(OWWidget.Warning):
show_probs = settings.Setting(False)
stacked_columns = settings.Setting(False)
cumulative_distr = settings.Setting(False)
sort_by_freq = settings.Setting(False)
kde_smoothing = settings.Setting(10)

auto_apply = settings.Setting(True)
Expand Down Expand Up @@ -314,11 +315,14 @@ def __init__(self):
self.key_operation = None
self._user_var_bins = {}

gui.listView(
varview = gui.listView(
self.controlArea, self, "var", box="Variable",
model=DomainModel(valid_types=DomainModel.PRIMITIVE,
separators=False),
callback=self._on_var_changed)
gui.checkBox(
varview.box, self, "sort_by_freq", "Sort categories by frequency",
callback=self._on_sort_by_freq, stateWhenDisabled=False)

box = self.continuous_box = gui.vBox(self.controlArea, "Distribution")
slider = gui.hSlider(
Expand Down Expand Up @@ -466,6 +470,10 @@ def _on_show_cumulative(self):
self.replot()
self.apply()

def _on_sort_by_freq(self):
self.replot()
self.apply()

def _on_bins_changed(self):
self.reset_select()
self._set_bin_width_slider_label()
Expand Down Expand Up @@ -581,6 +589,7 @@ def _set_axis_names(self):

def _update_controls_state(self):
assert self.is_valid # called only from replot, so assumes data is OK
self.controls.sort_by_freq.setDisabled(self.var.is_continuous)
self.continuous_box.setDisabled(self.var.is_discrete)
self.controls.show_probs.setDisabled(self.cvar is None)
self.controls.stacked_columns.setDisabled(self.cvar is None)
Expand Down Expand Up @@ -610,11 +619,18 @@ def _add_bar(self, x, width, padding, freqs, colors, stacked, expanded,

def _disc_plot(self):
var = self.var
self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))])
colors = [QColor(0, 128, 255)]
dist = distribution.get_distribution(self.data, self.var)
for i, freq in enumerate(dist):
desc = var.values[i]
dist = np.array(dist) # Distribution misbehaves in further operations
if self.sort_by_freq:
order = np.argsort(dist)[::-1]
else:
order = np.arange(len(dist))

ordered_values = np.array(var.values)[order]
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])

colors = [QColor(0, 128, 255)]
for i, freq, desc in zip(count(), dist[order], ordered_values):
tooltip = \
"<p style='white-space:pre;'>" \
f"<b>{escape(desc)}</b>: {int(freq)} " \
Expand All @@ -625,13 +641,20 @@ def _disc_plot(self):

def _disc_split_plot(self):
var = self.var
self.ploti.getAxis("bottom").setTicks([list(enumerate(var.values))])
conts = contingency.get_contingency(self.data, self.cvar, self.var)
conts = np.array(conts) # Contingency misbehaves in further operations
if self.sort_by_freq:
order = np.argsort(conts.sum(axis=1))[::-1]
else:
order = np.arange(len(conts))

ordered_values = np.array(var.values)[order]
self.ploti.getAxis("bottom").setTicks([list(enumerate(ordered_values))])

gcolors = [QColor(*col) for col in self.cvar.colors]
gvalues = self.cvar.values
conts = contingency.get_contingency(self.data, self.cvar, self.var)
total = len(self.data)
for i, freqs in enumerate(conts):
desc = var.values[i]
for i, freqs, desc in zip(count(), conts[order], ordered_values):
self._add_bar(
i - 0.5, 1, 0.1, freqs, gcolors,
stacked=self.stacked_columns, expanded=self.show_probs,
Expand Down
76 changes: 67 additions & 9 deletions Orange/widgets/visualize/tests/test_owdistributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,35 +376,40 @@ def test_controls_disabling(self):
cont = self.iris.domain[0]
disc = self.iris.domain.class_var
cont_box = widget.continuous_box
sort_by_freq = widget.controls.sort_by_freq
show_probs = widget.controls.show_probs
stacked = widget.controls.stacked_columns

self._set_var(cont)
self._set_cvar(disc)
self.assertFalse(sort_by_freq.isEnabled())
self.assertTrue(cont_box.isEnabled())
self.assertTrue(show_probs.isEnabled())
self.assertTrue(stacked.isEnabled())

self._set_var(cont)
self._set_cvar(None)
self.assertFalse(sort_by_freq.isEnabled())
self.assertTrue(cont_box.isEnabled())
self.assertFalse(show_probs.isEnabled())
self.assertFalse(stacked.isEnabled())

self._set_var(disc)
self._set_cvar(None)
self.assertTrue(sort_by_freq.isEnabled())
self.assertFalse(cont_box.isEnabled())
self.assertFalse(show_probs.isEnabled())
self.assertFalse(stacked.isEnabled())

self._set_var(disc)
self._set_cvar(disc)
self.assertTrue(sort_by_freq.isEnabled())
self.assertFalse(cont_box.isEnabled())
self.assertTrue(show_probs.isEnabled())
self.assertTrue(stacked.isEnabled())

if os.getenv("CI"):
# Testing all combinations takes 10-15 seconds; this should take < 2s
# Testing all combinations takes almost a minute; this should take < 2s
# Code for fitter, stacked_columns and show_probs is independent, so
# changing them simultaneously doesn't significantly degrade the tests
def test_plot_types_combinations(self):
Expand All @@ -424,6 +429,7 @@ def test_plot_types_combinations(self):
self._set_fitter(2 * b)
self._set_check(c.stacked_columns, b)
self._set_check(c.show_probs, b)
self._set_check(c.sort_by_freq, b)
qApp.processEvents()
else:
def test_plot_types_combinations(self):
Expand All @@ -433,6 +439,7 @@ def test_plot_types_combinations(self):

widget = self.widget
c = widget.controls
set_chk = self._set_check
self.send_signal(widget.Inputs.data, self.iris)
cont = self.iris.domain[0]
disc = self.iris.domain.class_var
Expand All @@ -442,14 +449,15 @@ def test_plot_types_combinations(self):
for cumulative in [False, True]:
for stack in [False, True]:
for show_probs in [False, True]:
self._set_var(var)
self._set_cvar(cvar)
self._set_fitter(fitter)
self._set_check(c.cumulative_distr,
cumulative)
self._set_check(c.stacked_columns, stack)
self._set_check(c.show_probs, show_probs)
qApp.processEvents()
for sort_by_freq in [False, True]:
self._set_var(var)
self._set_cvar(cvar)
self._set_fitter(fitter)
set_chk(c.cumulative_distr, cumulative)
set_chk(c.stacked_columns, stack)
set_chk(c.show_probs, show_probs)
set_chk(c.sort_by_freq, sort_by_freq)
qApp.processEvents()

def test_selection_grouping(self):
"""Widget groups consecutive selected bars"""
Expand Down Expand Up @@ -543,6 +551,56 @@ def test_summary(self):
self.assertEqual(info._StateInfo__output_summary.brief, "")
self.assertEqual(info._StateInfo__output_summary.details, no_output)

def test_sort_by_freq_no_split(self):
data = Table("heart_disease")
domain = data.domain
sort_by_freq = self.widget.controls.sort_by_freq

self.send_signal(self.widget.Inputs.data, data)
self._set_var(domain["gender"])
self._set_cvar(None)

self._set_check(sort_by_freq, False)
out = self.get_output(self.widget.Outputs.histogram_data)
self.assertEqual(out[0][0], "female")
self.assertEqual(out[0][1], 97)
self.assertEqual(out[1][0], "male")
self.assertEqual(out[1][1], 206)

self._set_check(sort_by_freq, True)
out = self.get_output(self.widget.Outputs.histogram_data)
self.assertEqual(out[0][0], "male")
self.assertEqual(out[0][1], 206)
self.assertEqual(out[1][0], "female")
self.assertEqual(out[1][1], 97)

def test_sort_by_freq_split(self):
data = Table("heart_disease")
domain = data.domain
sort_by_freq = self.widget.controls.sort_by_freq

self.send_signal(self.widget.Inputs.data, data)
self._set_var(domain["gender"])
self._set_cvar(domain["rest ECG"])

self._set_check(sort_by_freq, False)
out = self.get_output(self.widget.Outputs.histogram_data)
self.assertEqual(out[0][0], "female")
self.assertEqual(out[0][1], "normal")
self.assertEqual(out[0][2], 49)
self.assertEqual(out[4][0], "male")
self.assertEqual(out[4][1], "left vent hypertrophy")
self.assertEqual(out[4][2], 103)

self._set_check(sort_by_freq, True)
out = self.get_output(self.widget.Outputs.histogram_data)
self.assertEqual(out[0][0], "male")
self.assertEqual(out[0][1], "normal")
self.assertEqual(out[0][2], 102)
self.assertEqual(out[4][0], "female")
self.assertEqual(out[4][1], "left vent hypertrophy")
self.assertEqual(out[4][2], 45)


if __name__ == "__main__":
unittest.main()