Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Projections: Retain embedding if non-relevant variables change #3428

Merged
merged 3 commits into from
Nov 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions Orange/widgets/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,13 +894,22 @@ def test_none_data(self):
"""Test widget for empty dataset"""
self.send_signal(self.widget.Inputs.data, self.data[:0])

def test_subset_data(self, timeout=DEFAULT_TIMEOUT):
"""Test widget for subset data"""
def test_plot_once(self, timeout=DEFAULT_TIMEOUT):
"""Test if data is plotted only once but committed on every input change"""
self.widget.setup_plot = Mock()
self.widget.commit = Mock()
self.send_signal(self.widget.Inputs.data, self.data)
self.widget.setup_plot.assert_called_once()
self.widget.commit.assert_called_once()

if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.commit.reset_mock()
self.send_signal(self.widget.Inputs.data_subset, self.data[::10])
self.widget.setup_plot.assert_called_once()
self.widget.commit.assert_called_once()

def test_class_density(self, timeout=DEFAULT_TIMEOUT):
"""Check class density update"""
Expand Down Expand Up @@ -932,6 +941,24 @@ def test_sparse_data(self, timeout=DEFAULT_TIMEOUT):
self.send_signal(self.widget.Inputs.data_subset, table[::30])
self.assertEqual(len(self.widget.subset_indices), 5)

def test_invalidated_embedding(self, timeout=DEFAULT_TIMEOUT):
"""Check if graph has been replotted when sending same data"""
self.widget.graph.update_coordinates = Mock()
self.widget.graph.update_point_props = Mock()
self.send_signal(self.widget.Inputs.data, self.data)
self.widget.graph.update_coordinates.assert_called_once()
self.widget.graph.update_point_props.assert_called_once()

if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.graph.update_coordinates.reset_mock()
self.widget.graph.update_point_props.reset_mock()
self.send_signal(self.widget.Inputs.data, self.data)
self.widget.graph.update_coordinates.assert_not_called()
self.widget.graph.update_point_props.assert_called_once()

def test_send_report(self, timeout=DEFAULT_TIMEOUT):
"""Test report """
self.send_signal(self.widget.Inputs.data, self.data)
Expand Down
42 changes: 22 additions & 20 deletions Orange/widgets/unsupervised/owmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(self):
#: Input data table
self.signal_data = None

self._invalidated = False
self.__invalidated = True
self.embedding = None
self.effective_matrix = None

Expand Down Expand Up @@ -212,15 +212,6 @@ def set_data(self, data):

self.signal_data = data

if self.matrix is not None and data is not None and \
len(self.matrix) == len(data):
self.closeContext()
self.data = data
self.init_attr_values()
self.openContext(data)
else:
self._invalidated = True

