Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] owkaplanmeier: add subset selection, refactor #50

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 161 additions & 44 deletions orangecontrib/survival_analysis/widgets/owkaplanmeier.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import html
import pandas as pd
import numpy as np
import pyqtgraph as pg

from typing import Dict, List, Optional, NamedTuple
from itertools import zip_longest
from collections import defaultdict, Counter

from AnyQt.QtGui import QBrush, QColor, QPainterPath, QPalette
from AnyQt.QtCore import Qt, QSize, QEvent
from AnyQt.QtCore import pyqtSignal as Signal
from pyqtgraph.functions import mkPen
from pyqtgraph.functions import mkPen, mkBrush
from pyqtgraph.graphicsItems.ViewBox import ViewBox
from pyqtgraph.graphicsItems.LegendItem import ItemSample, LabelItem
from lifelines import KaplanMeierFitter
Expand All @@ -17,7 +19,7 @@

from Orange.data import Table, DiscreteVariable, Domain
from Orange.widgets import gui
from Orange.widgets.widget import Input, Output, OWWidget
from Orange.widgets.widget import Input, Output, OWWidget, Msg
from Orange.widgets.settings import (
Setting,
ContextSetting,
Expand Down Expand Up @@ -61,7 +63,8 @@ def generate_curve_coordinates(timeline, probabilities):
)
return np.array(x), np.array(y)

def __init__(self, time, events, label=None, color=None):
def __init__(self, time, events, data, label=None, color=None):
self.data = data
self._kmf = KaplanMeierFitter().fit(time.astype(np.float64), events.astype(np.float64))

self._label: str = label
Expand Down Expand Up @@ -96,7 +99,6 @@ def __init__(self, time, events, label=None, color=None):
)

censored_data = self.get_censored_data()

