Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] Tidy up classification and regression tests #2314

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions aeon/base/estimators/interval_based/base_interval_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,10 @@ def temporal_importance_curves(
curves : list of np.ndarray
The temporal importance curves for each feature.
"""
if is_regressor(self):
raise NotImplementedError(
"Temporal importance curves are not available for regression."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this covered in the tests? Would be good to increase the coverage on this class, its currently on 69% with 190 missed lines

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pr is the first step for doing this really, now that all the regressors are properly covered in general testing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
if not isinstance(self._base_estimator, ContinuousIntervalTree):
raise ValueError(
"base_estimator for temporal importance curves must"
Expand Down
7 changes: 2 additions & 5 deletions aeon/classification/interval_based/_drcif.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,7 @@ def __init__(
n_jobs=1,
parallel_backend=None,
):
d = []
self.use_pycatch22 = use_pycatch22
if use_pycatch22:
d.append("pycatch22")

if isinstance(base_estimator, ContinuousIntervalTree):
replace_nan = "nan"
Expand Down Expand Up @@ -241,8 +238,8 @@ def __init__(
parallel_backend=parallel_backend,
)

if d:
self.set_tags(**{"python_dependencies": d})
if use_pycatch22:
self.set_tags(**{"python_dependencies": "pycatch22"})

def _fit(self, X, y):
return super()._fit(X, y)
Expand Down
8 changes: 4 additions & 4 deletions aeon/classification/interval_based/_interval_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def _fit(self, X, y):
),
self.random_state,
)
m = hasattr(self._estimator, "n_jobs")
if m:

if hasattr(self._estimator, "n_jobs"):
self._estimator.n_jobs = self._n_jobs

X_t = self._transformer.fit_transform(X, y)
Expand Down Expand Up @@ -401,8 +401,8 @@ def _fit(self, X, y):
),
self.random_state,
)
m = hasattr(self._estimator, "n_jobs")
if m:

if hasattr(self._estimator, "n_jobs"):
self._estimator.n_jobs = self._n_jobs

X_t = self._transformer.fit_transform(X, y)
Expand Down
8 changes: 8 additions & 0 deletions aeon/classification/interval_based/_rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ def _fit_predict(self, X, y) -> np.ndarray:
def _fit_predict_proba(self, X, y) -> np.ndarray:
return super()._fit_predict_proba(X, y)

def temporal_importance_curves(
self, return_dict=False, normalise_time_points=False
):
raise NotImplementedError(
"No temporal importance curves available for "
"RandomIntervalSpectralEnsemble."
)

@classmethod
def _get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Expand Down
2 changes: 1 addition & 1 deletion aeon/classification/interval_based/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Tests for interval based classifiers."""
"""Tests for interval-based classifiers."""
14 changes: 0 additions & 14 deletions aeon/classification/interval_based/tests/test_cif.py

This file was deleted.

12 changes: 0 additions & 12 deletions aeon/classification/interval_based/tests/test_dr_cif.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Test interval forest classifiers."""

import pytest

from aeon.classification.interval_based import (
CanonicalIntervalForestClassifier,
DrCIFClassifier,
RandomIntervalSpectralEnsembleClassifier,
SupervisedTimeSeriesForest,
TimeSeriesForestClassifier,
)
from aeon.classification.sklearn import ContinuousIntervalTree
from aeon.testing.testing_data import EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION
from aeon.testing.utils.estimator_checks import _assert_predict_probabilities
from aeon.utils.validation._dependencies import _check_soft_dependencies
from aeon.visualisation import plot_temporal_importance_curves


@pytest.mark.skipif(
not _check_soft_dependencies(["matplotlib", "seaborn"], severity="none"),
reason="skip test if required soft dependency not available",
)
@pytest.mark.parametrize(
"cls",
[
CanonicalIntervalForestClassifier,
DrCIFClassifier,
SupervisedTimeSeriesForest,
TimeSeriesForestClassifier,
],
)
def test_tic_curves(cls):
"""Test whether temporal_importance_curves runs without error."""
import matplotlib

matplotlib.use("Agg")

X_train, y_train = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"]

params = cls._get_test_params()
if isinstance(params, list):
params = params[0]
params.update({"base_estimator": ContinuousIntervalTree()})

clf = cls(**params)
clf.fit(X_train, y_train)

names, curves = clf.temporal_importance_curves()
plot_temporal_importance_curves(curves, names)


@pytest.mark.parametrize("cls", [RandomIntervalSpectralEnsembleClassifier])
def test_tic_curves_invalid(cls):
"""Test whether temporal_importance_curves raises an error."""
clf = cls()
with pytest.raises(
NotImplementedError, match="No temporal importance curves available."
):
clf.temporal_importance_curves()


@pytest.mark.skipif(
not _check_soft_dependencies(["pycatch22"], severity="none"),
reason="skip test if required soft dependency not available",
)
@pytest.mark.parametrize("cls", [CanonicalIntervalForestClassifier, DrCIFClassifier])
def test_forest_pycatch22(cls):
"""Test whether the forest classifiers with pycatch22 run without error."""
X_train, y_train = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"]
X_test, _ = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["test"]

params = cls._get_test_params()
if isinstance(params, list):
params = params[0]
params.update({"use_pycatch22": True})

clf = cls(**params)
clf.fit(X_train, y_train)
prob = clf.predict_proba(X_test)
_assert_predict_probabilities(prob, X_test, n_classes=2)
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,24 @@
RandomIntervalClassifier,
SupervisedIntervalClassifier,
)
from aeon.testing.data_generation import make_example_3d_numpy
from aeon.testing.testing_data import EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION
from aeon.testing.utils.estimator_checks import _assert_predict_probabilities


@pytest.mark.parametrize(
"cls", [SupervisedIntervalClassifier, RandomIntervalClassifier]
)
def test_random_interval_classifier(cls):
def test_interval_pipeline_classifiers(cls):
"""Test the random interval classifiers."""
X, y = make_example_3d_numpy(n_cases=5, n_channels=1, n_timepoints=12)
r = cls(estimator=SVC())
r.fit(X, y)
p = r.predict_proba(X)
assert p.shape == (5, 2)
r = cls(n_jobs=2)
r.fit(X, y)
assert r._estimator.n_jobs == 2
X_train, y_train = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"]
X_test, y_test = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["test"]

params = cls._get_test_params()
if isinstance(params, list):
params = params[0]
params.update({"estimator": SVC()})

def test_parameter_sets():
"""Test results comparison parameter sets."""
paras = SupervisedIntervalClassifier._get_test_params(
parameter_set="results_comparison"
)
assert paras["n_intervals"] == 2
paras = RandomIntervalClassifier._get_test_params(
parameter_set="results_comparison"
)
assert paras["n_intervals"] == 3
clf = cls(**params)
clf.fit(X_train, y_train)
prob = clf.predict_proba(X_test)
_assert_predict_probabilities(prob, X_test, n_classes=2)
31 changes: 8 additions & 23 deletions aeon/classification/interval_based/tests/test_quant.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Tests for the QUANTClassifier class."""

import numpy as np
import pytest
from sklearn.linear_model import RidgeClassifierCV
from sklearn.svm import SVC

from aeon.classification.interval_based import QUANTClassifier
from aeon.testing.data_generation import make_example_3d_numpy
from aeon.testing.testing_data import EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION
from aeon.testing.utils.estimator_checks import _assert_predict_probabilities
from aeon.utils.validation._dependencies import _check_soft_dependencies


Expand All @@ -16,13 +15,12 @@
)
def test_alternative_estimator():
"""Test QUANTClassifier with an alternative estimator."""
X, y = make_example_3d_numpy()
clf = QUANTClassifier(estimator=RidgeClassifierCV())
clf.fit(X, y)
pred = clf.predict(X)
X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"]

