Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX][ENH] Scatter Plot VizRank: some fixes and regard to color #2787

Merged
merged 2 commits into from
Nov 27, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 52 additions & 42 deletions Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import chain

import numpy as np

from AnyQt.QtCore import Qt, QTimer
Expand All @@ -8,7 +10,7 @@
from sklearn.metrics import r2_score

import Orange
from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable
from Orange.data import Table, Domain, DiscreteVariable
from Orange.canvas import report
from Orange.data.sql.table import SqlTable, AUTO_DL_LIMIT
from Orange.preprocess.score import ReliefF, RReliefF
Expand All @@ -26,15 +28,24 @@
class ScatterPlotVizRank(VizRankDialogAttrPair):
captionTitle = "Score Plots"
minK = 10
attr_color = None

def __init__(self, master):
super().__init__(master)
self.attr_color = self.master.graph.attr_color

def initialize(self):
self.attr_color = self.master.graph.attr_color
super().initialize()

def check_preconditions(self):
self.Information.add_message(
"class_required", "Data with a class variable is required.")
self.Information.class_required.clear()
"color_required", "Color variable must be selected")
self.Information.color_required.clear()
if not super().check_preconditions():
return False
if not self.master.data.domain.class_var:
self.Information.class_required()
if not self.attr_color:
self.Information.color_required()
return False
return True

Expand All @@ -46,43 +57,38 @@ def iterate_states(self, initial_state):
yield from super().iterate_states(initial_state)

def compute_score(self, state):
graph = self.master.graph
attrs = [self.attrs[x] for x in state]
valid = graph.get_valid_list(attrs)
cols = []
for var in attrs:
cols.append(graph.jittered_data.get_column_view(var)[0][valid])
X = np.column_stack(cols)
Y = self.master.data.Y[valid]
if X.shape[0] < self.minK:
attrs = [self.attrs[i] for i in state]
data = self.master.graph.scaled_data
data = data.transform(Domain(attrs, self.attr_color))
data = data[~np.isnan(data.X).any(axis=1) & ~np.isnan(data.Y).T]
if len(data) < self.minK:
return
n_neighbors = min(self.minK, len(X) - 1)
knn = NearestNeighbors(n_neighbors=n_neighbors).fit(X)
n_neighbors = min(self.minK, len(data) - 1)
knn = NearestNeighbors(n_neighbors=n_neighbors).fit(data.X)
ind = knn.kneighbors(return_distance=False)
if self.master.data.domain.has_discrete_class:
return -np.sum(Y[ind] == Y.reshape(-1, 1)) / n_neighbors / len(Y)
if data.domain.has_discrete_class:
return -np.sum(data.Y[ind] == data.Y.reshape(-1, 1)) / n_neighbors / len(data.Y)
else:
return -r2_score(Y, np.mean(Y[ind], axis=1)) * \
(len(Y) / len(self.master.data))
return -r2_score(data.Y, np.mean(data.Y[ind], axis=1)) * \
(len(data.Y) / len(self.master.data))

def bar_length(self, score):
return max(0, -score)

def score_heuristic(self):
X = self.master.graph.jittered_data.X
Y = self.master.data.Y
mdomain = self.master.data.domain
dom = Domain([ContinuousVariable(str(i)) for i in range(X.shape[1])],
mdomain.class_vars)
data = Table(dom, X, Y)
relief = ReliefF if isinstance(dom.class_var, DiscreteVariable) \
else RReliefF
assert self.attr_color is not None
master_domain = self.master.graph.scaled_data.domain
vars = [v for v in chain(master_domain.variables, master_domain.metas)
if v is not self.attr_color]
domain = Domain(attributes=vars, class_vars=self.attr_color)
data = self.master.graph.scaled_data.transform(domain)
relief = ReliefF if isinstance(domain.class_var, DiscreteVariable) else RReliefF
weights = relief(n_iterations=100, k_nearest=self.minK)(data)
attrs = sorted(zip(weights, mdomain.attributes),
key=lambda x: (-x[0], x[1].name))
attrs = sorted(zip(weights, domain.attributes), key=lambda x: (-x[0], x[1].name))
return [a for _, a in attrs]



class OWScatterPlot(OWWidget):
"""Scatterplot visualization with explorative analysis and intelligent
data visualization enhancements."""
Expand Down Expand Up @@ -216,6 +222,20 @@ def reset_graph_data(self, *_):
self.graph.rescale_data()
self.update_graph()

def _vizrank_color_change(self):
self.vizrank.initialize()
is_enabled = self.data is not None and not self.data.is_sparse() and \
len([v for v in chain(self.data.domain.variables, self.data.domain.metas)
if v.is_primitive]) > 2\
and len(self.data) > 1
self.vizrank_button.setEnabled(
is_enabled and self.graph.attr_color is not None and
not np.isnan(self.data.get_column_view(self.graph.attr_color)[0].astype(float)).all())
if is_enabled and self.graph.attr_color is None:
self.vizrank_button.setToolTip("Color variable has to be selected.")
else:
self.vizrank_button.setToolTip("")

@Inputs.data
def set_data(self, data):
self.clear_messages()
Expand Down Expand Up @@ -248,19 +268,8 @@ def set_data(self, data):

if not same_domain:
self.init_attr_values()
self.vizrank.initialize()
self.vizrank.attrs = self.data.domain.attributes if self.data is not None else []
self.vizrank_button.setEnabled(
self.data is not None and not self.data.is_sparse() and
self.data.domain.class_var is not None and not np.isnan(self.data.Y).all() and
len(self.data.domain.attributes) > 1 and len(self.data) > 1)
if self.data is not None and self.data.domain.class_var is None \
and len(self.data.domain.attributes) > 1 and len(self.data) > 1:
self.vizrank_button.setToolTip(
"Data with a class variable is required.")
else:
self.vizrank_button.setToolTip("")
self.openContext(self.data)
self._vizrank_color_change()

def findvar(name, iterable):
"""Find a Orange.data.Variable in `iterable` by name"""
Expand Down Expand Up @@ -372,6 +381,7 @@ def update_attr(self):
self.send_features()

def update_colors(self):
self._vizrank_color_change()
self.cb_class_density.setEnabled(self.graph.can_draw_density())

def update_density(self):
Expand Down