@Inputs.distances
def set_disimilarity(self, matrix):
"""Set the dissimilarity (distance) matrix.
Expand All @@ -238,30 +229,33 @@ def set_disimilarity(self, matrix):

self.matrix = matrix
self.matrix_data = matrix.row_items if matrix is not None else None
self._invalidated = True

def clear(self):
super().clear()
self.embedding = None
self.effective_matrix = None
self.graph.set_effective_matrix(None)
self.__set_update_loop(None)
self.__state = OWMDS.Waiting

def _initialize(self):
matrix_existed = self.effective_matrix is not None
effective_matrix = self.effective_matrix
self.__invalidated = True
self.data = None
self.effective_matrix = None
self.closeContext()
self.clear()
self.clear_messages()

# if no data nor matrix is present reset plot
if self.signal_data is None and self.matrix is None:
self.data = None
self.clear()
self.init_attr_values()
return

if self.signal_data is not None and self.matrix is not None and \
len(self.signal_data) != len(self.matrix):
self.Error.mismatching_dimensions()
self.clear()
self.init_attr_values()
return

Expand All @@ -279,11 +273,18 @@ def _initialize(self):
self.effective_matrix = Euclidean(preprocessed_data)
else:
self.Error.no_attributes()
self.clear()
self.init_attr_values()
return

self.init_attr_values()
self.openContext(self.data)
self.__invalidated = not (matrix_existed and
self.effective_matrix is not None and
np.array_equal(effective_matrix,
self.effective_matrix))
if self.__invalidated:
self.clear()
self.graph.set_effective_matrix(self.effective_matrix)

def _toggle_run(self):
Expand Down Expand Up @@ -407,7 +408,6 @@ def __next_step(self):
self.__set_update_loop(None)
self.unconditional_commit()
self.graph.resume_drawing_pairs()
self.graph.update_coordinates()
except MemoryError:
self.Error.out_of_memory()
self.__set_update_loop(None)
Expand Down Expand Up @@ -446,6 +446,7 @@ def jitter_coord(part):
self.__set_update_loop(None)

if self.effective_matrix is None:
self.graph.reset_graph()
return

X = self.effective_matrix
Expand Down Expand Up @@ -477,15 +478,16 @@ def __invalidate_refresh(self):
self.__start()

def handleNewSignals(self):
if self._invalidated:
self._initialize()
if self.__invalidated:
self.graph.pause_drawing_pairs()
self._invalidated = False
self._initialize()
self.__invalidated = False
self.__invalidate_embedding()
self.cb_class_density.setEnabled(self.can_draw_density())
self.start()

super().handleNewSignals()
else:
self.graph.update_point_props()
self.commit()

def _invalidate_output(self):
self.commit()
Expand Down
28 changes: 8 additions & 20 deletions Orange/widgets/unsupervised/owtsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(self):
super().__init__()
self.pca_data = None
self.projection = None
self.__invalidated = True
self.__update_loop = None
# timer for scheduling updates
self.__timer = QTimer(self, singleShot=True, interval=1,
Expand Down Expand Up @@ -112,11 +111,6 @@ def _add_controls_start_box(self):
gui.hSlider(box, self, "pca_components", label="PCA components:",
minValue=2, maxValue=50, step=1)

def set_data(self, data):
self.__invalidated = not (self.data and data and
np.array_equal(self.data.X, data.X))
super().set_data(data)

def check_data(self):
def error(err):
err()
Expand Down Expand Up @@ -251,14 +245,9 @@ def __next_step(self):

self.__in_next_step = False

def handleNewSignals(self):
if self.__invalidated:
self.__invalidated = False
self.setup_plot()
self.start()
else:
self.graph.update_point_props()
self.commit()
def setup_plot(self):
super().setup_plot()
self.start()

def commit(self):
super().commit()
Expand All @@ -281,12 +270,11 @@ def send_preprocessor(self):
self.Outputs.preprocessor.send(prep)

def clear(self):
if self.__invalidated:
super().clear()
self.__set_update_loop(None)
self.__state = OWtSNE.Waiting
self.pca_data = None
self.projection = None
super().clear()
self.__set_update_loop(None)
self.__state = OWtSNE.Waiting
self.pca_data = None
self.projection = None

@classmethod
def migrate_settings(cls, settings, version):
Expand Down
19 changes: 0 additions & 19 deletions Orange/widgets/unsupervised/tests/test_owtsne.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import unittest
import numpy as np

from AnyQt.QtTest import QSignalSpy

from Orange.data import DiscreteVariable, ContinuousVariable, Domain, Table
from Orange.preprocess import Preprocess
from Orange.widgets.tests.base import (
Expand Down Expand Up @@ -98,23 +96,6 @@ def test_output_preprocessor(self):
self.assertEqual([a.name for a in transformed.domain.attributes],
[m.name for m in output.domain.metas[:2]])

def test_invalidated_embedding(self):
self.widget.graph.update_coordinates = unittest.mock.Mock()
self.widget.graph.update_point_props = unittest.mock.Mock()
self.send_signal(self.widget.Inputs.data, self.data)
self.widget.graph.update_coordinates.assert_called_once()
self.widget.graph.update_point_props.assert_called_once()

if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(5000))

self.widget.graph.update_coordinates.reset_mock()
self.widget.graph.update_point_props.reset_mock()
self.send_signal(self.widget.Inputs.data, self.data)
self.widget.graph.update_coordinates.assert_not_called()
self.widget.graph.update_point_props.assert_called_once()


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions Orange/widgets/utils/itemmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ class PyListModel(QAbstractListModel):
"""
MIME_TYPE = "application/x-Orange-PyListModelData"
Separator = object()
removed = Signal()

