Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC] Update the ElasticEnsemble documentation #1544

Merged
merged 3 commits into from
May 21, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions aeon/classification/distance_based/_elastic_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class ElasticEnsemble(BaseClassifier):
----------
distance_measures : str or list of str, default="all"
A list of strings identifying which distance measures to include. Valid values
are one or more of: euclidean, dtw, wdtw, ddtw, dwdtw, lcss, erp, msm, twe, all
are one or more of: ``euclidean``, ``dtw``, ``wdtw``, ``ddtw``, ``wddtw``,
``lcss``, ``erp``, ``msm``, ``twe``, all
itsdivya1309 marked this conversation as resolved.
Show resolved Hide resolved
proportion_of_param_options : float, default=1
The proportion of the parameter grid space to search optional.
proportion_train_in_param_finding : float, default=1
Expand All @@ -66,11 +67,11 @@ class ElasticEnsemble(BaseClassifier):
constituent_build_times_ : array of float
build time for each member of the ensemble.

Notes
-----
References
----------
.. [1] Jason Lines and Anthony Bagnall,
"Time Series Classification with Ensembles of Elastic Distance Measures",
Data Mining and Knowledge Discovery, 29(3), 2015.
Data Mining and Knowledge Discovery, 29(3), 2015.
https://link.springer.com/article/10.1007/s10618-014-0361-2
itsdivya1309 marked this conversation as resolved.
Show resolved Hide resolved

Examples
Expand Down Expand Up @@ -132,7 +133,7 @@ def _fit(self, X, y):
or list of [n_cases] np.ndarray shape (n_channels, n_timepoints_i)
The training input samples.

y : array-like, shape = (n_cases) The class labels.
y : array-like, shape = (n_cases,) The class labels.

Returns
-------
Expand Down Expand Up @@ -398,14 +399,31 @@ def _predict(self, X) -> np.ndarray:
return preds

def get_metric_params(self) -> dict:
"""Return the parameters for the distance metrics used."""
"""Return the parameters for the distance metrics used.

Returns
-------
params : dict
The distance measures and the list of their parameter values.
"""
return {
self._distance_measures[dm]: str(self.estimators_[dm]._distance_params)
for dm in range(len(self.estimators_))
}

@staticmethod
def _get_100_param_options(distance_measure: str, train_x=None):
"""Generate 100 parameter values for each classifier.

Parameters
----------
distance_measure : str, the name of the distance measure.

train_x : np.ndarray of shape = (n_cases, n_channels, n_timepoints)
or list of [n_cases] np.ndarray shape (n_channels, n_timepoints_i)
The training input samples.
"""

def get_inclusive(min_val: float, max_val: float, num_vals: float):
inc = (max_val - min_val) / (num_vals - 1)
return np.arange(min_val, max_val + inc / 2, inc)
Expand Down