Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Heatmap: option to center color palette at 0 #4218

Merged
merged 1 commit into from
Nov 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions Orange/widgets/visualize/owheatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ def color_palette_table(colors,
return np.c_[r, g, b]


def levels_with_thresholds(low, high, threshold_low, threshold_high):
def levels_with_thresholds(low, high, threshold_low, threshold_high, center_palette):
lt = low + (high - low) * threshold_low
ht = low + (high - low) * threshold_high
if center_palette:
ht = max(abs(lt), abs(ht))
lt = -max(abs(lt), abs(ht))
return lt, ht


Expand Down Expand Up @@ -317,6 +320,8 @@ class Outputs:
selected_data = Output("Selected Data", Table, default=True)
annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

settings_version = 2

settingsHandler = settings.DomainContextHandler()

NoPosition, PositionTop, PositionBottom = 0, 1, 2
Expand All @@ -327,6 +332,7 @@ class Outputs:
gamma = settings.Setting(0)
threshold_low = settings.Setting(0.0)
threshold_high = settings.Setting(1.0)
center_palette = settings.Setting(False)

merge_kmeans = settings.Setting(False)
merge_kmeans_k = settings.Setting(50)
Expand Down Expand Up @@ -447,6 +453,9 @@ def __init__(self):

colorbox.layout().addLayout(form)

gui.checkBox(colorbox, self, 'center_palette', 'Center colors at 0',
callback=self.update_color_schema)

mergebox = gui.vBox(self.controlArea, "Merge",)
gui.checkBox(mergebox, self, "merge_kmeans", "Merge by k-means",
callback=self.update_sorting_examples)
Expand Down Expand Up @@ -982,7 +991,7 @@ def setup_scene(self, parts, data):

hw.set_levels(parts.levels)
hw.set_thresholds(self.threshold_low, self.threshold_high)
hw.set_color_table(palette)
hw.set_color_table(palette, self.center_palette)
hw.set_show_averages(self.averages)
hw.set_heatmap_data(X_part)

Expand Down Expand Up @@ -1057,7 +1066,7 @@ def setup_scene(self, parts, data):
parts.levels[0], parts.levels[1], self.threshold_low, self.threshold_high,
parent=widget)

legend.set_color_table(palette)
legend.set_color_table(palette, self.center_palette)
legend.setMinimumSize(QSizeF(100, 20))
legend.setVisible(self.legend)

Expand Down Expand Up @@ -1318,11 +1327,11 @@ def update_color_schema(self):
palette = self.color_palette()
for heatmap in self.heatmap_widgets():
heatmap.set_thresholds(self.threshold_low, self.threshold_high)
heatmap.set_color_table(palette)
heatmap.set_color_table(palette, self.center_palette)

for legend in self.legend_widgets():
legend.set_thresholds(self.threshold_low, self.threshold_high)
legend.set_color_table(palette)
legend.set_color_table(palette, self.center_palette)

def update_sorting_examples(self):
self.update_heatmaps()
Expand Down Expand Up @@ -1601,6 +1610,7 @@ def __init__(self, parent=None, data=None, **kwargs):

self.__levels = None
self.__threshold_low, self.__threshold_high = 0., 1.
self.__center_palette = False
self.__colortable = None
self.__data = data

Expand Down Expand Up @@ -1677,8 +1687,9 @@ def set_show_averages(self, show):
self.layout().invalidate()
self.update()

def set_color_table(self, table):
def set_color_table(self, table, center):
self.__colortable = table
self.__center_palette = center
self._update_pixmap()
self.update()

Expand All @@ -1699,7 +1710,8 @@ def _update_pixmap(self):
lut = None

ll, lh = self.__levels
ll, lh = levels_with_thresholds(ll, lh, self.__threshold_low, self.__threshold_high)
ll, lh = levels_with_thresholds(ll, lh, self.__threshold_low, self.__threshold_high,
self.__center_palette)

argb, _ = pg.makeARGB(
self.__data, lut=lut, levels=(ll, lh))
Expand Down Expand Up @@ -2058,6 +2070,7 @@ def __init__(self, low, high, threshold_low, threshold_high, parent=None):
self.high = high
self.threshold_low = threshold_low
self.threshold_high = threshold_high
self.center_palette = False
self.color_table = None

layout = QGraphicsLinearLayout(Qt.Vertical)
Expand All @@ -2084,8 +2097,9 @@ def __init__(self, low, high, threshold_low, threshold_high, parent=None):
layout.addItem(self.__pixitem)
self.__update()

def set_color_table(self, color_table):
def set_color_table(self, color_table, center):
self.color_table = color_table
self.center_palette = center
self.__update()

def set_thresholds(self, threshold_low, threshold_high):
Expand All @@ -2097,7 +2111,8 @@ def __update(self):
data = np.linspace(self.low, self.high, num=1000)
data = data.reshape((1, -1))
ll, lh = levels_with_thresholds(self.low, self.high,
self.threshold_low, self.threshold_high)
self.threshold_low, self.threshold_high,
self.center_palette)
argb, _ = pg.makeARGB(data, lut=self.color_table,
levels=(ll, lh))
qimg = pg.makeQImage(argb, transpose=False)
Expand Down
36 changes: 32 additions & 4 deletions Orange/widgets/visualize/tests/test_owheatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin, datasets


def image_row_colors(image):
colors = np.full((image.height(), 3), np.nan)
for r in range(image.height()):
c = image.pixelColor(0, r)
colors[r] = c.red(), c.green(), c.blue()
return colors


class TestOWHeatMap(WidgetTest, WidgetOutputsTestMixin):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -164,10 +172,7 @@ def test_use_enough_colors(self):
self.widget.update_color_schema()
heatmap_widget = self.widget.heatmap_widget_grid[0][0]
image = heatmap_widget.heatmap_item.pixmap().toImage()
colors = np.full((len(data), 3), np.nan)
for r in range(len(data)):
c = image.pixelColor(0, r)
colors[r] = c.red(), c.green(), c.blue()
colors = image_row_colors(image)
unique_colors = len(np.unique(colors, axis=0))
self.assertLessEqual(len(data)*self.widget.threshold_low, unique_colors)

Expand Down Expand Up @@ -209,6 +214,29 @@ def test_set_split_var(self):
self.assertIs(w.split_by_var, None)
self.assertEqual(len(w.heatmapparts.rows), 1)

def test_center_palette(self):
data = np.arange(2).reshape(-1, 1)
table = Table.from_numpy(Domain([ContinuousVariable("y")]), data)
self.send_signal(self.widget.Inputs.data, table)

cb_model = self.widget.color_cb.model()
ind = cb_model.indexFromItem(cb_model.findItems("Green-Black-Red")[0]).row()
self.widget.palette_index = ind

desired_uncentered = [[0, 255, 0],
[255, 0, 0]]

desired_centered = [[0, 0, 0],
[255, 0, 0]]

for center, desired in [(False, desired_uncentered), (True, desired_centered)]:
self.widget.center_palette = center
self.widget.update_color_schema()
heatmap_widget = self.widget.heatmap_widget_grid[0][0]
image = heatmap_widget.heatmap_item.pixmap().toImage()
colors = image_row_colors(image)
np.testing.assert_almost_equal(colors, desired)


if __name__ == "__main__":
unittest.main()