diff --git a/orangecontrib/survival_analysis/evaluation/scoring.py b/orangecontrib/survival_analysis/evaluation/scoring.py index 34efd7c..de677c4 100644 --- a/orangecontrib/survival_analysis/evaluation/scoring.py +++ b/orangecontrib/survival_analysis/evaluation/scoring.py @@ -1,7 +1,7 @@ from lifelines.utils import concordance_index -from Orange.data import DiscreteVariable, ContinuousVariable +from Orange.data import DiscreteVariable, ContinuousVariable, Domain from Orange.evaluation.scoring import Score -from orangecontrib.survival_analysis.widgets.data import get_survival_endpoints +from orangecontrib.survival_analysis.widgets.data import get_survival_endpoints, contains_survival_endpoints __all__ = ['ConcordanceIndex'] @@ -11,8 +11,10 @@ class SurvivalScorer(Score, abstract=True): ContinuousVariable, DiscreteVariable, ) - is_built_in = False - problem_type = 'time_to_event' + + @staticmethod + def is_compatible(domain: Domain) -> bool: + return contains_survival_endpoints(domain) class ConcordanceIndex(SurvivalScorer): diff --git a/orangecontrib/survival_analysis/modeling/cox.py b/orangecontrib/survival_analysis/modeling/cox.py index 04a7b83..66fbdcf 100644 --- a/orangecontrib/survival_analysis/modeling/cox.py +++ b/orangecontrib/survival_analysis/modeling/cox.py @@ -4,7 +4,11 @@ from Orange.data.pandas_compat import table_to_frame from Orange.base import Learner, Model -from orangecontrib.survival_analysis.widgets.data import contains_survival_endpoints, get_survival_endpoints +from orangecontrib.survival_analysis.widgets.data import ( + contains_survival_endpoints, + get_survival_endpoints, + MISSING_SURVIVAL_DATA, +) class CoxRegressionModel(Model): @@ -36,14 +40,14 @@ def __call__(self, data, ret=Model.Value): class CoxRegressionLearner(Learner): __returns__ = CoxRegressionModel supports_multiclass = True - learner_adequacy_err_msg = 'Survival variables expected. Use As Survival Data widget.' def __init__(self, preprocessors=None, **kwargs): self.params = vars() super().__init__(preprocessors=preprocessors) - def check_learner_adequacy(self, domain): - return len(domain.class_vars) == 2 + def incompatibility_reason(self, domain): + if not contains_survival_endpoints(domain): + return MISSING_SURVIVAL_DATA def fit_storage(self, data): return self.fit(data) @@ -56,7 +60,7 @@ def _fit_model(self, data): def fit(self, data): if not contains_survival_endpoints(data.domain): - raise ValueError(self.learner_adequacy_err_msg) + raise ValueError(MISSING_SURVIVAL_DATA) time_var, event_var = get_survival_endpoints(data.domain) df = table_to_frame(data, include_metas=False) diff --git a/orangecontrib/survival_analysis/widgets/data.py b/orangecontrib/survival_analysis/widgets/data.py index 12af980..0561d61 100644 --- a/orangecontrib/survival_analysis/widgets/data.py +++ b/orangecontrib/survival_analysis/widgets/data.py @@ -5,9 +5,15 @@ from Orange.data import Table, Domain, Variable -TIME_VAR = 'time' -EVENT_VAR = 'event' -TIME_TO_EVENT_VAR = '_time_to_event_var' +TIME_VAR: str = 'time' +EVENT_VAR: str = 'event' +TIME_TO_EVENT_VAR: str = '_time_to_event_var' + +# Error/Warning messages related to survival data tables. +MISSING_ROWS: str = 'Rows with missing values detected. They will be omitted.' +MISSING_SURVIVAL_DATA: str = ( + 'No survival data detected. ' 'Use the "As Survival Data" widget or consult the documentation.' +) def contains_survival_endpoints(domain: Domain): @@ -35,15 +41,13 @@ def get_survival_endpoints(domain: Domain) -> Tuple[Optional[Variable], Optional def check_survival_data(f): """Check for survival data.""" - error_msg = 'No survival data detected. Use the "As Survival Data" widget or consult the documentation.' - warning_msg = 'Rows with missing values detected. They will be omitted.' @wraps(f) def wrapper(widget, data: Table, *args, **kwargs): - widget.Error.add_message('missing_survival_data', UnboundMsg(error_msg)) + widget.Error.add_message('missing_survival_data', UnboundMsg(MISSING_SURVIVAL_DATA)) widget.Error.missing_survival_data.clear() - widget.Warning.add_message('missing_values_detected', UnboundMsg(warning_msg)) + widget.Warning.add_message('missing_values_detected', UnboundMsg(MISSING_ROWS)) widget.Warning.missing_values_detected.clear() if data is not None and isinstance(data, Table): diff --git a/orangecontrib/survival_analysis/widgets/owcoxregression.py b/orangecontrib/survival_analysis/widgets/owcoxregression.py index 89185fd..fc52a9d 100644 --- a/orangecontrib/survival_analysis/widgets/owcoxregression.py +++ b/orangecontrib/survival_analysis/widgets/owcoxregression.py @@ -123,8 +123,9 @@ def check_data(self): self.Error.sparse_not_supported.clear() if self.data is not None and self.learner is not None: self.Error.data_error.clear() - if not self.learner.check_learner_adequacy(self.data.domain): - self.Error.data_error(self.learner.learner_adequacy_err_msg) + incompatibility_reason = self.learner.incompatibility_reason(self.data.domain) + if incompatibility_reason is not None: + self.Error.data_error(incompatibility_reason) elif not len(self.data): self.Error.data_error("Dataset is empty.") elif self.data.X.size == 0: