diff --git a/Orange/widgets/visualize/owheatmap.py b/Orange/widgets/visualize/owheatmap.py index 4531f1bf34a..90a823248bb 100644 --- a/Orange/widgets/visualize/owheatmap.py +++ b/Orange/widgets/visualize/owheatmap.py @@ -118,9 +118,12 @@ def color_palette_table(colors, return np.c_[r, g, b] -def levels_with_thresholds(low, high, threshold_low, threshold_high): +def levels_with_thresholds(low, high, threshold_low, threshold_high, center_palette): lt = low + (high - low) * threshold_low ht = low + (high - low) * threshold_high + if center_palette: + ht = max(abs(lt), abs(ht)) + lt = -max(abs(lt), abs(ht)) return lt, ht @@ -317,6 +320,8 @@ class Outputs: selected_data = Output("Selected Data", Table, default=True) annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table) + settings_version = 2 + settingsHandler = settings.DomainContextHandler() NoPosition, PositionTop, PositionBottom = 0, 1, 2 @@ -327,6 +332,7 @@ class Outputs: gamma = settings.Setting(0) threshold_low = settings.Setting(0.0) threshold_high = settings.Setting(1.0) + center_palette = settings.Setting(False) merge_kmeans = settings.Setting(False) merge_kmeans_k = settings.Setting(50) @@ -447,6 +453,9 @@ def __init__(self): colorbox.layout().addLayout(form) + gui.checkBox(colorbox, self, 'center_palette', 'Center colors at 0', + callback=self.update_color_schema) + mergebox = gui.vBox(self.controlArea, "Merge",) gui.checkBox(mergebox, self, "merge_kmeans", "Merge by k-means", callback=self.update_sorting_examples) @@ -982,7 +991,7 @@ def setup_scene(self, parts, data): hw.set_levels(parts.levels) hw.set_thresholds(self.threshold_low, self.threshold_high) - hw.set_color_table(palette) + hw.set_color_table(palette, self.center_palette) hw.set_show_averages(self.averages) hw.set_heatmap_data(X_part) @@ -1057,7 +1066,7 @@ def setup_scene(self, parts, data): parts.levels[0], parts.levels[1], self.threshold_low, self.threshold_high, parent=widget) - legend.set_color_table(palette) + legend.set_color_table(palette, self.center_palette) legend.setMinimumSize(QSizeF(100, 20)) legend.setVisible(self.legend) @@ -1318,11 +1327,11 @@ def update_color_schema(self): palette = self.color_palette() for heatmap in self.heatmap_widgets(): heatmap.set_thresholds(self.threshold_low, self.threshold_high) - heatmap.set_color_table(palette) + heatmap.set_color_table(palette, self.center_palette) for legend in self.legend_widgets(): legend.set_thresholds(self.threshold_low, self.threshold_high) - legend.set_color_table(palette) + legend.set_color_table(palette, self.center_palette) def update_sorting_examples(self): self.update_heatmaps() @@ -1601,6 +1610,7 @@ def __init__(self, parent=None, data=None, **kwargs): self.__levels = None self.__threshold_low, self.__threshold_high = 0., 1. + self.__center_palette = False self.__colortable = None self.__data = data @@ -1677,8 +1687,9 @@ def set_show_averages(self, show): self.layout().invalidate() self.update() - def set_color_table(self, table): + def set_color_table(self, table, center): self.__colortable = table + self.__center_palette = center self._update_pixmap() self.update() @@ -1699,7 +1710,8 @@ def _update_pixmap(self): lut = None ll, lh = self.__levels - ll, lh = levels_with_thresholds(ll, lh, self.__threshold_low, self.__threshold_high) + ll, lh = levels_with_thresholds(ll, lh, self.__threshold_low, self.__threshold_high, + self.__center_palette) argb, _ = pg.makeARGB( self.__data, lut=lut, levels=(ll, lh)) @@ -2058,6 +2070,7 @@ def __init__(self, low, high, threshold_low, threshold_high, parent=None): self.high = high self.threshold_low = threshold_low self.threshold_high = threshold_high + self.center_palette = False self.color_table = None layout = QGraphicsLinearLayout(Qt.Vertical) @@ -2084,8 +2097,9 @@ def __init__(self, low, high, threshold_low, threshold_high, parent=None): layout.addItem(self.__pixitem) self.__update() - def set_color_table(self, color_table): + def set_color_table(self, color_table, center): self.color_table = color_table + self.center_palette = center self.__update() def set_thresholds(self, threshold_low, threshold_high): @@ -2097,7 +2111,8 @@ def __update(self): data = np.linspace(self.low, self.high, num=1000) data = data.reshape((1, -1)) ll, lh = levels_with_thresholds(self.low, self.high, - self.threshold_low, self.threshold_high) + self.threshold_low, self.threshold_high, + self.center_palette) argb, _ = pg.makeARGB(data, lut=self.color_table, levels=(ll, lh)) qimg = pg.makeQImage(argb, transpose=False) diff --git a/Orange/widgets/visualize/tests/test_owheatmap.py b/Orange/widgets/visualize/tests/test_owheatmap.py index 39fd1860b9c..7699f1c508b 100644 --- a/Orange/widgets/visualize/tests/test_owheatmap.py +++ b/Orange/widgets/visualize/tests/test_owheatmap.py @@ -13,6 +13,14 @@ from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin, datasets +def image_row_colors(image): + colors = np.full((image.height(), 3), np.nan) + for r in range(image.height()): + c = image.pixelColor(0, r) + colors[r] = c.red(), c.green(), c.blue() + return colors + + class TestOWHeatMap(WidgetTest, WidgetOutputsTestMixin): @classmethod def setUpClass(cls): @@ -164,10 +172,7 @@ def test_use_enough_colors(self): self.widget.update_color_schema() heatmap_widget = self.widget.heatmap_widget_grid[0][0] image = heatmap_widget.heatmap_item.pixmap().toImage() - colors = np.full((len(data), 3), np.nan) - for r in range(len(data)): - c = image.pixelColor(0, r) - colors[r] = c.red(), c.green(), c.blue() + colors = image_row_colors(image) unique_colors = len(np.unique(colors, axis=0)) self.assertLessEqual(len(data)*self.widget.threshold_low, unique_colors) @@ -209,6 +214,29 @@ def test_set_split_var(self): self.assertIs(w.split_by_var, None) self.assertEqual(len(w.heatmapparts.rows), 1) + def test_center_palette(self): + data = np.arange(2).reshape(-1, 1) + table = Table.from_numpy(Domain([ContinuousVariable("y")]), data) + self.send_signal(self.widget.Inputs.data, table) + + cb_model = self.widget.color_cb.model() + ind = cb_model.indexFromItem(cb_model.findItems("Green-Black-Red")[0]).row() + self.widget.palette_index = ind + + desired_uncentered = [[0, 255, 0], + [255, 0, 0]] + + desired_centered = [[0, 0, 0], + [255, 0, 0]] + + for center, desired in [(False, desired_uncentered), (True, desired_centered)]: + self.widget.center_palette = center + self.widget.update_color_schema() + heatmap_widget = self.widget.heatmap_widget_grid[0][0] + image = heatmap_widget.heatmap_item.pixmap().toImage() + colors = image_row_colors(image) + np.testing.assert_almost_equal(colors, desired) + if __name__ == "__main__": unittest.main()