self.censored_data = pg.ScatterPlotItem(
x=censored_data[:, 0],
y=censored_data[:, 1],
Expand All @@ -110,6 +112,36 @@ def __init__(self, time, events, label=None, color=None):
self.num_of_samples = len(events)
self.num_of_censored_samples = len(censored_data)

self.points = self.generate_scatter_points()

def generate_scatter_points(self):
unique_probabilities = pd.unique(self._kmf.survival_function_['KM_estimate'])
vertical_drops = list(zip(unique_probabilities, unique_probabilities[1:]))
dropped_from = {to_: from_ for from_, to_ in vertical_drops}

time_events = np.column_stack((self._kmf.durations, self._kmf.event_observed))
time_at_observed_event = set(time_events[np.argwhere(time_events[:, 1] == 1), 0].flatten())

points = []
point_pen = self.get_pen(alpha=255, width=1)
point_brush = self.get_brush(alpha=255)

for duration, count in Counter(self._kmf.durations).items():
survival_at_time = float(self._kmf.survival_function_.loc[duration])

data_indices_ids = self.data[self.data.Y == duration].ids
if count > 1 and duration in time_at_observed_event:
y_spacing = np.linspace(
survival_at_time, float(dropped_from[survival_at_time]), num=2 + count
)[1:-1]
for id_, y_ in zip(data_indices_ids, y_spacing):
points.append((duration, y_, id_, point_pen, point_brush))
else:
for id_, _ in zip(data_indices_ids, range(count)):
points.append((duration, survival_at_time, id_, point_pen, point_brush))

return points

@property
def label(self):
return self._label if self._label else 'All'
Expand All @@ -130,6 +162,9 @@ def get_color(self, alpha=255) -> QColor:
def get_pen(self, width=3, alpha=100) -> mkPen:
return mkPen(color=self.get_color(alpha), width=width)

def get_brush(self, alpha=100) -> mkBrush:
return mkBrush(color=self.get_color(alpha))

def set_highlighted(self, highlighted):
if highlighted:
estimated_fun_alpha = 200
Expand Down Expand Up @@ -280,6 +315,8 @@ def __init__(self, parent: OWWidget = None):
self.parent: OWWidget = parent
self.highlighted_curve: Optional[int] = None
self.curves: Dict[int, EstimatedFunctionCurve] = {}
self.log_rank_test = None
self._scatter_item = None
self.__selection_items: Dict[int, Optional[pg.PlotDataItem]] = {}

self.view_box: KaplanMeierViewBox = self.getViewBox()
Expand All @@ -294,7 +331,65 @@ def __init__(self, parent: OWWidget = None):
self.legend.restoreAnchor(((1, 0), (1, 0)))
self.legend.hide()

self.setLabels(left='Survival Probability', bottom=self.parent.time_var_name)
self.setLabels(left='Survival Probability', bottom='Time')

def set_scatter_item_size(self, item_size):
if self._scatter_item:
self._scatter_item.setSize(item_size)

def generate_plot_curves(self):
self.curves = {}
self.log_rank_test = None
self._scatter_item = None
data = self.parent.data

if not data:
return

# time_var, event_var = get_survival_endpoints(data.domain)
if self.parent.time_var is None or self.parent.event_var is None:
return

time, _ = data.get_column_view(self.parent.time_var)
events, _ = data.get_column_view(self.parent.event_var)
group_var = self.parent.group_var

def _get_discrete_var_color(index: Optional[int]):
if group_var is not None and index is not None:
return list(group_var.colors[index])

if group_var:
groups, _ = data.get_column_view(group_var.name)
group_indexes = [index for index, _ in enumerate(group_var.values)]
colors = [_get_discrete_var_color(index) for index in group_indexes]
masks = groups == np.reshape(group_indexes, (-1, 1))
self.log_rank_test = multivariate_logrank_test(time, groups, events)
curves = [
EstimatedFunctionCurve(
time[mask],
events[mask],
data=data[mask][:, self.parent.time_var],
color=color,
label=label,
)
for mask, color, label in zip(masks, colors, group_var.values)
if mask.any()
]
else:
curves = [EstimatedFunctionCurve(time, events, data=data[:, self.parent.time_var])]

self.curves = {curve_id: curve for curve_id, curve in enumerate(curves)}

kwargs = defaultdict(list)
for curve in self.curves.values():
x, y, data, pen, brush = list(zip(*curve.points))
kwargs['x'] += x
kwargs['y'] += y
kwargs['data'] += data
kwargs['pen'] += pen
kwargs['brush'] += brush

self._scatter_item = pg.ScatterPlotItem(**kwargs)

def mouseMovedEvent(self, ev):
pos = self.view_box.mapSceneToView(ev[0])
Expand Down Expand Up @@ -430,6 +525,13 @@ def update_plot(self, confidence_interval=False, median=False, censored=False):
if censored:
self.addItem(curve.censored_data)

visible = np.zeros(len(self._scatter_item.data['data']))
mask = np.in1d(self._scatter_item.data['data'], list(self.parent.subset_indices))
visible[mask] = 1
self._scatter_item.setPointsVisible(visible)
self.set_scatter_item_size(self.parent.scatter_item_size)
self.addItem(self._scatter_item)

self.set_selection()
self.update_legend()
self.setLabels(bottom=self.parent.time_var_name)
Expand All @@ -441,7 +543,7 @@ def update_legend(self):
for curve in [c for c in self.curves.values()]:
self.legend.set_curve(curve)
if len(self.curves) > 1:
self.legend.set_footer(format_p_value(2)(self.parent.log_rank_test.p_value))
self.legend.set_footer(format_p_value(2)(self.log_rank_test.p_value))
self.legend.updateSize()

if bool(len(self.legend.items)):
Expand Down Expand Up @@ -480,6 +582,9 @@ class OWKaplanMeier(OWWidget):
show_censored_data: bool
show_censored_data = Setting(False)

scatter_item_size: int
scatter_item_size = Setting(4)

settingsHandler = DomainContextHandler()
group_var: Optional[DiscreteVariable] = ContextSetting(None, schema_only=True)

Expand All @@ -490,15 +595,24 @@ class OWKaplanMeier(OWWidget):

class Inputs:
data = Input('Data', Table)
data_subset = Input('Data Subset', Table)

class Outputs:
selected_data = Output('Selected Data', Table, default=True)
annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)

class Warning(OWWidget.Warning):
subset_not_subset = Msg(
"Subset data contains some instances that do not appear in " "input data"
)
subset_independent = Msg("No subset data instances appear in input data")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.data: Optional[Table] = None
self.subset_data: Optional[Table] = None
self.subset_indices: set = set()
self.plot_curves = None
self.log_rank_test = None