assert isinstance(pred, np.ndarray)
assert pred.shape[0] == X.shape[0]
clf = QUANTClassifier(estimator=SVC())
clf.fit(X, y)
prob = clf.predict_proba(X)
_assert_predict_probabilities(prob, X, n_classes=2)


@pytest.mark.skipif(
Expand All @@ -31,7 +29,7 @@ def test_alternative_estimator():
)
def test_invalid_inputs():
"""Test handling of invalid inputs by QUANTClassifier."""
X, y = make_example_3d_numpy()
X, y = EQUAL_LENGTH_UNIVARIATE_CLASSIFICATION["numpy3D"]["train"]

with pytest.raises(ValueError, match="quantile_divisor must be >= 1"):
quant = QUANTClassifier(quantile_divisor=0)
Expand All @@ -40,16 +38,3 @@ def test_invalid_inputs():
with pytest.raises(ValueError, match="interval_depth must be >= 1"):
quant = QUANTClassifier(interval_depth=0)
quant.fit(X, y)


@pytest.mark.skipif(
not _check_soft_dependencies("torch", severity="none"),
reason="skip test if required soft dependency tsfresh not available",
)
def test_predict_proba():
"""Test predict proba with a sklearn classifier without predict proba."""
X, y = make_example_3d_numpy(n_cases=5, n_channels=1, n_timepoints=12)
r = QUANTClassifier(estimator=SVC())
r.fit(X, y)
p = r.predict_proba(X)
assert p.shape == (5, 2)
12 changes: 0 additions & 12 deletions aeon/classification/interval_based/tests/test_rise.py

