Skip to content

Commit

Permalink
Group By - Fix error with hidden attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Jun 9, 2023
1 parent a2c8956 commit d9f5dd1
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 41 deletions.
48 changes: 22 additions & 26 deletions Orange/widgets/data/owgroupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
QCheckBox,
QGridLayout,
QHeaderView,
QListView,
QTableView,
)
from orangewidget.settings import ContextSetting, Setting
from orangewidget.utils.listview import ListViewSearch
from orangewidget.utils.listview import ListViewFilter
from orangewidget.utils.signals import Input, Output
from orangewidget.utils import enum_as_int
from orangewidget.widget import Msg
Expand All @@ -40,7 +39,6 @@
from Orange.data.aggregate import OrangeTableGroupBy
from Orange.util import wrap_callback
from Orange.widgets import gui
from Orange.widgets.data.oweditdomain import disconnected
from Orange.widgets.settings import DomainContextHandler
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
from Orange.widgets.utils.itemmodels import DomainModel
Expand Down Expand Up @@ -246,7 +244,7 @@ def headerData(self, i, orientation, role=Qt.DisplayRole) -> str:
return super().headerData(i, orientation, role)


class AggregateListViewSearch(ListViewSearch):
class AggregateListViewSearch(ListViewFilter):
"""ListViewSearch that disables unselecting all items in the list"""

