Skip to content

Commit

Permalink
_OutlierModel: Add callback to model
Browse files Browse the repository at this point in the history
  • Loading branch information
VesnaT committed Feb 12, 2020
1 parent e9c068e commit c5ef939
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
37 changes: 27 additions & 10 deletions Orange/classification/outlier_detection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# pylint: disable=unused-argument
from typing import Callable

import numpy as np

from Orange.data.table import DomainTransformationError
from Orange.data.util import get_unique_names
from sklearn.covariance import EllipticEnvelope
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor
Expand All @@ -11,6 +11,9 @@
from Orange.base import SklLearner, SklModel
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable, \
Variable
from Orange.data.table import DomainTransformationError
from Orange.data.util import get_unique_names, progress_callback, \
dummy_callback
from Orange.preprocess import AdaptiveNormalize
from Orange.statistics.util import all_nan

Expand All @@ -29,29 +32,44 @@ def predict(self, X: np.ndarray) -> np.ndarray:
pred[pred == -1] = 0
return pred[:, None]

def __call__(self, data: Table) -> Table:
def __call__(self, data: Table, callback: Callable = None) -> Table:
assert isinstance(data, Table)
assert self.outlier_var is not None

domain = Domain(data.domain.attributes, data.domain.class_vars,
data.domain.metas + (self.outlier_var,))
self._cached_data = self.data_to_model_domain(data)
if callback is None:
callback = dummy_callback
callback(0, "Preprocessing...")
self._cached_data = self.data_to_model_domain(
data, progress_callback(callback, end=0.1))
callback(0.1, "Predicting...")
metas = np.hstack((data.metas, self.predict(self._cached_data.X)))
callback(1)
return Table.from_numpy(domain, data.X, data.Y, metas)

def data_to_model_domain(self, data: Table) -> Table:
def data_to_model_domain(self, data: Table, callback: Callable) -> Table:
if data.domain == self.domain:
return data

callback(0)
if self.original_domain.attributes != data.domain.attributes \
and data.X.size \
and not all_nan(data.X):
callback(0.5)
new_data = data.transform(self.original_domain)
if all_nan(new_data.X):
raise DomainTransformationError(
"domain transformation produced no defined values")
return new_data.transform(self.domain)
return data.transform(self.domain)
callback(0.75)
data = new_data.transform(self.domain)
callback(1)
return data

callback(0.5)
data = data.transform(self.domain)
callback(1)
return data


class _OutlierLearner(SklLearner):
Expand Down Expand Up @@ -148,8 +166,8 @@ def mahalanobis(self, observations: np.ndarray) -> np.ndarray:
"""
return self.skl_model.mahalanobis(observations)[:, None]

def __call__(self, data: Table) -> Table:
pred = super().__call__(data)
def __call__(self, data: Table, callback: Callable = None) -> Table:
pred = super().__call__(data, callback)
domain = Domain(pred.domain.attributes, pred.domain.class_vars,
pred.domain.metas + (self.mahal_var,))
metas = np.hstack((pred.metas, self.mahalanobis(self._cached_data.X)))
Expand Down Expand Up @@ -181,4 +199,3 @@ def _fit_model(self, data: Table) -> EllipticEnvelopeClassifier:
transformer.variable = variable
model.mahal_var = variable
return model

30 changes: 30 additions & 0 deletions Orange/classification/tests/test_outlier_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import pickle
import tempfile
import unittest
from unittest.mock import Mock

import numpy as np

from Orange.classification import EllipticEnvelopeLearner, \
IsolationForestLearner, LocalOutlierFactorLearner, OneClassSVMLearner
from Orange.data import Table, Domain, ContinuousVariable
from Orange.data.table import DomainTransformationError


class _TestDetector(unittest.TestCase):
Expand Down Expand Up @@ -207,6 +209,17 @@ def test_unique_name(self):
pred = detect(table)
self.assertEqual(pred.domain.metas[0].name, "Outlier (1)")

def test_predict(self):
detect = self.detector(self.iris)
subset = self.iris[:, :3]
pred = detect(subset)
self.assert_table_appended_outlier(subset, pred)

def test_predict_all_nan(self):
detect = self.detector(self.iris[:, :2])
subset = self.iris[:, 2:]
self.assertRaises(DomainTransformationError, detect, subset)

def test_transform(self):
detect = self.detector(self.iris)
pred = detect(self.iris)
Expand Down Expand Up @@ -235,6 +248,23 @@ def test_pickle_prediction(self):
pickle.dump(pred, f)
f.close()

def test_fit_callback(self):
callback = Mock()
self.detector(self.iris, callback)
args = [x[0][0] for x in callback.call_args_list]
self.assertEqual(min(args), 0)
self.assertEqual(max(args), 1)
self.assertListEqual(args, sorted(args))

def test_predict_callback(self):
callback = Mock()
detect = self.detector(self.iris)
detect(self.iris, callback)
args = [x[0][0] for x in callback.call_args_list]
self.assertEqual(min(args), 0)
self.assertEqual(max(args), 1)
self.assertListEqual(args, sorted(args))


if __name__ == "__main__":
unittest.main()

0 comments on commit c5ef939

Please sign in to comment.