From d9f5dd179644fb2fa0909bc23df742f4edc6e549 Mon Sep 17 00:00:00 2001
From: PrimozGodec
Date: Fri, 9 Jun 2023 12:15:47 +0200
Subject: [PATCH] Group By - Fix error with hidden attributes
---
Orange/widgets/data/owgroupby.py | 48 ++++++++--------
Orange/widgets/data/tests/test_owgroupby.py | 62 ++++++++++++++++-----
2 files changed, 69 insertions(+), 41 deletions(-)
diff --git a/Orange/widgets/data/owgroupby.py b/Orange/widgets/data/owgroupby.py
index 0f3ce30f84c..01b332cb451 100644
--- a/Orange/widgets/data/owgroupby.py
+++ b/Orange/widgets/data/owgroupby.py
@@ -18,11 +18,10 @@
QCheckBox,
QGridLayout,
QHeaderView,
- QListView,
QTableView,
)
from orangewidget.settings import ContextSetting, Setting
-from orangewidget.utils.listview import ListViewSearch
+from orangewidget.utils.listview import ListViewFilter
from orangewidget.utils.signals import Input, Output
from orangewidget.utils import enum_as_int
from orangewidget.widget import Msg
@@ -40,7 +39,6 @@
from Orange.data.aggregate import OrangeTableGroupBy
from Orange.util import wrap_callback
from Orange.widgets import gui
-from Orange.widgets.data.oweditdomain import disconnected
from Orange.widgets.settings import DomainContextHandler
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
from Orange.widgets.utils.itemmodels import DomainModel
@@ -246,7 +244,7 @@ def headerData(self, i, orientation, role=Qt.DisplayRole) -> str:
return super().headerData(i, orientation, role)
-class AggregateListViewSearch(ListViewSearch):
+class AggregateListViewSearch(ListViewFilter):
"""ListViewSearch that disables unselecting all items in the list"""
def selectionCommand(
@@ -372,13 +370,16 @@ def __init__(self):
def __init_control_area(self) -> None:
"""Init all controls in the control area"""
- box = gui.vBox(self.controlArea, "Group by")
- self.gb_attrs_view = AggregateListViewSearch(
- selectionMode=QListView.ExtendedSelection
+ gui.listView(
+ self.controlArea,
+ self,
+ "gb_attrs",
+ box="Group by",
+ model=self.gb_attrs_model,
+ viewType=AggregateListViewSearch,
+ callback=self.__gb_changed,
+ selectionMode=ListViewFilter.ExtendedSelection,
)
- self.gb_attrs_view.setModel(self.gb_attrs_model)
- self.gb_attrs_view.selectionModel().selectionChanged.connect(self.__gb_changed)
- box.layout().addWidget(self.gb_attrs_view)
gui.auto_send(self.buttonsArea, self, "auto_commit")
@@ -434,14 +435,7 @@ def __rows_selected(self) -> None:
)
def __gb_changed(self) -> None:
- """
- Callback for Group-by attributes selection change; update attribute
- and call commit
- """
- rows = self.gb_attrs_view.selectionModel().selectedRows()
- values = self.gb_attrs_view.model()[:]
- self.gb_attrs = [values[row.row()] for row in sorted(rows)]
- # everything cached in result should be recomputed on gb change
+ """Callback for Group-by attributes selection change"""
self.result = Result()
self.commit.deferred()
@@ -471,7 +465,7 @@ def set_data(self, data: Table) -> None:
self.result = Result()
self.Outputs.data.send(None)
self.gb_attrs_model.set_domain(data.domain if data else None)
- self.gb_attrs = data.domain[:1] if data else []
+ self.gb_attrs = self.gb_attrs_model[:1] if self.gb_attrs_model else []
self.aggregations = (
{
attr: DEFAULT_AGGREGATIONS[type(attr)].copy()
@@ -526,14 +520,16 @@ def get_selected_attributes(self):
return [vars_[index.row()] for index in sel_rows]
def _set_gb_selection(self) -> None:
- """Set selection in groupby list according to self.gb_attrs"""
- sm = self.gb_attrs_view.selectionModel()
+ """
+ Update selected attributes. When context includes variable hidden in
+ data, it will match and gb_attrs may include hidden attribute. Remove it
+ since otherwise widget groups by attribute that is not present in view.
+ """
values = self.gb_attrs_model[:]
- with disconnected(sm.selectionChanged, self.__gb_changed):
- for val in self.gb_attrs:
- index = values.index(val)
- model_index = self.gb_attrs_model.index(index, 0)
- sm.select(model_index, QItemSelectionModel.Select)
+ self.gb_attrs = [var_ for var_ in self.gb_attrs if var_ in values]
+ if not self.gb_attrs and self.gb_attrs_model:
+ # if gb_attrs empty select first
+ self.gb_attrs = self.gb_attrs_model[:1]
@staticmethod
def __aggregation_compatible(agg, attr):
diff --git a/Orange/widgets/data/tests/test_owgroupby.py b/Orange/widgets/data/tests/test_owgroupby.py
index fdfdfed18ce..94cfbc6891b 100644
--- a/Orange/widgets/data/tests/test_owgroupby.py
+++ b/Orange/widgets/data/tests/test_owgroupby.py
@@ -44,7 +44,7 @@ def test_data(self):
self.assertEqual(self.widget.gb_attrs_model.rowCount(), 5)
output = self.get_output(self.widget.Outputs.data)
- self.assertEqual(len(output), 35)
+ self.assertEqual(3, len(output))
self.send_signal(self.widget.Inputs.data, None)
self.assertIsNone(self.get_output(self.widget.Outputs.data))
@@ -63,27 +63,30 @@ def _set_selection(view: QListView, indices: List[int]):
sm = view.selectionModel()
model = view.model()
for ind in indices:
- sm.select(model.index(ind), QItemSelectionModel.Select)
+ sm.select(model.index(ind, 0), QItemSelectionModel.Select)
def test_groupby_attr_selection(self):
+ gb_view = self.widget.controls.gb_attrs
self.send_signal(self.widget.Inputs.data, self.iris)
+ self._set_selection(gb_view, [1]) # sepal length
+ self.wait_until_finished()
output = self.get_output(self.widget.Outputs.data)
- self.assertEqual(len(output), 35)
+ self.assertEqual(35, len(output))
# select iris attribute with index 0
- self._set_selection(self.widget.gb_attrs_view, [0])
+ self._set_selection(gb_view, [0])
self.wait_until_finished()
output = self.get_output(self.widget.Outputs.data)
- self.assertEqual(len(output), 3)
+ self.assertEqual(3, len(output))
- # select iris attribute with index 0
- self._set_selection(self.widget.gb_attrs_view, [0, 1])
+ # select iris and sepal length attribute
+ self._set_selection(gb_view, [0, 1])
self.wait_until_finished()
output = self.get_output(self.widget.Outputs.data)
- self.assertEqual(len(output), 57)
+ self.assertEqual(57, len(output))
def assert_enabled_cbs(self, enabled_true):
enabled_actual = set(
@@ -399,6 +402,7 @@ def test_aggregations_change(self):
def test_aggregation(self):
"""Test aggregation results"""
self.send_signal(self.widget.Inputs.data, self.data)
+ self._set_selection(self.widget.controls.gb_attrs, [1]) # a var
output = self.get_output(self.widget.Outputs.data)
np.testing.assert_array_almost_equal(
@@ -422,7 +426,7 @@ def test_aggregation(self):
)
# select all aggregations for all features except a and b
- self._set_selection(self.widget.gb_attrs_view, [1, 2])
+ self._set_selection(self.widget.controls.gb_attrs, [1, 2])
self.select_table_rows(self.widget.agg_table_view, [2, 3, 4])
# select all aggregations
for cb in self.widget.agg_checkboxes.values():
@@ -525,7 +529,7 @@ def test_aggregation(self):
def test_metas_results(self):
"""Test if variable that is in meta in input table remains in metas"""
self.send_signal(self.widget.Inputs.data, self.data)
- self._set_selection(self.widget.gb_attrs_view, [0, 1])
+ self._set_selection(self.widget.controls.gb_attrs, [0, 1])
output = self.get_output(self.widget.Outputs.data)
self.assertIn(self.data.domain["svar"], output.domain.metas)
@@ -544,7 +548,7 @@ def test_context(self):
["Mean, Median", "Mean", "Mean, Median", "Mode", "Concatenate"]
)
- self._set_selection(self.widget.gb_attrs_view, [1, 2])
+ self._set_selection(self.widget.controls.gb_attrs, [1, 2])
self.assertListEqual([d["a"], d["b"]], self.widget.gb_attrs)
self.assertDictEqual(
{
@@ -564,7 +568,7 @@ def test_context(self):
self.assert_aggregations_equal(
["Mean, Median", "Mean", "Mean, Median", "Mode", "Concatenate"]
)
- self._set_selection(self.widget.gb_attrs_view, [1, 2])
+ self._set_selection(self.widget.controls.gb_attrs, [1, 2])
self.assertListEqual([d["a"], d["b"]], self.widget.gb_attrs)
self.assertDictEqual(
{
@@ -624,14 +628,14 @@ def test_time_variable(self):
# time variable as a group by variable
self.send_signal(self.widget.Inputs.data, data)
- self._set_selection(self.widget.gb_attrs_view, [3])
+ self._set_selection(self.widget.controls.gb_attrs, [3])
output = self.get_output(self.widget.Outputs.data)
self.assertEqual(3, len(output))
# time variable as a grouped variable
attributes = [data.domain["c2"], data.domain["d2"]]
self.send_signal(self.widget.Inputs.data, data[:, attributes])
- self._set_selection(self.widget.gb_attrs_view, [1]) # d2
+ self._set_selection(self.widget.controls.gb_attrs, [1]) # d2
# check all aggregations
self.assert_aggregations_equal(["Mean", "Mode"])
self.select_table_rows(self.widget.agg_table_view, [0]) # c2
@@ -855,7 +859,7 @@ def test_only_nan_in_group(self):
self.send_signal(self.widget.Inputs.data, data)
# select feature A as group-by
- self._set_selection(self.widget.gb_attrs_view, [0])
+ self._set_selection(self.widget.controls.gb_attrs, [0])
# select all aggregations for feature B
self.select_table_rows(self.widget.agg_table_view, [1])
for cb in self.widget.agg_checkboxes.values():
@@ -908,6 +912,34 @@ def test_only_nan_in_group(self):
check_categorical=False,
)
+ def test_hidden_attributes(self):
+ domain = self.iris.domain
+ data = self.iris.transform(domain.copy())
+
+ data.domain.attributes[0].attributes["hidden"] = True
+ self.send_signal(self.widget.Inputs.data, data)
+ self.assertListEqual([data.domain["iris"]], self.widget.gb_attrs)
+
+ data = self.iris.transform(domain.copy())
+ data.domain.class_vars[0].attributes["hidden"] = True
+ self.send_signal(self.widget.Inputs.data, data)
+ # iris is hidden now so sepal length is selected
+ self.assertListEqual([data.domain["sepal length"]], self.widget.gb_attrs)
+
+ d = domain.copy()
+ data = self.iris.transform(Domain(d.attributes[:3], metas=d.attributes[3:]))
+ data.domain.metas[0].attributes["hidden"] = True
+ self.send_signal(self.widget.Inputs.data, data)
+ # sepal length still selected because of context
+ self.assertListEqual([data.domain["sepal length"]], self.widget.gb_attrs)
+
+ # test case when one of two selected attributes is hidden
+ self._set_selection(self.widget.controls.gb_attrs, [0, 1]) # sep l, sep w
+ data.domain.attributes[0].attributes["hidden"] = True
+ self.send_signal(self.widget.Inputs.data, data)
+ # sepal length is hidden - only sepal width remain selected
+ self.assertListEqual([data.domain["sepal width"]], self.widget.gb_attrs)
+
if __name__ == "__main__":
unittest.main()