From 4717a0635b98a5d23356cfaef4d2347ddc1e18ca Mon Sep 17 00:00:00 2001 From: Andreja Date: Thu, 27 Sep 2018 10:56:09 +0200 Subject: [PATCH 1/3] scaling disabled --- Orange/canvas/application/application.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Orange/canvas/application/application.py b/Orange/canvas/application/application.py index f826ccf1c26..06fdaea770a 100644 --- a/Orange/canvas/application/application.py +++ b/Orange/canvas/application/application.py @@ -14,7 +14,8 @@ class CanvasApplication(QApplication): def __init__(self, argv): if hasattr(Qt, "AA_EnableHighDpiScaling"): # Turn on HighDPI support when available - QApplication.setAttribute(Qt.AA_EnableHighDpiScaling) + #QApplication.setAttribute(Qt.AA_EnableHighDpiScaling) + pass QApplication.__init__(self, argv) self.setAttribute(Qt.AA_DontShowIconsInMenus, True) From 4efecba4d04e1900c3d3dea5c0f310730d3347f9 Mon Sep 17 00:00:00 2001 From: Andreja Date: Fri, 26 Oct 2018 14:22:32 +0200 Subject: [PATCH 2/3] fixed losing data attributes when merging --- Orange/canvas/application/application.py | 4 ++-- Orange/data/table.py | 6 ++++-- Orange/widgets/data/owmergedata.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Orange/canvas/application/application.py b/Orange/canvas/application/application.py index 06fdaea770a..a893f73714b 100644 --- a/Orange/canvas/application/application.py +++ b/Orange/canvas/application/application.py @@ -14,8 +14,8 @@ class CanvasApplication(QApplication): def __init__(self, argv): if hasattr(Qt, "AA_EnableHighDpiScaling"): # Turn on HighDPI support when available - #QApplication.setAttribute(Qt.AA_EnableHighDpiScaling) - pass + QApplication.setAttribute(Qt.AA_EnableHighDpiScaling) + QApplication.__init__(self, argv) self.setAttribute(Qt.AA_DontShowIconsInMenus, True) diff --git a/Orange/data/table.py b/Orange/data/table.py index c8492fbd456..7bd7c5a294f 100644 --- a/Orange/data/table.py +++ b/Orange/data/table.py @@ -465,7 +465,7 @@ def from_table_rows(cls, source, row_indices): return self @classmethod - def from_numpy(cls, domain, X, Y=None, metas=None, W=None): + def from_numpy(cls, domain, X, Y=None, metas=None, W=None, data_attributes=None): """ Construct a table from numpy arrays with the given domain. The number of variables in the domain must match the number of columns in the @@ -482,6 +482,8 @@ def from_numpy(cls, domain, X, Y=None, metas=None, W=None): :type metas: np.array :param W: array with weights :type W: np.array + :param data_attributes: dictionary of data attributes + :type data_attributes: OrderedDict :return: """ X, Y, W = _check_arrays(X, Y, W, dtype='float64') @@ -532,7 +534,7 @@ def from_numpy(cls, domain, X, Y=None, metas=None, W=None): self.W = W self.n_rows = self.X.shape[0] cls._init_ids(self) - self.attributes = {} + self.attributes = data_attributes return self @classmethod diff --git a/Orange/widgets/data/owmergedata.py b/Orange/widgets/data/owmergedata.py index eb276976a61..ab3a35e02d0 100644 --- a/Orange/widgets/data/owmergedata.py +++ b/Orange/widgets/data/owmergedata.py @@ -369,7 +369,7 @@ def _join_table_by_indices(self, reduced_extra, indices): string_cols = [i for i, var in enumerate(domain.metas) if var.is_string] metas = self._join_array_by_indices( self.data.metas, reduced_extra.metas, indices, string_cols) - return Orange.data.Table.from_numpy(domain, X, Y, metas) + return Orange.data.Table.from_numpy(domain, X, Y, metas, data_attributes = getattr(self.data, "attributes")) @staticmethod def _join_array_by_indices(left, right, indices, string_cols=None): From c32369a4ffbeebd87d88d116ed581135d7e0e3d5 Mon Sep 17 00:00:00 2001 From: Andreja Date: Fri, 23 Nov 2018 10:07:26 +0100 Subject: [PATCH 3/3] fixed losing data attributes when merging --- Orange/canvas/application/application.py | 1 + Orange/data/table.py | 6 +- Orange/projection/radviz.py | 8 +- Orange/tests/test_pca.py | 8 +- Orange/widgets/data/owmergedata.py | 2 +- Orange/widgets/report/tests/test_report.py | 18 - Orange/widgets/tests/base.py | 129 +- Orange/widgets/unsupervised/owmds.py | 197 +-- .../widgets/unsupervised/tests/test_owmds.py | 31 +- Orange/widgets/utils/plot/owplotgui.py | 54 +- Orange/widgets/visualize/owfreeviz.py | 520 +++----- .../widgets/visualize/owlinearprojection.py | 958 ++++++-------- Orange/widgets/visualize/owradviz.py | 589 +++------ Orange/widgets/visualize/owscatterplot.py | 307 ++--- .../widgets/visualize/owscatterplotgraph.py | 1142 ++++++----------- .../widgets/visualize/tests/test_owfreeviz.py | 40 +- .../tests/test_owlinearprojection.py | 171 +-- .../tests/test_owprojectionwidget.py | 146 +++ .../widgets/visualize/tests/test_owradviz.py | 63 +- .../visualize/tests/test_owscatterplot.py | 61 +- .../visualize/tests/test_owscatterplotbase.py | 961 ++++++++++++++ Orange/widgets/visualize/utils/__init__.py | 6 +- Orange/widgets/visualize/utils/component.py | 124 +- Orange/widgets/visualize/utils/plotutils.py | 62 +- Orange/widgets/visualize/utils/widget.py | 671 ++++++++++ 25 files changed, 3443 insertions(+), 2832 deletions(-) create mode 100644 Orange/widgets/visualize/tests/test_owprojectionwidget.py create mode 100644 Orange/widgets/visualize/tests/test_owscatterplotbase.py create mode 100644 Orange/widgets/visualize/utils/widget.py diff --git a/Orange/canvas/application/application.py b/Orange/canvas/application/application.py index f826ccf1c26..a893f73714b 100644 --- a/Orange/canvas/application/application.py +++ b/Orange/canvas/application/application.py @@ -15,6 +15,7 @@ def __init__(self, argv): if hasattr(Qt, "AA_EnableHighDpiScaling"): # Turn on HighDPI support when available QApplication.setAttribute(Qt.AA_EnableHighDpiScaling) + QApplication.__init__(self, argv) self.setAttribute(Qt.AA_DontShowIconsInMenus, True) diff --git a/Orange/data/table.py b/Orange/data/table.py index c8492fbd456..7bd7c5a294f 100644 --- a/Orange/data/table.py +++ b/Orange/data/table.py @@ -465,7 +465,7 @@ def from_table_rows(cls, source, row_indices): return self @classmethod - def from_numpy(cls, domain, X, Y=None, metas=None, W=None): + def from_numpy(cls, domain, X, Y=None, metas=None, W=None, data_attributes=None): """ Construct a table from numpy arrays with the given domain. The number of variables in the domain must match the number of columns in the @@ -482,6 +482,8 @@ def from_numpy(cls, domain, X, Y=None, metas=None, W=None): :type metas: np.array :param W: array with weights :type W: np.array + :param data_attributes: dictionary of data attributes + :type data_attributes: OrderedDict :return: """ X, Y, W = _check_arrays(X, Y, W, dtype='float64') @@ -532,7 +534,7 @@ def from_numpy(cls, domain, X, Y=None, metas=None, W=None): self.W = W self.n_rows = self.X.shape[0] cls._init_ids(self) - self.attributes = {} + self.attributes = data_attributes return self @classmethod diff --git a/Orange/projection/radviz.py b/Orange/projection/radviz.py index efd026ac278..fdddf5fb481 100644 --- a/Orange/projection/radviz.py +++ b/Orange/projection/radviz.py @@ -18,7 +18,7 @@ def radviz(data, attrs, points=None): m = x.shape[1] if points is not None: - s = points[:, :2] + s = points else: s = np.array([(np.cos(t), np.sin(t)) for t in [2.0 * np.pi * (i / float(m)) @@ -33,7 +33,7 @@ def radviz(data, attrs, points=None): r_x[i] = y[0] r_y[i] = y[1] - return np.stack((r_x, r_y), axis=1), np.column_stack((s, attrs)), mask + return np.stack((r_x, r_y), axis=1), s, mask def normalize(x): @@ -42,4 +42,6 @@ def normalize(x): """ a = x.min(axis=0) b = x.max(axis=0) - return (x - a[np.newaxis, :]) / ((b - a)[np.newaxis, :]) + diff = b - a + diff[diff == 0] = 1 + return (x - a[np.newaxis, :]) / diff[np.newaxis, :] diff --git a/Orange/tests/test_pca.py b/Orange/tests/test_pca.py index 5ecb3ee2729..397686abebc 100644 --- a/Orange/tests/test_pca.py +++ b/Orange/tests/test_pca.py @@ -1,13 +1,15 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring -import unittest import pickle +import unittest + import numpy as np +from sklearn import __version__ as sklearn_version +from Orange.data import Table from Orange.preprocess import Continuize, Normalize from Orange.projection import PCA, SparsePCA, IncrementalPCA, TruncatedSVD -from Orange.data import Table class TestPCA(unittest.TestCase): @@ -62,6 +64,8 @@ def __rnd_pca_test_helper(self, data, n_com, min_xpl_var): proj = np.dot(data.X - pca_model.mean_, pca_model.components_.T) np.testing.assert_almost_equal(pca_model(data).X, proj) + @unittest.skipIf(sklearn_version.startswith('0.20'), + "https://github.com/scikit-learn/scikit-learn/issues/12234") def test_incremental_pca(self): data = self.ionosphere self.__ipca_test_helper(data, n_com=3, min_xpl_var=0.49) diff --git a/Orange/widgets/data/owmergedata.py b/Orange/widgets/data/owmergedata.py index eb276976a61..ab3a35e02d0 100644 --- a/Orange/widgets/data/owmergedata.py +++ b/Orange/widgets/data/owmergedata.py @@ -369,7 +369,7 @@ def _join_table_by_indices(self, reduced_extra, indices): string_cols = [i for i, var in enumerate(domain.metas) if var.is_string] metas = self._join_array_by_indices( self.data.metas, reduced_extra.metas, indices, string_cols) - return Orange.data.Table.from_numpy(domain, X, Y, metas) + return Orange.data.Table.from_numpy(domain, X, Y, metas, data_attributes = getattr(self.data, "attributes")) @staticmethod def _join_array_by_indices(left, right, indices, string_cols=None): diff --git a/Orange/widgets/report/tests/test_report.py b/Orange/widgets/report/tests/test_report.py index fa8a5858e45..a56f0ddcfa6 100644 --- a/Orange/widgets/report/tests/test_report.py +++ b/Orange/widgets/report/tests/test_report.py @@ -14,7 +14,6 @@ from Orange.classification.tree import TreeLearner from Orange.evaluation import CrossValidation from Orange.distance import Euclidean -from Orange.util import OrangeDeprecationWarning from Orange.widgets.report.owreport import OWReport from Orange.widgets import gui from Orange.widgets.widget import OWWidget @@ -34,7 +33,6 @@ from Orange.widgets.unsupervised.owmds import OWMDS from Orange.widgets.unsupervised.owpca import OWPCA from Orange.widgets.utils.itemmodels import PyTableModel -from Orange.widgets.visualize.owlinearprojection import OWLinearProjection def get_owwidgets(top_module_name): @@ -230,34 +228,18 @@ def test_report_widgets_unsupervised_dist(self): self._create_report(widgets, rep, dist) def test_report_widgets_visualize(self): - _warnings = warnings.catch_warnings() - _warnings.__enter__() - warnings.simplefilter("ignore", OrangeDeprecationWarning) rep = OWReport.get_instance() data = Table("zoo") widgets = self.visu_widgets self._create_report(widgets, rep, data) - _warnings.__exit__() - - def test_deprecated_graph(self): - # Remove this test and lines 17, 37, 233 - 235 and 252 -254 - # since the widget is not using deprecate class any more - with warnings.catch_warnings(): - warnings.simplefilter("error", OrangeDeprecationWarning) - self.assertRaises(OrangeDeprecationWarning, - lambda: self.create_widget(OWLinearProjection)) @unittest.skipIf(AnyQt.USED_API == "pyqt5", "Segfaults on PyQt5") def test_report_widgets_all(self): - _warnings = warnings.catch_warnings() - _warnings.__enter__() - warnings.simplefilter("ignore", OrangeDeprecationWarning) rep = OWReport.get_instance() widgets = self.model_widgets + self.data_widgets + self.eval_widgets + \ self.unsu_widgets + self.dist_widgets + self.visu_widgets + \ self.spec_widgets self._create_report(widgets, rep, None) - _warnings.__exit__() def test_disable_saving_empty(self): """Test if save and print buttons are disabled on empty report""" diff --git a/Orange/widgets/tests/base.py b/Orange/widgets/tests/base.py index e0ecebdc60e..eca7438bc28 100644 --- a/Orange/widgets/tests/base.py +++ b/Orange/widgets/tests/base.py @@ -12,31 +12,33 @@ pass import numpy as np +import scipy.sparse as sp import sip -from AnyQt.QtCore import Qt +from AnyQt.QtCore import Qt, QRectF, QPointF from AnyQt.QtTest import QTest, QSignalSpy from AnyQt.QtWidgets import ( QApplication, QComboBox, QSpinBox, QDoubleSpinBox, QSlider ) from Orange.base import SklModel, Model -from Orange.widgets.report.owreport import OWReport from Orange.classification.base_classification import ( LearnerClassification, ModelClassification ) -from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable,\ - Variable +from Orange.data import ( + Table, Domain, DiscreteVariable, ContinuousVariable, Variable +) from Orange.modelling import Fitter -from Orange.preprocess import RemoveNaNColumns, Randomize +from Orange.preprocess import RemoveNaNColumns, Randomize, Continuize from Orange.preprocess.preprocess import PreprocessorList from Orange.regression.base_regression import ( LearnerRegression, ModelRegression ) +from Orange.widgets.report.owreport import OWReport +from Orange.widgets.tests.utils import simulate from Orange.widgets.utils.annotated_data import ( ANNOTATED_DATA_FEATURE_NAME, ANNOTATED_DATA_SIGNAL_NAME ) -from Orange.widgets.tests.utils import simulate from Orange.widgets.utils.owlearnerwidget import OWBaseLearner from Orange.widgets.utils.plot import OWPlotGUI from Orange.widgets.widget import OWWidget @@ -805,15 +807,28 @@ def _compare_selected_annotated_domains(self, selected, annotated): class ProjectionWidgetTestMixin: - """Class for projection widget testing. - - It init method to set up testing parameters and some test methods - """ + """Class for projection widget testing""" def init(self): Variable._clear_all_caches() self.data = Table("iris") + def _select_data(self): + rect = QRectF(QPointF(-20, -20), QPointF(20, 20)) + self.widget.graph.select_by_rectangle(rect) + return self.widget.graph.get_selection() + + def _compare_selected_annotated_domains(self, selected, annotated): + selected_vars = selected.domain.variables + annotated_vars = annotated.domain.variables + self.assertLessEqual(set(selected_vars), set(annotated_vars)) + + def test_setup_graph(self): + """Plot should exist after data has been sent in order to be + properly set/updated""" + self.send_signal(self.widget.Inputs.data, self.data) + self.assertIsNotNone(self.widget.graph.scatterplot_item) + def test_default_attrs(self, timeout=DEFAULT_TIMEOUT): """Check default values for 'Color', 'Shape', 'Size' and 'Label'""" self.send_signal(self.widget.Inputs.data, self.data) @@ -859,7 +874,8 @@ def test_overlap(self): def test_attr_label_metas(self, timeout=DEFAULT_TIMEOUT): """Set 'Label' from string meta attribute""" - data = Table("zoo") + cont = Continuize(multinomial_treatment=Continuize.AsOrdinal) + data = cont(Table("zoo")) self.send_signal(self.widget.Inputs.data, data) if self.widget.isBlocking(): spy = QSignalSpy(self.widget.blockingStateChanged) @@ -877,28 +893,12 @@ def test_handle_primitive_metas(self): def test_datasets(self, timeout=DEFAULT_TIMEOUT): """Test widget for datasets with missing values and constant features""" - for ds in self.__datasets(): + for ds in datasets.datasets(): self.send_signal(self.widget.Inputs.data, ds) if self.widget.isBlocking(): spy = QSignalSpy(self.widget.blockingStateChanged) self.assertTrue(spy.wait(timeout)) - @staticmethod - def __datasets(): - ds_cls = Table(datasets.path("testing_dataset_cls")) - ds_reg = Table(datasets.path("testing_dataset_reg")) - for ds in (ds_cls, ds_reg): - d, a = ds.domain, ds.domain.attributes - for i in range(0, len(a), 2): - yield ds.transform(Domain(a[i: i + 2], d.class_vars, d.metas)) - yield datasets.missing_data_1() - yield datasets.missing_data_2() - yield datasets.missing_data_3() - yield datasets.data_one_column_nans() - yield datasets.data_one_column_infs() - yield ds_cls - yield ds_reg - def test_none_data(self): """Test widget for empty dataset""" self.send_signal(self.widget.Inputs.data, self.data[:0]) @@ -921,6 +921,26 @@ def test_class_density(self, timeout=DEFAULT_TIMEOUT): self.send_signal(self.widget.Inputs.data, None) self.widget.cb_class_density.click() + def test_dragging_tooltip(self): + """Dragging tooltip depends on data being jittered""" + text = self.widget.graph.tiptexts[0] + self.send_signal(self.widget.Inputs.data, Table("heart_disease")) + self.assertEqual(self.widget.graph.tip_textitem.toPlainText(), text) + self.widget.graph.controls.jitter_size.setValue(1) + self.assertGreater(self.widget.graph.tip_textitem.toPlainText(), text) + + def test_sparse_data(self, timeout=DEFAULT_TIMEOUT): + """Test widget for sparse data""" + table = Table("iris") + table.X = sp.csr_matrix(table.X) + self.assertTrue(sp.issparse(table.X)) + self.send_signal(self.widget.Inputs.data, table) + if self.widget.isBlocking(): + spy = QSignalSpy(self.widget.blockingStateChanged) + self.assertTrue(spy.wait(timeout)) + self.send_signal(self.widget.Inputs.data_subset, table[::30]) + self.assertEqual(len(self.widget.subset_indices), 5) + def test_send_report(self, timeout=DEFAULT_TIMEOUT): """Test report """ self.send_signal(self.widget.Inputs.data, self.data) @@ -932,6 +952,35 @@ def test_send_report(self, timeout=DEFAULT_TIMEOUT): self.widget.report_button.click() +class AnchorProjectionWidgetTestMixin(ProjectionWidgetTestMixin): + def test_sparse_data(self): + table = Table("iris") + table.X = sp.csr_matrix(table.X) + self.assertTrue(sp.issparse(table.X)) + self.send_signal(self.widget.Inputs.data, table) + self.assertTrue(self.widget.Error.sparse_data.is_shown()) + self.send_signal(self.widget.Inputs.data_subset, table[::30]) + self.assertEqual(len(self.widget.subset_indices), 5) + self.send_signal(self.widget.Inputs.data, None) + self.assertFalse(self.widget.Error.sparse_data.is_shown()) + + def test_manual_move(self): + self.send_signal(self.widget.Inputs.data, self.data) + self.widget.graph.select_by_indices(list(range(0, len(self.data), 10))) + selection = self.widget.graph.selection + components = self.get_output(self.widget.Outputs.components) + self.widget._manual_move_start() + self.widget._manual_move(0, 1, 1) + self.assertEqual(len(self.widget.graph.scatterplot_item.data), + self.widget.SAMPLE_SIZE) + self.widget._manual_move_finish(0, 1, 2) + self.assertEqual(len(self.widget.graph.scatterplot_item.data), + len(self.data)) + self.assertNotEqual(components, + self.get_output(self.widget.Outputs.components)) + np.testing.assert_equal(self.widget.graph.selection, selection) + + class datasets: @staticmethod def path(filename): @@ -1022,3 +1071,27 @@ def data_one_column_nans(cls): @classmethod def data_one_column_infs(cls): return cls.data_one_column_vals(value=np.inf) + + @classmethod + def datasets(cls): + """ + Yields multiple datasets. + + Returns + ------- + data : Generator of Orange.data.Table + """ + ds_cls = Table(cls.path("testing_dataset_cls")) + ds_reg = Table(cls.path("testing_dataset_reg")) + for ds in (ds_cls, ds_reg): + d, a = ds.domain, ds.domain.attributes + for i in range(0, len(a), 2): + yield ds.transform(Domain(a[i: i + 2], d.class_vars, d.metas)) + yield ds.transform(Domain(a[:2] + a[8: 10], d.class_vars, d.metas)) + yield cls.missing_data_1() + yield cls.missing_data_2() + yield cls.missing_data_3() + yield cls.data_one_column_nans() + yield cls.data_one_column_infs() + yield ds_cls + yield ds_reg diff --git a/Orange/widgets/unsupervised/owmds.py b/Orange/widgets/unsupervised/owmds.py index 0940f884578..1ca05a94ef5 100644 --- a/Orange/widgets/unsupervised/owmds.py +++ b/Orange/widgets/unsupervised/owmds.py @@ -8,22 +8,16 @@ import pyqtgraph as pg -from Orange.data import ContinuousVariable, Domain, Table, Variable +from Orange.data import ContinuousVariable, Domain, Table from Orange.distance import Euclidean from Orange.misc import DistMatrix from Orange.projection.manifold import torgerson, MDS -from Orange.widgets import gui, settings, report +from Orange.widgets import gui, settings from Orange.widgets.settings import SettingProvider -from Orange.widgets.utils.sql import check_sql_input -from Orange.widgets.visualize.owscatterplotgraph import ( - OWScatterPlotBase, OWProjectionWidget -) -from Orange.widgets.widget import Msg, OWWidget, Input, Output -from Orange.widgets.utils.annotated_data import ( - ANNOTATED_DATA_SIGNAL_NAME, create_annotated_table, create_groups_table, - get_unique_names -) +from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase +from Orange.widgets.visualize.utils.widget import OWDataProjectionWidget +from Orange.widgets.widget import Msg, OWWidget, Input def stress(X, distD): @@ -110,21 +104,15 @@ def update_pairs(self, reconnect): self.plot_widget.addItem(self.pairs_curve) -class OWMDS(OWProjectionWidget): +class OWMDS(OWDataProjectionWidget): name = "MDS" description = "Two-dimensional data projection by multidimensional " \ "scaling constructed from a distance matrix." icon = "icons/MDS.svg" keywords = ["multidimensional scaling", "multi dimensional scaling"] - class Inputs: - data = Input("Data", Table, default=True) + class Inputs(OWDataProjectionWidget.Inputs): distances = Input("Distances", DistMatrix) - data_subset = Input("Data Subset", Table) - - class Outputs: - selected_data = Output("Selected Data", Table, default=True) - annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table) settings_version = 3 @@ -144,18 +132,15 @@ class Outputs: #: Runtime state Running, Finished, Waiting = 1, 2, 3 - settingsHandler = settings.DomainContextHandler() - max_iter = settings.Setting(300) initialization = settings.Setting(PCA) refresh_rate = settings.Setting(3) - auto_commit = settings.Setting(True) - + GRAPH_CLASS = OWMDSGraph graph = SettingProvider(OWMDSGraph) - graph_name = "graph.plot_widget.plotItem" + embedding_variables_names = ("mds-x", "mds-y") - class Error(OWWidget.Error): + class Error(OWDataProjectionWidget.Error): not_enough_rows = Msg("Input data needs at least 2 rows") matrix_too_small = Msg("Input matrix must be at least 2x2") no_attributes = Msg("Data has no attributes") @@ -168,15 +153,13 @@ def __init__(self): super().__init__() #: Input dissimilarity matrix self.matrix = None # type: Optional[DistMatrix] - #: Input subset data table - self.subset_data = None # type: Optional[Table] #: Data table from the `self.matrix.row_items` (if present) self.matrix_data = None # type: Optional[Table] #: Input data table self.signal_data = None - self._subset_mask = None # type: Optional[np.ndarray] self._invalidated = False + self.embedding = None self.effective_matrix = None self.__update_loop = None @@ -186,13 +169,24 @@ def __init__(self): self.__state = OWMDS.Waiting self.__in_next_step = False - box = gui.vBox(self.mainArea, True, margin=0) - self.graph = OWMDSGraph(self, box) self.graph.pause_drawing_pairs() - box.layout().addWidget(self.graph.plot_widget) - self.plot = self.graph.plot_widget + g = self.graph.gui + self.size_model = g.points_models[2] + self.size_model.order = g.points_models[2].order[:1] + ("Stress", ) + \ + g.points_models[2].order[1:] + # self._initialize() + + def _add_controls(self): + self._add_controls_optimization() + super()._add_controls() + self.graph.gui.add_control( + self._effects_box, gui.hSlider, "Show similar pairs:", + master=self.graph, value="connected_pairs", minValue=0, + maxValue=20, createLabel=False, callback=self._on_connected_changed + ) + def _add_controls_optimization(self): box = gui.vBox(self.controlArea, box=True) self.runbutton = gui.button(box, self, "Run optimization", callback=self._toggle_run) @@ -205,31 +199,6 @@ def __init__(self): gui.button(hbox, self, "Randomize", callback=self.do_random) gui.button(hbox, self, "Jitter", callback=self.do_jitter) - g.point_properties_box(self.controlArea) - box = g.effects_box(self.controlArea) - g.add_control(box, gui.hSlider, "Show similar pairs:", - master=self.graph, value="connected_pairs", - minValue=0, maxValue=20, createLabel=False, - callback=self._on_connected_changed - ) - g.plot_properties_box(self.controlArea) - - self.size_model = g.points_models[2] - self.size_model.order = g.points_models[2].order[:1] + ("Stress", ) + \ - g.points_models[2].order[1:] - - self.controlArea.layout().addStretch(100) - self.graph.box_zoom_select(self.controlArea) - gui.auto_commit(self.controlArea, self, "auto_commit", - "Send Selection", "Send Automatically") - - self._initialize() - - def selection_changed(self): - self.commit() - - @Inputs.data - @check_sql_input def set_data(self, data): """Set the input dataset. @@ -249,6 +218,7 @@ def set_data(self, data): len(self.matrix) == len(data): self.closeContext() self.data = data + self.init_attr_values() self.openContext(data) else: self._invalidated = True @@ -272,34 +242,21 @@ def set_disimilarity(self, matrix): self.matrix_data = matrix.row_items if matrix is not None else None self._invalidated = True - @Inputs.data_subset - def set_subset_data(self, subset_data): - """Set a subset of `data` input to highlight in the plot. - - Parameters - ---------- - subset_data: Optional[Table] - """ - self.subset_data = subset_data - # invalidate the pen/brush when the subset is changed - self._subset_mask = None # type: Optional[np.ndarray] - self.controls.graph.alpha_value.setEnabled(subset_data is None) - - def _clear(self): + def clear(self): + super().clear() + self.embedding = None + self.effective_matrix = None self.graph.set_effective_matrix(None) self.__set_update_loop(None) self.__state = OWMDS.Waiting def _initialize(self): - # clear everything self.closeContext() - self._clear() - self.Error.clear() - self.effective_matrix = None - self.embedding = None + self.clear() + self.clear_messages() # if no data nor matrix is present reset plot - if self.signal_data is None and self.matrix_data is None: + if self.signal_data is None and self.matrix is None: self.data = None self.init_attr_values() return @@ -307,7 +264,7 @@ def _initialize(self): if self.signal_data is not None and self.matrix is not None and \ len(self.signal_data) != len(self.matrix): self.Error.mismatching_dimensions() - self._update_plot() + self.init_attr_values() return if self.signal_data is not None: @@ -324,6 +281,7 @@ def _initialize(self): self.effective_matrix = Euclidean(preprocessed_data) else: self.Error.no_attributes() + self.init_attr_values() return self.init_attr_values() @@ -485,12 +443,13 @@ def jitter_coord(part): # reset/invalidate the MDS embedding, to the default initialization # (Random or PCA), restarting the optimization if necessary. - if self.embedding is None: - return state = self.__state if self.__update_loop is not None: self.__set_update_loop(None) + if self.effective_matrix is None: + return + X = self.effective_matrix if initialization == OWMDS.PCA: @@ -501,7 +460,7 @@ def jitter_coord(part): jitter_coord(self.embedding[:, 0]) jitter_coord(self.embedding[:, 1]) - self._update_plot() + self.setup_plot() # restart the optimization if it was interrupted. if state == OWMDS.Running: @@ -524,10 +483,11 @@ def handleNewSignals(self): self.graph.pause_drawing_pairs() self._invalidated = False self._initialize() + self.__invalidate_embedding() + self.cb_class_density.setEnabled(self.can_draw_density()) self.start() - self._update_plot() - self.unconditional_commit() + super().handleNewSignals() def _invalidate_output(self): self.commit() @@ -536,8 +496,8 @@ def _on_connected_changed(self): self.graph.set_effective_matrix(self.effective_matrix) self.graph.update_pairs(reconnect=True) - def _update_plot(self): - self.graph.reset_graph() + def setup_plot(self): + super().setup_plot() if self.embedding is not None: self.graph.update_pairs(reconnect=True) @@ -547,64 +507,20 @@ def get_size_data(self): else: return super().get_size_data() - def get_coordinates_data(self): - return self.embedding.T if self.embedding is not None else (None, None) + def get_embedding(self): + self.valid_data = np.ones(len(self.embedding), dtype=bool) \ + if self.embedding is not None else None + return self.embedding - def get_subset_mask(self): - if self.data is not None and self.subset_data is not None: - return np.in1d(self.data.ids, self.subset_data.ids) - - def commit(self): - if self.embedding is not None: - names = get_unique_names([v.name for v in self.data.domain.variables], - ["mds-x", "mds-y"]) - domain = Domain([ContinuousVariable(names[0]), - ContinuousVariable(names[1])]) - output = embedding = Table.from_numpy(domain, self.embedding) - else: - output = embedding = None - - if self.embedding is not None and self.data is not None: - domain = self.data.domain - domain = Domain(domain.attributes, domain.class_vars, - domain.metas + embedding.domain.attributes) - output = self.data.transform(domain) - output.metas[:, -2:] = embedding.X - - selection = self.graph.get_selection() - if output is not None and len(selection) > 0: - selected = output[selection] - else: - selected = None - if self.graph.selection is not None and np.max(self.graph.selection) > 1: - annotated = create_groups_table(output, self.graph.selection) - else: - annotated = create_annotated_table(output, selection) - self.Outputs.selected_data.send(selected) - self.Outputs.annotated_data.send(annotated) - - def onDeleteWidget(self): - super().onDeleteWidget() - self.graph.clear() - self._clear() + def _get_projection_data(self): + if self.embedding is None: + return None - def send_report(self): if self.data is None: - return - - def name(var): - return var.name if isinstance(var, Variable) else var - - caption = report.render_items_vert(( - ("Color", name(self.attr_color)), - ("Label", name(self.attr_label)), - ("Shape", name(self.attr_shape)), - ("Size", name(self.attr_size)), - ("Jittering", self.graph.jitter_size != 0 and "{} %".format( - self.graph.jitter_size)))) - self.report_plot() - if caption: - self.report_caption(caption) + x_name, y_name = self.embedding_variables_names + variables = ContinuousVariable(x_name), ContinuousVariable(y_name) + return Table(Domain(variables), self.embedding) + return super()._get_projection_data() @classmethod def migrate_settings(cls, settings_, version): @@ -684,5 +600,6 @@ def main(argv=None): app.processEvents() return rval + if __name__ == "__main__": sys.exit(main()) diff --git a/Orange/widgets/unsupervised/tests/test_owmds.py b/Orange/widgets/unsupervised/tests/test_owmds.py index 1d49a393b7a..99710cd453f 100644 --- a/Orange/widgets/unsupervised/tests/test_owmds.py +++ b/Orange/widgets/unsupervised/tests/test_owmds.py @@ -6,21 +6,23 @@ from unittest.mock import patch, Mock import numpy as np -from AnyQt.QtCore import QRectF, QPointF + from AnyQt.QtTest import QSignalSpy from Orange.data import Table -from Orange.misc import DistMatrix from Orange.distance import Euclidean +from Orange.misc import DistMatrix from Orange.widgets.settings import Context -from Orange.widgets.tests.base import (WidgetTest, WidgetOutputsTestMixin, - datasets, ProjectionWidgetTestMixin) +from Orange.widgets.tests.base import ( + WidgetTest, WidgetOutputsTestMixin, datasets, ProjectionWidgetTestMixin +) from Orange.widgets.tests.utils import simulate from Orange.widgets.unsupervised.owmds import OWMDS from Orange.widgets.utils.plot import OWPlotGUI -class TestOWMDS(WidgetTest, WidgetOutputsTestMixin, ProjectionWidgetTestMixin): +class TestOWMDS(WidgetTest, ProjectionWidgetTestMixin, + WidgetOutputsTestMixin): @classmethod def setUpClass(cls): super().setUpClass() @@ -49,10 +51,6 @@ def tearDown(self): self.widget.onDeleteWidget() super().tearDown() - def _select_data(self): - self.widget.graph.select_by_rectangle(QRectF(QPointF(-20, -20), QPointF(20, 20))) - return self.widget.graph.get_selection() - def test_pca_init(self): self.send_signal(self.signal_name, self.signal_data) output = self.get_output(self.widget.Outputs.annotated_data, wait=1000) @@ -262,6 +260,21 @@ def test_attr_label_matrix_and_data(self): self.assertTrue(set(chain(data.domain.variables, data.domain.metas)) < set(w.controls.attr_label.model())) + def test_saved_matrix_and_data(self): + towns_data = self.towns.row_items + attr_label = self.widget.controls.attr_label + self.widget.start = Mock() + self.towns.row_items = None + + # Matrix without data + self.send_signal(self.widget.Inputs.distances, self.towns) + self.assertIsNotNone(self.widget.graph.scatterplot_item) + self.assertEqual(list(attr_label.model()), [None]) + + # Data + self.send_signal(self.widget.Inputs.data, towns_data) + self.assertIn(towns_data.domain["label"], attr_label.model()) + def test_overlap(self): self.send_signal(self.signal_name, self.signal_data) if self.widget.isBlocking(): diff --git a/Orange/widgets/utils/plot/owplotgui.py b/Orange/widgets/utils/plot/owplotgui.py index 68e05276d5c..a16d081eeca 100644 --- a/Orange/widgets/utils/plot/owplotgui.py +++ b/Orange/widgets/utils/plot/owplotgui.py @@ -136,8 +136,10 @@ def add_variables(self): class VariablesSelection: - def __call__(self, master, model_selected, model_other, widget=None): + def __init__(self, master, model_selected, model_other, widget=None): self.master = master + self.model_selected = model_selected + self.model_other = model_other params_view = {"sizePolicy": QSizePolicy(*SIZE_POLICY_ADAPTING), "selectionMode": QListView.ExtendedSelection, @@ -159,14 +161,7 @@ def __call__(self, master, model_selected, model_other, widget=None): triggered=self.__deactivate_selection ) view.addAction(delete) - - self.model_selected = model = model_selected - - model.rowsInserted.connect(master.invalidate_plot) - model.rowsRemoved.connect(master.invalidate_plot) - model.rowsMoved.connect(master.invalidate_plot) - - view.setModel(model) + view.setModel(self.model_selected) addClassLabel = QAction("+", master, toolTip="Add new class label", @@ -175,7 +170,8 @@ def __call__(self, master, model_selected, model_other, widget=None): toolTip="Remove selected class label", triggered=self.__deactivate_selection) - add_remove = itemmodels.ModelActionsWidget([addClassLabel, removeClassLabel], master) + add_remove = itemmodels.ModelActionsWidget( + [addClassLabel, removeClassLabel], master) add_remove.layout().addStretch(10) add_remove.layout().setSpacing(1) add_remove.setSizePolicy(*SIZE_POLICY_FIXED) @@ -184,21 +180,11 @@ def __call__(self, master, model_selected, model_other, widget=None): self.add_remove = add_remove self.box = add_remove.buttons[1] - self.model_other = model_other - def set_enabled(self, is_enabled): self.view_selected.setEnabled(is_enabled) for btn in self.add_remove.buttons: btn.setEnabled(is_enabled) - def display_all(self): - self.model_selected[:] += self.model_other[:] - self.model_other[:] = [] - - def display_none(self): - self.model_other[:] += self.model_selected[:] - self.model_selected[:] = [] - def __deactivate_selection(self): view = self.view_selected model = self.model_selected @@ -214,24 +200,6 @@ def __deactivate_selection(self): def _action_add(self): self.add_variables_dialog = AddVariablesDialog(self, self.model_other) - @staticmethod - def encode_var_state(lists): - return {(type(var), var.name): (source_ind, pos) - for source_ind, var_list in enumerate(lists) - for pos, var in enumerate(var_list) - if isinstance(var, Variable)} - - @staticmethod - def decode_var_state(state, lists): - all_vars = reduce(list.__iadd__, lists, []) - - newlists = [[] for _ in lists] - for var in all_vars: - source, pos = state[(type(var), var.name)] - newlists[source].append((pos, var)) - return [[var for _, var in sorted(newlist, key=itemgetter(0))] - for newlist in newlists] - class OrientedWidget(QWidget): ''' @@ -606,19 +574,20 @@ def jitter_size_slider(self, widget): widget, gui.valueSlider, "Jittering", master=self._plot, value='jitter_size', values=getattr(self._plot, "jitter_sizes", self.JITTER_SIZES), - callback=self._plot.update_coordinates) + callback=self._plot.update_jittering) def jitter_numeric_check_box(self, widget): self._check_box( widget=widget, value="jitter_continuous", label="Jitter numeric values", - cb_name="update_coordinates") + cb_name="update_jittering") def show_legend_check_box(self, widget): ''' Creates a check box that shows and hides the plot legend ''' - self._check_box(widget, 'show_legend', 'Show legend', 'update_legend') + self._check_box(widget, 'show_legend', 'Show legend', + 'update_legend_visibility') def tooltip_shows_all_check_box(self, widget): gui.checkBox( @@ -647,7 +616,8 @@ def filled_symbols_check_box(self, widget): 'update_filled_symbols') def grid_lines_check_box(self, widget): - self._check_box(widget, 'show_grid', 'Show gridlines', 'update_grid') + self._check_box(widget, 'show_grid', 'Show gridlines', + 'update_grid_visibility') def animations_check_box(self, widget): ''' diff --git a/Orange/widgets/visualize/owfreeviz.py b/Orange/widgets/visualize/owfreeviz.py index 6dd9ca087f9..c1f9583c4c5 100644 --- a/Orange/widgets/visualize/owfreeviz.py +++ b/Orange/widgets/visualize/owfreeviz.py @@ -2,27 +2,22 @@ import sys import numpy as np -from scipy.spatial import distance from AnyQt.QtCore import ( Qt, QObject, QEvent, QRectF, QLineF, QTimer, QPoint, pyqtSignal as Signal, pyqtSlot as Slot ) from AnyQt.QtGui import QColor -from AnyQt.QtWidgets import QApplication, QGraphicsEllipseItem +from AnyQt.QtWidgets import QApplication import pyqtgraph as pg -from Orange.data import Table, Domain, StringVariable, ContinuousVariable +from Orange.data import Table, Domain, StringVariable from Orange.projection.freeviz import FreeViz -from Orange.widgets import widget, gui, settings, report -from Orange.widgets.utils.annotated_data import ( - create_annotated_table, ANNOTATED_DATA_SIGNAL_NAME, create_groups_table -) -from Orange.widgets.visualize.owscatterplotgraph import OWProjectionWidget -from Orange.widgets.visualize.utils.component import OWVizGraph +from Orange.widgets import widget, gui, settings +from Orange.widgets.visualize.utils.component import OWGraphWithAnchors from Orange.widgets.visualize.utils.plotutils import AnchorItem -from Orange.widgets.widget import Input, Output +from Orange.widgets.visualize.utils.widget import OWAnchorProjectionWidget class AsyncUpdateLoop(QObject): @@ -143,75 +138,64 @@ def customEvent(self, event): super().customEvent(event) -class OWFreeVizGraph(OWVizGraph): - radius = settings.Setting(0) +class OWFreeVizGraph(OWGraphWithAnchors): + hide_radius = settings.Setting(0) - def __init__(self, scatter_widget, parent): - super().__init__(scatter_widget, parent) - self._points = [] - self._point_items = [] + @property + def scaled_radius(self): + return self.hide_radius / 100 + 1e-5 def update_radius(self): - if self._circle_item is None: - return - - r = self.radius / 100 + 1e-5 - for point, axitem in zip(self._points, self._point_items): - axitem.setVisible(np.linalg.norm(point) > r) - self._circle_item.setRect(QRectF(-r, -r, 2 * r, 2 * r)) + self.update_circle() + self.update_anchors() def set_view_box_range(self): - self.view_box.setRange(RANGE) - - def can_show_indicator(self, pos): - if not len(self._points): - return False, None - - r = self.radius / 100 + 1e-5 - mask = np.zeros((len(self._points)), dtype=bool) - mask[np.linalg.norm(self._points, axis=1) > r] = True - distances = distance.cdist([[pos.x(), pos.y()]], self._points)[0] - distances = distances[mask] - if len(distances) and np.min(distances) < self.DISTANCE_DIFF: - return True, np.flatnonzero(mask)[np.argmin(distances)] - return False, None - - def _remove_point_items(self): - for item in self._point_items: - self.plot_widget.removeItem(item) - self._point_items = [] - - def _add_point_items(self): - r = self.radius / 100 + 1e-5 - for point, var in zip(self._points, self._attributes): - axitem = AnchorItem(line=QLineF(0, 0, *point), text=var.name) - axitem.setVisible(np.linalg.norm(point) > r) - axitem.setPen(pg.mkPen((100, 100, 100))) - self.plot_widget.addItem(axitem) - self._point_items.append(axitem) - - def _add_circle_item(self): - if not len(self._points): + self.view_box.setRange(QRectF(-1.05, -1.05, 2.1, 2.1)) + + def closest_draggable_item(self, pos): + points, *_ = self.master.get_anchors() + if points is None or not len(points): + return None + mask = np.linalg.norm(points, axis=1) > self.scaled_radius + xi, yi = points[mask].T + distances = (xi - pos.x()) ** 2 + (yi - pos.y()) ** 2 + if len(distances) and np.min(distances) < self.DISTANCE_DIFF ** 2: + return np.flatnonzero(mask)[np.argmin(distances)] + return None + + def update_anchors(self): + points, labels = self.master.get_anchors() + if points is None: return - r = self.radius / 100 + 1e-5 - pen = pg.mkPen(QColor(Qt.lightGray), width=1, cosmetic=True) - self._circle_item = QGraphicsEllipseItem() - self._circle_item.setRect(QRectF(-r, -r, 2 * r, 2 * r)) - self._circle_item.setPen(pen) - self.plot_widget.addItem(self._circle_item) - - def _add_indicator_item(self, point_i): - x, y = self._points[point_i] + r = self.scaled_radius + if self.anchor_items is None: + self.anchor_items = [] + for point, label in zip(points, labels): + anchor = AnchorItem(line=QLineF(0, 0, *point), text=label) + anchor.setVisible(np.linalg.norm(point) > r) + anchor.setPen(pg.mkPen((100, 100, 100))) + self.plot_widget.addItem(anchor) + self.anchor_items.append(anchor) + else: + for anchor, point, label in zip(self.anchor_items, points, labels): + anchor.setLine(QLineF(0, 0, *point)) + anchor.setText(label) + anchor.setVisible(np.linalg.norm(point) > r) + + def update_circle(self): + super().update_circle() + if self.circle_item is not None: + r = self.scaled_radius + self.circle_item.setRect(QRectF(-r, -r, 2 * r, 2 * r)) + pen = pg.mkPen(QColor(Qt.lightGray), width=1, cosmetic=True) + self.circle_item.setPen(pen) + + def _add_indicator_item(self, anchor_idx): + x, y = self.anchor_items[anchor_idx].get_xy() dx = (self.view_box.childGroup.mapToDevice(QPoint(1, 0)) - self.view_box.childGroup.mapToDevice(QPoint(-1, 0))).x() - self._indicator_item = MoveIndicator(x, y, 600 / dx) - self.plot_widget.addItem(self._indicator_item) - - -MAX_ITERATIONS = 1000 -MAX_POINTS = 300 -MAX_INSTANCES = 10000 -RANGE = QRectF(-1.05, -1.05, 2.1, 2.1) + self.indicator_item = MoveIndicator(x, y, 600 / dx) + self.plot_widget.addItem(self.indicator_item) class InitType(IntEnum): @@ -222,83 +206,34 @@ def items(): return ["Circular", "Random"] -class OWFreeViz(OWProjectionWidget): +class OWFreeViz(OWAnchorProjectionWidget): + MAX_ITERATIONS = 1000 + MAX_INSTANCES = 10000 + name = "FreeViz" description = "Displays FreeViz projection" icon = "icons/Freeviz.svg" priority = 240 keywords = ["viz"] - class Inputs: - data = Input("Data", Table, default=True) - data_subset = Input("Data Subset", Table) - - class Outputs: - selected_data = Output("Selected Data", Table, default=True) - annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table) - components = Output("Components", Table) - settings_version = 3 - settingsHandler = settings.DomainContextHandler() - initialization = settings.Setting(InitType.Circular) - auto_commit = settings.Setting(True) - + GRAPH_CLASS = OWFreeVizGraph graph = settings.SettingProvider(OWFreeVizGraph) - graph_name = "graph.plot_widget.plotItem" + embedding_variables_names = ("freeviz-x", "freeviz-y") - class Error(OWProjectionWidget.Error): - sparse_data = widget.Msg("Sparse data is not supported") - no_class_var = widget.Msg("Need a class variable") + class Error(OWAnchorProjectionWidget.Error): + no_class_var = widget.Msg("Data has no target variable") not_enough_class_vars = widget.Msg( - "Needs discrete class variable with at lest 2 values" - ) + "Target variable is not at least binary") features_exceeds_instances = widget.Msg( - "Algorithm should not be used when number of features " - "exceeds the number of instances." - ) - too_many_data_instances = widget.Msg("Cannot handle so large data.") - no_valid_data = widget.Msg("No valid data.") + "Number of features exceeds the number of instances.") + too_many_data_instances = widget.Msg("Data is too large.") def __init__(self): super().__init__() - - self.data = None - self.subset_data = None - self.subset_indices = None - self._embedding_coords = None self._X = None self._Y = None - self._rand_indices = None - self.variable_x = ContinuousVariable("freeviz-x") - self.variable_y = ContinuousVariable("freeviz-y") - - box = gui.vBox(self.mainArea, True, margin=0) - self.graph = OWFreeVizGraph(self, box) - box.layout().addWidget(self.graph.plot_widget) - - box = gui.vBox(self.controlArea, box=True) - gui.comboBox(box, self, "initialization", label="Initialization:", - items=InitType.items(), orientation=Qt.Horizontal, - labelWidth=90, callback=self.__init_combo_changed) - self.btn_start = gui.button(box, self, "Optimize", self.__toggle_start, - enabled=False) - - g = self.graph.gui - g.point_properties_box(self.controlArea) - box = g.effects_box(self.controlArea) - g.add_control(box, gui.hSlider, "Hide radius:", - master=self.graph, value="radius", - minValue=0, maxValue=100, - step=10, createLabel=False, - callback=self.__radius_slider_changed) - g.plot_properties_box(self.controlArea) - - self.controlArea.layout().addStretch(100) - self.graph.box_zoom_select(self.controlArea) - - gui.auto_commit(self.controlArea, self, "auto_commit", - "Send Selection", "Send Automatically") # FreeViz self._loop = AsyncUpdateLoop(parent=self) @@ -306,9 +241,23 @@ def __init__(self): self._loop.finished.connect(self.__freeviz_finished) self._loop.raised.connect(self.__on_error) - self.graph.view_box.started.connect(self._randomize_indices) - self.graph.view_box.moved.connect(self._manual_move) - self.graph.view_box.finished.connect(self._finish_manual_move) + def _add_controls(self): + self.__add_controls_start_box() + super()._add_controls() + self.graph.gui.add_control( + self._effects_box, gui.hSlider, "Hide radius:", master=self.graph, + value="hide_radius", minValue=0, maxValue=100, step=10, + createLabel=False, callback=self.__radius_slider_changed + ) + + def __add_controls_start_box(self): + box = gui.vBox(self.controlArea, box=True) + gui.comboBox( + box, self, "initialization", label="Initialization:", + items=InitType.items(), orientation=Qt.Horizontal, + labelWidth=90, callback=self.__init_combo_changed) + self.btn_start = gui.button( + box, self, "Optimize", self.__toggle_start, enabled=False) def __radius_slider_changed(self): self.graph.update_radius() @@ -322,30 +271,29 @@ def __toggle_start(self): self._start() def __init_combo_changed(self): + if self.data is None: + return running = self._loop.isRunning() if running: self._loop.cancel() - if self.data is not None: - self.setup_plot() + self.init_embedding_coords() + self.graph.update_coordinates() if running: self._start() def _start(self): - """ - Start the projection optimization. - """ def update_freeviz(anchors): while True: - projection = FreeViz.freeviz( + _, projection, *_ = FreeViz.freeviz( self._X, self._Y, scale=False, center=False, - initial=anchors, maxiter=10 - ) - yield projection[0], projection[1] - if np.allclose(anchors, projection[1], rtol=1e-5, atol=1e-4): + initial=anchors, maxiter=10) + yield projection + if np.allclose(anchors, projection, rtol=1e-5, atol=1e-4): return - anchors = projection[1] + anchors = projection - self._loop.setCoroutine(update_freeviz(self.graph.get_points())) + self.graph.set_sample_size(self.SAMPLE_SIZE) + self._loop.setCoroutine(update_freeviz(self.projection)) self.btn_start.setText("Stop") self.progressBarInit() self.setBlocking(True) @@ -353,13 +301,12 @@ def update_freeviz(anchors): def __set_projection(self, projection): # Set/update the projection matrix and coordinate embeddings - self.progressBarAdvance(100. / MAX_ITERATIONS) - self._embedding_coords = projection[0] - self.graph.set_points(projection[1]) - self._update_xy() + self.progressBarAdvance(100. / self.MAX_ITERATIONS) + self.projection = projection + self.graph.update_coordinates() def __freeviz_finished(self): - # Projection optimization has finished + self.graph.set_sample_size(None) self.btn_start.setText("Optimize") self.setStatusMessage("") self.setBlocking(False) @@ -369,228 +316,91 @@ def __freeviz_finished(self): def __on_error(self, err): sys.excepthook(type(err), err, getattr(err, "__traceback__")) - def _update_xy(self): - coords = self._embedding_coords - self._embedding_coords /= np.max(np.linalg.norm(coords, axis=1)) - self.graph.update_coordinates() - - def clear(self): - self._loop.cancel() - self.data = None - self.valid_data = None - self._embedding_coords = None - self._X = None - self._Y = None - self._rand_indices = None - - self.graph.set_attributes(()) - self.graph.set_points([]) - self.graph.update_coordinates() - self.graph.clear() - - @Inputs.data - def set_data(self, data): - self.clear_messages() - self.closeContext() - self.clear() - self.data = data - self._check_data() - self.init_attr_values() - self.openContext(data) - self.btn_start.setEnabled(self.data is not None) - self.cb_class_density.setEnabled(self.can_draw_density()) + def check_data(self): + def error(err): + err() + self.data = None - def _check_data(self): + super().check_data() if self.data is not None: - if self.data.is_sparse(): - self.Error.sparse_data() - self.data = None - elif self.data.domain.class_var is None: - self.Error.no_class_var() - self.data = None - elif self.data.domain.class_var.is_discrete and \ - len(self.data.domain.class_var.values) < 2: - self.Error.not_enough_class_vars() - self.data = None + class_var = self.data.domain.class_var + if class_var is None: + error(self.Error.no_class_var) + elif class_var.is_discrete and len(np.unique(self.data.Y)) < 2: + error(self.Error.not_enough_class_vars) + elif len(self.data.domain.attributes) < 2: + error(self.Error.not_enough_features) elif len(self.data.domain.attributes) > self.data.X.shape[0]: - self.Error.features_exceeds_instances() - self.data = None + error(self.Error.features_exceeds_instances) else: - self._prepare_freeviz_data() - if self._X is not None: - if len(self._X) > MAX_INSTANCES: - self.Error.too_many_data_instances() - self.data = None - elif np.allclose(np.nan_to_num(self._X - self._X[0]), 0) \ - or not len(self._X): - self.Error.no_valid_data() - self.data = None - else: - self.Error.no_valid_data() - self.data = None - - def _prepare_freeviz_data(self): - valid_mask = np.all(np.isfinite(self.data.X), axis=1) & \ - np.isfinite(self.data.Y) - X, Y = self.data.X[valid_mask], self.data.Y[valid_mask] - if not len(X): - self.valid_data = None - return + self.valid_data = np.all(np.isfinite(self.data.X), axis=1) & \ + np.isfinite(self.data.Y) + n_valid = np.sum(self.valid_data) + if n_valid > self.MAX_INSTANCES: + error(self.Error.too_many_data_instances) + elif n_valid == 0: + error(self.Error.no_valid_data) + self.btn_start.setEnabled(self.data is not None) - if self.data.domain.class_var.is_discrete: - Y = Y.astype(int) - X = (X - np.mean(X, axis=0)) - span = np.ptp(X, axis=0) - X[:, span > 0] /= span[span > 0].reshape(1, -1) - self._X, self._Y, self.valid_data = X, Y, valid_mask - - @Inputs.data_subset - def set_subset_data(self, subset): - self.subset_data = subset - self.subset_indices = {e.id for e in subset} \ - if subset is not None else {} - self.controls.graph.alpha_value.setEnabled(subset is None) - - def handleNewSignals(self): - if self.data is not None and self.valid_data is not None: - self.setup_plot() - self.commit() + def set_data(self, data): + super().set_data(data) + if self.data is not None: + self.prepare_projection_data() + self.init_embedding_coords() + + def prepare_projection_data(self): + if not np.any(self.valid_data): + self._X = self._Y = self.valid_data = None + return - def get_coordinates_data(self): - return (self._embedding_coords[:, 0], self._embedding_coords[:, 1]) \ - if self._embedding_coords is not None else (None, None) + self._X = self.data.X.copy() + self._X -= np.nanmean(self._X, axis=0) + span = np.ptp(self._X[self.valid_data], axis=0) + self._X[:, span > 0] /= span[span > 0].reshape(1, -1) - def get_subset_mask(self): - if self.subset_indices: - return np.array([ex.id in self.subset_indices - for ex in self.data[self.valid_data]]) + self._Y = self.data.Y + if self.data.domain.class_var.is_discrete: + self._Y = self._Y.astype(int) - def setup_plot(self): - points = FreeViz.init_radial(self._X.shape[1]) \ + def init_embedding_coords(self): + self.projection = FreeViz.init_radial(self._X.shape[1]) \ if self.initialization == InitType.Circular \ else FreeViz.init_random(self._X.shape[1], 2) - self.graph.set_points(points) - self.__set_embedding_coords() - self.graph.set_attributes(self.data.domain.attributes) - self.graph.reset_graph() - - def _randomize_indices(self): - n = len(self._X) - if n > MAX_POINTS: - self._rand_indices = np.random.choice(n, MAX_POINTS, replace=False) - self._rand_indices = sorted(self._rand_indices) - - def _manual_move(self): - self.__set_embedding_coords() - if self._rand_indices is not None: - # save widget state - selection = self.graph.selection - valid_data = self.valid_data.copy() - data = self.data.copy() - ec = self._embedding_coords.copy() - - # plot subset - self.__plot_random_subset(selection) - - # restore widget state - self.graph.selection = selection - self.valid_data = valid_data - self.data = data - self._embedding_coords = ec - else: - self.graph.update_coordinates() - - def __plot_random_subset(self, selection): - self._embedding_coords = self._embedding_coords[self._rand_indices] - self.data = self.data[self._rand_indices] - self.valid_data = self.valid_data[self._rand_indices] - self.graph.reset_graph() - if selection is not None: - self.graph.selection = selection[self._rand_indices] - self.graph.update_selection_colors() - - def _finish_manual_move(self): - if self._rand_indices is not None: - selection = self.graph.selection - self.graph.reset_graph() - if selection is not None: - self.graph.selection = selection - self.graph.select_by_index(self.graph.get_selection()) - - def __set_embedding_coords(self): - points = self.graph.get_points() - ex = np.dot(self._X, points) - self._embedding_coords = (ex / np.max(np.linalg.norm(ex, axis=1))) - - def selection_changed(self): - self.commit() - def commit(self): - selected = annotated = components = None + def get_embedding(self): + if self.data is None: + return None + embedding = np.dot(self._X, self.projection) + embedding /= \ + np.max(np.linalg.norm(embedding[self.valid_data], axis=1)) or 1 + return embedding + + def get_anchors(self): + if self.projection is None: + return None, None + return self.projection, [a.name for a in self.data.domain.attributes] + + def send_components(self): + components = None if self.data is not None and self.valid_data is not None: - name = self.data.name - domain = self.data.domain - metas = domain.metas + (self.variable_x, self.variable_y) - domain = Domain(domain.attributes, domain.class_vars, metas) - embedding_coords = np.zeros((len(self.data), 2), dtype=np.float) - embedding_coords[self.valid_data] = self._embedding_coords - - data = self.data.transform(domain) - data[:, self.variable_x] = embedding_coords[:, 0][:, None] - data[:, self.variable_y] = embedding_coords[:, 1][:, None] - - selection = self.graph.get_selection() - if len(selection): - selected = data[selection] - selected.name = name + ": selected" - selected.attributes = self.data.attributes - if self.graph.selection is not None and \ - np.max(self.graph.selection) > 1: - annotated = create_groups_table(data, self.graph.selection) - else: - annotated = create_annotated_table(data, selection) - annotated.attributes = self.data.attributes - annotated.name = name + ": annotated" - - comp_domain = Domain( - self.data.domain.attributes, - metas=[StringVariable(name='component')]) - + meta_attrs = [StringVariable(name='component')] + domain = Domain(self.data.domain.attributes, metas=meta_attrs) metas = np.array([["FreeViz 1"], ["FreeViz 2"]]) - components = Table.from_numpy( - comp_domain, - X=self.graph.get_points().T, - metas=metas) - - components.name = name + ": components" - - self.Outputs.selected_data.send(selected) - self.Outputs.annotated_data.send(annotated) + components = Table(domain, self.projection.T, metas=metas) + components.name = self.data.name self.Outputs.components.send(components) - def send_report(self): - if self.data is None: - return - - def name(var): - return var and var.name - - caption = report.render_items_vert(( - ("Color", name(self.attr_color)), - ("Label", name(self.attr_label)), - ("Shape", name(self.attr_shape)), - ("Size", name(self.attr_size)), - ("Jittering", self.graph.jitter_size != 0 and - "{} %".format(self.graph.jitter_size)))) - self.report_plot() - if caption: - self.report_caption(caption) + def clear(self): + super().clear() + self._loop.cancel() + self._X = None + self._Y = None @classmethod def migrate_settings(cls, _settings, version): if version < 3: if "radius" in _settings: - _settings["graph"]["radius"] = _settings["radius"] + _settings["graph"]["hide_radius"] = _settings["radius"] @classmethod def migrate_context(cls, context, version): @@ -621,8 +431,6 @@ def boundingRect(self): def main(argv=None): - import sip - argv = sys.argv[1:] if argv is None else argv if argv: filename = argv[0] @@ -641,7 +449,7 @@ def main(argv=None): r = app.exec() w.set_data(None) w.saveSettings() - sip.delete(w) + del w return r diff --git a/Orange/widgets/visualize/owlinearprojection.py b/Orange/widgets/visualize/owlinearprojection.py index afed9de51d0..d977d2c4d12 100644 --- a/Orange/widgets/visualize/owlinearprojection.py +++ b/Orange/widgets/visualize/owlinearprojection.py @@ -5,7 +5,6 @@ from itertools import islice, permutations, chain from math import factorial -from types import SimpleNamespace as namespace import numpy as np @@ -13,35 +12,33 @@ from sklearn.neighbors import NearestNeighbors from sklearn.metrics import r2_score -from AnyQt.QtWidgets import QGraphicsEllipseItem, QApplication, QSizePolicy -from AnyQt.QtGui import QPen, QStandardItem -from AnyQt.QtCore import Qt, QEvent, QRectF, QLineF -from AnyQt.QtCore import pyqtSignal as Signal +from AnyQt.QtWidgets import QApplication, QSizePolicy +from AnyQt.QtGui import QStandardItem, QColor +from AnyQt.QtCore import Qt, QRectF, QLineF, pyqtSignal as Signal import pyqtgraph as pg -from Orange.data import Table, Domain, StringVariable, ContinuousVariable -from Orange.data.sql.table import SqlTable +from Orange.data import Table, Domain, StringVariable from Orange.preprocess import Normalize +from Orange.preprocess.score import ReliefF, RReliefF from Orange.projection import PCA from Orange.util import Enum -from Orange.widgets import widget, gui, settings +from Orange.widgets import gui, report from Orange.widgets.gui import OWComponent -from Orange.widgets.utils.annotated_data import ( - create_annotated_table, ANNOTATED_DATA_SIGNAL_NAME, create_groups_table, get_unique_names) +from Orange.widgets.settings import Setting, ContextSetting, SettingProvider +from Orange.widgets.utils import vartype from Orange.widgets.utils.itemmodels import VariableListModel from Orange.widgets.utils.plot import VariablesSelection from Orange.widgets.visualize.utils import VizRankDialog +from Orange.widgets.visualize.utils.component import OWGraphWithAnchors from Orange.widgets.visualize.utils.plotutils import AnchorItem -from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotGraph, InteractiveViewBox -from Orange.widgets.widget import Input, Output -from Orange.widgets import report -from Orange.preprocess.score import ReliefF, RReliefF +from Orange.widgets.visualize.utils.widget import OWAnchorProjectionWidget +from Orange.widgets.widget import Input, Msg class LinearProjectionVizRank(VizRankDialog, OWComponent): captionTitle = "Score Plots" - n_attrs = settings.Setting(3) + n_attrs = Setting(3) minK = 10 attrsSelected = Signal([]) @@ -59,11 +56,11 @@ def __init__(self, master): controlWidth=50, alignment=Qt.AlignRight, callback=self._n_attrs_changed) gui.rubber(box) self.last_run_n_attrs = None - self.attr_color = master.graph.attr_color + self.attr_color = master.attr_color def initialize(self): super().initialize() - self.attr_color = self.master.graph.attr_color + self.attr_color = self.master.attr_color def before_running(self): """ @@ -89,7 +86,9 @@ def check_preconditions(self): return False elif not master.btn_vizrank.isEnabled(): return False - self.n_attrs_spin.setMaximum(self.master.n_cont_var) + n_cont_var = len([v for v in master.continuous_variables + if v is not master.attr_color]) + self.n_attrs_spin.setMaximum(n_cont_var) return True def state_count(self): @@ -121,10 +120,11 @@ def combinations(n, s): def compute_score(self, state): master = self.master - _, ec, _ = master.prepare_plot_data([self.attrs[i] for i in state]) + _, ec, _ = master.prepare_projection_data( + [self.attrs[i] for i in state]) y = column_data(master.data, self.attr_color, dtype=float) if ec.shape[0] < self.minK: - return + return None n_neighbors = min(self.minK, len(ec) - 1) knn = NearestNeighbors(n_neighbors=n_neighbors).fit(ec) ind = knn.kneighbors(return_distance=False) @@ -145,7 +145,7 @@ def normalized(col): span = col_max - col_min return (col - col_min) / (span or 1) domain = self.master.data.domain - attr_color = self.master.graph.attr_color + attr_color = self.master.attr_color domain = Domain( attributes=[v for v in chain(domain.variables, domain.metas) if v.is_continuous and v is not attr_color], @@ -177,41 +177,75 @@ def _n_attrs_changed(self): self.button.setEnabled(self.check_preconditions()) -class LinProjInteractiveViewBox(InteractiveViewBox): - def _dragtip_pos(self): - return 10, 10 +class OWLinProjGraph(OWGraphWithAnchors): + hide_radius = Setting(0) + + @property + def always_show_axes(self): + return self.master.placement == self.master.Placement.Circular + @property + def scaled_radius(self): + return self.hide_radius / 100 + 1e-5 -class OWLinProjGraph(OWScatterPlotGraph): - jitter_size = settings.Setting(0) + def update_radius(self): + self.update_circle() + self.update_anchors() - def hide_axes(self): - for axis in ["left", "bottom"]: - self.plot_widget.hideAxis(axis) + def set_view_box_range(self): + def min_max(a, b): + return (min(np.amin(a), np.amin(b), -1.05), + max(np.amax(a), np.amax(b), 1.05)) - def update_data(self, attr_x, attr_y, reset_view=True): - axes = self.master.plotdata.axes - axes_x, axes_y = axes[:, 0], axes[:, 1] - x_data, y_data = self.get_xy_data_positions(attr_x, attr_y, self.valid_data) - f = lambda a, b: (min(np.nanmin(a), np.nanmin(b)), max(np.nanmax(a), np.nanmax(b))) - min_x, max_x = f(axes_x, x_data) - min_y, max_y = f(axes_y, y_data) - self.view_box.setRange(QRectF(min_x, min_y, max_x - min_x, max_y - min_y), padding=0.025) - self.view_box.setAspectLocked(True, 1) + points, _ = self.master.get_anchors() + coords = self.master.get_coordinates_data() + if points is None or coords is None: + return - super().update_data(attr_x, attr_y, reset_view=False) - self.hide_axes() + min_x, max_x = min_max(points[:, 0], coords[0]) + min_y, max_y = min_max(points[:, 1], coords[1]) + rect = QRectF(min_x, min_y, max_x - min_x, max_y - min_y) + self.view_box.setRange(rect, padding=0.025) - def update_labels(self): - if self.master.model_selected[:]: - super().update_labels() + def update_anchors(self): + points, labels = self.master.get_anchors() + if points is None: + return + r = self.scaled_radius * np.max(np.linalg.norm(points, axis=1)) + if self.anchor_items is None: + self.anchor_items = [] + for point, label in zip(points, labels): + anchor = AnchorItem(line=QLineF(0, 0, *point), text=label) + visible = self.always_show_axes or np.linalg.norm(point) > r + anchor.setVisible(visible) + anchor.setPen(pg.mkPen((100, 100, 100))) + self.plot_widget.addItem(anchor) + self.anchor_items.append(anchor) + else: + for anchor, point, label in zip(self.anchor_items, points, labels): + anchor.setLine(QLineF(0, 0, *point)) + visible = self.always_show_axes or np.linalg.norm(point) > r + anchor.setVisible(visible) - def update_shapes(self): - if self.master.model_selected[:]: - super().update_shapes() + def update_circle(self): + super().update_circle() + if self.always_show_axes: + self.plot_widget.removeItem(self.circle_item) + self.circle_item = None -class OWLinearProjection(widget.OWWidget): + if self.circle_item is not None: + points, _ = self.master.get_anchors() + if points is None: + return + + r = self.scaled_radius * np.max(np.linalg.norm(points, axis=1)) + self.circle_item.setRect(QRectF(-r, -r, 2 * r, 2 * r)) + pen = pg.mkPen(QColor(Qt.lightGray), width=1, cosmetic=True) + self.circle_item.setPen(pen) + + +class OWLinearProjection(OWAnchorProjectionWidget): name = "Linear Projection" description = "A multi-axis projection of data onto " \ "a two-dimensional plane." @@ -219,526 +253,295 @@ class OWLinearProjection(widget.OWWidget): priority = 240 keywords = [] - selection_indices = settings.Setting(None, schema_only=True) - - class Inputs: - data = Input("Data", Table, default=True) - data_subset = Input("Data Subset", Table) - projection = Input("Projection", Table) + class Inputs(OWAnchorProjectionWidget.Inputs): + projection_input = Input("Projection", Table) - class Outputs: - selected_data = Output("Selected Data", Table, default=True) - annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table) - components = Output("Components", Table) + Placement = Enum("Placement", dict(Circular=0, LDA=1, PCA=2, Projection=3), + type=int, qualname="OWLinearProjection.Placement") - Placement = Enum("Placement", - dict(Circular=0, - LDA=1, - PCA=2, - Projection=3), - type=int, - qualname="OWLinearProjection.Placement") - - Component_name = {Placement.Circular: "C", Placement.LDA: "LD", Placement.PCA: "PC"} + Component_name = {Placement.Circular: "C", Placement.LDA: "LD", + Placement.PCA: "PC"} Variable_name = {Placement.Circular: "circular", Placement.LDA: "lda", Placement.PCA: "pca", Placement.Projection: "projection"} - - jitter_sizes = [0, 0.1, 0.5, 1.0, 2.0] - - settings_version = 3 - settingsHandler = settings.DomainContextHandler() - - variable_state = settings.ContextSetting({}) - placement = settings.Setting(Placement.Circular) - radius = settings.Setting(0) - auto_commit = settings.Setting(True) - - resolution = 256 - - graph = settings.SettingProvider(OWLinProjGraph) - ReplotRequest = QEvent.registerEventType() - vizrank = settings.SettingProvider(LinearProjectionVizRank) - graph_name = "graph.plot_widget.plotItem" - - class Warning(widget.OWWidget.Warning): - no_cont_features = widget.Msg("Plotting requires numeric features") - not_enough_components = widget.Msg("Input projection has less than 2 components") - trivial_components = widget.Msg( - "All components of the PCA are trivial (explain 0 variance). " + Projection_name = {Placement.Circular: "Circular Placement", + Placement.LDA: "Linear Discriminant Analysis", + Placement.PCA: "Principal Component Analysis", + Placement.Projection: "Use input projection"} + + settings_version = 4 + + placement = Setting(Placement.Circular) + selected_vars = ContextSetting([]) + vizrank = SettingProvider(LinearProjectionVizRank) + GRAPH_CLASS = OWLinProjGraph + graph = SettingProvider(OWLinProjGraph) + + class Warning(OWAnchorProjectionWidget.Warning): + not_enough_comp = Msg("Input projection has less than two components") + trivial_components = Msg( + "All components of the PCA are trivial (explain zero variance). " "Input data is constant (or near constant).") - class Error(widget.OWWidget.Error): - proj_and_domain_match = widget.Msg("Projection and Data domains do not match") - no_valid_data = widget.Msg("No projection due to invalid data") + class Error(OWAnchorProjectionWidget.Error): + no_cont_features = Msg("Plotting requires numeric features") + proj_and_domain_match = Msg("Projection and Data domains do not match") def __init__(self): - super().__init__() - - self.data = None - self.projection = None - self.subset_data = None - self._subset_mask = None - self._selection = None - self.__replot_requested = False - self.n_cont_var = 0 - #: Remember the saved state to restore - self.__pending_selection_restore = self.selection_indices - self.selection_indices = None - - self.variable_x = None - self.variable_y = None - - box = gui.vBox(self.mainArea, True, margin=0) - self.graph = OWLinProjGraph(self, box, "Plot", view_box=LinProjInteractiveViewBox) - box.layout().addWidget(self.graph.plot_widget) - plot = self.graph.plot_widget - - SIZE_POLICY = (QSizePolicy.Minimum, QSizePolicy.Maximum) - - self.variables_selection = VariablesSelection() self.model_selected = VariableListModel(enable_dnd=True) + self.model_selected.rowsInserted.connect(self.__model_selected_changed) + self.model_selected.rowsRemoved.connect(self.__model_selected_changed) self.model_other = VariableListModel(enable_dnd=True) - self.variables_selection(self, self.model_selected, self.model_other) self.vizrank, self.btn_vizrank = LinearProjectionVizRank.add_vizrank( - self.controlArea, self, "Suggest Features", self._vizrank) - self.variables_selection.add_remove.layout().addWidget(self.btn_vizrank) + None, self, "Suggest Features", self.__vizrank_set_attrs) + super().__init__() + self.projection_input = None + self.variables = None + + def _add_controls(self): + self._add_controls_variables() + self._add_controls_placement() + super()._add_controls() + self.graph.gui.add_control( + self._effects_box, gui.hSlider, "Hide radius:", master=self.graph, + value="hide_radius", minValue=0, maxValue=100, step=10, + createLabel=False, callback=self.__radius_slider_changed + ) + self.controlArea.layout().removeWidget(self.control_area_stretch) + self.control_area_stretch.setParent(None) + + def _add_controls_variables(self): + self.variables_selection = VariablesSelection( + self, self.model_selected, self.model_other, self.controlArea + ) + self.variables_selection.add_remove.layout().addWidget( + self.btn_vizrank + ) + + def _add_controls_placement(self): box = gui.widgetBox( - self.controlArea, "Placement", sizePolicy=SIZE_POLICY) + self.controlArea, True, + sizePolicy=(QSizePolicy.Minimum, QSizePolicy.Maximum) + ) self.radio_placement = gui.radioButtonsInBox( box, self, "placement", - btnLabels=["Circular Placement", - "Linear Discriminant Analysis", - "Principal Component Analysis", - "Use input projection"], - callback=self._change_placement + btnLabels=[self.Projection_name[x] for x in self.Placement], + callback=self.__placement_radio_changed ) - self.viewbox = plot.getViewBox() - self.replot = None + @property + def continuous_variables(self): + if self.data is None or self.data.domain is None: + return [] + dom = self.data.domain + return [v for v in chain(dom.variables, dom.metas) if v.is_continuous] - g = self.graph.gui - box = g.point_properties_box(self.controlArea) - self.models = g.points_models - g.add_widget(g.JitterSizeSlider, box) - box.setSizePolicy(*SIZE_POLICY) - - box = gui.widgetBox(self.controlArea, "Hide axes", sizePolicy=SIZE_POLICY) - self.rslider = gui.hSlider( - box, self, "radius", minValue=0, maxValue=100, - step=5, label="Radius", createLabel=False, ticks=True, - callback=self.update_radius) - self.rslider.setTickInterval(0) - self.rslider.setPageStep(10) - - box = gui.vBox(self.controlArea, "Plot Properties") - box.setSizePolicy(*SIZE_POLICY) - - g.add_widgets([g.ShowLegend, - g.ToolTipShowsAll, - g.ClassDensity, - g.LabelOnlySelected], box) - - box = self.graph.box_zoom_select(self.controlArea) - box.setSizePolicy(*SIZE_POLICY) - - self.icons = gui.attributeIconDict - - gui.auto_commit(self.controlArea, self, "auto_commit", "Send Selection", - auto_label="Send Automatically") - self.graph.zoom_actions(self) - - self._new_plotdata() - self._change_placement() - self.graph.jitter_continuous = True - - def reset_graph_data(self): - if self.data is not None: - self.graph.rescale_data() - self._update_graph(reset_view=True) - - def keyPressEvent(self, event): - super().keyPressEvent(event) - self.graph.update_tooltip(event.modifiers()) - - def keyReleaseEvent(self, event): - super().keyReleaseEvent(event) - self.graph.update_tooltip(event.modifiers()) - - def _vizrank(self, attrs): - self.variables_selection.display_none() - self.model_selected[:] = attrs[:] - self.model_other[:] = [var for var in self.model_other if var not in attrs] - - def _change_placement(self): - placement = self.placement - p_Circular = self.Placement.Circular - p_LDA = self.Placement.LDA - self.variables_selection.set_enabled(placement in [p_Circular, p_LDA]) - self._vizrank_color_change() - self.rslider.setEnabled(placement != p_Circular) - self._setup_plot() - self.commit() - - def _get_min_radius(self): - return self.radius * np.max(np.linalg.norm(self.plotdata.axes, axis=1)) / 100 + 1e-5 - - def update_radius(self): - # Update the anchor/axes visibility - pd = self.plotdata - assert pd is not None - if pd.hidecircle is None: + def __vizrank_set_attrs(self, attrs): + if not attrs: return - min_radius = self._get_min_radius() - for anchor, item in zip(pd.axes, pd.axisitems): - item.setVisible(np.linalg.norm(anchor) > min_radius) - pd.hidecircle.setRect(QRectF(-min_radius, -min_radius, 2 * min_radius, 2 * min_radius)) - - def _new_plotdata(self): - self.plotdata = namespace( - valid_mask=None, - embedding_coords=None, - axisitems=[], - axes=[], - variables=[], - data=None, - hidecircle=None - ) - - def _anchor_circle(self, variables): - # minimum visible anchor radius (radius) - min_radius = self._get_min_radius() - axisitems = [] - for anchor, var in zip(self.plotdata.axes, variables[:]): - axitem = AnchorItem(line=QLineF(0, 0, *anchor), text=var.name,) - axitem.setVisible(np.linalg.norm(anchor) > min_radius) - axitem.setPen(pg.mkPen((100, 100, 100))) - axitem.setArrowVisible(True) - self.viewbox.addItem(axitem) - axisitems.append(axitem) - - self.plotdata.axisitems = axisitems - if self.placement == self.Placement.Circular: - return - - hidecircle = QGraphicsEllipseItem() - hidecircle.setRect(QRectF(-min_radius, -min_radius, 2 * min_radius, 2 * min_radius)) - - _pen = QPen(Qt.lightGray, 1) - _pen.setCosmetic(True) - hidecircle.setPen(_pen) + self.model_selected[:] = attrs[:] + self.model_other[:] = [var for var in self.continuous_variables + if var not in attrs] - self.viewbox.addItem(hidecircle) - self.plotdata.hidecircle = hidecircle + def __model_selected_changed(self): + self.selected_vars = [(var.name, vartype(var)) for var + in self.model_selected] + self.projection = None + self.variables = None + self._check_options() + self.setup_plot() + self.commit() - def update_colors(self): - self._vizrank_color_change() + def __placement_radio_changed(self): + self.variables_selection.set_enabled( + self.placement in [self.Placement.Circular, self.Placement.LDA]) + self.controls.graph.hide_radius.setEnabled( + self.placement != self.Placement.Circular) + self.projection = None + self.variables = None + self._init_vizrank() + self.setup_plot() + self.commit() - def clear(self): - # Clear/reset the widget state - self.data = None - self.model_selected.clear() - self.model_other.clear() - self._clear_plot() - self.selection_indices = None - - def _clear_plot(self): - self.Warning.trivial_components.clear() - for axisitem in self.plotdata.axisitems: - self.viewbox.removeItem(axisitem) - if self.plotdata.hidecircle: - self.viewbox.removeItem(self.plotdata.hidecircle) - self._new_plotdata() - self.graph.hide_axes() - - def invalidate_plot(self): - """ - Schedule a delayed replot. - """ - if not self.__replot_requested: - self.__replot_requested = True - QApplication.postEvent(self, QEvent(self.ReplotRequest), Qt.LowEventPriority - 10) + def __radius_slider_changed(self): + self.graph.update_radius() - def init_attr_values(self): - self.graph.set_domain(self.data) + def colors_changed(self): + super().colors_changed() + self._init_vizrank() - def _vizrank_color_change(self): - is_enabled = False + def set_data(self, data): + super().set_data(data) + if self.data is not None and len(self.selected_vars): + d, selected = self.data.domain, [v[0] for v in self.selected_vars] + self.model_selected[:] = [d[attr] for attr in selected] + self.model_other[:] = [d[attr.name] for attr in + self.continuous_variables + if attr.name not in selected] + elif self.data is not None: + self.model_selected[:] = self.continuous_variables[:3] + self.model_other[:] = self.continuous_variables[3:] + + self._check_options() + self._init_vizrank() + + def _check_options(self): + buttons = self.radio_placement.buttons + for btn in buttons: + btn.setEnabled(True) + if self.data is not None: + has_discrete_class = self.data.domain.has_discrete_class + if not has_discrete_class or len(np.unique(self.data.Y)) < 2: + buttons[self.Placement.LDA].setEnabled(False) + if self.placement == self.Placement.LDA: + self.placement = self.Placement.Circular + if not self.projection_input: + buttons[self.Placement.Projection].setEnabled(False) + if self.placement == self.Placement.Projection: + self.placement = self.Placement.Circular + + self.variables_selection.set_enabled( + self.placement in [self.Placement.Circular, self.Placement.LDA]) + self.controls.graph.hide_radius.setEnabled( + self.placement != self.Placement.Circular) + + def _init_vizrank(self): + is_enabled, msg = False, "" if self.data is None: - self.btn_vizrank.setToolTip("There is no data.") - return - vars = [v for v in chain(self.data.domain.variables, self.data.domain.metas) if - v.is_primitive and v is not self.graph.attr_color] - self.n_cont_var = len(vars) - if self.placement not in [self.Placement.Circular, self.Placement.LDA]: + msg = "There is no data." + elif self.placement not in [self.Placement.Circular, + self.Placement.LDA]: msg = "Suggest Features works only for Circular and " \ "Linear Discriminant Analysis Projection" - elif self.graph.attr_color is None: + elif self.attr_color is None: msg = "Color variable has to be selected" - elif self.graph.attr_color.is_continuous and self.placement == self.Placement.LDA: - msg = "Suggest Features does not work for Linear Discriminant Analysis Projection " \ - "when continuous color variable is selected." - elif len(vars) < 3: + elif self.attr_color.is_continuous and \ + self.placement == self.Placement.LDA: + msg = "Suggest Features does not work for Linear " \ + "Discriminant Analysis Projection when " \ + "continuous color variable is selected." + elif len([v for v in self.continuous_variables + if v is not self.attr_color]) < 3: msg = "Not enough available continuous variables" + elif len(self.data[self.valid_data]) < 2: + msg = "Not enough valid data instances" else: - is_enabled = True - msg = "" + is_enabled = not np.isnan(self.data.get_column_view( + self.attr_color)[0].astype(float)).all() self.btn_vizrank.setToolTip(msg) self.btn_vizrank.setEnabled(is_enabled) - self.vizrank.stop_and_reset(is_enabled) + if is_enabled: + self.vizrank.initialize() + + def check_data(self): + def error(err): + err() + self.data = None + + super().check_data() + if self.data is not None: + if not len(self.continuous_variables): + error(self.Error.no_cont_features) - @Inputs.projection + def init_attr_values(self): + super().init_attr_values() + self.selected_vars = [] + + @Inputs.projection_input def set_projection(self, projection): - self.Warning.not_enough_components.clear() + self.Warning.not_enough_comp.clear() if projection and len(projection) < 2: - self.Warning.not_enough_components() + self.Warning.not_enough_comp() projection = None if projection is not None: self.placement = self.Placement.Projection - self.projection = projection + self.projection_input = projection + self._check_options() - @Inputs.data - def set_data(self, data): - """ - Set the input dataset. + def get_embedding(self): + self.valid_data = None + if self.data is None or not self.variables: + return None - Args: - data (Orange.data.table): data instances - """ - def sql(data): - if isinstance(data, SqlTable): - if data.approx_len() < 4000: - data = Table(data) - else: - self.information("Data has been sampled") - data_sample = data.sample_time(1, no_cache=True) - data_sample.download_data(2000, partial=True) - data = Table(data_sample) - return data - - def settings(data): - # get the default encoded state, replacing the position with Inf - state = VariablesSelection.encode_var_state( - [list(self.model_selected), list(self.model_other)] - ) - state = {key: (source_ind, np.inf) for key, (source_ind, _) in state.items()} - - self.openContext(data.domain) - selected_keys = [key for key, (sind, _) in self.variable_state.items() if sind == 0] - - if set(selected_keys).issubset(set(state.keys())): - pass - - if self.__pending_selection_restore is not None: - self._selection = np.array(self.__pending_selection_restore, dtype=int) - self.__pending_selection_restore = None - - # update the defaults state (the encoded state must contain - # all variables in the input domain) - state.update(self.variable_state) - # ... and restore it with saved positions taking precedence over - # the defaults - selected, other = VariablesSelection.decode_var_state( - state, [list(self.model_selected), list(self.model_other)]) - return selected, other - - self.closeContext() - self.clear() - self.Warning.no_cont_features.clear() - self.information() - data = sql(data) - if data is not None: - domain = data.domain - vars = [var for var in chain(domain.variables, domain.metas) if var.is_continuous] - if not len(vars): - self.Warning.no_cont_features() - data = None - self.data = data - self.init_attr_values() - if data is not None and len(data): - self._initialize(data) - self.model_selected[:], self.model_other[:] = settings(data) - self.vizrank.stop_and_reset() - self.vizrank.attrs = self.data.domain.attributes if self.data is not None else [] - - def _check_possible_opt(self): - def set_enabled(is_enabled): - for btn in self.radio_placement.buttons: - btn.setEnabled(is_enabled) - self.variables_selection.set_enabled(is_enabled) - - p_Circular = self.Placement.Circular - p_LDA = self.Placement.LDA - p_Input = self.Placement.Projection - if self.data: - set_enabled(True) - domain = self.data.domain - if not domain.has_discrete_class or len(domain.class_var.values) < 2: - self.radio_placement.buttons[p_LDA].setEnabled(False) - if self.placement == p_LDA: - self.placement = p_Circular - if not self.projection: - self.radio_placement.buttons[p_Input].setEnabled(False) - if self.placement == p_Input: - self.placement = p_Circular - self._setup_plot() - else: - self.graph.new_data(None) - self.rslider.setEnabled(False) - set_enabled(False) - self.commit() - - @Inputs.data_subset - def set_subset_data(self, subset): - """ - Set the supplementary input subset dataset. - - Args: - subset (Orange.data.table): subset of data instances - """ - self.subset_data = subset - self._subset_mask = None - self.controls.graph.alpha_value.setEnabled(subset is None) - - def handleNewSignals(self): - if self.data is not None and self.subset_data is not None: - # Update the plot's highlight items - dataids = self.data.ids.ravel() - subsetids = np.unique(self.subset_data.ids) - self._subset_mask = np.in1d(dataids, subsetids, assume_unique=True) - self._check_possible_opt() - self._change_placement() - self.commit() - - def customEvent(self, event): - if event.type() == OWLinearProjection.ReplotRequest: - self.__replot_requested = False - self._setup_plot() - self.commit() + if self.placement == self.Placement.PCA: + self.valid_data, ec, self.projection = self._get_pca() + self.variables = self._pca.orig_domain.attributes else: - super().customEvent(event) + self.valid_data, ec, self.projection = \ + self.prepare_projection_data(self.variables) - def closeContext(self): - self.variable_state = VariablesSelection.encode_var_state( - [list(self.model_selected), list(self.model_other)] - ) - super().closeContext() - - def _initialize(self, data): - # Initialize the GUI controls from data's domain. - vars = [v for v in chain(data.domain.metas, data.domain.attributes) if v.is_continuous] - self.model_other[:] = vars[3:] - self.model_selected[:] = vars[:3] - - def prepare_plot_data(self, variables): - def projection(variables): - if set(self.projection.domain.attributes).issuperset(variables): - axes = self.projection[:2, variables].X - elif set(f.name for f in - self.projection.domain.attributes).issuperset(f.name for f in variables): - axes = self.projection[:2, [f.name for f in variables]].X + self.Error.no_valid_data.clear() + if self.valid_data is None or not sum(self.valid_data) or \ + self.projection is None or ec is None: + self.Error.no_valid_data() + return None + + embedding = np.zeros((len(self.data), 2), dtype=np.float) + embedding[self.valid_data] = ec + return embedding + + def prepare_projection_data(self, variables): + def projection(_vars): + attrs = self.projection_input.domain.attributes + if set(attrs).issuperset(_vars): + return self.projection_input[:2, _vars].X + elif set(f.name for f in attrs).issuperset(f.name for f in _vars): + return self.projection_input[:2, [f.name for f in _vars]].X else: self.Error.proj_and_domain_match() - axes = None - return axes + return None - def get_axes(variables): + def get_axes(_vars): self.Error.proj_and_domain_match.clear() - axes = None if self.placement == self.Placement.Circular: - axes = LinProj.defaultaxes(len(variables)) + return LinProj.defaultaxes(len(_vars)) elif self.placement == self.Placement.LDA: - axes = self._get_lda(self.data, variables) - elif self.placement == self.Placement.Projection and self.projection: - axes = projection(variables) - return axes - - coords = [column_data(self.data, var, dtype=float) for var in variables] - coords = np.vstack(coords) - p, N = coords.shape - assert N == len(self.data), p == len(variables) + return self._get_lda(self.data, _vars) + elif self.placement == self.Placement.Projection and \ + self.projection_input is not None: + return projection(_vars) + else: + return None + coords = np.vstack(column_data(self.data, v, float) for v in variables) axes = get_axes(variables) if axes is None: return None, None, None - assert axes.shape == (2, p) valid_mask = ~np.isnan(coords).any(axis=0) - coords = coords[:, valid_mask] - - X, Y = np.dot(axes, coords) + X, Y = np.dot(axes, coords[:, valid_mask]) if X.size and Y.size: X = normalized(X) Y = normalized(Y) - return valid_mask, np.stack((X, Y), axis=1), axes.T - def _setup_plot(self): - self._clear_plot() + def get_anchors(self): + if self.projection is None: + return None, None + return self.projection, [v.name for v in self.variables] + + def setup_plot(self): + self.init_projection_variables() + super().setup_plot() + + def init_projection_variables(self): + self.variables = None if self.data is None: return - self.__replot_requested = False - names = get_unique_names([v.name for v in chain(self.data.domain.variables, - self.data.domain.metas)], - ["{}-x".format(self.Variable_name[self.placement]), - "{}-y".format(self.Variable_name[self.placement])]) - self.variable_x = ContinuousVariable(names[0]) - self.variable_y = ContinuousVariable(names[1]) + if self.placement in [self.Placement.Circular, self.Placement.LDA]: - variables = list(self.model_selected) + self.variables = self.model_selected[:] elif self.placement == self.Placement.Projection: - variables = self.model_selected[:] + self.model_other[:] + self.variables = self.model_selected[:] + self.model_other[:] elif self.placement == self.Placement.PCA: - variables = [var for var in self.data.domain.attributes if var.is_continuous] - if not variables: - self.graph.new_data(None) - return - if self.placement == self.Placement.PCA: - valid_mask, ec, axes = self._get_pca() - variables = self._pca.orig_domain.attributes - else: - valid_mask, ec, axes = self.prepare_plot_data(variables) - - self.plotdata.variables = variables - self.plotdata.valid_mask = valid_mask - self.plotdata.embedding_coords = ec - self.plotdata.axes = axes - if any(e is None for e in (valid_mask, ec, axes)): - return - - if not sum(valid_mask): - self.Error.no_valid_data() - self.graph.new_data(None, None) - return - self.Error.no_valid_data.clear() - - self._anchor_circle(variables=variables) - self._plot() - - def _plot(self): - domain = self.data.domain - new_metas = domain.metas + (self.variable_x, self.variable_y) - domain = Domain(attributes=domain.attributes, class_vars=domain.class_vars, metas=new_metas) - valid_mask = self.plotdata.valid_mask - array = np.zeros((len(self.data), 2), dtype=np.float) - array[valid_mask] = self.plotdata.embedding_coords - self.plotdata.data = data = self.data.transform(domain) - data[:, self.variable_x] = array[:, 0].reshape(-1, 1) - data[:, self.variable_y] = array[:, 1].reshape(-1, 1) - subset_data = data[self._subset_mask & valid_mask]\ - if self._subset_mask is not None and len(self._subset_mask) else None - self.plotdata.data = data - self.graph.new_data(data[valid_mask], subset_data) - if self._selection is not None: - self.graph.selection = self._selection[valid_mask] - self.graph.update_data(self.variable_x, self.variable_y, False) + self.variables = [var for var in self.data.domain.attributes + if var.is_continuous] def _get_lda(self, data, variables): - domain = Domain(attributes=variables, class_vars=data.domain.class_vars) - data = data.transform(domain) + data = data.transform(Domain(variables, data.domain.class_vars)) lda = LinearDiscriminantAnalysis(solver='eigen', n_components=2) lda.fit(data.X, data.Y) scalings = lda.scalings_[:, :2].T @@ -747,16 +550,11 @@ def _get_lda(self, data, variables): return scalings def _get_pca(self): - data = self.data - MAX_COMPONENTS = 2 - ncomponents = 2 - DECOMPOSITIONS = [PCA] # TruncatedSVD - cls = DECOMPOSITIONS[0] - pca_projector = cls(n_components=MAX_COMPONENTS) - pca_projector.component = ncomponents - pca_projector.preprocessors = cls.preprocessors + [Normalize()] - - pca = pca_projector(data) + pca_projector = PCA(n_components=2) + pca_projector.component = 2 + pca_projector.preprocessors = PCA.preprocessors + [Normalize()] + + pca = pca_projector(self.data) variance_ratio = pca.explained_variance_ratio_ cumulative = np.cumsum(variance_ratio) @@ -764,108 +562,62 @@ def _get_pca(self): if not np.isfinite(cumulative[-1]): self.Warning.trivial_components() - coords = pca(data).X + coords = pca(self.data).X valid_mask = ~np.isnan(coords).any(axis=1) # scale axes - max_radius = np.min([np.abs(np.min(coords, axis=0)), np.max(coords, axis=0)]) + max_radius = np.min([np.abs(np.min(coords, axis=0)), + np.max(coords, axis=0)]) axes = pca.components_.T.copy() axes *= max_radius / np.max(np.linalg.norm(axes, axis=1)) return valid_mask, coords, axes - def _update_graph(self, reset_view=False): - self.graph.zoomStack = [] - if self.graph.data is None: - return - self.graph.update_data(self.variable_x, self.variable_y, reset_view) - - def update_density(self): - self._update_graph(reset_view=False) - - def selection_changed(self): - if self.graph.selection is not None: - self._selection = np.zeros(len(self.data), dtype=np.uint8) - self._selection[self.plotdata.valid_mask] = self.graph.selection - self.selection_indices = self._selection.tolist() - else: - self._selection = self.selection_indices = None - self.commit() - - def prepare_data(self): - pass - - def commit(self): - def prepare_components(): + def send_components(self): + components = None + if self.data is not None and self.valid_data is not None and \ + self.projection is not None: if self.placement in [self.Placement.Circular, self.Placement.LDA]: - attrs = [a for a in self.model_selected[:]] - axes = self.plotdata.axes + axes = self.projection + attrs = self.model_selected elif self.placement == self.Placement.PCA: axes = self._pca.components_.T - attrs = [a for a in self._pca.orig_domain.attributes] + attrs = self._pca.orig_domain.attributes if self.placement != self.Placement.Projection: - domain = Domain([ContinuousVariable(a.name, compute_value=lambda _: None) - for a in attrs], - metas=[StringVariable(name='component')]) - metas = np.array([["{}{}".format(self.Component_name[self.placement], i + 1) - for i in range(axes.shape[1])]], - dtype=object).T - components = Table(domain, axes.T, metas=metas) - components.name = 'components' - else: - components = self.projection - return components - - selected = annotated = components = None - if self.data is not None and self.plotdata.data is not None: - components = prepare_components() - - graph = self.graph - mask = self.plotdata.valid_mask.astype(int) - mask[mask == 1] = graph.selection if graph.selection is not None \ - else [False * len(mask)] - - selection = np.array([], dtype=np.uint8) if mask is None else np.flatnonzero(mask) - name = self.data.name - data = self.plotdata.data - if len(selection): - selected = data[selection] - selected.name = name + ": selected" - selected.attributes = self.data.attributes - - if graph.selection is not None and np.max(graph.selection) > 1: - annotated = create_groups_table(data, mask) + meta_attrs = [StringVariable(name='component')] + metas = np.array( + [["{}{}".format(self.Component_name[self.placement], i + 1) + for i in range(axes.shape[1])]], dtype=object).T + components = Table(Domain(attrs, metas=meta_attrs), + axes.T, metas=metas) + components.name = self.data.name else: - annotated = create_annotated_table(data, selection) - annotated.attributes = self.data.attributes - annotated.name = name + ": annotated" - - self.Outputs.selected_data.send(selected) - self.Outputs.annotated_data.send(annotated) + components = self.projection_input self.Outputs.components.send(components) - def send_report(self): - if self.data is None: - return - - def name(var): - return var and var.name + def _get_projection_variables(self): + pn = self.Variable_name[self.placement] + self.embedding_variables_names = ("{}-x".format(pn), "{}-y".format(pn)) + return super()._get_projection_variables() + def _get_send_report_caption(self): def projection_name(): - name = ("Circular Placement", - "Linear Discriminant Analysis", - "Principal Component Analysis", - "Input projection") - return name[self.placement] + return self.Projection_name[self.placement] - caption = report.render_items_vert(( + return report.render_items_vert(( ("Projection", projection_name()), - ("Color", name(self.graph.attr_color)), - ("Label", name(self.graph.attr_label)), - ("Shape", name(self.graph.attr_shape)), - ("Size", name(self.graph.attr_size)), - ("Jittering", self.graph.jitter_size != 0 and "{} %".format(self.graph.jitter_size)))) - self.report_plot() - if caption: - self.report_caption(caption) + ("Color", self._get_caption_var_name(self.attr_color)), + ("Label", self._get_caption_var_name(self.attr_label)), + ("Shape", self._get_caption_var_name(self.attr_shape)), + ("Size", self._get_caption_var_name(self.attr_size)), + ("Jittering", self.graph.jitter_size != 0 and + "{} %".format(self.graph.jitter_size)))) + + def clear(self): + self.variables = None + if self.model_selected: + self.model_selected.clear() + if self.model_other: + self.model_other.clear() + super().clear() @classmethod def migrate_settings(cls, settings_, version): @@ -878,6 +630,14 @@ def migrate_settings(cls, settings_, version): settings_graph["alpha_value"] = settings_["alpha_value"] settings_graph["class_density"] = settings_["class_density"] settings_["graph"] = settings_graph + if version < 4: + if "radius" in settings_: + settings_["graph"]["hide_radius"] = settings_["radius"] + if "selection_indices" in settings_ and \ + settings_["selection_indices"] is not None: + selection = settings_["selection_indices"] + settings_["selection"] = [(i, 1) for i, selected in + enumerate(selection) if selected] @classmethod def migrate_context(cls, context, version): @@ -897,6 +657,12 @@ def migrate_context(cls, context, version): "attr_shape": context.values["attr_shape"], "attr_size": context.values["attr_size"] } + if version == 3: + values = context.values + values["attr_color"] = values["graph"]["attr_color"] + values["attr_size"] = values["graph"]["attr_size"] + values["attr_shape"] = values["graph"]["attr_shape"] + values["attr_label"] = values["graph"]["attr_label"] def column_data(table, var, dtype): @@ -937,6 +703,7 @@ def defaultaxes(naxes): def project(axes, X): return np.dot(axes, X) + def normalized(a): if not a.size: return a.copy() @@ -950,7 +717,6 @@ def normalized(a): def main(argv=None): import sys - import sip argv = sys.argv[1:] if argv is None else argv if argv: @@ -967,12 +733,10 @@ def main(argv=None): w.handleNewSignals() w.show() w.raise_() - r = app.exec() + app.exec() w.set_data(None) w.saveSettings() - sip.delete(w) del w - return r if __name__ == "__main__": diff --git a/Orange/widgets/visualize/owradviz.py b/Orange/widgets/visualize/owradviz.py index 6169154da9d..52cdb6f2d3a 100644 --- a/Orange/widgets/visualize/owradviz.py +++ b/Orange/widgets/visualize/owradviz.py @@ -8,40 +8,30 @@ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor from AnyQt.QtGui import QStandardItem, QColor -from AnyQt.QtCore import ( - Qt, QEvent, QRectF, QPoint, pyqtSignal as Signal -) -from AnyQt.QtWidgets import ( - qApp, QApplication, QToolTip, QGraphicsEllipseItem -) +from AnyQt.QtCore import Qt, QRectF, QPoint, pyqtSignal as Signal +from AnyQt.QtWidgets import qApp, QApplication import pyqtgraph as pg from pyqtgraph.graphicsItems.ScatterPlotItem import ScatterPlotItem -from Orange.data import Table, Domain, ContinuousVariable, StringVariable -from Orange.data.sql.table import SqlTable +from Orange.data import Table, Domain, StringVariable from Orange.preprocess.score import ReliefF, RReliefF from Orange.projection import radviz -from Orange.widgets import widget, gui, settings, report +from Orange.widgets import widget, gui from Orange.widgets.gui import OWComponent -from Orange.widgets.settings import Setting -from Orange.widgets.utils.annotated_data import ( - create_annotated_table, ANNOTATED_DATA_SIGNAL_NAME, create_groups_table -) +from Orange.widgets.settings import Setting, ContextSetting, SettingProvider +from Orange.widgets.utils import vartype from Orange.widgets.utils.itemmodels import VariableListModel from Orange.widgets.utils.plot import VariablesSelection -from Orange.widgets.visualize.owscatterplotgraph import OWProjectionWidget from Orange.widgets.visualize.utils import VizRankDialog -from Orange.widgets.visualize.utils.component import OWVizGraph -from Orange.widgets.visualize.utils.plotutils import ( - TextItem, VizInteractiveViewBox -) -from Orange.widgets.widget import Input, Output +from Orange.widgets.visualize.utils.component import OWGraphWithAnchors +from Orange.widgets.visualize.utils.plotutils import TextItem +from Orange.widgets.visualize.utils.widget import OWAnchorProjectionWidget class RadvizVizRank(VizRankDialog, OWComponent): captionTitle = "Score Plots" - n_attrs = settings.Setting(3) + n_attrs = Setting(3) minK = 10 attrsSelected = Signal([]) @@ -208,444 +198,237 @@ def stopped(self): self.n_attrs_spin.setDisabled(False) -class RadvizInteractiveViewBox(VizInteractiveViewBox): - def mouseDragEvent(self, ev, axis=None): - super().mouseDragEvent(ev, axis) - if ev.finish: - self.setCursor(Qt.ArrowCursor) - self.graph.show_indicator(None) - - def _show_tooltip(self, ev): - pos = self.childGroup.mapFromParent(ev.pos()) - angle = np.arctan2(pos.y(), pos.x()) - point = QPoint(ev.screenPos().x(), ev.screenPos().y()) - QToolTip.showText(point, "{:.2f}".format(np.rad2deg(angle))) - - -class OWRadvizGraph(OWVizGraph): +class OWRadvizGraph(OWGraphWithAnchors): def __init__(self, scatter_widget, parent): - super().__init__(scatter_widget, parent, RadvizInteractiveViewBox) - self._text_items = [] + super().__init__(scatter_widget, parent) + self.anchors_scatter_item = None - def set_point(self, i, x, y): - angle = np.arctan2(y, x) - super().set_point(i, np.cos(angle), np.sin(angle)) + def clear(self): + super().clear() + self.anchors_scatter_item = None def set_view_box_range(self): - self.view_box.setRange(RANGE, padding=0.025) - - def can_show_indicator(self, pos): - if self._points is None: - return False, None + self.view_box.setRange(QRectF(-1.2, -1.05, 2.4, 2.1), padding=0.025) + def closest_draggable_item(self, pos): + points, _ = self.master.get_anchors() + if points is None: + return None np_pos = np.array([[pos.x(), pos.y()]]) - distances = distance.cdist(np_pos, self._points[:, :2])[0] + distances = distance.cdist(np_pos, points[:, :2])[0] if len(distances) and np.min(distances) < self.DISTANCE_DIFF: - return True, np.argmin(distances) - return False, None - - def update_items(self): - super().update_items() - self._update_text_items() - - def _update_text_items(self): - self._remove_text_items() - self._add_text_items() + return np.argmin(distances) + return None - def _remove_text_items(self): - for item in self._text_items: - self.plot_widget.removeItem(item) - self._text_items = [] - - def _add_text_items(self): - if self._points is None: - return - for point in self._points: - ti = TextItem() - ti.setText(point[2].name) - ti.setColor(QColor(0, 0, 0)) - ti.setPos(point[0], point[1]) - self._text_items.append(ti) - self.plot_widget.addItem(ti) - - def _add_point_items(self): - if self._points is None: + def update_anchors(self): + points, labels = self.master.get_anchors() + if points is None: return - x, y = self._points[:, 0], self._points[:, 1] - self._point_items = ScatterPlotItem(x=x, y=y) - self.plot_widget.addItem(self._point_items) - - def _add_circle_item(self): - if self._points is None: - return - self._circle_item = QGraphicsEllipseItem() - self._circle_item.setRect(QRectF(-1., -1., 2., 2.)) - self._circle_item.setPen(pg.mkPen(QColor(0, 0, 0), width=2)) - self.plot_widget.addItem(self._circle_item) - - def _add_indicator_item(self, point_i): - if point_i is None: + if self.anchor_items is None: + self.anchor_items = [] + for point, label in zip(points, labels): + anchor = TextItem() + anchor.setText(label) + anchor.setColor(QColor(0, 0, 0)) + anchor.setPos(*point) + self.plot_widget.addItem(anchor) + self.anchor_items.append(anchor) + else: + for anchor, point in zip(self.anchor_items, points): + anchor.setPos(*point) + self._update_anchors_scatter_item(points) + + def _update_anchors_scatter_item(self, points): + if self.anchors_scatter_item is not None: + self.plot_widget.removeItem(self.anchors_scatter_item) + self.anchors_scatter_item = None + self.anchors_scatter_item = ScatterPlotItem(x=points[:, 0], + y=points[:, 1]) + self.plot_widget.addItem(self.anchors_scatter_item) + + def _add_indicator_item(self, anchor_idx): + if anchor_idx is None: return - x, y = self._points[point_i][:2] + x, y = self.anchor_items[anchor_idx].get_xy() col = self.view_box.mouse_state dx = (self.view_box.childGroup.mapToDevice(QPoint(1, 0)) - self.view_box.childGroup.mapToDevice(QPoint(-1, 0))).x() - self._indicator_item = MoveIndicator(np.arctan2(y, x), col, 6000 / dx) - self.plot_widget.addItem(self._indicator_item) - + self.indicator_item = MoveIndicator(np.arctan2(y, x), col, 6000 / dx) + self.plot_widget.addItem(self.indicator_item) -RANGE = QRectF(-1.2, -1.05, 2.4, 2.1) -MAX_POINTS = 100 - -class OWRadviz(OWProjectionWidget): +class OWRadviz(OWAnchorProjectionWidget): name = "Radviz" description = "Display Radviz projection" icon = "icons/Radviz.svg" priority = 241 keywords = ["viz"] - class Inputs: - data = Input("Data", Table, default=True) - data_subset = Input("Data Subset", Table) - - class Outputs: - selected_data = Output("Selected Data", Table, default=True) - annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table) - components = Output("Components", Table) - settings_version = 2 - settingsHandler = settings.DomainContextHandler() - - variable_state = settings.ContextSetting({}) - auto_commit = settings.Setting(True) - vizrank = settings.SettingProvider(RadvizVizRank) - graph = settings.SettingProvider(OWRadvizGraph) - graph_name = "graph.plot_widget.plotItem" + selected_vars = ContextSetting([]) + vizrank = SettingProvider(RadvizVizRank) + GRAPH_CLASS = OWRadvizGraph + graph = SettingProvider(OWRadvizGraph) + embedding_variables_names = ("radviz-x", "radviz-y") - ReplotRequest = QEvent.registerEventType() - - class Information(OWProjectionWidget.Information): - sql_sampled_data = widget.Msg("Data has been sampled") - - class Warning(OWProjectionWidget.Warning): - no_features = widget.Msg("At least 2 features have to be chosen") + class Warning(OWAnchorProjectionWidget.Warning): + no_features = widget.Msg("Radviz requires at least two features.") invalid_embedding = widget.Msg("No projection for selected features") - class Error(OWProjectionWidget.Error): - sparse_data = widget.Msg("Sparse data is not supported") + class Error(OWAnchorProjectionWidget.Error): no_features = widget.Msg( - "At least 3 numeric or categorical variables are required" + "At least three numeric or categorical variables are required" ) - no_instances = widget.Msg("At least 2 data instances are required") + no_instances = widget.Msg("At least two data instances are required") def __init__(self): - super().__init__() - - self.data = None - self.subset_data = None - self.subset_indices = None - self._embedding_coords = None - self._rand_indices = None - - self.__replot_requested = False - - self.variable_x = ContinuousVariable("radviz-x") - self.variable_y = ContinuousVariable("radviz-y") - - box = gui.vBox(self.mainArea, True, margin=0) - self.graph = OWRadvizGraph(self, box) - box.layout().addWidget(self.graph.plot_widget) - - self.variables_selection = VariablesSelection() - self.model_selected = selected = VariableListModel(enable_dnd=True) - self.model_other = other = VariableListModel(enable_dnd=True) - self.variables_selection(self, selected, other, self.controlArea) + self.model_selected = VariableListModel(enable_dnd=True) + self.model_selected.rowsInserted.connect(self.__model_selected_changed) + self.model_selected.rowsRemoved.connect(self.__model_selected_changed) + self.model_other = VariableListModel(enable_dnd=True) self.vizrank, self.btn_vizrank = RadvizVizRank.add_vizrank( - None, self, "Suggest features", self.vizrank_set_attrs) - # Todo: this button introduces some margin at the bottom?! - self.variables_selection.add_remove.layout().addWidget(self.btn_vizrank) - - g = self.graph.gui - g.point_properties_box(self.controlArea) - g.effects_box(self.controlArea) - g.plot_properties_box(self.controlArea) - - self.graph.box_zoom_select(self.controlArea) - - gui.auto_commit(self.controlArea, self, "auto_commit", - "Send Selection", "Send Automatically") - - self.graph.view_box.started.connect(self._randomize_indices) - self.graph.view_box.moved.connect(self._manual_move) - self.graph.view_box.finished.connect(self._finish_manual_move) + None, self, "Suggest features", self.__vizrank_set_attrs + ) + super().__init__() - def vizrank_set_attrs(self, attrs): + def _add_controls(self): + self.variables_selection = VariablesSelection( + self, self.model_selected, self.model_other, self.controlArea + ) + self.variables_selection.add_remove.layout().addWidget( + self.btn_vizrank + ) + super()._add_controls() + self.controlArea.layout().removeWidget(self.control_area_stretch) + self.control_area_stretch.setParent(None) + + @property + def primitive_variables(self): + if self.data is None or self.data.domain is None: + return [] + dom = self.data.domain + return [v for v in chain(dom.variables, dom.metas) if v.is_primitive()] + + def __vizrank_set_attrs(self, attrs): if not attrs: return - self.variables_selection.display_none() self.model_selected[:] = attrs[:] - self.model_other[:] = [v for v in self.model_other if v not in attrs] + self.model_other[:] = [var for var in self.primitive_variables + if var not in attrs] - def update_colors(self): - self._vizrank_color_change() - self.cb_class_density.setEnabled(self.can_draw_density()) + def __model_selected_changed(self): + self.selected_vars = [(var.name, vartype(var)) for var + in self.model_selected] + self.projection = None + self.setup_plot() + self.commit() - def invalidate_plot(self): - """ - Schedule a delayed replot. - """ - if not self.__replot_requested: - self.__replot_requested = True - QApplication.postEvent(self, QEvent(self.ReplotRequest), Qt.LowEventPriority - 10) + def colors_changed(self): + super().colors_changed() + self._init_vizrank() - def _vizrank_color_change(self): - is_enabled = self.data is not None and not self.data.is_sparse() and \ - len(self.model_other) + len(self.model_selected) > 3 and \ + def set_data(self, data): + super().set_data(data) + if self.data is not None and len(self.selected_vars): + d, selected = self.data.domain, [v[0] for v in self.selected_vars] + self.model_selected[:] = [d[name] for name in selected] + self.model_other[:] = [d[attr.name] for attr in + self.primitive_variables + if attr.name not in selected] + elif self.data is not None: + d, variables = self.data.domain, self.primitive_variables + class_var = [variables.pop(variables.index(d.class_var))] \ + if d.class_var in variables else [] + self.model_selected[:] = variables[:5] + self.model_other[:] = variables[5:] + class_var + + self._init_vizrank() + + def _init_vizrank(self): + is_enabled = self.data is not None and \ + len(self.primitive_variables) > 3 and \ + self.attr_color is not None and \ + not np.isnan(self.data.get_column_view( + self.attr_color)[0].astype(float)).all() and \ len(self.data[self.valid_data]) > 1 and \ np.all(np.nan_to_num(np.nanstd(self.data.X, 0)) != 0) - self.btn_vizrank.setEnabled( - is_enabled and self.attr_color is not None - and not np.isnan(self.data.get_column_view( - self.attr_color)[0].astype(float)).all()) - self.vizrank.initialize() - - def clear(self): - self.data = None - self.valid_data = None - self._embedding_coords = None - self._rand_indices = None - self.model_selected.clear() - self.model_other.clear() - - self.graph.set_attributes(()) - self.graph.set_points(None) - self.graph.update_coordinates() - self.graph.clear() + self.btn_vizrank.setEnabled(is_enabled) + if is_enabled: + self.vizrank.initialize() - @Inputs.data - def set_data(self, data): - self.clear_messages() - self.btn_vizrank.setEnabled(False) - self.closeContext() - self.clear() - self.data = data - self._check_data() - self.init_attr_values() - self.openContext(self.data) - if self.data is not None: - self.model_selected[:], self.model_other[:] = self._load_settings() + def check_data(self): + def error(err): + err() + self.data = None - def _check_data(self): + super().check_data() if self.data is not None: - domain = self.data.domain - if self.data.is_sparse(): - self.Error.sparse_data() - self.data = None - elif isinstance(self.data, SqlTable): - if self.data.approx_len() < 4000: - self.data = Table(self.data) - else: - self.Information.sql_sampled_data() - data_sample = self.data.sample_time(1, no_cache=True) - data_sample.download_data(2000, partial=True) - self.data = Table(data_sample) - elif len(self.data) < 2: - self.Error.no_instances() - self.data = None - elif len([v for v in domain.variables + - domain.metas if v.is_primitive()]) < 3: - self.Error.no_features() - self.data = None - - def _load_settings(self): - domain = self.data.domain - variables = [v for v in domain.attributes + domain.metas - if v.is_primitive()] - self.model_selected[:] = variables[:5] - self.model_other[:] = variables[5:] + list(domain.class_vars) - - state = VariablesSelection.encode_var_state( - [list(self.model_selected), list(self.model_other)] - ) - state = {key: (ind, np.inf) for key, (ind, _) in state.items()} - state.update(self.variable_state) - return VariablesSelection.decode_var_state( - state, [list(self.model_selected), list(self.model_other)]) - - @Inputs.data_subset - def set_subset_data(self, subset): - self.subset_data = subset - self.subset_indices = {e.id for e in subset} \ - if subset is not None else {} - self.controls.graph.alpha_value.setEnabled(subset is None) - - def handleNewSignals(self): - self.setup_plot() - self._vizrank_color_change() - self.commit() - - def get_coordinates_data(self): - ec = self._embedding_coords - if ec is None or np.any(np.isnan(ec)): - return None, None - return ec[:, 0], ec[:, 1] + if len(self.data) < 2: + error(self.Error.no_instances) + elif len(self.primitive_variables) < 3: + error(self.Error.no_features) - def get_subset_mask(self): - if self.subset_indices: - return np.array([ex.id in self.subset_indices - for ex in self.data[self.valid_data]]) + def init_attr_values(self): + super().init_attr_values() + self.selected_vars = [] - def customEvent(self, event): - if event.type() == OWRadviz.ReplotRequest: - self.__replot_requested = False - self.setup_plot() - else: - super().customEvent(event) - - def closeContext(self): - self.variable_state = VariablesSelection.encode_var_state( - [list(self.model_selected), list(self.model_other)] - ) - super().closeContext() - - def setup_plot(self): + def get_embedding(self): + self.valid_data = None if self.data is None: - return - self.__replot_requested = False + return None - self.clear_messages() + self.Warning.no_features.clear() if len(self.model_selected) < 2: self.Warning.no_features() - self.graph.clear() - return + return None - r = radviz(self.data, self.model_selected) - self._embedding_coords = r[0] - self.graph.set_points(r[1]) - self.valid_data = r[2] - if self._embedding_coords is None or \ - np.any(np.isnan(self._embedding_coords)): + ec, proj, msk = radviz(self.data, self.model_selected, self.projection) + angle = np.arctan2(*proj.T[::-1]) + self.projection = np.vstack((np.cos(angle), np.sin(angle))).T + self.valid_data = msk + + self.Warning.invalid_embedding.clear() + if ec is None or np.any(np.isnan(ec)): self.Warning.invalid_embedding() - self.graph.reset_graph() - - def _randomize_indices(self): - n = len(self._embedding_coords) - if n > MAX_POINTS: - self._rand_indices = np.random.choice(n, MAX_POINTS, replace=False) - self._rand_indices = sorted(self._rand_indices) - - def _manual_move(self): - self.__replot_requested = False - - res = radviz(self.data, self.model_selected, self.graph.get_points()) - self._embedding_coords = res[0] - if self._rand_indices is not None: - # save widget state - selection = self.graph.selection - valid_data = self.valid_data.copy() - data = self.data.copy() - ec = self._embedding_coords.copy() - - # plot subset - self.__plot_random_subset(selection) - - # restore widget state - self.graph.selection = selection - self.valid_data = valid_data - self.data = data - self._embedding_coords = ec - else: - self.graph.update_coordinates() - - def __plot_random_subset(self, selection): - self._embedding_coords = self._embedding_coords[self._rand_indices] - self.data = self.data[self._rand_indices] - self.valid_data = self.valid_data[self._rand_indices] - self.graph.reset_graph() - if selection is not None: - self.graph.selection = selection[self._rand_indices] - self.graph.update_selection_colors() - - def _finish_manual_move(self): - if self._rand_indices is not None: - selection = self.graph.selection - self.graph.reset_graph() - if selection is not None: - self.graph.selection = selection - self.graph.select_by_index(self.graph.get_selection()) - - def selection_changed(self): - self.commit() + return None - def commit(self): - selected = annotated = components = None - if self.data is not None and np.sum(self.valid_data): - name = self.data.name - domain = self.data.domain - metas = domain.metas + (self.variable_x, self.variable_y) - domain = Domain(domain.attributes, domain.class_vars, metas) - embedding_coords = np.zeros((len(self.data), 2), dtype=np.float) - embedding_coords[self.valid_data] = self._embedding_coords - - data = self.data.transform(domain) - data[:, self.variable_x] = embedding_coords[:, 0][:, None] - data[:, self.variable_y] = embedding_coords[:, 1][:, None] - - selection = self.graph.get_selection() - if len(selection): - selected = data[selection] - selected.name = name + ": selected" - selected.attributes = self.data.attributes - if self.graph.selection is not None and \ - np.max(self.graph.selection) > 1: - annotated = create_groups_table(data, self.graph.selection) - else: - annotated = create_annotated_table(data, selection) - annotated.attributes = self.data.attributes - annotated.name = name + ": annotated" - - points = self.graph.get_points() - comp_domain = Domain( - points[:, 2], - metas=[StringVariable(name='component')]) - - metas = np.array([["RX"], ["RY"], ["angle"]]) - angle = np.arctan2(np.array(points[:, 1].T, dtype=float), - np.array(points[:, 0].T, dtype=float)) - components = Table.from_numpy( - comp_domain, - X=np.row_stack((points[:, :2].T, angle)), - metas=metas) - components.name = name + ": components" - - self.Outputs.selected_data.send(selected) - self.Outputs.annotated_data.send(annotated) - self.Outputs.components.send(components) + embedding = np.zeros((len(self.data), 2), dtype=np.float) + embedding[self.valid_data] = ec + return embedding - def send_report(self): - if self.data is None: - return + def get_anchors(self): + if self.projection is None: + return None, None + return self.projection, [a.name for a in self.model_selected] - def name(var): - return var and var.name + def _manual_move(self, anchor_idx, x, y): + angle = np.arctan2(y, x) + super()._manual_move(anchor_idx, np.cos(angle), np.sin(angle)) + + def send_components(self): + components = None + if self.data is not None and self.valid_data is not None and \ + self.projection is not None: + angle = np.arctan2(*self.projection.T[::-1]) + meta_attrs = [StringVariable(name='component')] + components = Table(Domain(self.model_selected, metas=meta_attrs), + np.row_stack((self.projection.T, angle)), + metas=np.array([["RX"], ["RY"], ["angle"]])) + components.name = self.data.name + self.Outputs.components.send(components) - caption = report.render_items_vert(( - ("Color", name(self.attr_color)), - ("Label", name(self.attr_label)), - ("Shape", name(self.attr_shape)), - ("Size", name(self.attr_size)), - ("Jittering", self.graph.jitter_size != 0 and - "{} %".format(self.graph.jitter_size)))) - self.report_plot() - if caption: - self.report_caption(caption) + def clear(self): + if self.model_selected: + self.model_selected.clear() + if self.model_other: + self.model_other.clear() + super().clear() @classmethod def migrate_context(cls, context, version): - if version < 3: + if version < 2: values = context.values values["attr_color"] = values["graph"]["attr_color"] values["attr_size"] = values["graph"]["attr_size"] @@ -690,7 +473,6 @@ def boundingRect(self): def main(argv=None): import sys - import sip argv = sys.argv[1:] if argv is None else argv if argv: @@ -706,13 +488,8 @@ def main(argv=None): w.set_subset_data(data[::10]) w.handleNewSignals() w.show() - w.raise_() - r = app.exec() - w.set_data(None) + app.exec() w.saveSettings() - sip.delete(w) - del w - return r if __name__ == "__main__": diff --git a/Orange/widgets/visualize/owscatterplot.py b/Orange/widgets/visualize/owscatterplot.py index 66cd62d5777..a806d94651a 100644 --- a/Orange/widgets/visualize/owscatterplot.py +++ b/Orange/widgets/visualize/owscatterplot.py @@ -19,17 +19,13 @@ from Orange.widgets import gui, report from Orange.widgets.io import MatplotlibFormat, MatplotlibPDFFormat from Orange.widgets.settings import ( - DomainContextHandler, Setting, ContextSetting, SettingProvider + Setting, ContextSetting, SettingProvider ) from Orange.widgets.utils import get_variable_values_sorted -from Orange.widgets.utils.annotated_data import ( - create_annotated_table, create_groups_table, ANNOTATED_DATA_SIGNAL_NAME -) from Orange.widgets.utils.itemmodels import DomainModel -from Orange.widgets.visualize.owscatterplotgraph import ( - OWScatterPlotBase, OWProjectionWidget -) +from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase from Orange.widgets.visualize.utils import VizRankDialogAttrPair +from Orange.widgets.visualize.utils.widget import OWDataProjectionWidget from Orange.widgets.widget import AttributeList, Msg, Input, Output @@ -70,7 +66,7 @@ def compute_score(self, state): data = data.transform(Domain(attrs, self.attr_color)) data = data[~np.isnan(data.X).any(axis=1) & ~np.isnan(data.Y).T] if len(data) < self.minK: - return + return None n_neighbors = min(self.minK, len(data) - 1) knn = NearestNeighbors(n_neighbors=n_neighbors).fit(data.X) ind = knn.kneighbors(return_distance=False) @@ -101,7 +97,6 @@ def score_heuristic(self): class OWScatterPlotGraph(OWScatterPlotBase): show_reg_line = Setting(False) - jitter_size = Setting(10) jitter_continuous = Setting(False) def __init__(self, scatter_widget, parent): @@ -125,6 +120,16 @@ def set_axis_title(self, axis, title): def update_coordinates(self): super().update_coordinates() self.update_regression_line() + self.update_tooltip() + + def _get_jittering_tooltip(self): + def is_discrete(attr): + return attr and attr.is_discrete + + if self.jitter_continuous or is_discrete(self.master.attr_x) or \ + is_discrete(self.master.attr_y): + return super()._get_jittering_tooltip() + return "" def jitter_coordinates(self, x, y): def get_span(attr): @@ -172,7 +177,7 @@ def update_regression_line(self): self.plot_widget.addItem(self.reg_line_item) -class OWScatterPlot(OWProjectionWidget): +class OWScatterPlot(OWDataProjectionWidget): """Scatterplot visualization with explorative analysis and intelligent data visualization enhancements.""" @@ -183,59 +188,57 @@ class OWScatterPlot(OWProjectionWidget): priority = 140 keywords = [] - class Inputs: - data = Input("Data", Table, default=True) - data_subset = Input("Data Subset", Table) + class Inputs(OWDataProjectionWidget.Inputs): features = Input("Features", AttributeList) - class Outputs: - selected_data = Output("Selected Data", Table, default=True) - annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table) + class Outputs(OWDataProjectionWidget.Outputs): features = Output("Features", AttributeList, dynamic=False) settings_version = 3 - settingsHandler = DomainContextHandler() - - auto_send_selection = Setting(True) auto_sample = Setting(True) - attr_x = ContextSetting(None) attr_y = ContextSetting(None) tooltip_shows_all = Setting(True) - #: Serialized selection state to be restored - selection_group = Setting(None, schema_only=True) - + GRAPH_CLASS = OWScatterPlotGraph graph = SettingProvider(OWScatterPlotGraph) - graph_name = "graph.plot_widget.plotItem" + embedding_variables_names = None - class Warning(OWProjectionWidget.Warning): + class Warning(OWDataProjectionWidget.Warning): missing_coords = Msg( "Plot cannot be displayed because '{}' or '{}' " "is missing for all data points") - class Information(OWProjectionWidget.Information): + class Information(OWDataProjectionWidget.Information): sampled_sql = Msg("Large SQL table; showing a sample.") missing_coords = Msg( "Points with missing '{}' or '{}' are not displayed") def __init__(self): - super().__init__() - - box = gui.vBox(self.mainArea, True, margin=0) - self.graph = OWScatterPlotGraph(self, box) - box.layout().addWidget(self.graph.plot_widget) - - self.subset_data = None # Orange.data.Table - self.subset_indices = None self.sql_data = None # Orange.data.sql.table.SqlTable self.attribute_selection_list = None # list of Orange.data.Variable self.__timer = QTimer(self, interval=1200) self.__timer.timeout.connect(self.add_data) - #: Remember the saved state to restore - self.__pending_selection_restore = self.selection_group - self.selection_group = None + super().__init__() + # manually register Matplotlib file writers + self.graph_writers = self.graph_writers.copy() + for w in [MatplotlibFormat, MatplotlibPDFFormat]: + for ext in w.EXTENSIONS: + self.graph_writers[ext] = w + + def _add_controls(self): + self._add_controls_axis() + self._add_controls_sampling() + super()._add_controls() + self.graph.gui.add_widget(self.graph.gui.JitterNumericValues, + self._effects_box) + self.graph.gui.add_widgets([self.graph.gui.ShowGridLines, + self.graph.gui.ToolTipShowsAll, + self.graph.gui.RegressionLine], + self._plot_box) + + def _add_controls_axis(self): common_options = dict( labelWidth=50, orientation=Qt.Horizontal, sendSelectedValue=True, valueType=str, contentsLength=14 @@ -249,41 +252,16 @@ def __init__(self): self.cb_attr_y = gui.comboBox( box, self, "attr_y", label="Axis y:", callback=self.attr_changed, model=self.xy_model, **common_options) - vizrank_box = gui.hBox(box) - #gui.separator(vizrank_box, width=common_options["labelWidth"]) self.vizrank, self.vizrank_button = ScatterPlotVizRank.add_vizrank( vizrank_box, self, "Find Informative Projections", self.set_attr) - g = self.graph.gui - + def _add_controls_sampling(self): self.sampling = gui.auto_commit( self.controlArea, self, "auto_sample", "Sample", box="Sampling", callback=self.switch_sampling, commit=lambda: self.add_data(1)) self.sampling.setVisible(False) - g.point_properties_box(self.controlArea) - - box = g.effects_box(self.controlArea) - g.add_widget(g.JitterNumericValues, box) - - box_plot_prop = g.plot_properties_box(self.controlArea) - g.add_widgets([ - g.ShowGridLines, - g.ToolTipShowsAll, - g.RegressionLine], box_plot_prop) - - self.controlArea.layout().addStretch(100) - self.graph.box_zoom_select(self.controlArea) - gui.auto_commit(self.controlArea, self, "auto_send_selection", - "Send Selection", "Send Automatically") - - # manually register Matplotlib file writers - self.graph_writers = self.graph_writers.copy() - for w in [MatplotlibFormat, MatplotlibPDFFormat]: - for ext in w.EXTENSIONS: - self.graph_writers[ext] = w - def _vizrank_color_change(self): self.vizrank.initialize() is_enabled = self.data is not None and not self.data.is_sparse() and \ @@ -297,39 +275,10 @@ def _vizrank_color_change(self): if is_enabled and self.attr_color is None else "" self.vizrank_button.setToolTip(text) - @Inputs.data def set_data(self, data): - self.clear_messages() - self.Information.sampled_sql.clear() - self.__timer.stop() - self.sampling.setVisible(False) - self.sql_data = None - if isinstance(data, SqlTable): - if data.approx_len() < 4000: - data = Table(data) - else: - self.Information.sampled_sql() - self.sql_data = data - data_sample = data.sample_time(0.8, no_cache=True) - data_sample.download_data(2000, partial=True) - data = Table(data_sample) - self.sampling.setVisible(True) - if self.auto_sample: - self.__timer.start() - - if data is not None and (len(data) == 0 or len(data.domain) == 0): - data = None if self.data and data and self.data.checksum() == data.checksum(): return - - self.closeContext() - same_domain = (self.data and data and - data.domain.checksum() == self.data.domain.checksum()) - self.data = data - - if not same_domain: - self.init_attr_values() - self.openContext(self.data) + super().set_data(data) def findvar(name, iterable): """Find a Orange.data.Variable in `iterable` by name""" @@ -357,25 +306,45 @@ def findvar(name, iterable): self.attr_size = findvar( self.attr_size, self.graph.gui.size_model) - def get_coordinates_data(self): - self.Warning.missing_coords.clear() - self.Information.missing_coords.clear() + def check_data(self): + self.clear_messages() + self.__timer.stop() + self.sampling.setVisible(False) + self.sql_data = None + if isinstance(self.data, SqlTable): + if self.data.approx_len() < 4000: + self.data = Table(self.data) + else: + self.Information.sampled_sql() + self.sql_data = self.data + data_sample = self.data.sample_time(0.8, no_cache=True) + data_sample.download_data(2000, partial=True) + self.data = Table(data_sample) + self.sampling.setVisible(True) + if self.auto_sample: + self.__timer.start() + + if self.data is not None and (len(self.data) == 0 or + len(self.data.domain) == 0): + self.data = None + + def get_embedding(self): + self.valid_data = None + if self.data is None: + return None x_data = self.get_column(self.attr_x, filter_valid=False) y_data = self.get_column(self.attr_y, filter_valid=False) if x_data is None or y_data is None: - self.valid_data = None - return None, None + return None + + self.Warning.missing_coords.clear() + self.Information.missing_coords.clear() self.valid_data = np.isfinite(x_data) & np.isfinite(y_data) - if not np.all(self.valid_data): + if self.valid_data is not None and not np.all(self.valid_data): msg = self.Information if np.any(self.valid_data) else self.Warning msg.missing_coords(self.attr_x.name, self.attr_y.name) - return x_data[self.valid_data], y_data[self.valid_data] - - def get_subset_mask(self): - if self.subset_indices: - return np.array([ex.id in self.subset_indices - for ex in self.data[self.valid_data]]) + return np.vstack((x_data, y_data)).T # Tooltip def _point_tooltip(self, point_id, skip_attrs=()): @@ -385,7 +354,7 @@ def _point_tooltip(self, point_id, skip_attrs=()): escape('{} = {}'.format(var.name, point_data[var])) for var in xy_attrs) if self.tooltip_shows_all: - others = super()._point_tooltip(point_id, skip_attrs=xy_attrs) + others = super()._point_tooltip(point_id, skip_attrs=xy_attrs) if others: text = "{}

{}".format(text, others) return text @@ -398,7 +367,8 @@ def can_draw_regresssion_line(self): def add_data(self, time=0.4): if self.data and len(self.data) > 2000: - return self.__timer.stop() + self.__timer.stop() + return data_sample = self.sql_data.sample_time(time, no_cache=True) if data_sample: data_sample.download_data(2000, partial=True) @@ -421,7 +391,6 @@ def switch_sampling(self): self.add_data() self.__timer.start() - @Inputs.data_subset def set_subset_data(self, subset_data): self.warning() if isinstance(subset_data, SqlTable): @@ -430,37 +399,20 @@ def set_subset_data(self, subset_data): else: self.warning("Data subset does not support large Sql tables") subset_data = None - self.subset_data = subset_data - self.subset_indices = {e.id for e in subset_data} \ - if subset_data is not None else {} - self.controls.graph.alpha_value.setEnabled(subset_data is None) + super().set_subset_data(subset_data) # called when all signals are received, so the graph is updated only once def handleNewSignals(self): if self.attribute_selection_list and self.data is not None and \ self.data.domain is not None and \ all(attr in self.data.domain for attr - in self.attribute_selection_list): + in self.attribute_selection_list): self.attr_x = self.attribute_selection_list[0] self.attr_y = self.attribute_selection_list[1] self.attribute_selection_list = None - self.attr_changed() + super().handleNewSignals() self._vizrank_color_change() - self.cb_class_density.setEnabled(self.can_draw_density()) self.cb_reg_line.setEnabled(self.can_draw_regresssion_line()) - if self.data is not None and self.__pending_selection_restore is not None: - self.apply_selection(self.__pending_selection_restore) - self.__pending_selection_restore = None - self.unconditional_commit() - - def apply_selection(self, selection): - """Apply `selection` to the current plot.""" - if self.data is not None: - self.graph.selection = np.zeros(self.graph.n_points, dtype=np.uint8) - self.selection_group = [x for x in selection if x[0] < len(self.data)] - selection_array = np.array(self.selection_group).T - self.graph.selection[selection_array[0]] = selection_array[1] - self.graph.update_selection_colors() @Inputs.features def set_shown_attributes(self, attributes): @@ -474,101 +426,57 @@ def set_attr(self, attr_x, attr_y): self.attr_changed() def attr_changed(self): - self.graph.reset_graph() + self.cb_reg_line.setEnabled(self.can_draw_regresssion_line()) + self.setup_plot() + self.commit() + + def setup_plot(self): + super().setup_plot() for axis, var in (("bottom", self.attr_x), ("left", self.attr_y)): self.graph.set_axis_title(axis, var) if var and var.is_discrete: - self.graph.set_axis_labels(axis, get_variable_values_sorted(var)) + self.graph.set_axis_labels(axis, + get_variable_values_sorted(var)) else: self.graph.set_axis_labels(axis, None) - self.cb_class_density.setEnabled(self.can_draw_density()) - self.cb_reg_line.setEnabled(self.can_draw_regresssion_line()) - self.send_features() - - # The color combo's callback calls self.graph.update_colors AND this method - def update_colors(self): + def colors_changed(self): + super().colors_changed() self._vizrank_color_change() - self.cb_class_density.setEnabled(self.can_draw_density()) - - def selection_changed(self): - # Store current selection in a setting that is stored in workflow - if isinstance(self.data, SqlTable): - selection = None - elif self.data is not None: - selection = self.graph.get_selection() - else: - selection = None - if selection is not None and len(selection): - self.selection_group = list(zip(selection, self.graph.selection[selection])) - else: - self.selection_group = None - - self.commit() - - def send_data(self): - # TODO: Implement selection for sql data - def _get_selected(): - if not len(selection): - return None - return create_groups_table(data, group_selection, False, "Group") - - def _get_annotated(): - if graph.selection is not None and np.max(graph.selection) > 1: - return create_groups_table(data, group_selection) - else: - return create_annotated_table(data, selection) - graph = self.graph - data = self.data - selection = graph.get_selection() - if graph.selection is not None: - group_selection = np.zeros(len(self.data), dtype=int) - group_selection[self.valid_data] = graph.selection - self.Outputs.annotated_data.send(_get_annotated()) - self.Outputs.selected_data.send(_get_selected()) + def commit(self): + super().commit() + self.send_features() def send_features(self): features = [attr for attr in [self.attr_x, self.attr_y] if attr] self.Outputs.features.send(features or None) - def commit(self): - self.send_data() - self.send_features() - def get_widget_name_extension(self): if self.data is not None: return "{} vs {}".format(self.attr_x.name, self.attr_y.name) - - def send_report(self): - if self.data is None: - return - - def name(var): - return var and var.name - - caption = report.render_items_vert(( - ("Color", name(self.attr_color)), - ("Label", name(self.attr_label)), - ("Shape", name(self.attr_shape)), - ("Size", name(self.attr_size)), + return None + + def _get_send_report_caption(self): + return report.render_items_vert(( + ("Color", self._get_caption_var_name(self.attr_color)), + ("Label", self._get_caption_var_name(self.attr_label)), + ("Shape", self._get_caption_var_name(self.attr_shape)), + ("Size", self._get_caption_var_name(self.attr_size)), ("Jittering", (self.attr_x.is_discrete or self.attr_y.is_discrete or self.graph.jitter_continuous) and self.graph.jitter_size))) - self.report_plot() - if caption: - self.report_caption(caption) - - def onDeleteWidget(self): - super().onDeleteWidget() - self.graph.plot_widget.getViewBox().deleteLater() - self.graph.plot_widget.clear() @classmethod def migrate_settings(cls, settings, version): if version < 2 and "selection" in settings and settings["selection"]: settings["selection_group"] = [(a, 1) for a in settings["selection"]] + if version < 3: + if "auto_send_selection" in settings: + settings["auto_commit"] = settings["auto_send_selection"] + if "selection_group" in settings: + settings["selection"] = settings["selection_group"] @classmethod def migrate_context(cls, context, version): @@ -589,7 +497,7 @@ def main(argv=None): if len(argv) > 1: filename = argv[1] else: - filename = "heart_disease" + filename = "iris" ow = OWScatterPlot() ow.show() @@ -609,5 +517,6 @@ def main(argv=None): return rval + if __name__ == "__main__": main() diff --git a/Orange/widgets/visualize/owscatterplotgraph.py b/Orange/widgets/visualize/owscatterplotgraph.py index 40b4f0b438b..bfc14f073f8 100644 --- a/Orange/widgets/visualize/owscatterplotgraph.py +++ b/Orange/widgets/visualize/owscatterplotgraph.py @@ -1,14 +1,13 @@ -from collections import Counter, defaultdict import sys import itertools import warnings +import threading from xml.sax.saxutils import escape -from math import log2, log10, floor, ceil +from math import log10, floor, ceil import numpy as np -from scipy import sparse as sp -from AnyQt.QtCore import Qt, QRectF, QPointF, QSize +from AnyQt.QtCore import Qt, QRectF, QSize from AnyQt.QtGui import ( QStaticText, QColor, QPen, QBrush, QPainterPath, QTransform, QPainter ) @@ -18,31 +17,27 @@ import pyqtgraph as pg import pyqtgraph.graphicsItems.ScatterPlotItem -from pyqtgraph.graphicsItems.LegendItem import LegendItem, ItemSample +from pyqtgraph.graphicsItems.LegendItem import ( + LegendItem as PgLegendItem, ItemSample +) from pyqtgraph.graphicsItems.TextItem import TextItem -from Orange.statistics.util import bincount from Orange.util import OrangeDeprecationWarning from Orange.widgets import gui +from Orange.widgets.settings import Setting from Orange.widgets.utils import classdensity -from Orange.widgets.utils.colorpalette import ( - ColorPaletteGenerator, ContinuousPaletteGenerator, DefaultRGBColors -) +from Orange.widgets.utils.colorpalette import ColorPaletteGenerator from Orange.widgets.utils.plot import OWPalette, OWPlotGUI +from Orange.widgets.visualize.owscatterplotgraph_obsolete import ( + OWScatterPlotGraph as OWScatterPlotGraphObs +) from Orange.widgets.visualize.utils.plotutils import ( HelpEventDelegate as EventDelegate, InteractiveViewBox as ViewBox ) -from Orange.widgets.visualize.owscatterplotgraph_obsolete import ( - OWScatterPlotGraph as OWScatterPlotGraphObs -) -from Orange.widgets.settings import Setting, ContextSetting -from Orange.widgets.widget import OWWidget, Msg SELECTION_WIDTH = 5 -MAX = 11 # maximum number of colors or shapes (including Other) -MAX_POINTS_IN_TOOLTIP = 5 class PaletteItemSample(ItemSample): @@ -85,7 +80,7 @@ def paint(self, p, *args): p.drawStaticText(20, i * 15 + 1, label) -class LegendItem(LegendItem): +class LegendItem(PgLegendItem): def __init__(self, size=None, offset=None, pen=None, brush=None): super().__init__(size, offset) @@ -102,15 +97,6 @@ def __init__(self, size=None, offset=None, pen=None, brush=None): brush = QBrush(QColor(232, 232, 232, 100)) self.__brush = brush - def storeAnchor(self): - """ - Return the current relative anchor position (relative to the parent) - """ - anchor = legend_anchor_pos(self) - if anchor is None: - anchor = ((1, 0), (1, 0)) - return anchor - def restoreAnchor(self, anchors): """ Restore (parent) relative position from stored anchors. @@ -120,29 +106,6 @@ def restoreAnchor(self, anchors): anchor, parentanchor = anchors self.anchor(*bound_anchor_pos(anchor, parentanchor)) - def setPen(self, pen): - """Set the legend frame pen.""" - pen = QPen(pen) - if pen != self.__pen: - self.prepareGeometryChange() - self.__pen = pen - self.updateGeometry() - - def pen(self): - """Pen used to draw the legend frame.""" - return QPen(self.__pen) - - def setBrush(self, brush): - """Set background brush""" - brush = QBrush(brush) - if brush != self.__brush: - self.__brush = brush - self.update() - - def brush(self): - """Background brush.""" - return QBrush(self._brush) - def paint(self, painter, option, widget=None): painter.setPen(self.__pen) painter.setBrush(self.__brush) @@ -170,42 +133,6 @@ def clear(self): self.updateSize() -ANCHORS = { - Qt.TopLeftCorner: (0, 0), - Qt.TopRightCorner: (1, 0), - Qt.BottomLeftCorner: (0, 1), - Qt.BottomRightCorner: (1, 1) -} - - -def corner_anchor(corner): - """Return the relative corner coordinates for Qt.Corner - """ - return ANCHORS[corner] - - -def legend_anchor_pos(legend): - """ - Return the legend's anchor positions relative to it's parent (if defined). - - Return `None` if legend does not have a parent or the parent's size - is empty. - - .. seealso:: LegendItem.anchor, rect_anchor_pos - - """ - parent = legend.parentItem() - if parent is None or parent.size().isEmpty(): - return None - - rect = legend.geometry() # in parent coordinates. - parent_rect = QRectF(QPointF(0, 0), parent.size()) - - # Find the closest corner of rect to parent rect - c1, _, *parentPos = rect_anchor_pos(rect, parent_rect) - return corner_anchor(c1), tuple(parentPos) - - def bound_anchor_pos(corner, parentpos): corner = np.clip(corner, 0, 1) parentpos = np.clip(parentpos, 0, 1) @@ -224,59 +151,6 @@ def bound_anchor_pos(corner, parentpos): return (irx, iry), (prx, pry) -def rect_anchor_pos(rect, parent_rect): - """ - Find the 'best' anchor corners of rect within parent_rect. - - Return a tuple of (rect_corner, parent_corner, rx, ry), - where rect/parent_corners are Qt.Corners which are closest and - rx, ry are the relative positions of the rect_corner within - parent_rect. If the parent_rect is empty return `None`. - - """ - if parent_rect.isEmpty(): - return None - - # Find the closest corner of rect to parent rect - corners = (Qt.TopLeftCorner, Qt.TopRightCorner, - Qt.BottomRightCorner, Qt.BottomLeftCorner) - - def rect_corner(rect, corner): - if corner == Qt.TopLeftCorner: - return rect.topLeft() - elif corner == Qt.TopRightCorner: - return rect.topRight() - elif corner == Qt.BottomLeftCorner: - return rect.bottomLeft() - elif corner == Qt.BottomRightCorner: - return rect.bottomRight() - else: - assert False - - def corner_dist(c1, c2): - d = (rect_corner(rect, c1) - rect_corner(parent_rect, c2)) - return d.x() ** 2 + d.y() ** 2 - - if parent_rect.contains(rect): - closest = min(corners, - key=lambda corner: corner_dist(corner, corner)) - p = rect_corner(rect, closest) - - return (closest, closest, - (p.x() - parent_rect.left()) / parent_rect.width(), - (p.y() - parent_rect.top()) / parent_rect.height()) - else: - - c1, c2 = min(itertools.product(corners, corners), - key=lambda pair: corner_dist(*pair)) - - p = rect_corner(rect, c1) - - return (c1, c2, - (p.x() - parent_rect.left()) / parent_rect.width(), - (p.y() - parent_rect.top()) / parent_rect.height()) - - class DiscretizedScale: """ Compute suitable bins for continuous value from its minimal and @@ -327,20 +201,6 @@ def __init__(self, min_v, max_v): self.decimals = max(decimals, 0) self.width = resolution - def compute_bins(self, a): - """ - Compute bin number(s) for the given value(s). - - :param a: value(s) - :type a: a number or numpy.ndarray - """ - a = (a - self.offset) / self.width - if isinstance(a, np.ndarray): - a.clip(0, self.bins - 1) - else: - a = min(self.bins - 1, max(0, a)) - return a - class InteractiveViewBox(ViewBox): def __init__(self, graph, enable_menu=False): @@ -359,10 +219,18 @@ def __init__(self, scatter_widget, parent=None, _="None", view_box=InteractiveVi class ScatterPlotItem(pg.ScatterPlotItem): + def __init__(self, *args, **kwargs): + self.lock = threading.Lock() + super().__init__(*args, **kwargs) + def paint(self, painter, option, widget=None): - painter.setRenderHint(QPainter.SmoothPixmapTransform, True) - super().paint(painter, option, widget) + with self.lock: + painter.setRenderHint(QPainter.SmoothPixmapTransform, True) + super().paint(painter, option, widget) + def setData(self, *args, **kwargs): + with self.lock: + super().setData(*args ,**kwargs) def _define_symbols(): """ @@ -414,11 +282,9 @@ class OWScatterPlotBase(gui.OWComponent): - `get_subset_mask` returns a bool array indicating whether a data point is in the subset or not (e.g. in the 'Data Subset' signal in the Scatter plot and similar widgets); - - `set_palette` sets the plot's palette appropriate for visualizing the + - `get_palette` returns a palette appropriate for visualizing the current color data; - `is_continuous_color` decides the type of the color legend; - - `combined_legend` tells whether the color and shape legend should be - combined into one (usually because they represent the same data). The widget (in a role of controller) must also provide methods - `selection_changed` @@ -433,10 +299,10 @@ class OWScatterPlotBase(gui.OWComponent): that the widget (in a role of a controler) should call when any of these properties are changed. If the widget calls, for instance, the plot's `update_colors`, the plot will react by calling the widget's - `get_color_data` as well as the widget's methods needed to contruct the + `get_color_data` as well as the widget's methods needed to construct the legend. - The view also provides a method `reset`, which should be called only + The view also provides a method `reset_graph`, which should be called only when - the widget gets entirely new data - the number of points may have changed, for instance when selecting @@ -453,7 +319,7 @@ def update_shapes(self): if self.scatterplot_item: shape_data = self.get_shapes() self.scatterplot_item.setSymbol(shape_data) - self.make_legend() + self.update_legends() def get_shapes(self): shape_data = self.master.get_shape_data() @@ -474,6 +340,12 @@ def get_size_data(self): two cases, "shapes" for the view and "sizes" for the model. The colors for the view are more complicated since they deal with discrete and continuous palettes, and the shapes for the view merge infrequent shapes.) + + The plot can also show just a random sample of the data. The sample size is + set by `set_sample_size`, and the rest is taken care by the plot: the + widget keeps providing the data for all points, selection indices refer + to the entire set etc. Internally, sampling happens as early as possible + (in methods `get_`). """ label_only_selected = Setting(False) point_width = Setting(10) @@ -490,6 +362,10 @@ def get_size_data(self): DarkerValue = 120 UnknownColor = (168, 50, 168) + COLOR_NOT_SUBSET = (128, 128, 128, 0) + COLOR_SUBSET = (128, 128, 128, 255) + COLOR_DEFAULT = (128, 128, 128, 0) + def __init__(self, scatter_widget, parent=None, view_box=ViewBox): super().__init__(scatter_widget) @@ -503,27 +379,28 @@ def __init__(self, scatter_widget, parent=None, view_box=ViewBox): self.plot_widget.getPlotItem().buttonsHidden = True self.plot_widget.setAntialiasing(True) self.plot_widget.sizeHint = lambda: QSize(500, 500) - scene = self.plot_widget.scene() - self._create_drag_tooltip(scene) - self.replot = self.plot_widget.replot self.density_img = None self.scatterplot_item = None self.scatterplot_item_sel = None - self.labels = [] self.master = scatter_widget + self._create_drag_tooltip(self.plot_widget.scene()) self.selection = None # np.ndarray - self.n_points = 0 + + self.n_valid = 0 + self.n_shown = 0 + self.sample_size = None + self.sample_indices = None self.gui = OWPlotGUI(self) self.palette = None - self.legend = self.color_legend = None - self.__legend_anchor = (1, 0), (1, 0) - self.__color_legend_anchor = (1, 1), (1, 1) + self.shape_legend = self._create_legend(((1, 0), (1, 0))) + self.color_legend = self._create_legend(((1, 1), (1, 1))) + self.update_legend_visibility() self.scale = None # DiscretizedScale @@ -531,10 +408,17 @@ def __init__(self, scatter_widget, parent=None, view_box=ViewBox): # self.grabGesture(QPinchGesture) # self.grabGesture(QPanGesture) - self.update_grid() + self.update_grid_visibility() self._tooltip_delegate = EventDelegate(self.help_event) self.plot_widget.scene().installEventFilter(self._tooltip_delegate) + self.view_box.sigTransformChanged.connect(self.update_density) + + def _create_legend(self, anchor): + legend = LegendItem() + legend.setParentItem(self.plot_widget.getViewBox()) + legend.restoreAnchor(anchor) + return legend def _create_drag_tooltip(self, scene): tip_parts = [ @@ -559,23 +443,43 @@ def _create_drag_tooltip(self, scene): rect = QGraphicsRectItem(0, 0, r.width() + 8, r.height() + 4) rect.setBrush(QColor(224, 224, 224, 212)) rect.setPen(QPen(Qt.NoPen)) - self.update_tooltip(Qt.NoModifier) + self.update_tooltip() scene.drag_tooltip = scene.createItemGroup([rect, text]) scene.drag_tooltip.hide() - def update_tooltip(self, modifiers): + def update_tooltip(self, modifiers=Qt.NoModifier): modifiers &= Qt.ShiftModifier + Qt.ControlModifier + Qt.AltModifier text = self.tiptexts.get(int(modifiers), self.tiptexts[0]) - self.tip_textitem.setHtml(text) + self.tip_textitem.setHtml(text + self._get_jittering_tooltip()) + + def _get_jittering_tooltip(self): + warn_jittered = "" + if self.jitter_size: + warn_jittered = \ + '

