From d9f5dd179644fb2fa0909bc23df742f4edc6e549 Mon Sep 17 00:00:00 2001 From: PrimozGodec Date: Fri, 9 Jun 2023 12:15:47 +0200 Subject: [PATCH] Group By - Fix error with hidden attributes --- Orange/widgets/data/owgroupby.py | 48 ++++++++-------- Orange/widgets/data/tests/test_owgroupby.py | 62 ++++++++++++++++----- 2 files changed, 69 insertions(+), 41 deletions(-) diff --git a/Orange/widgets/data/owgroupby.py b/Orange/widgets/data/owgroupby.py index 0f3ce30f84c..01b332cb451 100644 --- a/Orange/widgets/data/owgroupby.py +++ b/Orange/widgets/data/owgroupby.py @@ -18,11 +18,10 @@ QCheckBox, QGridLayout, QHeaderView, - QListView, QTableView, ) from orangewidget.settings import ContextSetting, Setting -from orangewidget.utils.listview import ListViewSearch +from orangewidget.utils.listview import ListViewFilter from orangewidget.utils.signals import Input, Output from orangewidget.utils import enum_as_int from orangewidget.widget import Msg @@ -40,7 +39,6 @@ from Orange.data.aggregate import OrangeTableGroupBy from Orange.util import wrap_callback from Orange.widgets import gui -from Orange.widgets.data.oweditdomain import disconnected from Orange.widgets.settings import DomainContextHandler from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState from Orange.widgets.utils.itemmodels import DomainModel @@ -246,7 +244,7 @@ def headerData(self, i, orientation, role=Qt.DisplayRole) -> str: return super().headerData(i, orientation, role) -class AggregateListViewSearch(ListViewSearch): +class AggregateListViewSearch(ListViewFilter): """ListViewSearch that disables unselecting all items in the list""" def selectionCommand( @@ -372,13 +370,16 @@ def __init__(self): def __init_control_area(self) -> None: """Init all controls in the control area""" - box = gui.vBox(self.controlArea, "Group by") - self.gb_attrs_view = AggregateListViewSearch( - selectionMode=QListView.ExtendedSelection + gui.listView( + self.controlArea, + self, + "gb_attrs", + box="Group by", + model=self.gb_attrs_model, + viewType=AggregateListViewSearch, + callback=self.__gb_changed, + selectionMode=ListViewFilter.ExtendedSelection, ) - self.gb_attrs_view.setModel(self.gb_attrs_model) - self.gb_attrs_view.selectionModel().selectionChanged.connect(self.__gb_changed) - box.layout().addWidget(self.gb_attrs_view) gui.auto_send(self.buttonsArea, self, "auto_commit") @@ -434,14 +435,7 @@ def __rows_selected(self) -> None: ) def __gb_changed(self) -> None: - """ - Callback for Group-by attributes selection change; update attribute - and call commit - """ - rows = self.gb_attrs_view.selectionModel().selectedRows() - values = self.gb_attrs_view.model()[:] - self.gb_attrs = [values[row.row()] for row in sorted(rows)] - # everything cached in result should be recomputed on gb change + """Callback for Group-by attributes selection change""" self.result = Result() self.commit.deferred() @@ -471,7 +465,7 @@ def set_data(self, data: Table) -> None: self.result = Result() self.Outputs.data.send(None) self.gb_attrs_model.set_domain(data.domain if data else None) - self.gb_attrs = data.domain[:1] if data else [] + self.gb_attrs = self.gb_attrs_model[:1] if self.gb_attrs_model else [] self.aggregations = ( { attr: DEFAULT_AGGREGATIONS[type(attr)].copy() @@ -526,14 +520,16 @@ def get_selected_attributes(self): return [vars_[index.row()] for index in sel_rows] def _set_gb_selection(self) -> None: - """Set selection in groupby list according to self.gb_attrs""" - sm = self.gb_attrs_view.selectionModel() + """ + Update selected attributes. When context includes variable hidden in + data, it will match and gb_attrs may include hidden attribute. Remove it + since otherwise widget groups by attribute that is not present in view. + """ values = self.gb_attrs_model[:] - with disconnected(sm.selectionChanged, self.__gb_changed): - for val in self.gb_attrs: - index = values.index(val) - model_index = self.gb_attrs_model.index(index, 0) - sm.select(model_index, QItemSelectionModel.Select) + self.gb_attrs = [var_ for var_ in self.gb_attrs if var_ in values] + if not self.gb_attrs and self.gb_attrs_model: + # if gb_attrs empty select first + self.gb_attrs = self.gb_attrs_model[:1] @staticmethod def __aggregation_compatible(agg, attr): diff --git a/Orange/widgets/data/tests/test_owgroupby.py b/Orange/widgets/data/tests/test_owgroupby.py index fdfdfed18ce..94cfbc6891b 100644 --- a/Orange/widgets/data/tests/test_owgroupby.py +++ b/Orange/widgets/data/tests/test_owgroupby.py @@ -44,7 +44,7 @@ def test_data(self): self.assertEqual(self.widget.gb_attrs_model.rowCount(), 5) output = self.get_output(self.widget.Outputs.data) - self.assertEqual(len(output), 35) + self.assertEqual(3, len(output)) self.send_signal(self.widget.Inputs.data, None) self.assertIsNone(self.get_output(self.widget.Outputs.data)) @@ -63,27 +63,30 @@ def _set_selection(view: QListView, indices: List[int]): sm = view.selectionModel() model = view.model() for ind in indices: - sm.select(model.index(ind), QItemSelectionModel.Select) + sm.select(model.index(ind, 0), QItemSelectionModel.Select) def test_groupby_attr_selection(self): + gb_view = self.widget.controls.gb_attrs self.send_signal(self.widget.Inputs.data, self.iris) + self._set_selection(gb_view, [1]) # sepal length + self.wait_until_finished() output = self.get_output(self.widget.Outputs.data) - self.assertEqual(len(output), 35) + self.assertEqual(35, len(output)) # select iris attribute with index 0 - self._set_selection(self.widget.gb_attrs_view, [0]) + self._set_selection(gb_view, [0]) self.wait_until_finished() output = self.get_output(self.widget.Outputs.data) - self.assertEqual(len(output), 3) + self.assertEqual(3, len(output)) - # select iris attribute with index 0 - self._set_selection(self.widget.gb_attrs_view, [0, 1]) + # select iris and sepal length attribute + self._set_selection(gb_view, [0, 1]) self.wait_until_finished() output = self.get_output(self.widget.Outputs.data) - self.assertEqual(len(output), 57) + self.assertEqual(57, len(output)) def assert_enabled_cbs(self, enabled_true): enabled_actual = set( @@ -399,6 +402,7 @@ def test_aggregations_change(self): def test_aggregation(self): """Test aggregation results""" self.send_signal(self.widget.Inputs.data, self.data) + self._set_selection(self.widget.controls.gb_attrs, [1]) # a var output = self.get_output(self.widget.Outputs.data) np.testing.assert_array_almost_equal( @@ -422,7 +426,7 @@ def test_aggregation(self): ) # select all aggregations for all features except a and b - self._set_selection(self.widget.gb_attrs_view, [1, 2]) + self._set_selection(self.widget.controls.gb_attrs, [1, 2]) self.select_table_rows(self.widget.agg_table_view, [2, 3, 4]) # select all aggregations for cb in self.widget.agg_checkboxes.values(): @@ -525,7 +529,7 @@ def test_aggregation(self): def test_metas_results(self): """Test if variable that is in meta in input table remains in metas""" self.send_signal(self.widget.Inputs.data, self.data) - self._set_selection(self.widget.gb_attrs_view, [0, 1]) + self._set_selection(self.widget.controls.gb_attrs, [0, 1]) output = self.get_output(self.widget.Outputs.data) self.assertIn(self.data.domain["svar"], output.domain.metas) @@ -544,7 +548,7 @@ def test_context(self): ["Mean, Median", "Mean", "Mean, Median", "Mode", "Concatenate"] ) - self._set_selection(self.widget.gb_attrs_view, [1, 2]) + self._set_selection(self.widget.controls.gb_attrs, [1, 2]) self.assertListEqual([d["a"], d["b"]], self.widget.gb_attrs) self.assertDictEqual( { @@ -564,7 +568,7 @@ def test_context(self): self.assert_aggregations_equal( ["Mean, Median", "Mean", "Mean, Median", "Mode", "Concatenate"] ) - self._set_selection(self.widget.gb_attrs_view, [1, 2]) + self._set_selection(self.widget.controls.gb_attrs, [1, 2]) self.assertListEqual([d["a"], d["b"]], self.widget.gb_attrs) self.assertDictEqual( { @@ -624,14 +628,14 @@ def test_time_variable(self): # time variable as a group by variable self.send_signal(self.widget.Inputs.data, data) - self._set_selection(self.widget.gb_attrs_view, [3]) + self._set_selection(self.widget.controls.gb_attrs, [3]) output = self.get_output(self.widget.Outputs.data) self.assertEqual(3, len(output)) # time variable as a grouped variable attributes = [data.domain["c2"], data.domain["d2"]] self.send_signal(self.widget.Inputs.data, data[:, attributes]) - self._set_selection(self.widget.gb_attrs_view, [1]) # d2 + self._set_selection(self.widget.controls.gb_attrs, [1]) # d2 # check all aggregations self.assert_aggregations_equal(["Mean", "Mode"]) self.select_table_rows(self.widget.agg_table_view, [0]) # c2 @@ -855,7 +859,7 @@ def test_only_nan_in_group(self): self.send_signal(self.widget.Inputs.data, data) # select feature A as group-by - self._set_selection(self.widget.gb_attrs_view, [0]) + self._set_selection(self.widget.controls.gb_attrs, [0]) # select all aggregations for feature B self.select_table_rows(self.widget.agg_table_view, [1]) for cb in self.widget.agg_checkboxes.values(): @@ -908,6 +912,34 @@ def test_only_nan_in_group(self): check_categorical=False, ) + def test_hidden_attributes(self): + domain = self.iris.domain + data = self.iris.transform(domain.copy()) + + data.domain.attributes[0].attributes["hidden"] = True + self.send_signal(self.widget.Inputs.data, data) + self.assertListEqual([data.domain["iris"]], self.widget.gb_attrs) + + data = self.iris.transform(domain.copy()) + data.domain.class_vars[0].attributes["hidden"] = True + self.send_signal(self.widget.Inputs.data, data) + # iris is hidden now so sepal length is selected + self.assertListEqual([data.domain["sepal length"]], self.widget.gb_attrs) + + d = domain.copy() + data = self.iris.transform(Domain(d.attributes[:3], metas=d.attributes[3:])) + data.domain.metas[0].attributes["hidden"] = True + self.send_signal(self.widget.Inputs.data, data) + # sepal length still selected because of context + self.assertListEqual([data.domain["sepal length"]], self.widget.gb_attrs) + + # test case when one of two selected attributes is hidden + self._set_selection(self.widget.controls.gb_attrs, [0, 1]) # sep l, sep w + data.domain.attributes[0].attributes["hidden"] = True + self.send_signal(self.widget.Inputs.data, data) + # sepal length is hidden - only sepal width remain selected + self.assertListEqual([data.domain["sepal width"]], self.widget.gb_attrs) + if __name__ == "__main__": unittest.main()