Skip to content

Commit

Permalink
learner adequacy check refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
JakaKokosar committed Mar 18, 2022
1 parent 19c1be7 commit e1e7419
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 26 deletions.
32 changes: 28 additions & 4 deletions Orange/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from collections.abc import Iterable
import re
import warnings
from typing import Callable, Dict
from typing import Callable, Dict, Optional

import numpy as np
import scipy

from Orange.data import Table, Storage, Instance, Value
from Orange.data import Table, Storage, Instance, Value, Domain
from Orange.data.filter import HasClass
from Orange.data.table import DomainTransformationError
from Orange.data.util import one_hot
Expand Down Expand Up @@ -86,6 +86,9 @@ class Learner(ReprableWithPreprocessors):
#: A sequence of data preprocessors to apply on data prior to
#: fitting the model
preprocessors = ()

# Note: Do not use this class attribute.
# It remains here for compatibility reasons.
learner_adequacy_err_msg = ''

def __init__(self, preprocessors=None):
Expand All @@ -95,6 +98,7 @@ def __init__(self, preprocessors=None):
elif preprocessors:
self.preprocessors = (preprocessors,)

# pylint: disable=R0201
def fit(self, X, Y, W=None):
raise RuntimeError(
"Descendants of Learner must overload method fit or fit_storage")
Expand All @@ -106,8 +110,23 @@ def fit_storage(self, data):
return self.fit(X, Y, W)

def __call__(self, data, progress_callback=None):
if not self.check_learner_adequacy(data.domain):
raise ValueError(self.learner_adequacy_err_msg)

for cls in type(self).mro():
if 'incompatibility_reason' in cls.__dict__:
incompatibility_reason = \
self.incompatibility_reason(data.domain) # pylint: disable=assignment-from-none
if incompatibility_reason is not None:
raise ValueError(incompatibility_reason)
break
if 'check_learner_adequacy' in cls.__dict__:
warnings.warn(
"check_learner_adequacy is deprecated and will be removed "
"in upcoming releases. Learners should instead implement "
"the incompatibility_reason method.",
OrangeDeprecationWarning)
if not self.check_learner_adequacy(data.domain):
raise ValueError(self.learner_adequacy_err_msg)
break

origdomain = data.domain

Expand Down Expand Up @@ -176,6 +195,11 @@ def active_preprocessors(self):
def check_learner_adequacy(self, _):
return True

# pylint: disable=no-self-use
def incompatibility_reason(self, _: Domain) -> Optional[str]:
"""Return None if a learner can fit domain or string explaining why it can not."""
return None

@property
def name(self):
"""Return a short name derived from Learner type name"""
Expand Down
14 changes: 6 additions & 8 deletions Orange/classification/base_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@

class LearnerClassification(Learner):

def check_learner_adequacy(self, domain):
is_adequate = True
if len(domain.class_vars) > 1:
is_adequate = False
self.learner_adequacy_err_msg = "Too many target variables."
def incompatibility_reason(self, domain):
reason = None
if len(domain.class_vars) > 1 and not self.supports_multiclass:
reason = "Too many target variables."
elif not domain.has_discrete_class:
is_adequate = False
self.learner_adequacy_err_msg = "Categorical class variable expected."
return is_adequate
reason = "Categorical class variable expected."
return reason


class ModelClassification(Model):
Expand Down
5 changes: 3 additions & 2 deletions Orange/preprocess/impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def __call__(self, data, variable):
variable = data.domain[variable]
domain = domain_with_class_var(data.domain, variable)

if self.learner.check_learner_adequacy(domain):
incompatibility_reason = self.learner.incompatibility_reason(domain)
if incompatibility_reason is None:
data = data.transform(domain)
model = self.learner(data)
assert model.domain.class_var == variable
Expand All @@ -239,7 +240,7 @@ def copy(self):

def supports_variable(self, variable):
domain = Orange.data.Domain([], class_vars=variable)
return self.learner.check_learner_adequacy(domain)
return self.learner.incompatibility_reason(domain) is None


def domain_with_class_var(domain, class_var):
Expand Down
14 changes: 6 additions & 8 deletions Orange/regression/base_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@

class LearnerRegression(Learner):

def check_learner_adequacy(self, domain):
is_adequate = True
if len(domain.class_vars) > 1:
is_adequate = False
self.learner_adequacy_err_msg = "Too many target variables."
def incompatibility_reason(self, domain):
reason = None
if len(domain.class_vars) > 1 and not self.supports_multiclass:
reason = "Too many target variables."
elif not domain.has_continuous_class:
is_adequate = False
self.learner_adequacy_err_msg = "Numeric class variable expected."
return is_adequate
reason = "Numeric class variable expected."
return reason