' \ + '' \ + ' Warning: Selection is applied to unjittered data ' \ + '' + return warn_jittered + + def update_jittering(self): + self.update_tooltip() + x, y = self.get_coordinates() + if x is None or not len(x) or self.scatterplot_item is None: + return + self._update_plot_coordinates(self.scatterplot_item, x, y) + self._update_plot_coordinates(self.scatterplot_item_sel, x, y) + self._update_label_coords(x, y) + # TODO: Rename to remove_plot_items def clear(self): """ Remove all graphical elements from the plot - Calls the pyqtgraph's plot widget's clear, removes the legend(s) and - resets the(ir) anchors, sets all handles to `None`, removes labels and - selections. + Calls the pyqtgraph's plot widget's clear, sets all handles to `None`, + removes labels and selections. This method should generally not be called by the widget. If the data is gone (*e.g.* upon receiving `None` as an input data signal), this @@ -587,16 +491,22 @@ def clear(self): `self.reg_line_item = None` (the line in the plot is already removed in this method). """ - self.remove_legend() self.plot_widget.clear() self.density_img = None self.scatterplot_item = None self.scatterplot_item_sel = None self.labels = [] - self.selection = None + self.view_box.init_history() + self.view_box.tag_history() - def reset_graph(self): + # TODO: I hate `keep_something` and `reset_something` arguments + # __keep_selection is used exclusively be set_sample size which would + # otherwise just repeat the code from reset_graph except for resetting + # the selection. I'm uncomfortable with this; we may prefer to have a + # method _reset_graph which does everything except resetting the selection, + # and reset_graph would call it. + def reset_graph(self, __keep_selection=False): """ Reset the graph to new data (or no data) @@ -608,11 +518,27 @@ def reset_graph(self): The method must also be called when the data is gone. The method calls `clear`, followed by calls of all update methods. + + NB. Argument `__keep_selection` is for internal use only """ self.clear() + if not __keep_selection: + self.selection = None + self.sample_indices = None self.update_coordinates() self.update_point_props() + def set_sample_size(self, sample_size): + """ + Set the sample size + + Args: + sample_size (int or None): sample size or `None` to show all points + """ + if self.sample_size != sample_size: + self.sample_size = sample_size + self.reset_graph(True) + def update_point_props(self): """ Update the sizes, colors, shapes and labels @@ -627,6 +553,12 @@ def update_point_props(self): self.update_labels() # Coordinates + # TODO: It could be nice if this method was run on entire data, not just + # a sample. For this, however, it would need to either be called from + # `get_coordinates` before sampling (very ugly) or call + # `self.master.get_coordinates_data` (beyond ugly) or the widget would + # have to store the ranges of unsampled data (ugly). + # Maybe we leave it as it is. def _reset_view(self, x_data, y_data): """ Set the range of the view box @@ -635,13 +567,18 @@ def _reset_view(self, x_data, y_data): x_data (np.ndarray): x coordinates y_data (np.ndarray) y coordinates """ - min_x, max_x = np.nanmin(x_data), np.nanmax(x_data) - min_y, max_y = np.nanmin(y_data), np.nanmax(y_data) + min_x, max_x = np.min(x_data), np.max(x_data) + min_y, max_y = np.min(y_data), np.max(y_data) self.view_box.setRange( - QRectF(min_x, min_y, max_x - min_x, max_y - min_y), + QRectF(min_x, min_y, max_x - min_x or 1, max_y - min_y or 1), padding=0.025) - self.view_box.init_history() - self.view_box.tag_history() + + def _filter_visible(self, data): + """Return the sample from the data using the stored sample_indices""" + if data is None or self.sample_indices is None: + return data + else: + return np.asarray(data[self.sample_indices]) def get_coordinates(self): """ @@ -650,21 +587,49 @@ def get_coordinates(self): The method is called by `update_coordinates`. It gets the coordinates from the widget, jitters them and return them. - The method also stores the number of points. + The methods also initializes the sample indices if neededd and stores + the original and sampled number of points. Returns: - (tuple): a pair of numpy arrays containing coordinates, + (tuple): a pair of numpy arrays containing (sampled) coordinates, or `(None, None)`. """ x, y = self.master.get_coordinates_data() if x is None: - self.n_points = 0 + self.n_valid = self.n_shown = 0 return None, None + self.n_valid = len(x) + self._create_sample() + x = self._filter_visible(x) + y = self._filter_visible(y) + # Jittering after sampling is OK if widgets do not change the sample + # semi-permanently, e.g. take a sample for the duration of some + # animation. If the sample size changes dynamically (like by adding + # a "sample size" slider), points would move around when the sample + # size changes. To prevent this, jittering should be done before + # sampling (i.e. two lines earlier). This would slow it down somewhat. x, y = self.jitter_coordinates(x, y) - self.n_points = len(x) if x is not None else 0 return x, y + def _create_sample(self): + """ + Create a random sample if the data is larger than the set sample size + """ + self.n_shown = min(self.n_valid, self.sample_size or self.n_valid) + if self.sample_size is not None \ + and self.sample_indices is None \ + and self.n_valid != self.n_shown: + random = np.random.RandomState(seed=0) + self.sample_indices = random.choice( + self.n_valid, self.n_shown, replace=False) + # TODO: Is this really needed? + np.sort(self.sample_indices) + def jitter_coordinates(self, x, y): + """ + Display coordinates to random positions within ellipses with + radiuses of `self.jittter_size` percents of spans + """ if self.jitter_size == 0: return x, y return self._jitter_data(x, y) @@ -681,8 +646,7 @@ def _jitter_data(self, x, y, span_x=None, span_y=None): return (x + magnitude * span_x * rs * np.cos(phis), y + magnitude * span_y * rs * np.sin(phis)) - @classmethod - def _update_plot_coordinates(cls, plot, x, y): + def _update_plot_coordinates(self, plot, x, y): """ Change the coordinates of points while keeping other properites @@ -695,7 +659,8 @@ def _update_plot_coordinates(cls, plot, x, y): can be time consuming. """ data = dict(x=x, y=y) - for prop in ('pen', 'brush', 'size', 'symbol', 'data'): + for prop in ('pen', 'brush', 'size', 'symbol', 'data', + 'sourceRect', 'targetRect'): data[prop] = plot.data[prop] plot.setData(**data) @@ -713,18 +678,21 @@ def update_coordinates(self): if x is None or not len(x): return if self.scatterplot_item is None: - kwargs = {"x": x, "y": y, "data": np.arange(self.n_points)} + if self.sample_indices is None: + indices = np.arange(self.n_valid) + else: + indices = self.sample_indices + kwargs = dict(x=x, y=y, data=indices) self.scatterplot_item = ScatterPlotItem(**kwargs) self.scatterplot_item.sigClicked.connect(self.select_by_click) self.scatterplot_item_sel = ScatterPlotItem(**kwargs) self.plot_widget.addItem(self.scatterplot_item_sel) self.plot_widget.addItem(self.scatterplot_item) - self.update_point_props() else: self._update_plot_coordinates(self.scatterplot_item, x, y) self._update_plot_coordinates(self.scatterplot_item_sel, x, y) - self.update_label_coords(x, y) + self._update_label_coords(x, y) self.update_density() # Todo: doesn't work: try MDS with density on self._reset_view(x, y) @@ -741,13 +709,17 @@ def get_sizes(self): """ size_column = self.master.get_size_data() if size_column is None: - return np.ones(self.n_points) * self.point_width + return np.full((self.n_shown,), + self.MinShapeSize + (5 + self.point_width) * 0.5) + size_column = self._filter_visible(size_column) size_column = size_column.copy() - size_column -= np.min(size_column) - mx = np.max(size_column) + size_column -= np.nanmin(size_column) + mx = np.nanmax(size_column) if mx > 0: size_column /= mx - return self.MinShapeSize + self.point_width * size_column + else: + size_column[:] = 0.5 + return self.MinShapeSize + (5 + self.point_width) * size_column def update_sizes(self): """ @@ -759,13 +731,32 @@ def update_sizes(self): """ if self.scatterplot_item: size_data = self.get_sizes() - self.master.impute_sizes(size_data) + size_imputer = getattr( + self.master, "impute_sizes", self.default_impute_sizes) + size_imputer(size_data) self.scatterplot_item.setSize(size_data) self.scatterplot_item_sel.setSize(size_data + SELECTION_WIDTH) update_point_size = update_sizes # backward compatibility (needed?!) update_size = update_sizes + @classmethod + def default_impute_sizes(cls, size_data): + """ + Fallback imputation for sizes. + + Set the size to two pixels smaller than the minimal size + + Returns: + (bool): True if there was any missing data + """ + nans = np.isnan(size_data) + if np.any(nans): + size_data[nans] = cls.MinShapeSize - 2 + return True + else: + return False + # Colors def get_colors(self): """ @@ -788,9 +779,11 @@ def get_colors(self): Returns: (tuple): a list of pens and list of brushes """ - self.master.set_palette() + self.palette = self.master.get_palette() c_data = self.master.get_color_data() + c_data = self._filter_visible(c_data) subset = self.master.get_subset_mask() + subset = self._filter_visible(subset) self.subset_is_shown = subset is not None if c_data is None: # same color return self._get_same_colors(subset) @@ -813,14 +806,16 @@ def _get_same_colors(self, subset): (tuple): a list of pens and list of brushes """ color = self.plot_widget.palette().color(OWPalette.Data) - pen = [_make_pen(color, 1.5) for _ in range(self.n_points)] + pen = [_make_pen(color, 1.5) for _ in range(self.n_shown)] if subset is not None: - brush = [(QBrush(QColor(128, 128, 128, 0)), - QBrush(QColor(128, 128, 128, 255)))[s] - for s in subset] + brush = np.where( + subset, + *(QBrush(QColor(*col)) + for col in (self.COLOR_SUBSET, self.COLOR_NOT_SUBSET))) else: - color = QColor(128, 128, 128, self.alpha_value) - brush = [QBrush(color) for _ in range(self.n_points)] + color = QColor(*self.COLOR_DEFAULT) + color.setAlpha(self.alpha_value) + brush = [QBrush(color) for _ in range(self.n_shown)] return pen, brush def _get_continuous_colors(self, c_data, subset): @@ -886,35 +881,18 @@ def update_colors(self): The method calls `self.get_colors`, which in turn calls the widget's `get_color_data` to get the indices in the pallette. `get_colors` returns a list of pens and brushes to which this method uses to - update the colors. Finally, the method triggets the update of the + update the colors. Finally, the method triggers the update of the legend and the density plot. """ - if self.scatterplot_item is None: - return - pen_data, brush_data = self.get_colors() - self.scatterplot_item.setPen(pen_data, update=False, mask=None) - self.scatterplot_item.setBrush(brush_data, mask=None) - self.make_legend() + if self.scatterplot_item is not None: + pen_data, brush_data = self.get_colors() + self.scatterplot_item.setPen(pen_data, update=False, mask=None) + self.scatterplot_item.setBrush(brush_data, mask=None) + self.update_legends() self.update_density() update_alpha_value = update_colors - def update_selection_colors(self): - """ - Trigger an update of selection markers - - This update method is usually not called by the widget but by the - plot, since it is the plot that handles the selections. - - Like other update methods, it calls the corresponding get method - (`get_colors_sel`) which returns a list of pens and brushes. - """ - if self.scatterplot_item_sel is None: - return - pen, brush = self.get_colors_sel() - self.scatterplot_item_sel.setPen(pen, update=False, mask=None) - self.scatterplot_item_sel.setBrush(brush, mask=None) - def update_density(self): """ Remove the existing density plot (if there is one) and replace it @@ -925,18 +903,35 @@ def update_density(self): """ if self.density_img: self.plot_widget.removeItem(self.density_img) - if self.scatterplot_item is not None \ - and self.master.can_draw_density() and self.class_density: + self.density_img = None + if self.class_density and self.scatterplot_item is not None: + rgb_data = [ + pen.color().getRgb()[:3] if pen is not None else (255, 255, 255) + for pen in self.scatterplot_item.data['pen']] + if len(set(rgb_data)) <= 1: + return [min_x, max_x], [min_y, max_y] = self.view_box.viewRange() x_data, y_data = self.scatterplot_item.getData() - rgb_data = [pen.color().getRgb()[:3] - for pen in self.scatterplot_item.data['pen']] self.density_img = classdensity.class_density_image( min_x, max_x, min_y, max_y, self.resolution, x_data, y_data, rgb_data) self.plot_widget.addItem(self.density_img) - else: - self.density_img = None + + def update_selection_colors(self): + """ + Trigger an update of selection markers + + This update method is usually not called by the widget but by the + plot, since it is the plot that handles the selections. + + Like other update methods, it calls the corresponding get method + (`get_colors_sel`) which returns a list of pens and brushes. + """ + if self.scatterplot_item_sel is None: + return + pen, brush = self.get_colors_sel() + self.scatterplot_item_sel.setPen(pen, update=False, mask=None) + self.scatterplot_item_sel.setBrush(brush, mask=None) def get_colors_sel(self): """ @@ -951,21 +946,21 @@ def get_colors_sel(self): """ nopen = QPen(Qt.NoPen) if self.selection is None: - pen = [nopen] * self.n_points + pen = [nopen] * self.n_shown else: sels = np.max(self.selection) if sels == 1: pen = np.where( - self.selection, + self._filter_visible(self.selection), _make_pen(QColor(255, 190, 0, 255), SELECTION_WIDTH + 1), nopen) else: palette = ColorPaletteGenerator(number_of_colors=sels + 1) pen = np.choose( - self.selection, + self._filter_visible(self.selection), [nopen] + [_make_pen(palette[i], SELECTION_WIDTH + 1) for i in range(sels)]) - return pen, [QBrush(QColor(255, 255, 255, 0))] * self.n_points + return pen, [QBrush(QColor(255, 255, 255, 0))] * self.n_shown # Labels def get_labels(self): @@ -977,7 +972,7 @@ def get_labels(self): Returns: (labels): a sequence of labels """ - return self.master.get_label_data() + return self._filter_visible(self.master.get_label_data()) def update_labels(self): """ @@ -987,56 +982,36 @@ def update_labels(self): `get_label_data`. The obtained labels are shown if the corresponding points are selected or if `label_only_selected` is `false`. """ - if self.label_only_selected and self.selection is None: - label_data = None - else: - label_data = self.get_labels() - if label_data is None: - for label in self.labels: - label.setText("") + for label in self.labels: + self.plot_widget.removeItem(label) + self.labels = [] + if self.scatterplot_item is None \ + or self.label_only_selected and self.selection is None: + return + labels = self.get_labels() + if labels is None: return - if not self.labels: - self._create_labels() black = pg.mkColor(0, 0, 0) + x, y = self.scatterplot_item.getData() if self.label_only_selected: - for label, text, selected \ - in zip(self.labels, label_data, self.selection): - label.setText(text if selected else "", black) - else: - for label, text in zip(self.labels, label_data): - label.setText(text, black) - - def _create_labels(self): - """ - Create a `TextItem` for each point and store them in `self.labels` - """ - if not self.scatterplot_item: - return - for x, y in zip(*self.scatterplot_item.getData()): - ti = TextItem() + selected = np.nonzero(self._filter_visible(self.selection)) + labels = labels[selected] + x = x[selected] + y = y[selected] + for label, xp, yp in zip(labels, x, y): + ti = TextItem(label, black) + ti.setPos(xp, yp) self.plot_widget.addItem(ti) - ti.setPos(x, y) self.labels.append(ti) - def update_label_coords(self, x, y): - """ - Update the coordinates of labels - - The method is currently called exclusively be `update_coordinates` - - Args: - x (np.ndarray): x coordinates - y (np.ndarray): y coordinates - """ + def _update_label_coords(self, x, y): + """Update label coordinates""" if self.label_only_selected: - if self.selection is not None: - for label, selected, xp, yp in zip( - self.labels, self.selection, x, y): - if selected: - label.setPos(xp, yp) - else: - for label, xp, yp in zip(self.labels, x, y): - label.setPos(xp, yp) + selected = np.nonzero(self._filter_visible(self.selection)) + x = x[selected] + y = y[selected] + for label, xp, yp in zip(self.labels, x, y): + label.setPos(xp, yp) # Shapes def get_shapes(self): @@ -1051,17 +1026,39 @@ def get_shapes(self): (np.ndarray): an array of symbols (e.g. o, x, + ...) """ shape_data = self.master.get_shape_data() + shape_data = self._filter_visible(shape_data) # Data has to be copied so the imputation can change it in-place # TODO: Try avoiding this when we move imputation to the widget if shape_data is not None: shape_data = np.copy(shape_data) - self.master.impute_shapes(shape_data, len(self.CurveSymbols) - 1) + shape_imputer = getattr( + self.master, "impute_shapes", self.default_impute_shapes) + shape_imputer(shape_data, len(self.CurveSymbols) - 1) if isinstance(shape_data, np.ndarray): shape_data = shape_data.astype(int) else: - shape_data = np.zeros(self.n_points, dtype=int) + shape_data = np.zeros(self.n_shown, dtype=int) return self.CurveSymbols[shape_data] + @staticmethod + def default_impute_shapes(shape_data, default_symbol): + """ + Fallback imputation for shapes. + + Use the default symbol, usually the last symbol in the list. + + Returns: + (bool): True if there was any missing data + """ + if shape_data is None: + return False + nans = np.isnan(shape_data) + if np.any(nans): + shape_data[nans] = default_symbol + return True + else: + return False + def update_shapes(self): """ Trigger an update of point symbols @@ -1074,103 +1071,83 @@ def update_shapes(self): if self.scatterplot_item: shape_data = self.get_shapes() self.scatterplot_item.setSymbol(shape_data) - self.make_legend() + self.update_legends() - def update_grid(self): + def update_grid_visibility(self): """Show or hide the grid""" self.plot_widget.showGrid(x=self.show_grid, y=self.show_grid) - def update_legend(self): - """Show or hide the legend""" - if self.legend: - self.legend.setVisible(self.show_legend) - - def create_legend(self): - """Create a legend""" - self.legend = LegendItem() - self.legend.setParentItem(self.plot_widget.getViewBox()) - self.legend.restoreAnchor(self.__legend_anchor) - - def remove_legend(self): - """Remove the legend and reset its position""" - if self.legend: - anchor = legend_anchor_pos(self.legend) - if anchor is not None: - self.__legend_anchor = anchor - self.legend.setParent(None) - self.legend = None - if self.color_legend: - anchor = legend_anchor_pos(self.color_legend) - if anchor is not None: - self.__color_legend_anchor = anchor - self.color_legend.setParent(None) - self.color_legend = None - - def make_legend(self): - """Create the color and shape legends""" - self.remove_legend() - if not self.legend: - self.create_legend() - self._make_color_legend() - self._make_shape_legend() - self.update_legend() - - def _make_color_legend(self): - """ - Adds items representing the colors to the legend - - - If the legend is continuous (which is checked by calling the - widget's `is_continuous_color`), the legend is a colored strip. - - Otherwise, if the same attribute is used for shape and color - (which is checked by the widget's method `combined_legend`), - this method returns a legend with different shapes in shown - in the corresponding color. - - Otherwise, a normal legend for colors is created. - """ - if self.master.is_continuous_color(): - if not self.scale: - return - legend = self.color_legend = LegendItem() - legend.setParentItem(self.plot_widget.getViewBox()) - legend.restoreAnchor(self.__color_legend_anchor) + def update_legend_visibility(self): + """ + Show or hide legends based on whether they are enabled and non-empty + """ + self.shape_legend.setVisible( + self.show_legend and bool(self.shape_legend.items)) + self.color_legend.setVisible( + self.show_legend and bool(self.color_legend.items)) - label = PaletteItemSample(self.palette, self.scale) - legend.addItem(label, "") - legend.setGeometry(label.boundingRect()) + def update_legends(self): + """Update content of legends and their visibility""" + cont_color = self.master.is_continuous_color() + shape_labels = self.master.get_shape_labels() + color_labels = None if cont_color else self.master.get_color_labels() + if shape_labels == color_labels and shape_labels is not None: + self._update_combined_legend(shape_labels) else: - labels = self.master.get_color_labels() - if not labels or not self.palette: - return - use_shape = self.master.combined_legend() - for i, value in enumerate(labels): - color = QColor(*self.palette.getRGB(i)) - pen = _make_pen(color.darker(self.DarkerValue), 1.5) - color.setAlpha(255 if self.subset_is_shown else self.alpha_value) - brush = QBrush(color) - self.legend.addItem( - ScatterPlotItem( - pen=pen, brush=brush, size=10, - symbol=self.CurveSymbols[i] if use_shape else "o"), - escape(value)) - - def _make_shape_legend(self): - """ - Adds items representing the shapes to the legend - - If the color and shape legends are combined (checked by the widget's - method `combined_legends`), this method does nothing. - """ - if self.master.combined_legend(): - return - labels = self.master.get_shape_labels() - if labels is None: + self._update_shape_legend(shape_labels) + if cont_color: + self._update_continuous_color_legend() + else: + self._update_color_legend(color_labels) + self.update_legend_visibility() + + def _update_shape_legend(self, labels): + self.shape_legend.clear() + if labels is None or self.scatterplot_item is None: return color = QColor(0, 0, 0) color.setAlpha(self.alpha_value) - for i, value in enumerate(labels): - self.legend.addItem( - ScatterPlotItem(pen=color, brush=color, size=10, - symbol=self.CurveSymbols[i]), escape(value)) + for label, symbol in zip(labels, self.CurveSymbols): + self.shape_legend.addItem( + ScatterPlotItem(pen=color, brush=color, size=10, symbol=symbol), + escape(label)) + + def _update_continuous_color_legend(self): + self.color_legend.clear() + if self.scale is None or self.scatterplot_item is None: + return + label = PaletteItemSample(self.palette, self.scale) + self.color_legend.addItem(label, "") + self.color_legend.setGeometry(label.boundingRect()) + + def _update_color_legend(self, labels): + self.color_legend.clear() + if labels is None: + return + self._update_colored_legend(self.color_legend, labels, 'o') + + def _update_combined_legend(self, labels): + # update_colored_legend will already clear the shape legend + # so we remove colors here + use_legend = \ + self.shape_legend if self.shape_legend.items else self.color_legend + self.color_legend.clear() + self.shape_legend.clear() + self._update_colored_legend(use_legend, labels, self.CurveSymbols) + + def _update_colored_legend(self, legend, labels, symbols): + if self.scatterplot_item is None or not self.palette: + return + if isinstance(symbols, str): + symbols = itertools.repeat(symbols, times=len(labels)) + for i, (label, symbol) in enumerate(zip(labels, symbols)): + color = QColor(*self.palette.getRGB(i)) + pen = _make_pen(color.darker(self.DarkerValue), 1.5) + color.setAlpha(255 if self.subset_is_shown else self.alpha_value) + brush = QBrush(color) + legend.addItem( + ScatterPlotItem(pen=pen, brush=brush, size=10, symbol=symbol), + escape(label)) def zoom_button_clicked(self): self.plot_widget.getViewBox().setMouseMode( @@ -1193,51 +1170,74 @@ def select_by_click(self, _, points): def select_by_rectangle(self, value_rect): if self.scatterplot_item is not None: - points = [point - for point in self.scatterplot_item.points() - if value_rect.contains(QPointF(point.pos()))] - self.select(points) - - def select_by_index(self, indices): - if self.scatterplot_item is not None: - points = [point for point in self.scatterplot_item.points() - if point.data() in indices] - self.select(points) + x0, y0 = value_rect.topLeft().x(), value_rect.topLeft().y() + x1, y1 = value_rect.bottomRight().x(), value_rect.bottomRight().y() + x, y = self.master.get_coordinates_data() + indices = np.flatnonzero( + (x0 <= x) & (x <= x1) & (y0 <= y) & (y <= y1)) + self.select_by_indices(indices.astype(int)) def unselect_all(self): - self.selection = None - self.select([]) - self.update_selection_colors() - if self.label_only_selected: - self.update_labels() - self.master.selection_changed() + if self.selection is not None: + self.selection = None + self.update_selection_colors() + if self.label_only_selected: + self.update_labels() + self.master.selection_changed() def select(self, points): # noinspection PyArgumentList - if self.n_points == 0: + if self.scatterplot_item is None: return - if self.selection is None: - self.selection = np.zeros(self.n_points, dtype=np.uint8) indices = [p.data() for p in points] + self.select_by_indices(indices) + + def select_by_indices(self, indices): + if self.selection is None: + self.selection = np.zeros(self.n_valid, dtype=np.uint8) keys = QApplication.keyboardModifiers() - # Remove from selection if keys & Qt.AltModifier: - self.selection[indices] = 0 - # Append to the last group + self.selection_remove(indices) elif keys & Qt.ShiftModifier and keys & Qt.ControlModifier: - self.selection[indices] = np.max(self.selection) - # Create a new group + self.selection_append(indices) elif keys & Qt.ShiftModifier: - self.selection[indices] = np.max(self.selection) + 1 - # No modifiers: new selection + self.selection_new_group(indices) else: - self.selection = np.zeros(self.n_points, dtype=np.uint8) - self.selection[indices] = 1 + self.selection_select(indices) + + def selection_select(self, indices): + self.selection = np.zeros(self.n_valid, dtype=np.uint8) + self.selection[indices] = 1 + self._update_after_selection() + + def selection_append(self, indices): + self.selection[indices] = np.max(self.selection) + self._update_after_selection() + + def selection_new_group(self, indices): + self.selection[indices] = np.max(self.selection) + 1 + self._update_after_selection() + + def selection_remove(self, indices): + self.selection[indices] = 0 + self._update_after_selection() + + def _update_after_selection(self): + self._compress_indices() self.update_selection_colors() if self.label_only_selected: self.update_labels() self.master.selection_changed() + def _compress_indices(self): + indices = sorted(set(self.selection) | {0}) + if len(indices) == max(indices) + 1: + return + mapping = np.zeros((max(indices) + 1,), dtype=int) + for i, ind in enumerate(indices): + mapping[ind] = i + self.selection = mapping[self.selection] + def get_selection(self): if self.selection is None: return np.array([], dtype=np.uint8) @@ -1281,345 +1281,3 @@ def __init__(self, delegate, parent=None): warnings.warn("HelpEventDelegate class has been deprecated since 3.17." " Use Orange.widgets.visualize.utils.plotutils." "HelpEventDelegate instead.", OrangeDeprecationWarning) - - -class OWProjectionWidget(OWWidget): - """ - Base widget for widgets that use attribute data to set the colors, labels, - shapes and sizes of points. - - The widgets defines settings `attr_color`, `attr_label`, `attr_shape` - and `attr_size`, but leaves defining the gui to the derived widgets. - These are expected to have controls that manipulate these settings, - and the controls are expected to use attribute models. - - The widgets also defines attributes `data` and `valid_data` and expects - the derived widgets to use them to store an instances of `data.Table` - and a bool `np.ndarray` with indicators of valid (that is, shown) - data points. - """ - attr_color = ContextSetting(None, required=ContextSetting.OPTIONAL) - attr_label = ContextSetting(None, required=ContextSetting.OPTIONAL) - attr_shape = ContextSetting(None, required=ContextSetting.OPTIONAL) - attr_size = ContextSetting(None, required=ContextSetting.OPTIONAL) - - class Information(OWWidget.Information): - missing_size = Msg( - "Points with undefined '{}' are shown in smaller size") - missing_shape = Msg( - "Points with undefined '{}' are shown as crossed circles") - - def __init__(self): - super().__init__() - self.data = None - self.valid_data = None - - self.set_palette() - - def init_attr_values(self): - """ - Set the models for `attr_color`, `attr_shape`, `attr_size` and - `attr_label`. All values are set to `None`, except `attr_color` - which is set to the class variable if it exists. - """ - data = self.data - domain = data.domain if data and len(data) else None - for attr in ("attr_color", "attr_shape", "attr_size", "attr_label"): - getattr(self.controls, attr).model().set_domain(domain) - setattr(self, attr, None) - if domain is not None: - self.attr_color = domain.class_var - - def get_coordinates_data(self): - """A get coordinated method that returns no coordinates. - - Derived classes must override this method. - """ - return None, None - - def get_subset_mask(self): - """ - Return the bool array indicating the points in the subset - - The base method does nothing and would usually be overridden by - a method that returns indicators from the subset signal. - - Do not confuse the subset with selection. - - Returns: - (np.ndarray or `None`): a bool array of indicators - """ - return None - - @staticmethod - def __get_overlap_groups(x, y): - coord_to_id = defaultdict(list) - for i, xy in enumerate(zip(x, y)): - coord_to_id[xy].append(i) - return coord_to_id - - def get_column(self, attr, filter_valid=True, - merge_infrequent=False, return_labels=False): - """ - Retrieve the data from the given column in the data table - - The method: - - densifies sparse data, - - converts arrays with dtype object to floats if the attribute is - actually primitive, - - filters out invalid data (if `filter_valid` is `True`), - - merges infrequent (discrete) values into a single value - (if `merge_infrequent` is `True`). - - Tha latter feature is used for shapes and labels, where only a - set number (`MAX`) of different values is shown, and others are - merged into category 'Other'. In this case, the method may return - either the data (e.g. color indices, shape indices) or the list - of retained values, followed by `['Other']`. - - Args: - attr (:obj:~Orange.data.Variable): the column to extract - filter_valid (bool): filter out invalid data (default: `True`) - merge_infrequent (bool): merge infrequent values (default: `False`) - return_labels (bool): return a list of labels instead of data - (default: `False`) - - Returns: - (np.ndarray): (valid) data from the column, or a list of labels - """ - if attr is None: - return None - all_data = self.data.get_column_view(attr)[0] - if sp.issparse(all_data): - all_data = all_data.toDense() # TODO -- just guessing; fix this! - elif all_data.dtype == object and attr.is_primitive(): - all_data = all_data.astype(float) - if filter_valid and self.valid_data is not None: - all_data = all_data[self.valid_data] - if not merge_infrequent or attr.is_continuous \ - or len(attr.values) <= MAX: - return attr.values if return_labels else all_data - dist = bincount(all_data, max_val=len(attr.values) - 1) - infrequent = np.zeros(len(attr.values), dtype=bool) - infrequent[np.argsort(dist[0])[:-(MAX-1)]] = True - # If discrete variable has more than maximium allowed values, - # less used values are joined as "Other" - if return_labels: - return [value for value, infreq in zip(attr.values, infrequent) - if not infreq] + ["Other"] - else: - result = all_data.copy() - freq_vals = [i for i, f in enumerate(infrequent) if not f] - for i, f in enumerate(infrequent): - result[all_data == i] = MAX - 1 if f else freq_vals.index(i) - return result - - # Sizes - def get_size_data(self): - """Return the column corresponding to `attr_size`""" - if self.attr_size == OWPlotGUI.SizeByOverlap: - x, y = self.get_coordinates_data() - coord_to_id = self.__get_overlap_groups(x, y) - overlaps = [len(coord_to_id[xy]) for xy in zip(x, y)] - return [1 + log2(o) for o in overlaps] - return self.get_column(self.attr_size) - - def impute_sizes(self, size_data): - """ - Default imputation for size data - - Missing values are replaced by `MinShapeSize - 2`. Imputation is - done in place. - - Args: - size_data (np.ndarray): scaled points sizes - """ - nans = np.isnan(size_data) - if np.any(nans): - size_data[nans] = self.graph.MinShapeSize - 2 - self.Information.missing_size(self.attr_size) - else: - self.Information.missing_size.clear() - - def sizes_changed(self): - self.graph.update_sizes() - self.graph.update_colors() - - # Colors - def get_color_data(self): - """Return the column corresponding to color data""" - colors = self.get_column(self.attr_color, merge_infrequent=True) - if self.attr_size == OWPlotGUI.SizeByOverlap: - # color overlapping points by most frequent color - x, y = self.get_coordinates_data() - coord_to_id = self.__get_overlap_groups(x, y) - majority_colors = np.empty(len(x)) - for i, xy in enumerate(zip(x, y)): - cnt = Counter(colors[j] for j in coord_to_id[xy]) - majority_colors[i] = cnt.most_common(1)[0][0] - return majority_colors - return colors - - def get_color_labels(self): - """ - Return labels for the color legend - - Returns: - (list of str): labels - """ - return self.get_column(self.attr_color, merge_infrequent=True, - return_labels=True) - - - def is_continuous_color(self): - """ - Tells whether the color is continuous - - Returns: - (bool): - """ - return self.attr_color is not None and self.attr_color.is_continuous - - def set_palette(self): - """ - Set the graph palette suitable for the current `attr_color` - - This method is invoked by the plot's `get_data` and must be overridden - if the widget offers coloring that is not based on attribute values. - """ - if self.attr_color is None: - self.graph.palette = None - return - colors = self.attr_color.colors - if self.attr_color.is_discrete: - self.graph.palette = ColorPaletteGenerator( - number_of_colors=min(len(colors), MAX), - rgb_colors=colors if len(colors) <= MAX - else DefaultRGBColors) - else: - self.graph.palette = ContinuousPaletteGenerator(*colors) - - def can_draw_density(self): - """ - Tells whether the current data and settings are suitable for drawing - densities - - Returns: - (bool): - """ - return self.data is not None and \ - self.data.domain is not None and \ - len(self.data) > 1 and \ - self.attr_color is not None - - def colors_changed(self): - self.graph.update_colors() - - # Labels - def get_label_data(self, formatter=None): - """Return the column corresponding to label data""" - if self.attr_label: - label_data = self.get_column(self.attr_label) - return map(formatter or self.attr_label.str_val, label_data) - - def labels_changed(self): - self.graph.update_labels() - - # Shapes - def get_shape_data(self): - """ - Return labels for the shape legend - - Returns: - (list of str): labels - """ - return self.get_column(self.attr_shape, merge_infrequent=True) - - def get_shape_labels(self): - return self.get_column(self.attr_shape, merge_infrequent=True, - return_labels=True) - - def impute_shapes(self, shape_data, default_symbol): - """ - Default imputation for shape data - - Missing values are replaced by `default_symbol`. Imputation is - done in place. - - Args: - shape_data (np.ndarray): scaled points sizes - default_symbol (str): a string representing the symbol - """ - if shape_data is None: - return 0 - nans = np.isnan(shape_data) - if np.any(nans): - shape_data[nans] = default_symbol - self.Information.missing_shape(self.attr_shape) - else: - self.Information.missing_shape.clear() - return shape_data - - def shapes_changed(self): - self.graph.update_shapes() - - # Tooltip - def _point_tooltip(self, point_id, skip_attrs=()): - def show_part(point_data, singular, plural, max_shown, vars): - cols = [escape('{} = {}'.format(var.name, point_data[var])) - for var in vars[:max_shown + 2] - if vars == domain.class_vars - or var not in skip_attrs][:max_shown] - if not cols: - return "" - n_vars = len(vars) - if n_vars > max_shown: - cols[-1] = "... and {} others".format(n_vars - max_shown + 1) - return \ - "{}:
".format(singular if n_vars < 2 else plural) \ - + "
".join(cols) - - domain = self.data.domain - parts = (("Class", "Classes", 4, domain.class_vars), - ("Meta", "Metas", 4, domain.metas), - ("Feature", "Features", 10, domain.attributes)) - - point_data = self.data[point_id] - return "
".join(show_part(point_data, *columns) - for columns in parts) - - def get_tooltip(self, point_ids): - """ - Return the tooltip string for the given points - - The method is called by the plot on mouse hover - - Args: - point_ids (list): indices into `data` - - Returns: - (str): - """ - text = "
".join(self._point_tooltip(point_id) - for point_id in point_ids[:MAX_POINTS_IN_TOOLTIP]) - if len(point_ids) > MAX_POINTS_IN_TOOLTIP: - text = "{} instances
{}
...".format(len(point_ids), text) - return text - - def keyPressEvent(self, event): - """Update the tip about using the modifier keys when selecting""" - super().keyPressEvent(event) - self.graph.update_tooltip(event.modifiers()) - - def keyReleaseEvent(self, event): - """Update the tip about using the modifier keys when selecting""" - super().keyReleaseEvent(event) - self.graph.update_tooltip(event.modifiers()) - - # Legend - def combined_legend(self): - """Tells whether the shape and color legends are combined into one""" - return self.attr_shape == self.attr_color - - def sizeHint(self): - return QSize(1132, 708) diff --git a/Orange/widgets/visualize/tests/test_owfreeviz.py b/Orange/widgets/visualize/tests/test_owfreeviz.py index a4d41de6603..0120a8c746c 100644 --- a/Orange/widgets/visualize/tests/test_owfreeviz.py +++ b/Orange/widgets/visualize/tests/test_owfreeviz.py @@ -1,19 +1,17 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring -import scipy.sparse as sp - -from AnyQt.QtCore import QRectF, QPointF +import numpy as np from Orange.data import Table from Orange.widgets.tests.base import ( - WidgetTest, WidgetOutputsTestMixin, ProjectionWidgetTestMixin + WidgetTest, WidgetOutputsTestMixin, AnchorProjectionWidgetTestMixin ) from Orange.widgets.tests.utils import simulate from Orange.widgets.visualize.owfreeviz import OWFreeViz -class TestOWFreeViz(WidgetTest, WidgetOutputsTestMixin, - ProjectionWidgetTestMixin): +class TestOWFreeViz(WidgetTest, AnchorProjectionWidgetTestMixin, + WidgetOutputsTestMixin): @classmethod def setUpClass(cls): super().setUpClass() @@ -34,20 +32,12 @@ def test_error_msg(self): self.send_signal(self.widget.Inputs.data, data) self.assertTrue(self.widget.Error.no_class_var.is_shown()) data = self.data[:40] - domain = self.data.domain.copy() - domain.class_var.values = self.data.domain.class_var.values[:1] - data = data.transform(domain) self.send_signal(self.widget.Inputs.data, data) self.assertTrue(self.widget.Error.not_enough_class_vars.is_shown()) self.send_signal(self.widget.Inputs.data, None) self.assertFalse(self.widget.Error.no_class_var.is_shown()) self.assertFalse(self.widget.Error.not_enough_class_vars.is_shown()) - def _select_data(self): - rect = QRectF(QPointF(-20, -20), QPointF(20, 20)) - self.widget.graph.select_by_rectangle(rect) - return self.widget.graph.get_selection() - def test_optimization(self): self.send_signal(self.widget.Inputs.data, self.data) self.widget.btn_start.click() @@ -62,14 +52,6 @@ def test_optimization_reset(self): simulate.combobox_activate_index(init, 0) simulate.combobox_activate_index(init, 1) - def test_sparse(self): - table = Table("iris") - table.X = sp.csr_matrix(table.X) - self.assertTrue(sp.issparse(table.X)) - self.assertFalse(self.widget.Error.sparse_data.is_shown()) - self.send_signal(self.widget.Inputs.data, table) - self.assertTrue(self.widget.Error.sparse_data.is_shown()) - def test_set_radius_no_data(self): """ Widget should not crash when there is no data and radius slider is moved. @@ -77,4 +59,16 @@ def test_set_radius_no_data(self): """ w = self.widget self.send_signal(w.Inputs.data, None) - self.widget.graph.controls.radius.setSliderPosition(3) + self.widget.graph.controls.hide_radius.setSliderPosition(3) + + def test_output_components(self): + self.send_signal(self.widget.Inputs.data, self.data) + components = self.get_output(self.widget.Outputs.components) + domain = components.domain + self.assertEqual(domain.attributes, self.data.domain.attributes) + self.assertEqual(domain.class_vars, ()) + self.assertEqual([m.name for m in domain.metas], ["component"]) + X = np.array([[1, 0, -1, 0], [0, 1, 0, -1]]).astype(float) + np.testing.assert_array_almost_equal(components.X, X) + metas = [["FreeViz 1"], ["FreeViz 2"]] + np.testing.assert_array_equal(components.metas, metas) diff --git a/Orange/widgets/visualize/tests/test_owlinearprojection.py b/Orange/widgets/visualize/tests/test_owlinearprojection.py index b4debc9e9e7..aa27c7226e7 100644 --- a/Orange/widgets/visualize/tests/test_owlinearprojection.py +++ b/Orange/widgets/visualize/tests/test_owlinearprojection.py @@ -1,21 +1,26 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring -import time -import warnings import numpy as np -from AnyQt.QtCore import QRectF, QPointF +from AnyQt.QtCore import QItemSelectionModel -from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable, StringVariable -from Orange.util import OrangeDeprecationWarning +from Orange.data import ( + Table, Domain, ContinuousVariable, DiscreteVariable, StringVariable +) from Orange.widgets.settings import Context -from Orange.widgets.visualize.owlinearprojection import OWLinearProjection -from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin, datasets -from Orange.widgets.tests.utils import EventSpy, excepthook_catch, simulate +from Orange.widgets.tests.base import ( + WidgetTest, WidgetOutputsTestMixin, datasets, + AnchorProjectionWidgetTestMixin +) +from Orange.widgets.tests.utils import simulate +from Orange.widgets.visualize.owlinearprojection import ( + OWLinearProjection, LinearProjectionVizRank +) from Orange.widgets.visualize.utils import Worker -class TestOWLinearProjection(WidgetTest, WidgetOutputsTestMixin): +class TestOWLinearProjection(WidgetTest, AnchorProjectionWidgetTestMixin, + WidgetOutputsTestMixin): @classmethod def setUpClass(cls): super().setUpClass() @@ -27,59 +32,20 @@ def setUpClass(cls): cls.projection_table = cls._get_projection_table() def setUp(self): - self._warnings = warnings.catch_warnings() - self._warnings.__enter__() - warnings.simplefilter("ignore", OrangeDeprecationWarning) self.widget = self.create_widget(OWLinearProjection) # type: OWLinearProjection - def tearDown(self): - super().tearDown() - self._warnings.__exit__() - - def test_deprecated_graph(self): - # Remove this test and lines 30 - 32 and 35 - 38, since the widget - # is not using deprecate class any more - with warnings.catch_warnings(): - warnings.simplefilter("error", OrangeDeprecationWarning) - self.assertRaises(OrangeDeprecationWarning, - lambda: self.create_widget(OWLinearProjection)) - - def _select_data(self): - self.widget.graph.select_by_rectangle(QRectF(QPointF(-20, -20), QPointF(20, 20))) - return self.widget.graph.get_selection() - - def test_no_data(self): - """Check that the widget doesn't crash on empty data""" - self.send_signal(self.widget.Inputs.data, Table(Table("iris").domain)) - def test_nan_plot(self): data = datasets.missing_data_1() - espy = EventSpy(self.widget, OWLinearProjection.ReplotRequest) - with excepthook_catch(): - self.send_signal(self.widget.Inputs.data, data) - # ensure delayed replot request is processed - if not espy.events(): - assert espy.wait(1000) - - cb = self.widget.graph.controls - simulate.combobox_run_through_all(cb.attr_color) - simulate.combobox_run_through_all(cb.attr_size) + self.send_signal(self.widget.Inputs.data, data) + simulate.combobox_run_through_all(self.widget.controls.attr_color) + simulate.combobox_run_through_all(self.widget.controls.attr_size) - data = data.copy() data.X[:, 0] = np.nan data.Y[:] = np.nan - - spy = EventSpy(self.widget, OWLinearProjection.ReplotRequest) self.send_signal(self.widget.Inputs.data, data) self.send_signal(self.widget.Inputs.data_subset, data[2:3]) - if not spy.events(): - assert spy.wait() - - with excepthook_catch(): - simulate.combobox_activate_item(cb.attr_color, "X1") - - with excepthook_catch(): - simulate.combobox_activate_item(cb.attr_size, "X1") + simulate.combobox_run_through_all(self.widget.controls.attr_color) + simulate.combobox_run_through_all(self.widget.controls.attr_size) def test_buttons(self): for btn in self.widget.radio_placement.buttons[:3]: @@ -87,14 +53,26 @@ def test_buttons(self): self.assertTrue(btn.isEnabled()) btn.click() - def test_rank(self): - self.send_signal(self.widget.Inputs.data, self.data) - self.widget.vizrank.button.click() - time.sleep(1) + def test_btn_vizrank(self): + def check_vizrank(data): + self.send_signal(self.widget.Inputs.data, data) + if data is not None and data.domain.class_var in \ + self.widget.controls.attr_color.model(): + self.widget.attr_color = data.domain.class_var + if self.widget.btn_vizrank.isEnabled(): + vizrank = LinearProjectionVizRank(self.widget) + states = [state for state in vizrank.iterate_states(None)] + self.assertIsNotNone(vizrank.compute_score(states[0])) + + check_vizrank(self.data) + check_vizrank(self.data[:, :3]) + check_vizrank(None) + for ds in datasets.datasets(): + check_vizrank(ds) @classmethod def _get_projection_table(cls): - domain = Domain(attributes=[ContinuousVariable("Attr {}".format(i)) for i in range(4)], + domain = Domain(cls.data.domain.attributes, metas=[StringVariable("Component")]) table = Table.from_numpy(domain, X=np.array([[0.522, -0.263, 0.581, 0.566], @@ -105,10 +83,11 @@ def _get_projection_table(cls): def test_projection(self): self.send_signal(self.widget.Inputs.data, self.data) self.assertFalse(self.widget.radio_placement.buttons[3].isEnabled()) - self.send_signal(self.widget.Inputs.projection, self.projection_table) + self.send_signal(self.widget.Inputs.projection_input, + self.projection_table) self.assertTrue(self.widget.radio_placement.buttons[3].isEnabled()) self.widget.radio_placement.buttons[3].click() - self.send_signal(self.widget.Inputs.projection, None) + self.send_signal(self.widget.Inputs.projection_input, None) self.assertFalse(self.widget.radio_placement.buttons[3].isChecked()) self.assertTrue(self.widget.radio_placement.buttons[0].isChecked()) @@ -118,9 +97,9 @@ def test_projection_error(self): table = Table.from_numpy(domain, X=np.array([[0.522, -0.263, 0.581, 0.566]]), metas=[["PC1"]]) - self.assertFalse(self.widget.Warning.not_enough_components.is_shown()) - self.send_signal(self.widget.Inputs.projection, table) - self.assertTrue(self.widget.Warning.not_enough_components.is_shown()) + self.assertFalse(self.widget.Warning.not_enough_comp.is_shown()) + self.send_signal(self.widget.Inputs.projection_input, table) + self.assertTrue(self.widget.Warning.not_enough_comp.is_shown()) def test_bad_data(self): w = self.widget @@ -134,38 +113,27 @@ def test_bad_data(self): self.assertFalse(w.radio_placement.buttons[1].isEnabled()) def test_no_data_for_lda(self): + buttons = self.widget.radio_placement.buttons self.send_signal(self.widget.Inputs.data, self.data) self.widget.radio_placement.buttons[self.widget.Placement.LDA].click() - self.assertTrue(self.widget.radio_placement.buttons[self.widget.Placement.LDA].isEnabled()) - data = Table("housing") - self.send_signal(self.widget.Inputs.data, data) - self.assertFalse(self.widget.radio_placement.buttons[self.widget.Placement.LDA].isEnabled()) + self.assertTrue(buttons[self.widget.Placement.LDA].isEnabled()) + self.send_signal(self.widget.Inputs.data, Table("housing")) + self.assertFalse(buttons[self.widget.Placement.LDA].isEnabled()) + self.send_signal(self.widget.Inputs.data, None) + self.assertTrue(buttons[self.widget.Placement.LDA].isEnabled()) def test_data_no_cont_features(self): data = Table("titanic") - self.assertFalse(self.widget.Warning.no_cont_features.is_shown()) + self.assertFalse(self.widget.Error.no_cont_features.is_shown()) self.send_signal(self.widget.Inputs.data, data) - self.assertTrue(self.widget.Warning.no_cont_features.is_shown()) + self.assertTrue(self.widget.Error.no_cont_features.is_shown()) self.send_signal(self.widget.Inputs.data, None) - self.assertFalse(self.widget.Warning.no_cont_features.is_shown()) - - def test_send_report(self): - self.send_signal(self.widget.Inputs.data, self.data) - self.widget.send_report() + self.assertFalse(self.widget.Error.no_cont_features.is_shown()) def test_radius(self): self.send_signal(self.widget.Inputs.data, self.data) self.widget.radio_placement.buttons[self.widget.Placement.LDA].click() - self.widget.rslider.setValue(5) - - def test_metas(self): - data = Table("iris") - domain = data.domain - domain = Domain(attributes=domain.attributes[:3], - class_vars=domain.class_vars, - metas=domain.attributes[3:]) - data = data.transform(domain) - self.send_signal(self.widget.Inputs.data, data) + self.widget.controls.graph.hide_radius.setValue(5) def test_invalid_data(self): def assertErrorShown(data, is_shown): @@ -209,9 +177,9 @@ def test_migrate_settings_from_version_1(self): iris = Table("iris") self.send_signal(w.Inputs.data, iris, widget=w) self.assertEqual(w.graph.point_width, 8) - self.assertEqual(w.graph.attr_color, iris.domain["iris"]) - self.assertEqual(w.graph.attr_shape, iris.domain["iris"]) - self.assertEqual(w.graph.attr_size, iris.domain["sepal length"]) + self.assertEqual(w.attr_color, iris.domain["iris"]) + self.assertEqual(w.attr_shape, iris.domain["iris"]) + self.assertEqual(w.attr_size, iris.domain["sepal length"]) def test_add_variables(self): w = self.widget @@ -223,7 +191,7 @@ def test_set_radius_no_data(self): """ w = self.widget self.send_signal(w.Inputs.data, None) - w.rslider.setSliderPosition(3) + w.controls.graph.hide_radius.setSliderPosition(3) class LinProjVizRankTests(WidgetTest): @@ -239,24 +207,9 @@ def setUpClass(cls): # cls.iris_no_class = Table(dom, cls.iris) def setUp(self): - self._warnings = warnings.catch_warnings() - self._warnings.__enter__() - warnings.simplefilter("ignore", OrangeDeprecationWarning) self.widget = self.create_widget(OWLinearProjection) self.vizrank = self.widget.vizrank - def tearDown(self): - super().tearDown() - self._warnings.__exit__() - - def test_deprecated_graph(self): - # Remove this test and lines 242 - 244 and 248 - 251, since the widget - # is not using deprecate class any more - with warnings.catch_warnings(): - warnings.simplefilter("error", OrangeDeprecationWarning) - self.assertRaises(OrangeDeprecationWarning, - lambda: self.create_widget(OWLinearProjection)) - def test_discrete_class(self): self.send_signal(self.widget.Inputs.data, self.data) worker = Worker(self.vizrank) @@ -269,3 +222,15 @@ def test_continuous_class(self): worker = Worker(self.vizrank) self.vizrank.keep_running = True worker.do_work() + + def test_set_attrs(self): + self.send_signal(self.widget.Inputs.data, self.data) + model_selected = self.widget.model_selected[:] + self.vizrank.toggle() + self.process_events(until=lambda: not self.vizrank.keep_running) + self.assertEqual(len(self.vizrank.scores), self.vizrank.state_count()) + self.vizrank.rank_table.selectionModel().select( + self.vizrank.rank_model.item(0, 0).index(), + QItemSelectionModel.ClearAndSelect + ) + self.assertNotEqual(self.widget.model_selected[:], model_selected) diff --git a/Orange/widgets/visualize/tests/test_owprojectionwidget.py b/Orange/widgets/visualize/tests/test_owprojectionwidget.py new file mode 100644 index 00000000000..cd4c6e0c51c --- /dev/null +++ b/Orange/widgets/visualize/tests/test_owprojectionwidget.py @@ -0,0 +1,146 @@ +# Test methods with long descriptive names can omit docstrings +# pylint: disable=missing-docstring +from unittest.mock import patch + +import numpy as np + +from Orange.data import ( + Table, ContinuousVariable, DiscreteVariable, StringVariable, Domain +) +from Orange.widgets.tests.base import ( + WidgetTest, WidgetOutputsTestMixin, ProjectionWidgetTestMixin +) +from Orange.widgets.visualize.utils.widget import ( + OWDataProjectionWidget, OWProjectionWidgetBase +) + + +class TestOWProjectionWidget(WidgetTest): + def setUp(self): + self.widget = self.create_widget(OWProjectionWidgetBase) + + def test_get_column(self): + widget = self.widget + get_column = widget.get_column + + cont = ContinuousVariable("cont") + disc = DiscreteVariable("disc", list("abcdefghijklmno")) + disc2 = DiscreteVariable("disc2", list("abc")) + disc3 = DiscreteVariable("disc3", list("abc")) + string = StringVariable("string") + domain = Domain([cont, disc], disc2, [disc3, string]) + + widget.data = Table.from_numpy( + domain, + np.array([[1, 4], [2, 15], [6, 7]], dtype=float), + np.array([2, 1, 0], dtype=float), + np.array([[0, "foo"], [2, "bar"], [1, "baz"]]) + ) + + self.assertIsNone(get_column(None)) + np.testing.assert_almost_equal(get_column(cont), [1, 2, 6]) + np.testing.assert_almost_equal(get_column(disc), [4, 15, 7]) + np.testing.assert_almost_equal(get_column(disc2), [2, 1, 0]) + np.testing.assert_almost_equal(get_column(disc3), [0, 2, 1]) + self.assertEqual(list(get_column(string)), ["foo", "bar", "baz"]) + + widget.valid_data = np.array([True, False, True]) + + self.assertIsNone(get_column(None)) + np.testing.assert_almost_equal(get_column(cont), [1, 6]) + self.assertEqual(list(get_column(string)), ["foo", "baz"]) + + self.assertIsNone(get_column(None, False)) + np.testing.assert_almost_equal(get_column(cont, False), [1, 2, 6]) + self.assertEqual(list(get_column(string, False)), ["foo", "bar", "baz"]) + + self.assertIsNone(get_column(None, return_labels=True)) + self.assertEqual(get_column(disc, return_labels=True), disc.values) + self.assertEqual(get_column(disc2, return_labels=True), disc2.values) + self.assertEqual(get_column(disc3, return_labels=True), disc3.values) + with self.assertRaises(AssertionError): + get_column(cont, return_labels=True) + get_column(cont, return_labels=True, merge_infrequent=True) + get_column(cont, merge_infrequent=True) + with self.assertRaises(AssertionError): + get_column(string, return_labels=True) + get_column(string, return_labels=True, merge_infrequent=True) + get_column(string, merge_infrequent=True) + + @patch("Orange.widgets.visualize.utils.widget.MAX_CATEGORIES", 4) + def test_get_column_merge_infrequent(self): + widget = self.widget + get_column = widget.get_column + + disc = DiscreteVariable("disc", list("abcdefghijklmno")) + disc2 = DiscreteVariable("disc2", list("abc")) + domain = Domain([disc], disc2) + + x = np.array( + [1, 1, 1, 5, 4, 1, 1, 5, 8, 5, 5, 0, 0, 0, 4, 5, 10], dtype=float) + y = np.ones(len(x)) + widget.data = Table.from_numpy(domain, np.atleast_2d(x).T, y) + + np.testing.assert_almost_equal(get_column(disc), x) + self.assertEqual(get_column(disc, return_labels=True), disc.values) + np.testing.assert_almost_equal(get_column(disc2), y) + self.assertEqual(get_column(disc2, return_labels=True), disc2.values) + + np.testing.assert_almost_equal( + get_column(disc, merge_infrequent=True), + [1, 1, 1, 2, 3, 1, 1, 2, 3, 2, 2, 0, 0, 0, 3, 2, 3]) + self.assertEqual( + get_column(disc, merge_infrequent=True, return_labels=True), + [disc.values[0], disc.values[1], disc.values[5], "Other"]) + np.testing.assert_almost_equal( + get_column(disc2, merge_infrequent=True), y) + self.assertEqual( + get_column(disc2, return_labels=True, merge_infrequent=True), + disc2.values) + + # Test that get_columns modify a copy of the data and not the data + np.testing.assert_almost_equal(get_column(disc), x) + self.assertEqual(get_column(disc, return_labels=True), disc.values) + + +class TestableDataProjectionWidget(OWDataProjectionWidget): + def get_embedding(self): + self.valid_data = None + if self.data is None: + return None + + x_data = self.data.X.toarray() if self.data.is_sparse() \ + else self.data.X + self.valid_data = np.any(np.isfinite(x_data), 1) + if not len(x_data[self.valid_data]): + return None + + x_data[x_data == np.inf] = np.nan + x_data = np.nanmean(x_data[self.valid_data], 1) + y_data = np.ones(len(x_data)) + return np.vstack((x_data, y_data)).T + + +class TestOWDataProjectionWidget(WidgetTest, ProjectionWidgetTestMixin, + WidgetOutputsTestMixin): + @classmethod + def setUpClass(cls): + super().setUpClass() + WidgetOutputsTestMixin.init(cls) + + cls.signal_name = "Data" + cls.signal_data = cls.data + cls.same_input_output_domain = False + + def setUp(self): + self.widget = self.create_widget(TestableDataProjectionWidget) + + def test_saved_selection(self): + self.send_signal(self.widget.Inputs.data, self.data) + self.widget.graph.select_by_indices(list(range(0, len(self.data), 10))) + settings = self.widget.settingsHandler.pack_data(self.widget) + w = self.create_widget(TestableDataProjectionWidget, + stored_settings=settings) + self.send_signal(self.widget.Inputs.data, self.data, widget=w) + self.assertEqual(np.sum(w.graph.selection), 15) + np.testing.assert_equal(self.widget.graph.selection, w.graph.selection) diff --git a/Orange/widgets/visualize/tests/test_owradviz.py b/Orange/widgets/visualize/tests/test_owradviz.py index 62bae95f289..cd3718e6cbf 100644 --- a/Orange/widgets/visualize/tests/test_owradviz.py +++ b/Orange/widgets/visualize/tests/test_owradviz.py @@ -1,16 +1,17 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring -from AnyQt.QtCore import QRectF, QPointF +import numpy as np from Orange.data import Table, Domain from Orange.widgets.tests.base import ( - WidgetTest, WidgetOutputsTestMixin, ProjectionWidgetTestMixin + WidgetTest, WidgetOutputsTestMixin, + AnchorProjectionWidgetTestMixin, datasets ) -from Orange.widgets.visualize.owradviz import OWRadviz +from Orange.widgets.visualize.owradviz import OWRadviz, RadvizVizRank -class TestOWRadviz(WidgetTest, WidgetOutputsTestMixin, - ProjectionWidgetTestMixin): +class TestOWRadviz(WidgetTest, AnchorProjectionWidgetTestMixin, + WidgetOutputsTestMixin): @classmethod def setUpClass(cls): super().setUpClass() @@ -25,19 +26,21 @@ def setUp(self): self.widget = self.create_widget(OWRadviz) def test_btn_vizrank(self): - # TODO: fix this - w = self.widget - def assertEnabled(data, is_enabled): - self.send_signal(w.Inputs.data, data) - self.assertEqual(is_enabled, w.btn_vizrank.isEnabled()) - - data = self.data - for data, is_enabled in zip([data[:, :3], data, None], [False, False, False]): - assertEnabled(data, is_enabled) + def check_vizrank(data): + self.send_signal(self.widget.Inputs.data, data) + if data is not None and data.domain.class_var in \ + self.widget.controls.attr_color.model(): + self.widget.attr_color = data.domain.class_var + if self.widget.btn_vizrank.isEnabled(): + vizrank = RadvizVizRank(self.widget) + states = [state for state in vizrank.iterate_states(None)] + self.assertIsNotNone(vizrank.compute_score(states[0])) - def _select_data(self): - self.widget.graph.select_by_rectangle(QRectF(QPointF(-20, -20), QPointF(20, 20))) - return self.widget.graph.get_selection() + check_vizrank(self.data) + check_vizrank(self.data[:, :3]) + check_vizrank(None) + for ds in datasets.datasets(): + check_vizrank(ds) def test_no_features(self): w = self.widget @@ -56,3 +59,29 @@ def test_not_enough_instances(self): self.assertTrue(w.Error.no_instances.is_shown()) self.send_signal(w.Inputs.data, self.data) self.assertFalse(w.Error.no_instances.is_shown()) + + def test_saved_features(self): + self.send_signal(self.widget.Inputs.data, self.data) + self.widget.model_selected.pop(0) + settings = self.widget.settingsHandler.pack_data(self.widget) + w = self.create_widget(OWRadviz, stored_settings=settings) + self.send_signal(self.widget.Inputs.data, self.data, widget=w) + selected = [a.name for a in self.widget.model_selected] + self.assertListEqual(selected, [a.name for a in w.model_selected]) + self.send_signal(self.widget.Inputs.data, self.heart_disease) + selected = [a.name for a in self.widget.model_selected] + names = [a.name for a in self.heart_disease.domain.attributes[:5]] + self.assertListEqual(selected, names) + + def test_output_components(self): + self.send_signal(self.widget.Inputs.data, self.data) + components = self.get_output(self.widget.Outputs.components) + domain = components.domain + self.assertEqual(domain.attributes, self.data.domain.attributes) + self.assertEqual(domain.class_vars, ()) + self.assertEqual([m.name for m in domain.metas], ["component"]) + X = np.array([[1, 0, -1, 0], [0, 1, 0, -1], + [0, 1.57, 3.14, -1.57]]) + np.testing.assert_array_almost_equal(components.X, X, 2) + metas = [["RX"], ["RY"], ["angle"]] + np.testing.assert_array_equal(components.metas, metas) diff --git a/Orange/widgets/visualize/tests/test_owscatterplot.py b/Orange/widgets/visualize/tests/test_owscatterplot.py index 629d12d29be..156ae3bbd1a 100644 --- a/Orange/widgets/visualize/tests/test_owscatterplot.py +++ b/Orange/widgets/visualize/tests/test_owscatterplot.py @@ -2,7 +2,6 @@ # pylint: disable=missing-docstring from unittest.mock import MagicMock, patch import numpy as np -import scipy.sparse as sp from AnyQt.QtCore import QRectF, Qt from AnyQt.QtWidgets import QToolTip @@ -15,12 +14,12 @@ from Orange.widgets.visualize.owscatterplot import ( OWScatterPlot, ScatterPlotVizRank ) -from Orange.widgets.visualize.owscatterplotgraph import MAX +from Orange.widgets.visualize.utils.widget import MAX_CATEGORIES from Orange.widgets.widget import AttributeList -class TestOWScatterPlot(WidgetTest, WidgetOutputsTestMixin, - ProjectionWidgetTestMixin): +class TestOWScatterPlot(WidgetTest, ProjectionWidgetTestMixin, + WidgetOutputsTestMixin): @classmethod def setUpClass(cls): super().setUpClass() @@ -33,11 +32,6 @@ def setUpClass(cls): def setUp(self): self.widget = self.create_widget(OWScatterPlot) - def _compare_selected_annotated_domains(self, selected, annotated): - # Base class tests that selected.domain is a subset of annotated.domain - # In scatter plot, the two domains are unrelated, so we disable the test - pass - def test_set_data(self): # Connect iris to scatter plot self.send_signal(self.widget.Inputs.data, self.data) @@ -62,7 +56,7 @@ def test_set_data(self): self.assertIsNone(self.widget.attr_color) # and remove the legend - self.assertIsNone(self.widget.graph.color_legend) + self.assertEqual(len(self.widget.graph.color_legend.items), 0) # Connect iris again # same attributes that were used last time should be selected @@ -94,10 +88,6 @@ def test_optional_combos(self): t2 = Table(d2, self.data) self.send_signal(self.widget.Inputs.data, t2) - def _select_data(self): - self.widget.graph.select_by_rectangle(QRectF(4, 3, 3, 1)) - return self.widget.graph.get_selection() - def test_error_message(self): """Check if error message appears and then disappears when data is removed from input""" @@ -262,7 +252,8 @@ def test_saving_selection(self): self.widget.graph.select_by_rectangle(QRectF(4, 3, 3, 1)) selected_inds = np.flatnonzero(self.widget.graph.selection) settings = self.widget.settingsHandler.pack_data(self.widget) - np.testing.assert_equal(selected_inds, [i for i, g in settings["selection_group"]]) + np.testing.assert_equal(selected_inds, + [i for i, g in settings["selection"]]) def test_points_selection(self): # Opening widget with saved selection should restore it @@ -316,21 +307,6 @@ def test_set_strings_settings(self): self.assertEqual(w.attr_shape.name, "iris") self.assertEqual(w.attr_size.name, "petal width") - def test_sparse(self): - """ - Test sparse data. - GH-2152 - GH-2157 - """ - table = Table("iris").to_sparse(sparse_attributes=True, - sparse_class=True) - self.send_signal(self.widget.Inputs.data, table) - self.widget.set_subset_data(table[:30]) - data = self.get_output("Data") - - self.assertTrue(data.is_sparse()) - self.assertEqual(len(data.domain), 5) - def test_features_and_no_data(self): """ Prevent crashing when features are sent but no data. @@ -416,12 +392,13 @@ def test_auto_send_selection(self): """ data = Table("iris") self.send_signal(self.widget.Inputs.data, data) - self.widget.controls.auto_send_selection.setChecked(False) - self.assertEqual(False, self.widget.controls.auto_send_selection.isChecked()) + self.widget.controls.auto_commit.setChecked(False) + self.assertFalse(self.widget.controls.auto_commit.isChecked()) self._select_data() self.assertIsNone(self.get_output(self.widget.Outputs.selected_data)) - self.widget.controls.auto_send_selection.setChecked(True) - self.assertIsInstance(self.get_output(self.widget.Outputs.selected_data), Table) + self.widget.controls.auto_commit.setChecked(True) + output = self.get_output(self.widget.Outputs.selected_data) + self.assertIsInstance(output, Table) def test_color_is_optional(self): zoo = Table("zoo") @@ -506,18 +483,6 @@ def test_subset_data(self): self.send_signal(w.Inputs.data_subset, data[::30]) self.assertEqual(len(w.subset_indices), 5) - def test_sparse_subset_data(self): - """ - Scatter Plot can handle sparse subset data. - GH-2773 - """ - data = Table("iris") - w = self.widget - data.X = sp.csr_matrix(data.X) - self.send_signal(w.Inputs.data, data) - self.send_signal(w.Inputs.data_subset, data[::30]) - self.assertEqual(len(w.subset_indices), 5) - def test_metas_zero_column(self): """ Prevent crash when metas column is zero. @@ -606,11 +571,11 @@ def assert_equal(data, max): pen_data, brush_data = self.widget.graph.get_colors() self.assertEqual(max, len(np.unique([id(p) for p in pen_data])), ) - assert_equal(prepare_data(), MAX) + assert_equal(prepare_data(), MAX_CATEGORIES) # data with nan value data = prepare_data() data.Y[42] = np.nan - assert_equal(data, MAX + 1) + assert_equal(data, MAX_CATEGORIES + 1) if __name__ == "__main__": diff --git a/Orange/widgets/visualize/tests/test_owscatterplotbase.py b/Orange/widgets/visualize/tests/test_owscatterplotbase.py new file mode 100644 index 00000000000..e966a4fc664 --- /dev/null +++ b/Orange/widgets/visualize/tests/test_owscatterplotbase.py @@ -0,0 +1,961 @@ +# Test methods with long descriptive names can omit docstrings +# pylint: disable=missing-docstring +from unittest.mock import patch, Mock +import numpy as np + +from AnyQt.QtCore import QRectF, Qt +from AnyQt.QtGui import QColor + +from pyqtgraph import mkPen + +from Orange.widgets.tests.base import GuiTest +from Orange.widgets.utils.colorpalette import ColorPaletteGenerator, \ + ContinuousPaletteGenerator, NAN_GREY +from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase, \ + ScatterPlotItem, SELECTION_WIDTH +from Orange.widgets.widget import OWWidget + + +class MockWidget(OWWidget): + name = "Mock" + + get_coordinates_data = Mock(return_value=(None, None)) + get_size_data = Mock(return_value=None) + get_shape_data = Mock(return_value=None) + get_color_data = Mock(return_value=None) + get_label_data = Mock(return_value=None) + get_color_labels = Mock(return_value=None) + get_shape_labels = Mock(return_value=None) + get_subset_mask = Mock(return_value=None) + get_tooltip = Mock(return_value="") + + is_continuous_color = Mock(return_value=False) + can_draw_density = Mock(return_value=True) + combined_legend = Mock(return_value=False) + selection_changed = Mock(return_value=None) + + def get_palette(self): + if self.is_continuous_color(): + return ContinuousPaletteGenerator(Qt.white, Qt.black, False) + else: + return ColorPaletteGenerator(12) + + +class TestOWScatterPlotBase(GuiTest): + def setUp(self): + self.master = MockWidget() + self.graph = OWScatterPlotBase(self.master) + + self.xy = (np.arange(10, dtype=float), np.arange(10, dtype=float)) + self.master.get_coordinates_data = lambda: self.xy + + # pylint: disable=keyword-arg-before-vararg + def setRange(self, rect=None, *_, **__): + if isinstance(rect, QRectF): + self.last_setRange = [[rect.left(), rect.right()], + [rect.top(), rect.bottom()]] + + def test_update_coordinates_no_data(self): + self.xy = None, None + self.graph.reset_graph() + self.assertIsNone(self.graph.scatterplot_item) + self.assertIsNone(self.graph.scatterplot_item_sel) + + self.xy = [], [] + self.graph.reset_graph() + self.assertIsNone(self.graph.scatterplot_item) + self.assertIsNone(self.graph.scatterplot_item_sel) + + def test_update_coordinates(self): + graph = self.graph + xy = self.xy = (np.array([1, 2]), np.array([3, 4])) + graph.reset_graph() + + scatterplot_item = graph.scatterplot_item + scatterplot_item_sel = graph.scatterplot_item_sel + data = scatterplot_item.data + + np.testing.assert_almost_equal(scatterplot_item.getData(), xy) + np.testing.assert_almost_equal(scatterplot_item_sel.getData(), xy) + scatterplot_item.setSize([5, 6]) + scatterplot_item.setSymbol([7, 8]) + scatterplot_item.setPen([mkPen(9), mkPen(10)]) + scatterplot_item.setBrush([11, 12]) + data["data"] = np.array([13, 14]) + + xy[0][0] = 0 + graph.update_coordinates() + np.testing.assert_almost_equal(graph.scatterplot_item.getData(), xy) + np.testing.assert_almost_equal(graph.scatterplot_item_sel.getData(), xy) + + # Graph updates coordinates instead of creating new items + self.assertIs(scatterplot_item, graph.scatterplot_item) + self.assertIs(scatterplot_item_sel, graph.scatterplot_item_sel) + np.testing.assert_almost_equal(data["size"], [5, 6]) + np.testing.assert_almost_equal(data["symbol"], [7, 8]) + self.assertEqual(data["pen"][0], mkPen(9)) + self.assertEqual(data["pen"][1], mkPen(10)) + np.testing.assert_almost_equal(data["brush"], [11, 12]) + np.testing.assert_almost_equal(data["data"], [13, 14]) + + def test_update_coordinates_and_labels(self): + graph = self.graph + xy = self.xy = (np.array([1, 2]), np.array([3, 4])) + self.master.get_label_data = lambda: ["a", "b"] + graph.reset_graph() + self.assertEqual(graph.labels[0].pos().x(), 1) + xy[0][0] = 0 + graph.update_coordinates() + self.assertEqual(graph.labels[0].pos().x(), 0) + + def test_update_coordinates_and_density(self): + graph = self.graph + xy = self.xy = (np.array([1, 2]), np.array([3, 4])) + self.master.get_label_data = lambda: ["a", "b"] + graph.reset_graph() + self.assertEqual(graph.labels[0].pos().x(), 1) + xy[0][0] = 0 + graph.update_density = Mock() + graph.update_coordinates() + graph.update_density.assert_called_with() + + def test_update_coordinates_reset_view(self): + graph = self.graph + graph.view_box.setRange = self.setRange + xy = self.xy = (np.array([2, 1]), np.array([3, 10])) + self.master.get_label_data = lambda: ["a", "b"] + graph.reset_graph() + self.assertEqual(self.last_setRange, [[1, 2], [3, 10]]) + + xy[0][1] = 0 + graph.update_coordinates() + self.assertEqual(self.last_setRange, [[0, 2], [3, 10]]) + + def test_reset_graph_no_data(self): + self.xy = (None, None) + self.graph.scatterplot_item = ScatterPlotItem([1, 2], [3, 4]) + self.graph.reset_graph() + self.assertIsNone(self.graph.scatterplot_item) + self.assertIsNone(self.graph.scatterplot_item_sel) + + def test_update_coordinates_indices(self): + graph = self.graph + self.xy = (np.array([2, 1]), np.array([3, 10])) + graph.reset_graph() + np.testing.assert_almost_equal( + graph.scatterplot_item.data["data"], [0, 1]) + + def test_sampling(self): + graph = self.graph + master = self.master + + # Enable sampling before getting the data + graph.set_sample_size(3) + xy = self.xy = (np.arange(10, dtype=float), + np.arange(0, 30, 3, dtype=float)) + d = np.arange(10, dtype=float) + master.get_size_data = lambda: d + master.get_shape_data = lambda: d + master.get_color_data = lambda: d + master.get_label_data = lambda: \ + np.array([str(x) for x in d], dtype=object) + graph.reset_graph() + + # Check proper sampling + scatterplot_item = graph.scatterplot_item + x, y = scatterplot_item.getData() + self.assertEqual(len(x), 3) + self.assertNotEqual(x[0], x[1]) + self.assertNotEqual(x[0], x[2]) + self.assertNotEqual(x[1], x[2]) + np.testing.assert_almost_equal(3 * x, y) + + data = scatterplot_item.data + s0, s1, s2 = data["size"] - graph.MinShapeSize + np.testing.assert_almost_equal( + (s2 - s1) / (s1 - s0), + (x[2] - x[1]) / (x[1] - x[0])) + self.assertEqual( + list(data["symbol"]), + [graph.CurveSymbols[int(xi)] for xi in x]) + self.assertEqual( + [pen.color().hue() for pen in data["pen"]], + [graph.palette[xi].hue() for xi in x]) + self.assertEqual( + [label.textItem.toPlainText() for label in graph.labels], + [str(xi) for xi in x]) + + # Check that sample is extended when sample size is changed + graph.set_sample_size(4) + scatterplot_item = graph.scatterplot_item + x, y = scatterplot_item.getData() + data = scatterplot_item.data + s0, s1, s2, s3 = data["size"] - graph.MinShapeSize + np.testing.assert_almost_equal( + (s2 - s1) / (s1 - s0), + (x[2] - x[1]) / (x[1] - x[0])) + np.testing.assert_almost_equal( + (s2 - s1) / (s1 - s3), + (x[2] - x[1]) / (x[1] - x[3])) + self.assertEqual( + list(data["symbol"]), + [graph.CurveSymbols[int(xi)] for xi in x]) + self.assertEqual( + [pen.color().hue() for pen in data["pen"]], + [graph.palette[xi].hue() for xi in x]) + self.assertEqual( + [label.textItem.toPlainText() for label in graph.labels], + [str(xi) for xi in x]) + + # Disable sampling + graph.set_sample_size(None) + scatterplot_item = graph.scatterplot_item + x, y = scatterplot_item.getData() + data = scatterplot_item.data + np.testing.assert_almost_equal(x, xy[0]) + np.testing.assert_almost_equal(y, xy[1]) + self.assertEqual( + list(data["symbol"]), + [graph.CurveSymbols[int(xi)] for xi in d]) + self.assertEqual( + [pen.color().hue() for pen in data["pen"]], + [graph.palette[xi].hue() for xi in d]) + self.assertEqual( + [label.textItem.toPlainText() for label in graph.labels], + [str(xi) for xi in d]) + + # Enable sampling when data is already present and not sampled + graph.set_sample_size(3) + scatterplot_item = graph.scatterplot_item + x, y = scatterplot_item.getData() + data = scatterplot_item.data + s0, s1, s2 = data["size"] - graph.MinShapeSize + np.testing.assert_almost_equal( + (s2 - s1) / (s1 - s0), + (x[2] - x[1]) / (x[1] - x[0])) + self.assertEqual( + list(data["symbol"]), + [graph.CurveSymbols[int(xi)] for xi in x]) + self.assertEqual( + [pen.color().hue() for pen in data["pen"]], + [graph.palette[xi].hue() for xi in x]) + self.assertEqual( + [label.textItem.toPlainText() for label in graph.labels], + [str(xi) for xi in x]) + + # Update data when data is present and sampling is enabled + xy[0][:] = np.arange(9, -1, -1, dtype=float) + d = xy[0] + graph.update_coordinates() + x1, _ = scatterplot_item.getData() + np.testing.assert_almost_equal(9 - x, x1) + graph.update_sizes() + data = scatterplot_item.data + s0, s1, s2 = data["size"] - graph.MinShapeSize + np.testing.assert_almost_equal( + (s2 - s1) / (s1 - s0), + (x[2] - x[1]) / (x[1] - x[0])) + + # Reset graph when data is present and sampling is enabled + self.xy = (np.arange(100, 105, dtype=float), + np.arange(100, 105, dtype=float)) + d = self.xy[0] - 100 + graph.reset_graph() + scatterplot_item = graph.scatterplot_item + x, y = scatterplot_item.getData() + self.assertEqual(len(x), 3) + self.assertTrue(np.all(x > 99)) + data = scatterplot_item.data + s0, s1, s2 = data["size"] - graph.MinShapeSize + np.testing.assert_almost_equal( + (s2 - s1) / (s1 - s0), + (x[2] - x[1]) / (x[1] - x[0])) + + # Don't sample when unnecessary + self.xy = (np.arange(100, dtype=float), ) * 2 + d = None + delattr(master, "get_label_data") + graph.reset_graph() + graph.set_sample_size(120) + scatterplot_item = graph.scatterplot_item + x, y = scatterplot_item.getData() + np.testing.assert_almost_equal(x, np.arange(100)) + + def test_sampling_keeps_selection(self): + graph = self.graph + + self.xy = (np.arange(100, dtype=float), + np.arange(100, dtype=float)) + graph.reset_graph() + graph.select_by_indices(np.arange(1, 100, 2)) + graph.set_sample_size(30) + np.testing.assert_almost_equal(graph.selection, np.arange(100) % 2) + graph.set_sample_size(None) + np.testing.assert_almost_equal(graph.selection, np.arange(100) % 2) + + base = "Orange.widgets.visualize.owscatterplotgraph.OWScatterPlotBase." + + @patch(base + "update_sizes") + @patch(base + "update_colors") + @patch(base + "update_selection_colors") + @patch(base + "update_shapes") + @patch(base + "update_labels") + def test_reset_calls_all_updates_and_update_doesnt(self, *mocks): + master = MockWidget() + graph = OWScatterPlotBase(master) + for mock in mocks: + mock.assert_not_called() + + graph.reset_graph() + for mock in mocks: + mock.assert_called_with() + mock.reset_mock() + + graph.update_coordinates() + for mock in mocks: + mock.assert_not_called() + + def test_jittering(self): + graph = self.graph + graph.jitter_size = 10 + graph.reset_graph() + scatterplot_item = graph.scatterplot_item + x, y = scatterplot_item.getData() + a10 = np.arange(10) + self.assertTrue(np.any(np.nonzero(a10 - x))) + self.assertTrue(np.any(np.nonzero(a10 - y))) + np.testing.assert_array_less(a10 - x, 1) + np.testing.assert_array_less(a10 - y, 1) + + graph.jitter_size = 0 + graph.update_coordinates() + scatterplot_item = graph.scatterplot_item + x, y = scatterplot_item.getData() + np.testing.assert_equal(a10, x) + + def test_size_normalization(self): + graph = self.graph + + self.master.get_size_data = lambda: d + d = np.arange(10, dtype=float) + + graph.reset_graph() + scatterplot_item = graph.scatterplot_item + size = scatterplot_item.data["size"] + diffs = [round(y - x, 2) for x, y in zip(size, size[1:])] + self.assertEqual(len(set(diffs)), 1) + self.assertGreater(diffs[0], 0) + + d = np.arange(10, 20, dtype=float) + graph.update_sizes() + self.assertIs(scatterplot_item, graph.scatterplot_item) + size = scatterplot_item.data["size"] + diffs2 = [round(y - x, 2) for x, y in zip(size, size[1:])] + self.assertEqual(diffs, diffs2) + + def test_size_with_nans(self): + graph = self.graph + + self.master.get_size_data = lambda: d + d = np.arange(10, dtype=float) + + graph.reset_graph() + scatterplot_item = graph.scatterplot_item + sizes = scatterplot_item.data["size"] + + d[4] = np.nan + graph.update_sizes() + sizes2 = scatterplot_item.data["size"] + + self.assertEqual(sizes[1] - sizes[0], sizes2[1] - sizes2[0]) + self.assertLess(sizes2[4], self.graph.MinShapeSize) + + d[:] = np.nan + graph.update_sizes() + sizes3 = scatterplot_item.data["size"] + np.testing.assert_almost_equal(sizes, sizes3) + + def test_sizes_all_same_or_nan(self): + graph = self.graph + + self.master.get_size_data = lambda: d + d = np.full((10, ), 3.0) + + graph.reset_graph() + scatterplot_item = graph.scatterplot_item + sizes = scatterplot_item.data["size"] + self.assertEqual(len(set(sizes)), 1) + self.assertGreater(sizes[0], self.graph.MinShapeSize) + + d = None + graph.update_sizes() + scatterplot_item = graph.scatterplot_item + sizes2 = scatterplot_item.data["size"] + np.testing.assert_almost_equal(sizes, sizes2) + + def test_sizes_point_width_is_linear(self): + graph = self.graph + + self.master.get_size_data = lambda: d + d = np.arange(10, dtype=float) + + graph.point_width = 1 + graph.reset_graph() + sizes1 = graph.scatterplot_item.data["size"] + + graph.point_width = 2 + graph.update_sizes() + sizes2 = graph.scatterplot_item.data["size"] + + graph.point_width = 3 + graph.update_sizes() + sizes3 = graph.scatterplot_item.data["size"] + + np.testing.assert_almost_equal(2 * (sizes2 - sizes1), sizes3 - sizes1) + + def test_sizes_custom_imputation(self): + + def impute_max(size_data): + size_data[np.isnan(size_data)] = np.nanmax(size_data) + + graph = self.graph + + self.master.get_size_data = lambda: d + self.master.impute_sizes = impute_max + d = np.arange(10, dtype=float) + d[4] = np.nan + graph.reset_graph() + sizes = graph.scatterplot_item.data["size"] + self.assertAlmostEqual(sizes[4], sizes[9]) + + def test_sizes_selection(self): + graph = self.graph + graph.get_size = lambda: np.arange(10, dtype=float) + graph.reset_graph() + np.testing.assert_almost_equal( + graph.scatterplot_item_sel.data["size"] + - graph.scatterplot_item.data["size"], + SELECTION_WIDTH) + + def test_colors_discrete(self): + self.master.is_continuous_color = lambda: False + palette = self.master.get_palette() + graph = self.graph + + self.master.get_color_data = lambda: d + d = np.arange(10, dtype=float) % 2 + + graph.reset_graph() + self.assertTrue( + all(pen.color().hue() is palette[i % 2].hue() + for i, pen in enumerate(graph.scatterplot_item.data["pen"]))) + self.assertTrue( + all(pen.color().hue() is palette[i % 2].hue() + for i, pen in enumerate(graph.scatterplot_item.data["brush"]))) + + def test_colors_discrete_nan(self): + self.master.is_continuous_color = lambda: False + palette = self.master.get_palette() + graph = self.graph + + d = np.arange(10, dtype=float) % 2 + d[4] = np.nan + self.master.get_color_data = lambda: d + graph.reset_graph() + pens = graph.scatterplot_item.data["pen"] + brushes = graph.scatterplot_item.data["brush"] + self.assertEqual(pens[0].color().hue(), palette[0].hue()) + self.assertEqual(pens[1].color().hue(), palette[1].hue()) + self.assertEqual(brushes[0].color().hue(), palette[0].hue()) + self.assertEqual(brushes[1].color().hue(), palette[1].hue()) + self.assertEqual(pens[4].color().hue(), QColor(128, 128, 128).hue()) + self.assertEqual(brushes[4].color().hue(), QColor(128, 128, 128).hue()) + + def test_colors_continuous(self): + self.master.is_continuous_color = lambda: True + graph = self.graph + + d = np.arange(10, dtype=float) + self.master.get_color_data = lambda: d + graph.reset_graph() # I don't have a good test ... just don't crash + + d[4] = np.nan + graph.update_colors() # Ditto + + def test_colors_continuous_nan(self): + self.master.is_continuous_color = lambda: True + graph = self.graph + + d = np.arange(10, dtype=float) % 2 + d[4] = np.nan + self.master.get_color_data = lambda: d + graph.reset_graph() + pens = graph.scatterplot_item.data["pen"] + brushes = graph.scatterplot_item.data["brush"] + nan_color = QColor(*NAN_GREY) + self.assertEqual(pens[4].color().hue(), nan_color.hue()) + self.assertEqual(brushes[4].color().hue(), nan_color.hue()) + + def test_colors_subset(self): + def run_tests(): + self.master.get_subset_mask = lambda: None + + graph.alpha_value = 42 + graph.reset_graph() + brushes = graph.scatterplot_item.data["brush"] + self.assertEqual(brushes[0].color().alpha(), 42) + self.assertEqual(brushes[1].color().alpha(), 42) + self.assertEqual(brushes[4].color().alpha(), 42) + + graph.alpha_value = 123 + graph.update_colors() + brushes = graph.scatterplot_item.data["brush"] + self.assertEqual(brushes[0].color().alpha(), 123) + self.assertEqual(brushes[1].color().alpha(), 123) + self.assertEqual(brushes[4].color().alpha(), 123) + + self.master.get_subset_mask = lambda: np.arange(10) >= 5 + graph.update_colors() + brushes = graph.scatterplot_item.data["brush"] + self.assertEqual(brushes[0].color().alpha(), 0) + self.assertEqual(brushes[1].color().alpha(), 0) + self.assertEqual(brushes[4].color().alpha(), 0) + self.assertEqual(brushes[5].color().alpha(), 255) + self.assertEqual(brushes[6].color().alpha(), 255) + self.assertEqual(brushes[7].color().alpha(), 255) + + graph = self.graph + + self.master.get_color_data = lambda: None + self.master.is_continuous_color = lambda: True + graph.reset_graph() + run_tests() + + self.master.is_continuous_color = lambda: False + graph.reset_graph() + run_tests() + + d = np.arange(10, dtype=float) % 2 + d[4:6] = np.nan + self.master.get_color_data = lambda: d + + self.master.is_continuous_color = lambda: True + graph.reset_graph() + run_tests() + + self.master.is_continuous_color = lambda: False + graph.reset_graph() + run_tests() + + def test_colors_none(self): + graph = self.graph + graph.reset_graph() + hue = QColor(128, 128, 128).hue() + + data = graph.scatterplot_item.data + self.assertTrue(all(pen.color().hue() == hue for pen in data["pen"])) + self.assertTrue(all(pen.color().hue() == hue for pen in data["brush"])) + + self.master.get_subset_mask = lambda: np.arange(10) < 5 + graph.update_colors() + data = graph.scatterplot_item.data + self.assertTrue(all(pen.color().hue() == hue for pen in data["pen"])) + self.assertTrue(all(pen.color().hue() == hue for pen in data["brush"])) + + def test_colors_update_legend_and_density(self): + graph = self.graph + graph.update_legends = Mock() + graph.update_density = Mock() + graph.reset_graph() + graph.update_legends.assert_called_with() + graph.update_density.assert_called_with() + + graph.update_legends.reset_mock() + graph.update_density.reset_mock() + + graph.update_coordinates() + graph.update_legends.assert_not_called() + + graph.update_colors() + graph.update_legends.assert_called_with() + graph.update_density.assert_called_with() + + def test_selection_colors(self): + graph = self.graph + graph.reset_graph() + data = graph.scatterplot_item_sel.data + + # One group + graph.select_by_indices(np.array([0, 1, 2, 3])) + graph.update_selection_colors() + pens = data["pen"] + for i in range(4): + self.assertNotEqual(pens[i].style(), Qt.NoPen) + for i in range(4, 10): + self.assertEqual(pens[i].style(), Qt.NoPen) + + # Two groups + with patch("AnyQt.QtWidgets.QApplication.keyboardModifiers", + lambda: Qt.ShiftModifier): + graph.select_by_indices(np.array([4, 5, 6])) + + graph.update_selection_colors() + pens = data["pen"] + for i in range(7): + self.assertNotEqual(pens[i].style(), Qt.NoPen) + for i in range(7, 10): + self.assertEqual(pens[i].style(), Qt.NoPen) + self.assertEqual(len({pen.color().hue() for pen in pens[:4]}), 1) + self.assertEqual(len({pen.color().hue() for pen in pens[4:7]}), 1) + color1 = pens[3].color().hue() + color2 = pens[4].color().hue() + self.assertNotEqual(color1, color2) + + # Two groups + sampling + graph.set_sample_size(7) + x = graph.scatterplot_item.getData()[0] + pens = graph.scatterplot_item_sel.data["pen"] + for xi, pen in zip(x, pens): + if xi < 4: + self.assertEqual(pen.color().hue(), color1) + elif xi < 7: + self.assertEqual(pen.color().hue(), color2) + else: + self.assertEqual(pen.style(), Qt.NoPen) + + def test_density(self): + graph = self.graph + density = object() + with patch("Orange.widgets.utils.classdensity.class_density_image", + return_value=density): + graph.reset_graph() + self.assertIsNone(graph.density_img) + + graph.plot_widget.addItem = Mock() + graph.plot_widget.removeItem = Mock() + + graph.class_density = True + graph.update_colors() + self.assertIsNone(graph.density_img) + + d = np.ones((10, ), dtype=float) + self.master.get_color_data = lambda: d + graph.update_colors() + self.assertIsNone(graph.density_img) + + d = np.arange(10) % 2 + graph.update_colors() + self.assertIs(graph.density_img, density) + self.assertIs(graph.plot_widget.addItem.call_args[0][0], density) + + graph.class_density = False + graph.update_colors() + self.assertIsNone(graph.density_img) + self.assertIs(graph.plot_widget.removeItem.call_args[0][0], density) + + graph.class_density = True + graph.update_colors() + self.assertIs(graph.density_img, density) + self.assertIs(graph.plot_widget.addItem.call_args[0][0], density) + + graph.update_coordinates = lambda: (None, None) + graph.reset_graph() + self.assertIsNone(graph.density_img) + self.assertIs(graph.plot_widget.removeItem.call_args[0][0], density) + + def test_labels(self): + graph = self.graph + graph.reset_graph() + + self.assertEqual(graph.labels, []) + + self.master.get_label_data = lambda: \ + np.array([str(x) for x in range(10)], dtype=object) + graph.update_labels() + self.assertEqual( + [label.textItem.toPlainText() for label in graph.labels], + [str(i) for i in range(10)]) + + # Label only selected + selected = [1, 3, 5] + graph.select_by_indices(selected) + self.graph.label_only_selected = True + graph.update_labels() + self.assertEqual( + [label.textItem.toPlainText() for label in graph.labels], + [str(x) for x in selected]) + x, y = graph.scatterplot_item.getData() + for i, index in enumerate(selected): + self.assertEqual(x[index], graph.labels[i].x()) + self.assertEqual(y[index], graph.labels[i].y()) + + # Disable label only selected + self.graph.label_only_selected = False + graph.update_labels() + self.assertEqual( + [label.textItem.toPlainText() for label in graph.labels], + [str(i) for i in range(10)]) + x, y = graph.scatterplot_item.getData() + for xi, yi, label in zip(x, y, graph.labels): + self.assertEqual(xi, label.x()) + self.assertEqual(yi, label.y()) + + # Label only selected + sampling + selected = [1, 3, 4, 5, 6, 7, 9] + graph.select_by_indices(selected) + self.graph.label_only_selected = True + graph.update_labels() + graph.set_sample_size(5) + for label in graph.labels: + ind = int(label.textItem.toPlainText()) + self.assertIn(ind, selected) + self.assertEqual(label.x(), x[ind]) + self.assertEqual(label.y(), y[ind]) + + def test_labels_update_coordinates(self): + graph = self.graph + self.master.get_label_data = lambda: \ + np.array([str(x) for x in range(10)], dtype=object) + + graph.reset_graph() + graph.set_sample_size(7) + x, y = graph.scatterplot_item.getData() + for xi, yi, label in zip(x, y, graph.labels): + self.assertEqual(xi, label.x()) + self.assertEqual(yi, label.y()) + + self.master.get_coordinates_data = \ + lambda: (np.arange(10, 20), np.arange(50, 60)) + graph.update_coordinates() + x, y = graph.scatterplot_item.getData() + for xi, yi, label in zip(x, y, graph.labels): + self.assertEqual(xi, label.x()) + self.assertEqual(yi, label.y()) + + def test_shapes(self): + graph = self.graph + + self.master.get_shape_data = lambda: d + d = np.arange(10, dtype=float) % 3 + + graph.reset_graph() + scatterplot_item = graph.scatterplot_item + symbols = scatterplot_item.data["symbol"] + self.assertTrue(all(symbol == graph.CurveSymbols[i % 3] + for i, symbol in enumerate(symbols))) + + d = np.arange(10, dtype=float) % 2 + graph.update_shapes() + symbols = scatterplot_item.data["symbol"] + self.assertTrue(all(symbol == graph.CurveSymbols[i % 2] + for i, symbol in enumerate(symbols))) + + d = None + graph.update_shapes() + symbols = scatterplot_item.data["symbol"] + self.assertEqual(len(set(symbols)), 1) + + def test_shapes_nan(self): + graph = self.graph + + self.master.get_shape_data = lambda: d + d = np.arange(10, dtype=float) % 3 + d[2] = np.nan + + graph.reset_graph() + self.assertEqual(graph.scatterplot_item.data["symbol"][2], '?') + + d[:] = np.nan + graph.update_shapes() + self.assertTrue( + all(symbol == '?' + for symbol in graph.scatterplot_item.data["symbol"])) + + def impute0(data, _): + data[np.isnan(data)] = 0 + + self.master.impute_shapes = impute0 + d = np.arange(10, dtype=float) % 3 + d[2] = np.nan + graph.update_shapes() + self.assertEqual(graph.scatterplot_item.data["symbol"][2], + graph.CurveSymbols[0]) + + def test_show_grid(self): + graph = self.graph + show_grid = self.graph.plot_widget.showGrid = Mock() + graph.show_grid = False + graph.update_grid_visibility() + self.assertEqual(show_grid.call_args[1], dict(x=False, y=False)) + + graph.show_grid = True + graph.update_grid_visibility() + self.assertEqual(show_grid.call_args[1], dict(x=True, y=True)) + + def test_show_legend(self): + graph = self.graph + graph.reset_graph() + + shape_legend = self.graph.shape_legend.setVisible = Mock() + color_legend = self.graph.color_legend.setVisible = Mock() + shape_labels = color_labels = None # Avoid pylint warning + self.master.get_shape_labels = lambda: shape_labels + self.master.get_color_labels = lambda: color_labels + for shape_labels in (None, ["a", "b"]): + for color_labels in (None, ["c", "d"], None): + for visible in (True, False, True): + graph.show_legend = visible + graph.update_legends() + self.assertIs( + shape_legend.call_args[0][0], + visible and bool(shape_labels), + msg="error at {}, {}".format(visible, shape_labels)) + self.assertIs( + color_legend.call_args[0][0], + visible and bool(color_labels), + msg="error at {}, {}".format(visible, color_labels)) + + def test_show_legend_no_data(self): + graph = self.graph + self.master.get_shape_labels = lambda: ["a", "b"] + self.master.get_color_labels = lambda: ["c", "d"] + self.master.get_shape_data = lambda: np.arange(10) % 2 + self.master.get_color_data = lambda: np.arange(10) < 6 + graph.reset_graph() + + shape_legend = self.graph.shape_legend.setVisible = Mock() + color_legend = self.graph.color_legend.setVisible = Mock() + self.master.get_coordinates_data = lambda: (None, None) + graph.reset_graph() + self.assertFalse(shape_legend.call_args[0][0]) + self.assertFalse(color_legend.call_args[0][0]) + + def test_legend_combine(self): + master = self.master + graph = self.graph + graph.reset_graph() + + shape_legend = self.graph.shape_legend.setVisible = Mock() + color_legend = self.graph.color_legend.setVisible = Mock() + + master.get_shape_labels = lambda: ["a", "b"] + master.get_color_labels = lambda: ["c", "d"] + graph.update_legends() + self.assertTrue(shape_legend.call_args[0][0]) + self.assertTrue(color_legend.call_args[0][0]) + + master.get_color_labels = lambda: ["a", "b"] + graph.update_legends() + self.assertTrue(shape_legend.call_args[0][0]) + self.assertFalse(color_legend.call_args[0][0]) + self.assertEqual(len(graph.shape_legend.items), 2) + + master.is_continuous_color = lambda: True + master.get_color_data = lambda: np.arange(10, dtype=float) + graph.update_colors() + self.assertTrue(shape_legend.call_args[0][0]) + self.assertTrue(color_legend.call_args[0][0]) + self.assertEqual(len(graph.shape_legend.items), 2) + + def test_select_by_click(self): + graph = self.graph + graph.reset_graph() + points = graph.scatterplot_item.points() + graph.select_by_click(None, [points[2]]) + np.testing.assert_almost_equal(graph.get_selection(), [2]) + with patch("AnyQt.QtWidgets.QApplication.keyboardModifiers", + lambda: Qt.ShiftModifier): + graph.select_by_click(None, points[3:6]) + np.testing.assert_almost_equal( + list(graph.get_selection()), [2, 3, 4, 5]) + np.testing.assert_almost_equal( + graph.selection, [0, 0, 1, 2, 2, 2, 0, 0, 0, 0]) + + def test_select_by_rectangle(self): + graph = self.graph + coords = np.array( + [(x, y) for y in range(10) for x in range(10)], dtype=float).T + self.master.get_coordinates_data = lambda: coords + + graph.reset_graph() + graph.select_by_rectangle(QRectF(3, 5, 3.9, 2.9)) + self.assertTrue( + all(selected == (3 <= coords[0][i] <= 6 and 5 <= coords[1][i] <= 7) + for i, selected in enumerate(graph.selection))) + + def test_select_by_indices(self): + graph = self.graph + graph.reset_graph() + graph.label_only_selected = True + + def select(modifiers, indices): + with patch("AnyQt.QtWidgets.QApplication.keyboardModifiers", + lambda: modifiers): + graph.update_selection_colors = Mock() + graph.update_labels = Mock() + self.master.selection_changed = Mock() + + graph.select_by_indices(np.array(indices)) + graph.update_selection_colors.assert_called_with() + if graph.label_only_selected: + graph.update_labels.assert_called_with() + else: + graph.update_labels.assert_not_called() + self.master.selection_changed.assert_called_with() + + select(0, [7, 8, 9]) + np.testing.assert_almost_equal( + graph.selection, [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]) + + select(Qt.ShiftModifier | Qt.ControlModifier, [5, 6]) + np.testing.assert_almost_equal( + graph.selection, [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + + select(Qt.ShiftModifier, [3, 4, 5]) + np.testing.assert_almost_equal( + graph.selection, [0, 0, 0, 2, 2, 2, 1, 1, 1, 1]) + + select(Qt.AltModifier, [1, 3, 7]) + np.testing.assert_almost_equal( + graph.selection, [0, 0, 0, 0, 2, 2, 1, 0, 1, 1]) + + select(0, [1, 8]) + np.testing.assert_almost_equal( + graph.selection, [0, 1, 0, 0, 0, 0, 0, 0, 1, 0]) + + graph.label_only_selected = False + select(0, [3, 4]) + + def test_unselect_all(self): + graph = self.graph + graph.reset_graph() + graph.label_only_selected = True + + graph.select_by_indices([3, 4, 5]) + np.testing.assert_almost_equal( + graph.selection, [0, 0, 0, 1, 1, 1, 0, 0, 0, 0]) + + graph.update_selection_colors = Mock() + graph.update_labels = Mock() + self.master.selection_changed = Mock() + + graph.unselect_all() + self.assertIsNone(graph.selection) + graph.update_selection_colors.assert_called_with() + graph.update_labels.assert_called_with() + self.master.selection_changed.assert_called_with() + + graph.update_selection_colors.reset_mock() + graph.update_labels.reset_mock() + self.master.selection_changed.reset_mock() + + graph.unselect_all() + self.assertIsNone(graph.selection) + graph.update_selection_colors.assert_not_called() + graph.update_labels.assert_not_called() + self.master.selection_changed.assert_not_called() + + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/Orange/widgets/visualize/utils/__init__.py b/Orange/widgets/visualize/utils/__init__.py index 4543584ddfb..e2f42328d9c 100644 --- a/Orange/widgets/visualize/utils/__init__.py +++ b/Orange/widgets/visualize/utils/__init__.py @@ -116,12 +116,14 @@ def __init__(self, master): self.setFocus(Qt.ActiveWindowFocusReason) self.rank_model = QStandardItemModel(self) - self.model_proxy = QSortFilterProxyModel(self) + self.model_proxy = QSortFilterProxyModel( + self, filterCaseSensitivity=False) self.model_proxy.setSourceModel(self.rank_model) self.rank_table = view = QTableView( selectionBehavior=QTableView.SelectRows, selectionMode=QTableView.SingleSelection, - showGrid=False) + showGrid=False, + editTriggers=gui.TableView.NoEditTriggers) if self._has_bars: view.setItemDelegate(TableBarItem()) else: diff --git a/Orange/widgets/visualize/utils/component.py b/Orange/widgets/visualize/utils/component.py index 8e6a36b30aa..a34cdd811f9 100644 --- a/Orange/widgets/visualize/utils/component.py +++ b/Orange/widgets/visualize/utils/component.py @@ -1,46 +1,54 @@ """Common gui.OWComponent components.""" -from AnyQt.QtCore import Qt +from AnyQt.QtCore import Qt, QRectF +from AnyQt.QtGui import QColor +from AnyQt.QtWidgets import QGraphicsEllipseItem +import pyqtgraph as pg from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase from Orange.widgets.visualize.utils.plotutils import ( - MouseEventDelegate, VizInteractiveViewBox + MouseEventDelegate, DraggableItemsViewBox ) -class OWVizGraph(OWScatterPlotBase): - """Class is used as a graph base class for OWFreeViz and OWRadviz.""" +class OWGraphWithAnchors(OWScatterPlotBase): + """ + Graph for projections in which dimensions can be manually moved + + Class is used as a graph base class for OWFreeViz and OWRadviz.""" DISTANCE_DIFF = 0.08 - def __init__(self, scatter_widget, parent, view_box=VizInteractiveViewBox): + def __init__(self, scatter_widget, parent, view_box=DraggableItemsViewBox): super().__init__(scatter_widget, parent, view_box) - self._attributes = () - self._points = None - self._point_items = None - self._circle_item = None - self._indicator_item = None + self.anchor_items = None + self.circle_item = None + self.indicator_item = None self._tooltip_delegate = MouseEventDelegate(self.help_event, self.show_indicator_event) self.plot_widget.scene().installEventFilter(self._tooltip_delegate) - self.view_box.sigResized.connect(self.update_density) - - def set_attributes(self, attributes): - self._attributes = attributes - - def set_point(self, i, x, y): - self._points[i][0] = x - self._points[i][1] = y - def set_points(self, points): - self._points = points - - def get_points(self): - return self._points + def clear(self): + super().clear() + self.anchor_items = None + self.circle_item = None + self.indicator_item = None def update_coordinates(self): super().update_coordinates() - self.update_items() - self.set_view_box_range() - self.view_box.setAspectLocked(True, 1) + if self.scatterplot_item is not None: + self.update_anchors() + self.update_circle() + self.set_view_box_range() + self.view_box.setAspectLocked(True, 1) + + def update_anchors(self): + raise NotImplementedError + + def update_circle(self): + if self.scatterplot_item is not None and not self.circle_item: + self.circle_item = QGraphicsEllipseItem() + self.circle_item.setRect(QRectF(-1, -1, 2, 2)) + self.circle_item.setPen(pg.mkPen(QColor(0, 0, 0), width=2)) + self.plot_widget.addItem(self.circle_item) def reset_button_clicked(self): self.set_view_box_range() @@ -48,11 +56,11 @@ def reset_button_clicked(self): def set_view_box_range(self): raise NotImplementedError - def can_show_indicator(self, pos): - raise NotImplementedError + def closest_draggable_item(self, pos): + return None - def show_indicator(self, point_i): - self._update_indicator_item(point_i) + def show_indicator(self, anchor_idx): + self._update_indicator_item(anchor_idx) def show_indicator_event(self, ev): scene = self.plot_widget.scene() @@ -62,54 +70,22 @@ def show_indicator_event(self, ev): return True pos = self.scatterplot_item.mapFromScene(ev.scenePos()) - can_show, point_i = self.can_show_indicator(pos) - if can_show: - self._update_indicator_item(point_i) + anchor_idx = self.closest_draggable_item(pos) + if anchor_idx is not None: + self._update_indicator_item(anchor_idx) if self.view_box.mouse_state == 0: self.view_box.setCursor(Qt.OpenHandCursor) else: - if self._indicator_item is not None: - self.plot_widget.removeItem(self._indicator_item) - self._indicator_item = None + if self.indicator_item is not None: + self.plot_widget.removeItem(self.indicator_item) + self.indicator_item = None self.view_box.setCursor(Qt.ArrowCursor) return True - def update_items(self): - self._update_point_items() - self._update_circle_item() - - def _update_point_items(self): - self._remove_point_items() - self._add_point_items() + def _update_indicator_item(self, anchor_idx): + if self.indicator_item is not None: + self.plot_widget.removeItem(self.indicator_item) + self._add_indicator_item(anchor_idx) - def _update_circle_item(self): - self._remove_circle_item() - self._add_circle_item() - - def _update_indicator_item(self, point_i): - self._remove_indicator_item() - self._add_indicator_item(point_i) - - def _remove_point_items(self): - if self._point_items is not None: - self.plot_widget.removeItem(self._point_items) - self._point_items = None - - def _remove_circle_item(self): - if self._circle_item is not None: - self.plot_widget.removeItem(self._circle_item) - self._circle_item = None - - def _remove_indicator_item(self): - if self._indicator_item is not None: - self.plot_widget.removeItem(self._indicator_item) - self._indicator_item = None - - def _add_point_items(self): - raise NotImplementedError - - def _add_circle_item(self): - raise NotImplementedError - - def _add_indicator_item(self, point_i): - raise NotImplementedError + def _add_indicator_item(self, anchor_idx): + pass diff --git a/Orange/widgets/visualize/utils/plotutils.py b/Orange/widgets/visualize/utils/plotutils.py index 355463ea81e..6f024090742 100644 --- a/Orange/widgets/visualize/utils/plotutils.py +++ b/Orange/widgets/visualize/utils/plotutils.py @@ -21,6 +21,10 @@ def setAnchor(self, anchor): self.anchor = pg.Point(anchor) self.updateText() + def get_xy(self): + point = self.pos() + return point.x(), point.y() + class AnchorItem(pg.GraphicsObject): def __init__(self, parent=None, line=QLineF(), text="", **kwargs): @@ -38,11 +42,15 @@ def __init__(self, parent=None, line=QLineF(), text="", **kwargs): self._label = TextItem(text=text, color=(10, 10, 10)) self._label.setParentItem(self) - self._label.setPos(self._spine.line().p2()) + self._label.setPos(*self.get_xy()) if parent is not None: self.setParentItem(parent) + def get_xy(self): + point = self._spine.line().p2() + return point.x(), point.y() + def setText(self, text): if text != self._text: self._text = text @@ -236,26 +244,41 @@ def gestureEvent(self, event): return True -class VizInteractiveViewBox(InteractiveViewBox): - started = Signal() - moved = Signal() - finished = Signal() +class DraggableItemsViewBox(InteractiveViewBox): + """ + A viewbox with draggable items + + Graph that uses it must provide two methods: + - `closest_draggable_item(pos)` returns an int representing the id of the + draggable item that is closest (and close enough) to `QPoint` pos, or + `None`; + - `show_indicator(item_id)` shows or updates an indicator for moving + the item with the given `item_id`. + + Viewbox emits three signals: + - `started = Signal(item_id)` + - `moved = Signal(item_id, x, y)` + - `finished = Signal(item_id, x, y)` + """ + started = Signal(int) + moved = Signal(int, float, float) + finished = Signal(int, float, float) def __init__(self, graph, enable_menu=False): self.mouse_state = 0 - self.point_i = None + self.item_id = None super().__init__(graph, enable_menu) def mousePressEvent(self, ev): super().mousePressEvent(ev) pos = self.childGroup.mapFromParent(ev.pos()) - if self.graph.can_show_indicator(pos)[0]: + if self.graph.closest_draggable_item(pos) is not None: self.setCursor(Qt.ClosedHandCursor) def mouseDragEvent(self, ev, axis=None): pos = self.childGroup.mapFromParent(ev.pos()) - can_show, point_i = self.graph.can_show_indicator(pos) - if ev.button() != Qt.LeftButton or (ev.start and not can_show): + item_id = self.graph.closest_draggable_item(pos) + if ev.button() != Qt.LeftButton or (ev.start and item_id is None): self.mouse_state = 2 if self.mouse_state == 2: if ev.finish: @@ -267,19 +290,18 @@ def mouseDragEvent(self, ev, axis=None): if ev.start: self.setCursor(Qt.ClosedHandCursor) self.mouse_state = 1 - self.point_i = point_i - self.started.emit() + self.item_id = item_id + self.started.emit(self.item_id) if self.mouse_state == 1: - self.graph.set_point(self.point_i, pos.x(), pos.y()) if ev.finish: - self.setCursor(Qt.OpenHandCursor) self.mouse_state = 0 - self.finished.emit() + self.finished.emit(self.item_id, pos.x(), pos.y()) + if self.graph.closest_draggable_item(pos) is not None: + self.setCursor(Qt.OpenHandCursor) + else: + self.setCursor(Qt.ArrowCursor) + self.item_id = None else: - self._show_tooltip(ev) - self.moved.emit() - self.graph.show_indicator(self.point_i) - - def _show_tooltip(self, ev): - pass + self.moved.emit(self.item_id, pos.x(), pos.y()) + self.graph.show_indicator(self.item_id) diff --git a/Orange/widgets/visualize/utils/widget.py b/Orange/widgets/visualize/utils/widget.py new file mode 100644 index 00000000000..968548c68d5 --- /dev/null +++ b/Orange/widgets/visualize/utils/widget.py @@ -0,0 +1,671 @@ +from collections import Counter, defaultdict +from xml.sax.saxutils import escape +from math import log2 + +import numpy as np + +from AnyQt.QtCore import QSize +from AnyQt.QtWidgets import QApplication + +from Orange.data import Table, ContinuousVariable, Domain, Variable +from Orange.data.sql.table import SqlTable +from Orange.statistics.util import bincount + +from Orange.widgets import gui, report +from Orange.widgets.settings import ( + Setting, ContextSetting, DomainContextHandler, SettingProvider +) +from Orange.widgets.utils.annotated_data import ( + create_annotated_table, ANNOTATED_DATA_SIGNAL_NAME, create_groups_table, + get_unique_names +) +from Orange.widgets.utils.colorpalette import ( + ColorPaletteGenerator, ContinuousPaletteGenerator, DefaultRGBColors +) +from Orange.widgets.utils.plot import OWPlotGUI +from Orange.widgets.utils.sql import check_sql_input +from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase +from Orange.widgets.visualize.utils.component import OWGraphWithAnchors +from Orange.widgets.widget import OWWidget, Input, Output, Msg + +MAX_CATEGORIES = 11 # maximum number of colors or shapes (including Other) +MAX_POINTS_IN_TOOLTIP = 5 + + +class OWProjectionWidgetBase(OWWidget): + """ + Base widget for widgets that use attribute data to set the colors, labels, + shapes and sizes of points. + + The widgets defines settings `attr_color`, `attr_label`, `attr_shape` + and `attr_size`, but leaves defining the gui to the derived widgets. + These are expected to have controls that manipulate these settings, + and the controls are expected to use attribute models. + + The widgets also defines attributes `data` and `valid_data` and expects + the derived widgets to use them to store an instances of `data.Table` + and a bool `np.ndarray` with indicators of valid (that is, shown) + data points. + """ + attr_color = ContextSetting(None, required=ContextSetting.OPTIONAL) + attr_label = ContextSetting(None, required=ContextSetting.OPTIONAL) + attr_shape = ContextSetting(None, required=ContextSetting.OPTIONAL) + attr_size = ContextSetting(None, required=ContextSetting.OPTIONAL) + + class Information(OWWidget.Information): + missing_size = Msg( + "Points with undefined '{}' are shown in smaller size") + missing_shape = Msg( + "Points with undefined '{}' are shown as crossed circles") + + def __init__(self): + super().__init__() + self.data = None + self.valid_data = None + + def init_attr_values(self): + """ + Set the models for `attr_color`, `attr_shape`, `attr_size` and + `attr_label`. All values are set to `None`, except `attr_color` + which is set to the class variable if it exists. + """ + data = self.data + domain = data.domain if data and len(data) else None + for attr in ("attr_color", "attr_shape", "attr_size", "attr_label"): + getattr(self.controls, attr).model().set_domain(domain) + setattr(self, attr, None) + if domain is not None: + self.attr_color = domain.class_var + + def get_coordinates_data(self): + """A get coordinated method that returns no coordinates. + + Derived classes must override this method. + """ + return None, None + + def get_subset_mask(self): + """ + Return the bool array indicating the points in the subset + + The base method does nothing and would usually be overridden by + a method that returns indicators from the subset signal. + + Do not confuse the subset with selection. + + Returns: + (np.ndarray or `None`): a bool array of indicators + """ + return None + + @staticmethod + def __get_overlap_groups(x, y): + coord_to_id = defaultdict(list) + for i, xy in enumerate(zip(x, y)): + coord_to_id[xy].append(i) + return coord_to_id + + def get_column(self, attr, filter_valid=True, + merge_infrequent=False, return_labels=False): + """ + Retrieve the data from the given column in the data table + + The method: + - densifies sparse data, + - converts arrays with dtype object to floats if the attribute is + actually primitive, + - filters out invalid data (if `filter_valid` is `True`), + - merges infrequent (discrete) values into a single value + (if `merge_infrequent` is `True`). + + Tha latter feature is used for shapes and labels, where only a + set number (`MAX`) of different values is shown, and others are + merged into category 'Other'. In this case, the method may return + either the data (e.g. color indices, shape indices) or the list + of retained values, followed by `['Other']`. + + Args: + attr (:obj:~Orange.data.Variable): the column to extract + filter_valid (bool): filter out invalid data (default: `True`) + merge_infrequent (bool): merge infrequent values (default: `False`); + ignored for non-discrete attributes + return_labels (bool): return a list of labels instead of data + (default: `False`) + + Returns: + (np.ndarray): (valid) data from the column, or a list of labels + """ + if attr is None: + return None + + needs_merging = \ + attr.is_discrete \ + and merge_infrequent and len(attr.values) >= MAX_CATEGORIES + if return_labels and not needs_merging: + assert attr.is_discrete + return attr.values + + all_data = self.data.get_column_view(attr)[0] + if all_data.dtype == object and attr.is_primitive(): + all_data = all_data.astype(float) + if filter_valid and self.valid_data is not None: + all_data = all_data[self.valid_data] + if not needs_merging: + return all_data + + dist = bincount(all_data, max_val=len(attr.values) - 1)[0] + infrequent = np.zeros(len(attr.values), dtype=bool) + infrequent[np.argsort(dist)[:-(MAX_CATEGORIES-1)]] = True + if return_labels: + return [value for value, infreq in zip(attr.values, infrequent) + if not infreq] + ["Other"] + else: + result = all_data.copy() + freq_vals = [i for i, f in enumerate(infrequent) if not f] + for i, infreq in enumerate(infrequent): + if infreq: + result[all_data == i] = MAX_CATEGORIES - 1 + else: + result[all_data == i] = freq_vals.index(i) + return result + + # Sizes + def get_size_data(self): + """Return the column corresponding to `attr_size`""" + if self.attr_size == OWPlotGUI.SizeByOverlap: + x, y = self.get_coordinates_data() + coord_to_id = self.__get_overlap_groups(x, y) + overlaps = [len(coord_to_id[xy]) for xy in zip(x, y)] + return [1 + log2(o) for o in overlaps] + return self.get_column(self.attr_size) + + def impute_sizes(self, size_data): + """ + Default imputation for size data + + Let the graph handle it, but add a warning if needed. + + Args: + size_data (np.ndarray): scaled points sizes + """ + if self.graph.default_impute_sizes(size_data): + self.Information.missing_size(self.attr_size) + else: + self.Information.missing_size.clear() + + def sizes_changed(self): + self.graph.update_sizes() + self.graph.update_colors() # Needed for overlapping + + # Colors + def get_color_data(self): + """Return the column corresponding to color data""" + colors = self.get_column(self.attr_color, merge_infrequent=True) + if self.attr_size == OWPlotGUI.SizeByOverlap: + # color overlapping points by most frequent color + x, y = self.get_coordinates_data() + coord_to_id = self.__get_overlap_groups(x, y) + majority_colors = np.empty(len(x)) + for i, xy in enumerate(zip(x, y)): + cnt = Counter(colors[j] for j in coord_to_id[xy]) + majority_colors[i] = cnt.most_common(1)[0][0] + return majority_colors + return colors + + def get_color_labels(self): + """ + Return labels for the color legend + + Returns: + (list of str): labels + """ + return self.get_column(self.attr_color, merge_infrequent=True, + return_labels=True) + + def is_continuous_color(self): + """ + Tells whether the color is continuous + + Returns: + (bool): + """ + return self.attr_color is not None and self.attr_color.is_continuous + + def get_palette(self): + """ + Return a palette suitable for the current `attr_color` + + This method must be overridden if the widget offers coloring that is + not based on attribute values. + """ + if self.attr_color is None: + return None + colors = self.attr_color.colors + if self.attr_color.is_discrete: + return ColorPaletteGenerator( + number_of_colors=min(len(colors), MAX_CATEGORIES), + rgb_colors=colors if len(colors) <= MAX_CATEGORIES + else DefaultRGBColors) + else: + return ContinuousPaletteGenerator(*colors) + + def can_draw_density(self): + """ + Tells whether the current data and settings are suitable for drawing + densities + + Returns: + (bool): + """ + return self.data is not None and self.data.domain is not None and \ + len(self.data) > 1 and self.attr_color is not None + + def colors_changed(self): + self.graph.update_colors() + self.cb_class_density.setEnabled(self.can_draw_density()) + + # Labels + def get_label_data(self, formatter=None): + """Return the column corresponding to label data""" + if self.attr_label: + label_data = self.get_column(self.attr_label) + if formatter is None: + formatter = self.attr_label.str_val + return np.array([formatter(x) for x in label_data]) + return None + + def labels_changed(self): + self.graph.update_labels() + + # Shapes + def get_shape_data(self): + """ + Return labels for the shape legend + + Returns: + (list of str): labels + """ + return self.get_column(self.attr_shape, merge_infrequent=True) + + def get_shape_labels(self): + return self.get_column(self.attr_shape, merge_infrequent=True, + return_labels=True) + + def impute_shapes(self, shape_data, default_symbol): + """ + Default imputation for shape data + + Let the graph handle it, but add a warning if needed. + + Args: + shape_data (np.ndarray): scaled points sizes + default_symbol (str): a string representing the symbol + """ + if self.graph.default_impute_shapes(shape_data, default_symbol): + self.Information.missing_shape(self.attr_shape) + else: + self.Information.missing_shape.clear() + + def shapes_changed(self): + self.graph.update_shapes() + + # Tooltip + def _point_tooltip(self, point_id, skip_attrs=()): + def show_part(_point_data, singular, plural, max_shown, _vars): + cols = [escape('{} = {}'.format(var.name, _point_data[var])) + for var in _vars[:max_shown + 2] + if _vars == domain.class_vars + or var not in skip_attrs][:max_shown] + if not cols: + return "" + n_vars = len(_vars) + if n_vars > max_shown: + cols[-1] = "... and {} others".format(n_vars - max_shown + 1) + return \ + "{}:
".format(singular if n_vars < 2 else plural) \ + + "
".join(cols) + + domain = self.data.domain + parts = (("Class", "Classes", 4, domain.class_vars), + ("Meta", "Metas", 4, domain.metas), + ("Feature", "Features", 10, domain.attributes)) + + point_data = self.data[point_id] + return "
".join(show_part(point_data, *columns) + for columns in parts) + + def get_tooltip(self, point_ids): + """ + Return the tooltip string for the given points + + The method is called by the plot on mouse hover + + Args: + point_ids (list): indices into `data` + + Returns: + (str): + """ + text = "
".join(self._point_tooltip(point_id) + for point_id in point_ids[:MAX_POINTS_IN_TOOLTIP]) + if len(point_ids) > MAX_POINTS_IN_TOOLTIP: + text = "{} instances
{}
...".format(len(point_ids), text) + return text + + def keyPressEvent(self, event): + """Update the tip about using the modifier keys when selecting""" + super().keyPressEvent(event) + self.graph.update_tooltip(event.modifiers()) + + def keyReleaseEvent(self, event): + """Update the tip about using the modifier keys when selecting""" + super().keyReleaseEvent(event) + self.graph.update_tooltip(event.modifiers()) + + +class OWDataProjectionWidget(OWProjectionWidgetBase): + """ + Base widget for widgets that get Data and Data Subset (both + Orange.data.Table) on the input, and output Selected Data and Data + (both Orange.data.Table). + + Beside that the widget displays data as two-dimensional projection + of points. + """ + class Inputs: + data = Input("Data", Table, default=True) + data_subset = Input("Data Subset", Table) + + class Outputs: + selected_data = Output("Selected Data", Table, default=True) + annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table) + + settingsHandler = DomainContextHandler() + selection = Setting(None, schema_only=True) + auto_commit = Setting(True) + + GRAPH_CLASS = OWScatterPlotBase + graph = SettingProvider(OWScatterPlotBase) + graph_name = "graph.plot_widget.plotItem" + embedding_variables_names = ("proj-x", "proj-y") + + def __init__(self): + super().__init__() + self.subset_data = None + self.subset_indices = None + self.__pending_selection = self.selection + self.setup_gui() + + # GUI + def setup_gui(self): + self._add_graph() + self._add_controls() + + def _add_graph(self): + box = gui.vBox(self.mainArea, True, margin=0) + self.graph = self.GRAPH_CLASS(self, box) + box.layout().addWidget(self.graph.plot_widget) + + def _add_controls(self): + self._point_box = self.graph.gui.point_properties_box(self.controlArea) + self._effects_box = self.graph.gui.effects_box(self.controlArea) + self._plot_box = self.graph.gui.plot_properties_box(self.controlArea) + self.control_area_stretch = gui.widgetBox(self.controlArea) + self.control_area_stretch.layout().addStretch(100) + self.graph.box_zoom_select(self.controlArea) + gui.auto_commit(self.controlArea, self, "auto_commit", + "Send Selection", "Send Automatically") + + # Input + @Inputs.data + @check_sql_input + def set_data(self, data): + same_domain = (self.data and data and + data.domain.checksum() == self.data.domain.checksum()) + self.closeContext() + self.clear() + self.data = data + self.check_data() + if not same_domain: + self.init_attr_values() + self.openContext(self.data) + self.cb_class_density.setEnabled(self.can_draw_density()) + + def check_data(self): + self.clear_messages() + + @Inputs.data_subset + @check_sql_input + def set_subset_data(self, subset): + self.subset_data = subset + self.subset_indices = {e.id for e in subset} \ + if subset is not None else {} + self.controls.graph.alpha_value.setEnabled(subset is None) + + def handleNewSignals(self): + self.setup_plot() + self.commit() + + def get_subset_mask(self): + if self.subset_indices: + return np.array([ex.id in self.subset_indices + for ex in self.data[self.valid_data]]) + return None + + # Plot + def get_embedding(self): + """A get embedding method. + + Derived classes must override this method. The overridden method + should return embedding for all data (valid and invalid). Invalid + data embedding coordinates should be set to 0 (in some cases to Nan). + + The method should also sets self.valid_data. + + Returns: + np.array: Array of embedding coordinates with shape + len(self.data) x 2 + """ + raise NotImplementedError + + def get_coordinates_data(self): + embedding = self.get_embedding() + return embedding[self.valid_data].T[:2] if embedding is not None \ + else (None, None) + + def setup_plot(self): + self.graph.reset_graph() + self.__pending_selection = self.selection or self.__pending_selection + self.apply_selection() + + # Selection + def apply_selection(self): + if self.data is not None and self.__pending_selection is not None \ + and self.graph.n_valid: + index_group = [(index, group) for index, group in + self.__pending_selection if index < len(self.data)] + index_group = np.array(index_group).T + selection = np.zeros(self.graph.n_valid, dtype=np.uint8) + selection[index_group[0]] = index_group[1] + + self.selection = self.__pending_selection + self.__pending_selection = None + self.graph.selection = selection + self.graph.update_selection_colors() + + def selection_changed(self): + sel = None if self.data and isinstance(self.data, SqlTable) \ + else self.graph.selection + self.selection = [(i, x) for i, x in enumerate(sel) if x] \ + if sel is not None else None + self.commit() + + # Output + def commit(self): + self.send_data() + + def send_data(self): + group_sel, data, graph = None, self._get_projection_data(), self.graph + if graph.selection is not None: + group_sel = np.zeros(len(data), dtype=int) + group_sel[self.valid_data] = graph.selection + self.Outputs.selected_data.send( + self._get_selected_data(data, graph.get_selection(), group_sel)) + self.Outputs.annotated_data.send( + self._get_annotated_data(data, graph.get_selection(), group_sel, + graph.selection)) + + def _get_projection_data(self): + if self.data is None or self.embedding_variables_names is None: + return self.data + variables = self._get_projection_variables() + data = self.data.transform(Domain(self.data.domain.attributes, + self.data.domain.class_vars, + self.data.domain.metas + variables)) + data.metas[:, -2:] = self.get_embedding() + return data + + def _get_projection_variables(self): + domain = self.data.domain + names = get_unique_names( + [v.name for v in domain.variables + domain.metas], + self.embedding_variables_names + ) + return ContinuousVariable(names[0]), ContinuousVariable(names[1]) + + @staticmethod + def _get_selected_data(data, selection, group_sel): + return create_groups_table(data, group_sel, False, "Group") \ + if len(selection) else None + + @staticmethod + def _get_annotated_data(data, selection, group_sel, graph_sel): + if graph_sel is not None and np.max(graph_sel) > 1: + return create_groups_table(data, group_sel) + else: + return create_annotated_table(data, selection) + + # Report + def send_report(self): + if self.data is None: + return + + caption = self._get_send_report_caption() + self.report_plot() + if caption: + self.report_caption(caption) + + def _get_send_report_caption(self): + return report.render_items_vert(( + ("Color", self._get_caption_var_name(self.attr_color)), + ("Label", self._get_caption_var_name(self.attr_label)), + ("Shape", self._get_caption_var_name(self.attr_shape)), + ("Size", self._get_caption_var_name(self.attr_size)), + ("Jittering", self.graph.jitter_size != 0 and + "{} %".format(self.graph.jitter_size)))) + + @staticmethod + def _get_caption_var_name(var): + return var.name if isinstance(var, Variable) else var + + # Misc + def sizeHint(self): + return QSize(1132, 708) + + def clear(self): + self.data = None + self.valid_data = None + self.selection = None + + def onDeleteWidget(self): + super().onDeleteWidget() + self.graph.plot_widget.getViewBox().deleteLater() + self.graph.plot_widget.clear() + + +class OWAnchorProjectionWidget(OWDataProjectionWidget): + """ Base widget for widgets with graphs with anchors. """ + SAMPLE_SIZE = 100 + + GRAPH_CLASS = OWGraphWithAnchors + graph = SettingProvider(OWGraphWithAnchors) + + class Outputs(OWDataProjectionWidget.Outputs): + components = Output("Components", Table) + + class Error(OWDataProjectionWidget.Error): + sparse_data = Msg("Sparse data is not supported") + no_valid_data = Msg("No projection due to no valid data") + not_enough_features = Msg("At least two features are required") + + def __init__(self): + super().__init__() + self.projection = None + self.graph.view_box.started.connect(self._manual_move_start) + self.graph.view_box.moved.connect(self._manual_move) + self.graph.view_box.finished.connect(self._manual_move_finish) + + def check_data(self): + def error(err): + err() + self.data = None + + super().check_data() + if self.data is not None: + if self.data.is_sparse(): + error(self.Error.sparse_data) + else: + self.valid_data = np.all(np.isfinite(self.data.X), axis=1) + if not np.sum(self.valid_data): + error(self.Error.no_valid_data) + + def get_anchors(self): + raise NotImplementedError + + def _manual_move_start(self): + self.graph.set_sample_size(self.SAMPLE_SIZE) + + def _manual_move(self, anchor_idx, x, y): + self.projection[anchor_idx] = [x, y] + self.graph.update_coordinates() + + def _manual_move_finish(self, anchor_idx, x, y): + self._manual_move(anchor_idx, x, y) + self.graph.set_sample_size(None) + self.commit() + + def commit(self): + super().commit() + self.send_components() + + def send_components(self): + raise NotImplementedError + + def clear(self): + super().clear() + self.projection = None + + +if __name__ == "__main__": + class OWProjectionWidgetWithName(OWDataProjectionWidget): + name = "projection" + + def get_embedding(self): + if self.data is None: + return None + self.valid_data = np.any(np.isfinite(self.data.X), 1) + x_data = self.data.X + x_data[x_data == np.inf] = np.nan + x_data = np.nanmean(x_data[self.valid_data], 1) + y_data = np.ones(len(x_data)) + return np.vstack((x_data, y_data)).T + + app = QApplication([]) + ow = OWProjectionWidgetWithName() + table = Table("iris") + ow.set_data(table) + ow.set_subset_data(table[::10]) + ow.handleNewSignals() + ow.show() + app.exec_() + ow.saveSettings()