Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Group By - Fix error with hidden attributes #6473

Merged
merged 1 commit into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions Orange/widgets/data/owgroupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
QCheckBox,
QGridLayout,
QHeaderView,
QListView,
QTableView,
)
from orangewidget.settings import ContextSetting, Setting
from orangewidget.utils.listview import ListViewSearch
from orangewidget.utils.listview import ListViewFilter
from orangewidget.utils.signals import Input, Output
from orangewidget.utils import enum_as_int
from orangewidget.widget import Msg
Expand All @@ -40,7 +39,6 @@
from Orange.data.aggregate import OrangeTableGroupBy
from Orange.util import wrap_callback
from Orange.widgets import gui
from Orange.widgets.data.oweditdomain import disconnected
from Orange.widgets.settings import DomainContextHandler
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
from Orange.widgets.utils.itemmodels import DomainModel
Expand Down Expand Up @@ -246,7 +244,7 @@ def headerData(self, i, orientation, role=Qt.DisplayRole) -> str:
return super().headerData(i, orientation, role)


class AggregateListViewSearch(ListViewSearch):
class AggregateListViewSearch(ListViewFilter):
"""ListViewSearch that disables unselecting all items in the list"""

def selectionCommand(
Expand Down Expand Up @@ -372,13 +370,16 @@ def __init__(self):

def __init_control_area(self) -> None:
"""Init all controls in the control area"""
box = gui.vBox(self.controlArea, "Group by")
self.gb_attrs_view = AggregateListViewSearch(
selectionMode=QListView.ExtendedSelection
gui.listView(
self.controlArea,
self,
"gb_attrs",
box="Group by",
model=self.gb_attrs_model,
viewType=AggregateListViewSearch,
callback=self.__gb_changed,
selectionMode=ListViewFilter.ExtendedSelection,
)
self.gb_attrs_view.setModel(self.gb_attrs_model)
self.gb_attrs_view.selectionModel().selectionChanged.connect(self.__gb_changed)
box.layout().addWidget(self.gb_attrs_view)

gui.auto_send(self.buttonsArea, self, "auto_commit")

Expand Down Expand Up @@ -434,14 +435,7 @@ def __rows_selected(self) -> None:
)

def __gb_changed(self) -> None:
"""
Callback for Group-by attributes selection change; update attribute
and call commit
"""
rows = self.gb_attrs_view.selectionModel().selectedRows()
values = self.gb_attrs_view.model()[:]
self.gb_attrs = [values[row.row()] for row in sorted(rows)]
# everything cached in result should be recomputed on gb change
"""Callback for Group-by attributes selection change"""
self.result = Result()
self.commit.deferred()

Expand Down Expand Up @@ -471,7 +465,7 @@ def set_data(self, data: Table) -> None:
self.result = Result()
self.Outputs.data.send(None)
self.gb_attrs_model.set_domain(data.domain if data else None)
self.gb_attrs = data.domain[:1] if data else []
self.gb_attrs = self.gb_attrs_model[:1] if self.gb_attrs_model else []
self.aggregations = (
{
attr: DEFAULT_AGGREGATIONS[type(attr)].copy()
Expand Down Expand Up @@ -526,14 +520,16 @@ def get_selected_attributes(self):
return [vars_[index.row()] for index in sel_rows]

def _set_gb_selection(self) -> None:
"""Set selection in groupby list according to self.gb_attrs"""
sm = self.gb_attrs_view.selectionModel()
"""
Update selected attributes. When context includes variable hidden in
data, it will match and gb_attrs may include hidden attribute. Remove it
since otherwise widget groups by attribute that is not present in view.
"""
values = self.gb_attrs_model[:]
with disconnected(sm.selectionChanged, self.__gb_changed):
for val in self.gb_attrs:
index = values.index(val)
model_index = self.gb_attrs_model.index(index, 0)
sm.select(model_index, QItemSelectionModel.Select)
self.gb_attrs = [var_ for var_ in self.gb_attrs if var_ in values]
if not self.gb_attrs and self.gb_attrs_model:
# if gb_attrs empty select first
self.gb_attrs = self.gb_attrs_model[:1]

@staticmethod
def __aggregation_compatible(agg, attr):
Expand Down
62 changes: 47 additions & 15 deletions Orange/widgets/data/tests/test_owgroupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_data(self):
self.assertEqual(self.widget.gb_attrs_model.rowCount(), 5)

output = self.get_output(self.widget.Outputs.data)
self.assertEqual(len(output), 35)
self.assertEqual(3, len(output))

self.send_signal(self.widget.Inputs.data, None)
self.assertIsNone(self.get_output(self.widget.Outputs.data))
Expand All @@ -63,27 +63,30 @@ def _set_selection(view: QListView, indices: List[int]):
sm = view.selectionModel()
model = view.model()
for ind in indices:
sm.select(model.index(ind), QItemSelectionModel.Select)
sm.select(model.index(ind, 0), QItemSelectionModel.Select)

def test_groupby_attr_selection(self):
gb_view = self.widget.controls.gb_attrs
self.send_signal(self.widget.Inputs.data, self.iris)

self._set_selection(gb_view, [1]) # sepal length
self.wait_until_finished()
output = self.get_output(self.widget.Outputs.data)
self.assertEqual(len(output), 35)
self.assertEqual(35, len(output))

# select iris attribute with index 0
self._set_selection(self.widget.gb_attrs_view, [0])
self._set_selection(gb_view, [0])
self.wait_until_finished()

output = self.get_output(self.widget.Outputs.data)
self.assertEqual(len(output), 3)
self.assertEqual(3, len(output))

# select iris attribute with index 0
self._set_selection(self.widget.gb_attrs_view, [0, 1])
# select iris and sepal length attribute
self._set_selection(gb_view, [0, 1])
self.wait_until_finished()

output = self.get_output(self.widget.Outputs.data)
self.assertEqual(len(output), 57)
self.assertEqual(57, len(output))

def assert_enabled_cbs(self, enabled_true):
enabled_actual = set(
Expand Down Expand Up @@ -399,6 +402,7 @@ def test_aggregations_change(self):
def test_aggregation(self):
"""Test aggregation results"""
self.send_signal(self.widget.Inputs.data, self.data)
self._set_selection(self.widget.controls.gb_attrs, [1]) # a var
output = self.get_output(self.widget.Outputs.data)

np.testing.assert_array_almost_equal(
Expand All @@ -422,7 +426,7 @@ def test_aggregation(self):
)

# select all aggregations for all features except a and b
self._set_selection(self.widget.gb_attrs_view, [1, 2])
self._set_selection(self.widget.controls.gb_attrs, [1, 2])
self.select_table_rows(self.widget.agg_table_view, [2, 3, 4])
# select all aggregations
for cb in self.widget.agg_checkboxes.values():
Expand Down Expand Up @@ -525,7 +529,7 @@ def test_aggregation(self):
def test_metas_results(self):
"""Test if variable that is in meta in input table remains in metas"""
self.send_signal(self.widget.Inputs.data, self.data)
self._set_selection(self.widget.gb_attrs_view, [0, 1])
self._set_selection(self.widget.controls.gb_attrs, [0, 1])

output = self.get_output(self.widget.Outputs.data)
self.assertIn(self.data.domain["svar"], output.domain.metas)
Expand All @@ -544,7 +548,7 @@ def test_context(self):
["Mean, Median", "Mean", "Mean, Median", "Mode", "Concatenate"]
)

self._set_selection(self.widget.gb_attrs_view, [1, 2])
self._set_selection(self.widget.controls.gb_attrs, [1, 2])
self.assertListEqual([d["a"], d["b"]], self.widget.gb_attrs)
self.assertDictEqual(
{
Expand All @@ -564,7 +568,7 @@ def test_context(self):
self.assert_aggregations_equal(
["Mean, Median", "Mean", "Mean, Median", "Mode", "Concatenate"]
)
self._set_selection(self.widget.gb_attrs_view, [1, 2])
self._set_selection(self.widget.controls.gb_attrs, [1, 2])
self.assertListEqual([d["a"], d["b"]], self.widget.gb_attrs)
self.assertDictEqual(
{
Expand Down Expand Up @@ -624,14 +628,14 @@ def test_time_variable(self):

# time variable as a group by variable
self.send_signal(self.widget.Inputs.data, data)
self._set_selection(self.widget.gb_attrs_view, [3])
self._set_selection(self.widget.controls.gb_attrs, [3])
output = self.get_output(self.widget.Outputs.data)
self.assertEqual(3, len(output))

# time variable as a grouped variable
attributes = [data.domain["c2"], data.domain["d2"]]
self.send_signal(self.widget.Inputs.data, data[:, attributes])
self._set_selection(self.widget.gb_attrs_view, [1]) # d2
self._set_selection(self.widget.controls.gb_attrs, [1]) # d2
# check all aggregations
self.assert_aggregations_equal(["Mean", "Mode"])
self.select_table_rows(self.widget.agg_table_view, [0]) # c2
Expand Down Expand Up @@ -855,7 +859,7 @@ def test_only_nan_in_group(self):
self.send_signal(self.widget.Inputs.data, data)

# select feature A as group-by
self._set_selection(self.widget.gb_attrs_view, [0])
self._set_selection(self.widget.controls.gb_attrs, [0])
# select all aggregations for feature B
self.select_table_rows(self.widget.agg_table_view, [1])
for cb in self.widget.agg_checkboxes.values():
Expand Down Expand Up @@ -908,6 +912,34 @@ def test_only_nan_in_group(self):
check_categorical=False,
)

def test_hidden_attributes(self):
domain = self.iris.domain
data = self.iris.transform(domain.copy())

data.domain.attributes[0].attributes["hidden"] = True
self.send_signal(self.widget.Inputs.data, data)
self.assertListEqual([data.domain["iris"]], self.widget.gb_attrs)

data = self.iris.transform(domain.copy())
data.domain.class_vars[0].attributes["hidden"] = True
self.send_signal(self.widget.Inputs.data, data)
# iris is hidden now so sepal length is selected
self.assertListEqual([data.domain["sepal length"]], self.widget.gb_attrs)

d = domain.copy()
data = self.iris.transform(Domain(d.attributes[:3], metas=d.attributes[3:]))
data.domain.metas[0].attributes["hidden"] = True
self.send_signal(self.widget.Inputs.data, data)
# sepal length still selected because of context
self.assertListEqual([data.domain["sepal length"]], self.widget.gb_attrs)

# test case when one of two selected attributes is hidden
self._set_selection(self.widget.controls.gb_attrs, [0, 1]) # sep l, sep w
data.domain.attributes[0].attributes["hidden"] = True
self.send_signal(self.widget.Inputs.data, data)
# sepal length is hidden - only sepal width remain selected
self.assertListEqual([data.domain["sepal width"]], self.widget.gb_attrs)


if __name__ == "__main__":
unittest.main()