class ModelRegression(Model):
Expand Down
5 changes: 3 additions & 2 deletions Orange/tests/dummy_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def __init__(self, value, prob):
class DummyMulticlassLearner(SklLearner):
supports_multiclass = True

def check_learner_adequacy(self, domain):
return all(c.is_discrete for c in domain.class_vars)
def incompatibility_reason(self, domain):
reason = 'Not all class variables are discrete'
return None if all(c.is_discrete for c in domain.class_vars) else reason

def fit(self, X, Y, W):
rows, class_vars = Y.shape
Expand Down
31 changes: 31 additions & 0 deletions Orange/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,59 @@
# pylint: disable=missing-docstring
import pickle
import unittest
from distutils.version import LooseVersion

import Orange
from Orange.base import SklLearner, Learner, Model
from Orange.data import Domain, Table
from Orange.preprocess import Discretize, Randomize, Continuize
from Orange.regression import LinearRegressionLearner
from Orange.util import OrangeDeprecationWarning


class DummyLearnerDeprecated(Learner):

def fit(self, *_, **__):
return unittest.mock.Mock()

def check_learner_adequacy(self, _):
return True


class DummyLearner(Learner):

def fit(self, *_, **__):
return unittest.mock.Mock()


class DummySklLearner(SklLearner):

def fit(self, *_, **__):
return unittest.mock.Mock()


class DummyLearnerPP(Learner):

preprocessors = (Randomize(),)


class TestLearner(unittest.TestCase):

def test_if_deprecation_warning_is_raised(self):
with self.assertWarns(OrangeDeprecationWarning):
DummyLearnerDeprecated()(Table('iris'))

def test_check_learner_adequacy_deprecated(self):
"""This test is to be included in the 3.32 release and will fail in
version 3.34. This serves as a reminder to remove the deprecated method
and this test."""
if LooseVersion(Orange.__version__) >= LooseVersion("3.34"):
self.fail(
"`Orange.base.Learner.check_learner_adequacy` was deprecated in "
"version 3.32, and there have been two minor versions in "
"between. Please remove the deprecated method."
)

def test_uses_default_preprocessors_unless_custom_pps_specified(self):
"""Learners should use their default preprocessors unless custom
preprocessors were passed in to the constructor"""
Expand Down
25 changes: 23 additions & 2 deletions Orange/widgets/utils/owlearnerwidget.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
import warnings

from AnyQt.QtCore import QTimer, Qt

Expand All @@ -12,6 +13,7 @@
from Orange.widgets.utils.signals import Output, Input
from Orange.widgets.utils.sql import check_sql_input
from Orange.widgets.widget import OWWidget, WidgetMetaClass, Msg
from Orange.util import OrangeDeprecationWarning


class OWBaseLearnerMeta(WidgetMetaClass):
Expand Down Expand Up @@ -246,8 +248,26 @@ def check_data(self):
self.Error.sparse_not_supported.clear()
if self.data is not None and self.learner is not None:
self.Error.data_error.clear()
if not self.learner.check_learner_adequacy(self.data.domain):
self.Error.data_error(self.learner.learner_adequacy_err_msg)

incompatibility_reason = None
for cls in type(self.learner).mro():
if 'incompatibility_reason' in cls.__dict__:
# pylint: disable=assignment-from-none
incompatibility_reason = \
self.learner.incompatibility_reason(self.data.domain)
break
if 'check_learner_adequacy' in cls.__dict__:
warnings.warn(
"check_learner_adequacy is deprecated and will be removed "
"in upcoming releases. Learners should instead implement "
"the incompatibility_reason method.",
OrangeDeprecationWarning)
if not self.learner.check_learner_adequacy(self.data.domain):
incompatibility_reason = self.learner.learner_adequacy_err_msg
break

if incompatibility_reason is not None:
self.Error.data_error(incompatibility_reason)
elif not len(self.data):
self.Error.data_error("Dataset is empty.")
elif len(ut.unique(self.data.Y)) < 2:
Expand All @@ -258,6 +278,7 @@ def check_data(self):
self.Error.sparse_not_supported()
else:
self.valid_data = True

return self.valid_data

def settings_changed(self, *args, **kwargs):
Expand Down

0 comments on commit e1e7419

Please sign in to comment.