This file was deleted.

10 changes: 0 additions & 10 deletions aeon/classification/interval_based/tests/test_tsf.py

This file was deleted.

9 changes: 9 additions & 0 deletions aeon/regression/interval_based/_cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ def __init__(
if use_pycatch22:
self.set_tags(**{"python_dependencies": "pycatch22"})

def _fit(self, X, y):
return super()._fit(X, y)

def _predict(self, X) -> np.ndarray:
return super()._predict(X)

def _fit_predict(self, X, y) -> np.ndarray:
return super()._fit_predict(X, y)

@classmethod
def _get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Expand Down
17 changes: 12 additions & 5 deletions aeon/regression/interval_based/_drcif.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
periodogram and differences representations as well as the base series.
"""

import numpy as np
from sklearn.preprocessing import FunctionTransformer

from aeon.base.estimators.interval_based import BaseIntervalForest
Expand Down Expand Up @@ -176,10 +177,7 @@ def __init__(
n_jobs=1,
parallel_backend=None,
):
d = []
self.use_pycatch22 = use_pycatch22
if use_pycatch22:
d.append("pycatch22")

series_transformers = [
None,
Expand Down Expand Up @@ -216,8 +214,17 @@ def __init__(
parallel_backend=parallel_backend,
)

if d:
self.set_tags(**{"python_dependencies": d})
if use_pycatch22:
self.set_tags(**{"python_dependencies": "pycatch22"})

def _fit(self, X, y):
return super()._fit(X, y)

def _predict(self, X) -> np.ndarray:
return super()._predict(X)

def _fit_predict(self, X, y) -> np.ndarray:
return super()._fit_predict(X, y)

@classmethod
def _get_test_params(cls, parameter_set="default"):
Expand Down
9 changes: 9 additions & 0 deletions aeon/regression/interval_based/_interval_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,15 @@ def __init__(
parallel_backend=parallel_backend,
)

def _fit(self, X, y):
return super()._fit(X, y)

def _predict(self, X) -> np.ndarray:
return super()._predict(X)

def _fit_predict(self, X, y) -> np.ndarray:
return super()._fit_predict(X, y)

@classmethod
def _get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Expand Down
9 changes: 9 additions & 0 deletions aeon/regression/interval_based/_rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ def __init__(
parallel_backend=parallel_backend,
)

def _fit(self, X, y):
return super()._fit(X, y)

def _predict(self, X) -> np.ndarray:
return super()._predict(X)

def _fit_predict(self, X, y) -> np.ndarray:
return super()._fit_predict(X, y)

@classmethod
def _get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Expand Down
Loading