Skip to content

Commit

Permalink
use new interface for testing and scoring learners
Browse files Browse the repository at this point in the history
  • Loading branch information
JakaKokosar committed Mar 21, 2022
1 parent 0974718 commit 332dc79
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 18 deletions.
10 changes: 6 additions & 4 deletions orangecontrib/survival_analysis/evaluation/scoring.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lifelines.utils import concordance_index
from Orange.data import DiscreteVariable, ContinuousVariable
from Orange.data import DiscreteVariable, ContinuousVariable, Domain
from Orange.evaluation.scoring import Score
from orangecontrib.survival_analysis.widgets.data import get_survival_endpoints
from orangecontrib.survival_analysis.widgets.data import get_survival_endpoints, contains_survival_endpoints

__all__ = ['ConcordanceIndex']

Expand All @@ -11,8 +11,10 @@ class SurvivalScorer(Score, abstract=True):
ContinuousVariable,
DiscreteVariable,
)
is_built_in = False
problem_type = 'time_to_event'

@staticmethod
def is_compatible(domain: Domain) -> bool:
return contains_survival_endpoints(domain)


class ConcordanceIndex(SurvivalScorer):
Expand Down
14 changes: 9 additions & 5 deletions orangecontrib/survival_analysis/modeling/cox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from Orange.data.pandas_compat import table_to_frame
from Orange.base import Learner, Model

from orangecontrib.survival_analysis.widgets.data import contains_survival_endpoints, get_survival_endpoints
from orangecontrib.survival_analysis.widgets.data import (
contains_survival_endpoints,
get_survival_endpoints,
MISSING_SURVIVAL_DATA,
)


class CoxRegressionModel(Model):
Expand Down Expand Up @@ -36,14 +40,14 @@ def __call__(self, data, ret=Model.Value):
class CoxRegressionLearner(Learner):
__returns__ = CoxRegressionModel
supports_multiclass = True
learner_adequacy_err_msg = 'Survival variables expected. Use As Survival Data widget.'

def __init__(self, preprocessors=None, **kwargs):
self.params = vars()
super().__init__(preprocessors=preprocessors)

def check_learner_adequacy(self, domain):
return len(domain.class_vars) == 2
def incompatibility_reason(self, domain):
if not contains_survival_endpoints(domain):
return MISSING_SURVIVAL_DATA

def fit_storage(self, data):
return self.fit(data)
Expand All @@ -56,7 +60,7 @@ def _fit_model(self, data):

def fit(self, data):
if not contains_survival_endpoints(data.domain):
raise ValueError(self.learner_adequacy_err_msg)
raise ValueError(MISSING_SURVIVAL_DATA)
time_var, event_var = get_survival_endpoints(data.domain)

df = table_to_frame(data, include_metas=False)
Expand Down
18 changes: 11 additions & 7 deletions orangecontrib/survival_analysis/widgets/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from Orange.data import Table, Domain, Variable


TIME_VAR = 'time'
EVENT_VAR = 'event'
TIME_TO_EVENT_VAR = '_time_to_event_var'
TIME_VAR: str = 'time'
EVENT_VAR: str = 'event'
TIME_TO_EVENT_VAR: str = '_time_to_event_var'

# Error/Warning messages related to survival data tables.
MISSING_ROWS: str = 'Rows with missing values detected. They will be omitted.'
MISSING_SURVIVAL_DATA: str = (
'No survival data detected. ' 'Use the "As Survival Data" widget or consult the documentation.'
)


def contains_survival_endpoints(domain: Domain):
Expand Down Expand Up @@ -35,15 +41,13 @@ def get_survival_endpoints(domain: Domain) -> Tuple[Optional[Variable], Optional

def check_survival_data(f):
"""Check for survival data."""
error_msg = 'No survival data detected. Use the "As Survival Data" widget or consult the documentation.'
warning_msg = 'Rows with missing values detected. They will be omitted.'

@wraps(f)
def wrapper(widget, data: Table, *args, **kwargs):
widget.Error.add_message('missing_survival_data', UnboundMsg(error_msg))
widget.Error.add_message('missing_survival_data', UnboundMsg(MISSING_SURVIVAL_DATA))
widget.Error.missing_survival_data.clear()

widget.Warning.add_message('missing_values_detected', UnboundMsg(warning_msg))
widget.Warning.add_message('missing_values_detected', UnboundMsg(MISSING_ROWS))
widget.Warning.missing_values_detected.clear()

if data is not None and isinstance(data, Table):
Expand Down
5 changes: 3 additions & 2 deletions orangecontrib/survival_analysis/widgets/owcoxregression.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def check_data(self):
self.Error.sparse_not_supported.clear()
if self.data is not None and self.learner is not None:
self.Error.data_error.clear()
if not self.learner.check_learner_adequacy(self.data.domain):
self.Error.data_error(self.learner.learner_adequacy_err_msg)
incompatibility_reason = self.learner.incompatibility_reason(self.data.domain)
if incompatibility_reason is not None:
self.Error.data_error(incompatibility_reason)
elif not len(self.data):
self.Error.data_error("Dataset is empty.")
elif self.data.X.size == 0:
Expand Down

0 comments on commit 332dc79

Please sign in to comment.