From 4dccf71fec54b573dbd2230eea4129c5315e53e7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Primo=C5=BE=20Godec?=
Date: Thu, 9 Apr 2020 16:41:42 +0200
Subject: [PATCH] Pythagoranstree/forest: Remove unnecessary setting combo box
indexes
---
Orange/widgets/visualize/owpythagorastree.py | 15 +++++++++------
.../widgets/visualize/owpythagoreanforest.py | 19 +++++++++++--------
.../visualize/tests/test_owpythagorastree.py | 15 +++++++++++++++
.../tests/test_owpythagoreanforest.py | 17 +++++++++++++++++
4 files changed, 52 insertions(+), 14 deletions(-)
diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py
index 7da38a1f0aa..ecfb08a43a5 100644
--- a/Orange/widgets/visualize/owpythagorastree.py
+++ b/Orange/widgets/visualize/owpythagorastree.py
@@ -57,7 +57,7 @@ class Outputs:
graph_name = 'scene'
# Settings
- settingsHandler = settings.DomainContextHandler()
+ settingsHandler = settings.ClassValuesContextHandler()
depth_limit = settings.ContextSetting(10)
target_class_index = settings.ContextSetting(0)
@@ -155,6 +155,8 @@ def set_tree(self, model=None):
if model is not None:
self.data = model.instances
+
+ self._update_target_class_combo()
self.tree_adapter = self._get_tree_adapter(self.model)
self.ptree.clear()
@@ -169,11 +171,12 @@ def set_tree(self, model=None):
self._update_legend_colors()
self._update_legend_visibility()
self._update_info_box()
- self._update_target_class_combo()
self._update_main_area()
- self.openContext(self.model)
+ self.openContext(
+ model.domain.class_var if model.domain is not None else None
+ )
self.update_depth()
@@ -277,8 +280,7 @@ def _clear_depth_slider(self):
def _clear_target_class_combo(self):
self.target_class_combo.clear()
- self.target_class_index = 0
- self.target_class_combo.setCurrentIndex(self.target_class_index)
+ self.target_class_index = -1
def _set_max_depth(self):
"""Set the depth to the max depth and update appropriate actors."""
@@ -339,7 +341,8 @@ def _update_target_class_combo(self):
values = list(ContinuousTreeNode.COLOR_METHODS.keys())
label.setText(label_text)
self.target_class_combo.addItems(values)
- self.target_class_combo.setCurrentIndex(self.target_class_index)
+ # set it to 0, context will change if required
+ self.target_class_index = 0
def _update_legend_colors(self):
if self.legend is not None:
diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py
index 02dfc2074b8..f3df48f17eb 100644
--- a/Orange/widgets/visualize/owpythagoreanforest.py
+++ b/Orange/widgets/visualize/owpythagoreanforest.py
@@ -174,9 +174,9 @@ class Outputs:
graph_name = 'scene'
# Settings
- settingsHandler = settings.DomainContextHandler()
+ settingsHandler = settings.ClassValuesContextHandler()
- depth_limit = settings.ContextSetting(10)
+ depth_limit = settings.Setting(10)
target_class_index = settings.ContextSetting(0)
size_calc_idx = settings.Setting(0)
zoom = settings.Setting(200)
@@ -274,15 +274,18 @@ def set_rf(self, model=None):
self.rf_model = model
if model is not None:
+ self.instances = model.instances
+ self._update_target_class_combo()
+
self.forest = self._get_forest_adapter(self.rf_model)
self.forest_model[:] = self.forest.trees
- self.instances = model.instances
self._update_info_box()
- self._update_target_class_combo()
self._update_depth_slider()
- self.openContext(model)
+ self.openContext(
+ model.domain.class_var if model.domain is not None else None
+ )
# Restore item selection
if self.selected_index is not None:
index = self.list_view.model().index(self.selected_index)
@@ -324,15 +327,15 @@ def _update_target_class_combo(self):
values = list(ContinuousTreeNode.COLOR_METHODS.keys())
label.setText(label_text)
self.ui_target_class_combo.addItems(values)
- self.ui_target_class_combo.setCurrentIndex(self.target_class_index)
+ # set it to 0, context will change if required
+ self.target_class_index = 0
def _clear_info_box(self):
self.ui_info.setText('No forest on input.')
def _clear_target_class_combo(self):
self.ui_target_class_combo.clear()
- self.target_class_index = 0
- self.ui_target_class_combo.setCurrentIndex(self.target_class_index)
+ self.target_class_index = -1
def _clear_depth_slider(self):
self.ui_depth_slider.parent().setEnabled(False)
diff --git a/Orange/widgets/visualize/tests/test_owpythagorastree.py b/Orange/widgets/visualize/tests/test_owpythagorastree.py
index eb8fb4db3fe..112d57e8f8f 100644
--- a/Orange/widgets/visualize/tests/test_owpythagorastree.py
+++ b/Orange/widgets/visualize/tests/test_owpythagorastree.py
@@ -396,6 +396,21 @@ def test_changing_data_restores_depth_from_previous_settings(self):
self.send_signal(self.widget.Inputs.tree, forest.trees[1])
self.assertEqual(self.widget.ptree._depth_limit, 1)
+ def test_context(self):
+ iris_tree = TreeLearner()(Table("iris"))
+ self.send_signal(self.widget.Inputs.tree, self.titanic)
+ self.widget.target_class_index = 1
+
+ self.send_signal(self.widget.Inputs.tree, iris_tree)
+ self.assertEqual(0, self.widget.target_class_index)
+
+ self.widget.target_class_index = 2
+ self.send_signal(self.widget.Inputs.tree, self.titanic)
+ self.assertEqual(1, self.widget.target_class_index)
+
+ self.send_signal(self.widget.Inputs.tree, iris_tree)
+ self.assertEqual(2, self.widget.target_class_index)
+
if __name__ == "__main__":
unittest.main()
diff --git a/Orange/widgets/visualize/tests/test_owpythagoreanforest.py b/Orange/widgets/visualize/tests/test_owpythagoreanforest.py
index c1c62f2e344..f2c88b36212 100644
--- a/Orange/widgets/visualize/tests/test_owpythagoreanforest.py
+++ b/Orange/widgets/visualize/tests/test_owpythagoreanforest.py
@@ -221,3 +221,20 @@ def test_storing_selection(self):
output = self.get_output(self.widget.Outputs.tree)
self.assertIsNotNone(output)
self.assertIs(output.skl_model, self.titanic.trees[idx].skl_model)
+
+ def test_context(self):
+ iris = Table("iris")
+ iris_tree = RandomForestLearner()(iris)
+ iris_tree.instances = iris
+ self.send_signal(self.widget.Inputs.random_forest, self.titanic)
+ self.widget.target_class_index = 1
+
+ self.send_signal(self.widget.Inputs.random_forest, iris_tree)
+ self.assertEqual(0, self.widget.target_class_index)
+
+ self.widget.target_class_index = 2
+ self.send_signal(self.widget.Inputs.random_forest, self.titanic)
+ self.assertEqual(1, self.widget.target_class_index)
+
+ self.send_signal(self.widget.Inputs.random_forest, iris_tree)
+ self.assertEqual(2, self.widget.target_class_index)