Expand Down Expand Up @@ -532,6 +646,20 @@ def __init__(self, *args, **kwargs):
callback=self.on_display_option_changed,
)

self.slider_box = gui.vBox(self.controlArea, 'Symbol size:')
gui.hSlider(
self.slider_box,
self,
'scatter_item_size',
None,
minValue=1,
maxValue=20,
createLabel=False,
callback=self.update_scatter_item_size,
addToLayout=True,
)
self.slider_box.setVisible(False)

self.graph: KaplanMeierPlot = KaplanMeierPlot(parent=self)
self.graph.selection_changed.connect(self.commit.deferred)
self.mainArea.layout().addWidget(self.graph)
Expand All @@ -545,6 +673,9 @@ def __init__(self, *args, **kwargs):
self.controlArea, self, 'auto_commit', '&Commit', box=False
)

def update_scatter_item_size(self):
self.graph.set_scatter_item_size(self.scatter_item_size)

@property
def time_var(self):
if not self.data:
Expand Down Expand Up @@ -594,14 +725,30 @@ def set_data(self, data: Table):

self.graph.selection = {}
self.openContext(domain)

self.graph.curves = {
curve_id: curve for curve_id, curve in enumerate(self.generate_plot_curves())
}
self.graph.generate_plot_curves()
self.graph.update_plot(**self._get_plot_options())

self.commit.now()

@Inputs.data_subset
def set_subset_data(self, subset):
self.subset_data = subset

def handleNewSignals(self):
self.Warning.subset_independent.clear()
self.Warning.subset_not_subset.clear()
if self.data is None or self.subset_data is None:
self.subset_indices = set()
else:
self.subset_indices = set(self.subset_data.ids)
ids = set(self.data.ids)
if not self.subset_indices & ids:
self.Warning.subset_independent()
elif self.subset_indices - ids:
self.Warning.subset_not_subset()

self.slider_box.setVisible(bool(len(self.subset_indices)))
self.graph.update_plot(**self._get_plot_options())

def _get_plot_options(self):
return {
'confidence_interval': self.show_confidence_interval,
Expand All @@ -616,41 +763,11 @@ def on_group_changed(self):
if not self.data:
return

self.graph.curves = {
curve_id: curve for curve_id, curve in enumerate(self.generate_plot_curves())
}
self.graph.generate_plot_curves()
self.graph.clear_selection()
self.graph.update_plot(**self._get_plot_options())
self.commit.now()

def _get_discrete_var_color(self, index: Optional[int]):
if self.group_var is not None and index is not None:
return list(self.group_var.colors[index])

def generate_plot_curves(self) -> List[EstimatedFunctionCurve]:
if self.time_var is None or self.event_var is None:
return []

data = self.data
time, _ = data.get_column_view(self.time_var)
events, _ = data.get_column_view(self.event_var)

if self.group_var:
groups, _ = data.get_column_view(self.group_var.name)
group_indexes = [index for index, _ in enumerate(self.group_var.values)]
colors = [self._get_discrete_var_color(index) for index in group_indexes]
masks = groups == np.reshape(group_indexes, (-1, 1))
self.log_rank_test = multivariate_logrank_test(time, groups, events)

return [
EstimatedFunctionCurve(time[mask], events[mask], color=color, label=label)
for mask, color, label in zip(masks, colors, self.group_var.values)
if mask.any()
]

else:
return [EstimatedFunctionCurve(time, events)]

@gui.deferred
def commit(self):
if not self.graph.selection:
Expand Down Expand Up @@ -693,5 +810,5 @@ def sizeHint(self):
if __name__ == "__main__":
from orangewidget.utils.widgetpreview import WidgetPreview

table = Table('http://datasets.biolab.si/core/melanoma.tab')
WidgetPreview(OWKaplanMeier).run(input_data=table)
table = Table('https://datasets.biolab.si/core/gbsg2.tab')
WidgetPreview(OWKaplanMeier).run(set_data=table, set_subset_data=table[:100])
Original file line number Diff line number Diff line change
Expand Up @@ -266,5 +266,5 @@ def test_display_options(self):
]

self.assertEqual(5, len(plot_items))
self.assertEqual(6, len(scatter_items))
self.assertEqual(7, len(scatter_items))
self.assertEqual(1, len(infinite_line))