def selectionCommand(
Expand Down Expand Up @@ -372,13 +370,16 @@ def __init__(self):

def __init_control_area(self) -> None:
"""Init all controls in the control area"""
box = gui.vBox(self.controlArea, "Group by")
self.gb_attrs_view = AggregateListViewSearch(
selectionMode=QListView.ExtendedSelection
gui.listView(
self.controlArea,
self,
"gb_attrs",
box="Group by",
model=self.gb_attrs_model,
viewType=AggregateListViewSearch,
callback=self.__gb_changed,
selectionMode=ListViewFilter.ExtendedSelection,
)
self.gb_attrs_view.setModel(self.gb_attrs_model)
self.gb_attrs_view.selectionModel().selectionChanged.connect(self.__gb_changed)
box.layout().addWidget(self.gb_attrs_view)

gui.auto_send(self.buttonsArea, self, "auto_commit")

Expand Down Expand Up @@ -434,14 +435,7 @@ def __rows_selected(self) -> None:
)

def __gb_changed(self) -> None:
"""
Callback for Group-by attributes selection change; update attribute
and call commit
"""
rows = self.gb_attrs_view.selectionModel().selectedRows()
values = self.gb_attrs_view.model()[:]
self.gb_attrs = [values[row.row()] for row in sorted(rows)]
# everything cached in result should be recomputed on gb change
"""Callback for Group-by attributes selection change"""
self.result = Result()
self.commit.deferred()

Expand Down Expand Up @@ -471,7 +465,7 @@ def set_data(self, data: Table) -> None:
self.result = Result()
self.Outputs.data.send(None)
self.gb_attrs_model.set_domain(data.domain if data else None)
self.gb_attrs = data.domain[:1] if data else []
self.gb_attrs = self.gb_attrs_model[:1] if self.gb_attrs_model else []
self.aggregations = (
{
attr: DEFAULT_AGGREGATIONS[type(attr)].copy()
Expand Down Expand Up @@ -526,14 +520,16 @@ def get_selected_attributes(self):
return [vars_[index.row()] for index in sel_rows]

def _set_gb_selection(self) -> None:
"""Set selection in groupby list according to self.gb_attrs"""
sm = self.gb_attrs_view.selectionModel()
"""
Update selected attributes. When context includes variable hidden in
data, it will match and gb_attrs may include hidden attribute. Remove it
since otherwise widget groups by attribute that is not present in view.
"""
values = self.gb_attrs_model[:]
with disconnected(sm.selectionChanged, self.__gb_changed):
for val in self.gb_attrs:
index = values.index(val)
model_index = self.gb_attrs_model.index(index, 0)
sm.select(model_index, QItemSelectionModel.Select)
self.gb_attrs = [var_ for var_ in self.gb_attrs if var_ in values]
if not self.gb_attrs and self.gb_attrs_model:
# if gb_attrs empty select first
self.gb_attrs = self.gb_attrs_model[:1]

@staticmethod
def __aggregation_compatible(agg, attr):
Expand Down
62 changes: 47 additions & 15 deletions Orange/widgets/data/tests/test_owgroupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_data(self):
self.assertEqual(self.widget.gb_attrs_model.rowCount(), 5)

output = self.get_output(self.widget.Outputs.data)
self.assertEqual(len(output), 35)
self.assertEqual(3, len(output))

self.send_signal(self.widget.Inputs.data, None)
self.assertIsNone(self.get_output(self.widget.Outputs.data))
Expand All @@ -63,27 +63,30 @@ def _set_selection(view: QListView, indices: List[int]):
sm = view.selectionModel()
model = view.model()
for ind in indices:
sm.select(model.index(ind), QItemSelectionModel.Select)
sm.select(model.index(ind, 0), QItemSelectionModel.Select)

def test_groupby_attr_selection(self):
gb_view = self.widget.controls.gb_attrs
self.send_signal(self.widget.Inputs.data, self.iris)

self._set_selection(gb_view, [1]) # sepal length
self.wait_until_finished()
output = self.get_output(self.widget.Outputs.data)
self.assertEqual(len(output), 35)
self.assertEqual(35, len(output))

# select iris attribute with index 0
self._set_selection(self.widget.gb_attrs_view, [0])
self._set_selection(gb_view, [0])
self.wait_until_finished()

output = self.get_output(self.widget.Outputs.data)
self.assertEqual(len(output), 3)
self.assertEqual(3, len(output))

# select iris attribute with index 0
self._set_selection(self.widget.gb_attrs_view, [0, 1])
# select iris and sepal length attribute
self._set_selection(gb_view, [0, 1])
self.wait_until_finished()

output = self.get_output(self.widget.Outputs.data)
self.assertEqual(len(output), 57)
self.assertEqual(57, len(output))

def assert_enabled_cbs(self, enabled_true):
enabled_actual = set(
Expand Down Expand Up @@ -399,6 +402,7 @@ def test_aggregations_change(self):
def test_aggregation(self):
"""Test aggregation results"""
self.send_signal(self.widget.Inputs.data, self.data)
self._set_selection(self.widget.controls.gb_attrs, [1]) # a var
output = self.get_output(self.widget.Outputs.data)

np.testing.assert_array_almost_equal(
Expand All @@ -422,7 +426,7 @@ def test_aggregation(self):
)

# select all aggregations for all features except a and b
self._set_selection(self.widget.gb_attrs_view, [1, 2])
self._set_selection(self.widget.controls.gb_attrs, [1, 2])
self.select_table_rows(self.widget.agg_table_view, [2, 3, 4])
# select all aggregations
for cb in self.widget.agg_checkboxes.values():
Expand Down Expand Up @@ -525,7 +529,7 @@ def test_aggregation(self):
def test_metas_results(self):
"""Test if variable that is in meta in input table remains in metas"""
self.send_signal(self.widget.Inputs.data, self.data)
self._set_selection(self.widget.gb_attrs_view, [0, 1])
self._set_selection(self.widget.controls.gb_attrs, [0, 1])

output = self.get_output(self.widget.Outputs.data)
self.assertIn(self.data.domain["svar"], output.domain.metas)
Expand All @@ -544,7 +548,7 @@ def test_context(self):
["Mean, Median", "Mean", "Mean, Median", "Mode", "Concatenate"]
)

self._set_selection(self.widget.gb_attrs_view, [1, 2])
self._set_selection(self.widget.controls.gb_attrs, [1, 2])
self.assertListEqual([d["a"], d["b"]], self.widget.gb_attrs)
self.assertDictEqual(
{
Expand All @@ -564,7 +568,7 @@ def test_context(self):
self.assert_aggregations_equal(
["Mean, Median", "Mean", "Mean, Median", "Mode", "Concatenate"]
)
self._set_selection(self.widget.gb_attrs_view, [1, 2])
self._set_selection(self.widget.controls.gb_attrs, [1, 2])
self.assertListEqual([d["a"], d["b"]], self.widget.gb_attrs)
self.assertDictEqual(
{
Expand Down Expand Up @@ -624,14 +628,14 @@ def test_time_variable(self):

# time variable as a group by variable
self.send_signal(self.widget.Inputs.data, data)
self._set_selection(self.widget.gb_attrs_view, [3])
self._set_selection(self.widget.controls.gb_attrs, [3])
output = self.get_output(self.widget.Outputs.data)
self.assertEqual(3, len(output))

# time variable as a grouped variable
attributes = [data.domain["c2"], data.domain["d2"]]
self.send_signal(self.widget.Inputs.data, data[:, attributes])
self._set_selection(self.widget.gb_attrs_view, [1]) # d2
self._set_selection(self.widget.controls.gb_attrs, [1]) # d2
# check all aggregations
self.assert_aggregations_equal(["Mean", "Mode"])
self.select_table_rows(self.widget.agg_table_view, [0]) # c2
Expand Down Expand Up @@ -855,7 +859,7 @@ def test_only_nan_in_group(self):
self.send_signal(self.widget.Inputs.data, data)

# select feature A as group-by
self._set_selection(self.widget.gb_attrs_view, [0])
self._set_selection(self.widget.controls.gb_attrs, [0])
# select all aggregations for feature B
self.select_table_rows(self.widget.agg_table_view, [1])
for cb in self.widget.agg_checkboxes.values():
Expand Down Expand Up @@ -908,6 +912,34 @@ def test_only_nan_in_group(self):
check_categorical=False,
)

def test_hidden_attributes(self):
domain = self.iris.domain
data = self.iris.transform(domain.copy())

data.domain.attributes[0].attributes["hidden"] = True
self.send_signal(self.widget.Inputs.data, data)
self.assertListEqual([data.domain["iris"]], self.widget.gb_attrs)

data = self.iris.transform(domain.copy())
data.domain.class_vars[0].attributes["hidden"] = True
self.send_signal(self.widget.Inputs.data, data)
# iris is hidden now so sepal length is selected
self.assertListEqual([data.domain["sepal length"]], self.widget.gb_attrs)

d = domain.copy()
data = self.iris.transform(Domain(d.attributes[:3], metas=d.attributes[3:]))
data.domain.metas[0].attributes["hidden"] = True
self.send_signal(self.widget.Inputs.data, data)
# sepal length still selected because of context
self.assertListEqual([data.domain["sepal length"]], self.widget.gb_attrs)

# test case when one of two selected attributes is hidden
self._set_selection(self.widget.controls.gb_attrs, [0, 1]) # sep l, sep w
data.domain.attributes[0].attributes["hidden"] = True
self.send_signal(self.widget.Inputs.data, data)
# sepal length is hidden - only sepal width remain selected
self.assertListEqual([data.domain["sepal width"]], self.widget.gb_attrs)


if __name__ == "__main__":
unittest.main()

0 comments on commit d9f5dd1

Please sign in to comment.