Skip to content

Commit

Permalink
Merge pull request #4165 from aturanjanin/data_sampler
Browse files Browse the repository at this point in the history
DataSampler: data info displayed in status bar
  • Loading branch information
VesnaT authored Nov 11, 2019
2 parents dee9bba + 2435635 commit 05c8ae0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
25 changes: 8 additions & 17 deletions Orange/widgets/data/owdatasampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ def __init__(self):
self.indices = None
self.sampled_instances = self.remaining_instances = None

box = gui.vBox(self.controlArea, "Information")
self.dataInfoLabel = gui.widgetLabel(box, 'No data on input.')
self.outputInfoLabel = gui.widgetLabel(box, ' ')
self.info.set_input_summary(self.info.NoInput)
self.info.set_output_summary(self.info.NoInput)

self.sampling_box = gui.vBox(self.controlArea, "Sampling Type")
sampling = gui.radioButtons(self.sampling_box, self, "sampling_type",
Expand Down Expand Up @@ -180,16 +179,14 @@ def set_data(self, dataset):
self.cb_seed.setVisible(not sql)
self.cb_stratify.setVisible(not sql)
self.cb_sql_dl.setVisible(sql)
self.dataInfoLabel.setText(
'{}{} instances in input dataset.'.format(*(
('~', dataset.approx_len()) if sql else
('', len(dataset)))))
self.info.set_input_summary(str(len(dataset)))

if not sql:
self._update_sample_max_size()
self.updateindices()
else:
self.dataInfoLabel.setText('No data on input.')
self.outputInfoLabel.setText('')
self.info.set_input_summary(self.info.NoInput)
self.info.set_output_summary(self.info.NoInput)
self.indices = None
self.clear_messages()
self.commit()
Expand All @@ -205,7 +202,6 @@ def commit(self):
if self.data is None:
sample = other = None
self.sampled_instances = self.remaining_instances = None
self.outputInfoLabel.setText("")
elif isinstance(self.data, SqlTable):
other = None
if self.sampling_type == self.SqlProportion:
Expand All @@ -226,15 +222,10 @@ def commit(self):
if self.sampling_type in (
self.FixedProportion, self.FixedSize, self.Bootstrap):
remaining, sample = self.indices
self.outputInfoLabel.setText(
'Outputting %d instance%s.' %
(len(sample), "s" * (len(sample) != 1)))
elif self.sampling_type == self.CrossValidation:
remaining, sample = self.indices[self.selectedFold - 1]
self.outputInfoLabel.setText(
'Outputting fold %d, %d instance%s.' %
(self.selectedFold, len(sample), "s" * (len(sample) != 1))
)
self.info.set_output_summary(str(len(sample)))

sample = self.data[sample]
other = self.data[remaining]
self.sampled_instances = len(sample)
Expand Down
26 changes: 24 additions & 2 deletions Orange/widgets/data/tests/test_owdatasampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
# pylint: disable=missing-docstring,unsubscriptable-object
from unittest.mock import Mock

from Orange.data import Table
from Orange.widgets.data.owdatasampler import OWDataSampler
from Orange.widgets.tests.base import WidgetTest
Expand Down Expand Up @@ -36,6 +38,7 @@ def test_stratified_on_unbalanced_data(self):
self.assertTrue(self.widget.Warning.could_not_stratify.is_shown())

def test_bootstrap(self):
output_sum = self.widget.info.set_output_summary = Mock()
self.select_sampling_type(self.widget.Bootstrap)

self.send_signal("Data", self.iris)
Expand All @@ -56,16 +59,20 @@ def test_bootstrap(self):
# high probability (1-(1/150*2/150*...*150/150) ~= 1-2e-64)
self.assertGreater(len(in_sample), 0)
self.assertGreater(len(in_remaining), 0)
#Check if status bar shows correct number of output data
output_sum.assert_called_with(str(len(sample)))

def select_sampling_type(self, sampling_type):
buttons = self.widget.controls.sampling_type.group.buttons()
buttons[sampling_type].click()

def test_no_intersection_in_outputs(self):
""" Check whether outputs intersect and whether length of outputs sums
to length of original data """
to length of original data and
if status bar displays correct output for each sampling type"""
self.send_signal("Data", self.iris)
w = self.widget
output_sum = self.widget.info.set_output_summary = Mock()
sampling_types = [w.FixedProportion, w.FixedSize, w.CrossValidation]

for replicable in [True, False]:
Expand All @@ -80,6 +87,7 @@ def test_no_intersection_in_outputs(self):
other = self.get_output("Remaining Data")
self.assertEqual(len(self.iris), len(sample) + len(other))
self.assertNoIntersection(sample, other)
output_sum.assert_called_with(str(len(sample)))

def test_bigger_size_with_replacement(self):
"""Allow bigger output without replacement."""
Expand Down Expand Up @@ -116,6 +124,20 @@ def test_shuffling(self):
self.assertTrue((self.iris.ids != sample.ids).any())
self.assertEqual(set(self.iris.ids), set(sample.ids))

def test_summary(self):
"""Check if status bar is updated when data is received"""
input_sum = self.widget.info.set_input_summary = Mock()
data = self.iris

input_sum.reset_mock()
self.send_signal(self.widget.Inputs.data, data[:])
input_sum.assert_called_with("150")

input_sum.reset_mock()
self.send_signal(self.widget.Inputs.data, None)
input_sum.assert_called_once()
self.assertEqual(input_sum.call_args[0][0].brief, "")

def set_fixed_sample_size(self, sample_size, with_replacement=False):
"""Set fixed sample size and return the number of gui spin.
Expand Down

0 comments on commit 05c8ae0

Please sign in to comment.