def __init__(self, iterable=None, parent=None,
flags=Qt.ItemIsSelectable | Qt.ItemIsEnabled,
Expand Down Expand Up @@ -616,6 +617,7 @@ def removeRows(self, row, count, parent=QModelIndex()):
"""
if not parent.isValid():
del self[row:row + count]
self.removed.emit()
return True
else:
return False
Expand Down
16 changes: 13 additions & 3 deletions Orange/widgets/utils/plot/owplotgui.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from AnyQt.QtWidgets import QWidget, QToolButton, QVBoxLayout, QHBoxLayout, QGridLayout, QMenu, QAction,\
QDialog, QSizePolicy, QPushButton, QListView, QLabel
from AnyQt.QtGui import QIcon, QKeySequence
from AnyQt.QtCore import Qt, pyqtSignal, QPoint, QSize
from AnyQt.QtCore import Qt, pyqtSignal, QPoint, QSize, QObject

from Orange.data import ContinuousVariable, DiscreteVariable
from Orange.widgets import gui
Expand All @@ -52,6 +52,8 @@


class AddVariablesDialog(QDialog):
add = pyqtSignal()

def __init__(self, master, model):
QDialog.__init__(self)

Expand Down Expand Up @@ -134,10 +136,16 @@ def add_variables(self):
del model[i]

self.master.model_selected.extend(variables)
self.add.emit()


class VariablesSelection(QObject):
added = pyqtSignal()
removed = pyqtSignal()

class VariablesSelection:
def __init__(self, master, model_selected, model_other, widget=None):
def __init__(self, master, model_selected, model_other,
widget=None, parent=None):
super().__init__(parent)
self.master = master
self.model_selected = model_selected
self.model_other = model_other
Expand Down Expand Up @@ -197,9 +205,11 @@ def __deactivate_selection(self):
del model[i]

self.model_other.extend(variables)
self.removed.emit()

def _action_add(self):
self.add_variables_dialog = AddVariablesDialog(self, self.model_other)
self.add_variables_dialog.add.connect(lambda: self.added.emit())


class OrientedWidget(QWidget):
Expand Down
9 changes: 5 additions & 4 deletions Orange/widgets/visualize/owlinearprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,7 @@ class Error(OWAnchorProjectionWidget.Error):

def __init__(self):
self.model_selected = VariableListModel(enable_dnd=True)
self.model_selected.rowsInserted.connect(self.__model_selected_changed)
self.model_selected.rowsRemoved.connect(self.__model_selected_changed)
self.model_selected.removed.connect(self.__model_selected_changed)
self.model_other = VariableListModel(enable_dnd=True)

self.vizrank, self.btn_vizrank = LinearProjectionVizRank.add_vizrank(
Expand All @@ -296,6 +295,8 @@ def _add_controls_variables(self):
self.variables_selection = VariablesSelection(
self, self.model_selected, self.model_other, self.controlArea
)
self.variables_selection.added.connect(self.__model_selected_changed)
self.variables_selection.removed.connect(self.__model_selected_changed)
self.variables_selection.add_remove.layout().addWidget(
self.btn_vizrank
)
Expand Down Expand Up @@ -328,6 +329,7 @@ def __vizrank_set_attrs(self, attrs):
self.model_selected[:] = attrs[:]
self.model_other[:] = [var for var in self.continuous_variables
if var not in attrs]
self.__model_selected_changed()

def __model_selected_changed(self):
self.selected_vars = [(var.name, vartype(var)) for var
Expand Down Expand Up @@ -368,6 +370,7 @@ def set_data(self, data):

self._check_options()
self._init_vizrank()
self.init_projection()

def _check_options(self):
buttons = self.radio_placement.buttons
Expand Down Expand Up @@ -423,8 +426,6 @@ def init_attr_values(self):
self.selected_vars = []

def init_projection(self):
if not len(self.effective_variables):
return
if self.placement == self.Placement.Circular:
self.projector = CircularPlacement()
elif self.placement == self.Placement.LDA:
Expand Down
Loading