Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add AEDCNNClusterer #1911

Merged
merged 80 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
72e32fb
Add DCNN
aadya940 May 17, 2024
d643b26
remove triplet loss and move utils.py to utils/networks
aadya940 May 17, 2024
c1e24ed
Add docstring and minor changes dcnn network
aadya940 May 17, 2024
02f387b
minor fixes
aadya940 May 18, 2024
79fe9fb
Update DCNNEncoderNetwork
aadya940 May 20, 2024
529e7b2
add activation kwarg
aadya940 May 20, 2024
5ed6aeb
minor
aadya940 May 20, 2024
9c3f0b2
minor
aadya940 May 27, 2024
6c7ec5a
minor fixes
aadya940 May 27, 2024
780a775
update class name
aadya940 May 27, 2024
d24b0ed
minor
aadya940 May 31, 2024
14f41c4
minor
aadya940 May 31, 2024
d93dbd3
Add temporal_latent_space kwarg
aadya940 Jun 1, 2024
8e98bd5
minor
aadya940 Jun 2, 2024
f66b3ea
minor
aadya940 Jun 2, 2024
dd18581
Add test for DCNNNetwork
aadya940 Jun 3, 2024
2ef7816
minor
aadya940 Jun 3, 2024
f3b6a57
refactor test
aadya940 Jun 4, 2024
77b19bf
add AEDCNN Network
aadya940 Jun 19, 2024
56d9dbc
Add tag
aadya940 Jun 19, 2024
7daa4ce
bug fix
aadya940 Jun 20, 2024
4aab94d
bug fix
aadya940 Jun 20, 2024
ff30811
bug fixes and add tests
aadya940 Jun 20, 2024
7f71641
add pytest.skipif
aadya940 Jun 20, 2024
5cadb8c
Merge branch 'main' into aedcnn
aadya940 Jun 21, 2024
8f85219
update base
aadya940 Jun 21, 2024
77fe309
Update _ae_dcnn.py
aadya940 Jun 28, 2024
13e5cb7
pre-commit
aadya940 Jun 28, 2024
02903ec
minor
aadya940 Jun 28, 2024
2bb5b8f
minor
aadya940 Jun 28, 2024
8bb3303
minor
aadya940 Jun 29, 2024
487153b
use flatten instead of GMP
aadya940 Jul 10, 2024
aec23da
minor fix
aadya940 Jul 11, 2024
e176dba
merge main
aadya940 Jul 12, 2024
761c922
typo fix
aadya940 Jul 12, 2024
975e152
Merge branch 'aeon-toolkit:main' into aedcnn
aadya940 Aug 4, 2024
c3b7644
Replace Conv1D with Conv1DTranspose in the decoder
aadya940 Aug 4, 2024
b783e8c
Merge branch 'aedcnn' of https://github.com/aadya940/aeon into aedcnn
aadya940 Aug 4, 2024
6ade78d
Add AEDCNNClusterer
aadya940 Aug 5, 2024
2c5e165
Merge branch 'aeon-toolkit:main' into aedcnn-clusterer
aadya940 Aug 6, 2024
b300fc8
add to __init__
aadya940 Aug 7, 2024
a8cf147
Merge branch 'aedcnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Aug 7, 2024
38f6b32
typo fix
aadya940 Aug 7, 2024
72bfa94
bug fix
aadya940 Aug 9, 2024
347a800
num => n
aadya940 Aug 12, 2024
671868b
fix clusterer
aadya940 Aug 12, 2024
5cdb447
fix tests
aadya940 Aug 12, 2024
441fedf
make symmetric only network
aadya940 Aug 12, 2024
09f293b
fix tests
aadya940 Aug 12, 2024
303f820
fix clusterer
aadya940 Aug 12, 2024
5f775f6
Fix bugs
aadya940 Aug 12, 2024
188dc53
Add estimator kwarg
aadya940 Aug 16, 2024
f464298
Add notebook
aadya940 Aug 16, 2024
236b5e2
Merge branch 'main' into aedcnn-clusterer
aadya940 Aug 16, 2024
9bd48dd
Add handling on None in kernel_size
aadya940 Aug 16, 2024
19f1587
Merge branch 'aedcnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Aug 16, 2024
408796c
Add padding kwargs
aadya940 Aug 20, 2024
6ef5436
Fix network
aadya940 Aug 22, 2024
8e6b99e
Merge branch 'main' into aedcnn-clusterer
aadya940 Aug 22, 2024
bff9efe
minor fixes
aadya940 Aug 22, 2024
8c5f9f3
Merge branch 'aedcnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Aug 22, 2024
57033e6
Warn if dilation rate > 1
aadya940 Aug 28, 2024
9e999ce
dilation-rate issues
aadya940 Aug 28, 2024
59fab04
Automatic `pre-commit` fixes
aadya940 Aug 28, 2024
c299c4e
fix tests
aadya940 Aug 28, 2024
c045bf6
Merge branch 'aedcnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Aug 28, 2024
fe67b02
fix
aadya940 Aug 28, 2024
206c0cb
Delete examples/clustering/deep_clustering.ipynb
aadya940 Aug 30, 2024
746bf9a
Add user warning
aadya940 Aug 31, 2024
63f0478
Merge branch 'aedcnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Aug 31, 2024
a804445
Add metrics kwarg to clusterer
aadya940 Aug 31, 2024
7285d6a
Merge branch 'main' into aedcnn-clusterer
aadya940 Oct 28, 2024
c3c11fe
remove return_X_y
aadya940 Oct 31, 2024
86be404
Merge branch 'aedcnn-clusterer' of https://github.com/aadya940/aeon i…
aadya940 Oct 31, 2024
6b410ae
Update _ae_dcnn.py
aadya940 Nov 3, 2024
8093235
Merge branch 'main' into aedcnn-clusterer
aadya940 Nov 4, 2024
7119249
Update _ae_dcnn.py
aadya940 Nov 5, 2024
d8a835b
Update _ae_dcnn.py
aadya940 Nov 5, 2024
9ce2e7e
minor
aadya940 Nov 9, 2024
688c129
Merge branch 'main' into aedcnn-clusterer
aadya940 Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aeon/clustering/deep_learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
"BaseDeepClusterer",
"AEFCNClusterer",
"AEResNetClusterer",
"AEDCNNClusterer",
"AEDRNNClusterer",
"AEAttentionBiGRUClusterer",
"AEBiGRUClusterer",
]
from aeon.clustering.deep_learning._ae_abgru import AEAttentionBiGRUClusterer
from aeon.clustering.deep_learning._ae_bgru import AEBiGRUClusterer
from aeon.clustering.deep_learning._ae_dcnn import AEDCNNClusterer
from aeon.clustering.deep_learning._ae_drnn import AEDRNNClusterer
from aeon.clustering.deep_learning._ae_fcn import AEFCNClusterer
from aeon.clustering.deep_learning._ae_resnet import AEResNetClusterer
Expand Down
350 changes: 350 additions & 0 deletions aeon/clustering/deep_learning/_ae_dcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
"""Deep Learning Auto-Encoder using DCNN Network."""

__all__ = ["AEDCNNClusterer"]

import gc
import os
import time
from copy import deepcopy

from sklearn.utils import check_random_state

from aeon.clustering import DummyClusterer
from aeon.clustering.deep_learning.base import BaseDeepClusterer
from aeon.networks import AEDCNNNetwork


class AEDCNNClusterer(BaseDeepClusterer):
"""Auto-Encoder based Dilated Convolutional Networks (DCNN), as described in [1]_.

Parameters
----------
n_clusters : int, default=None
Number of clusters for the deep learnign model.
clustering_algorithm : str, default="deprecated"
Use 'estimator' parameter instead.
clustering_params : dict, default=None
Use 'estimator' parameter instead.
estimator : aeon clusterer, default=None
An aeon estimator to be built using the transformed data.
Defaults to aeon TimeSeriesKMeans() with euclidean distance
and mean averaging method and n_clusters set to 2.
latent_space_dim : int, default=128
Dimension of the latent space of the auto-encoder.
temporal_latent_space : bool, default = False
Flag to choose whether the latent space is an MTS or Euclidean space.
n_layers : int, default = 3
Number of convolution layers in the encoder.
n_filters : int or list of int, default = None
Number of filters used in convolution layers in the encoder.
kernel_size : int or list of int, default = 3
Size of convolution kernel in the encoder.
dilation_rate : int or list of int, default = 1
The dilation rate for convolution in the encoder.
`dilation_rate` greater than `1` is not supported on
`Conv1DTranspose` for some devices/OS.
activation : str or list of str, default = "relu"
Activation used after the convolution in the encoder.
padding_encoder : str or list of str, default = "causal"
Keras compatible Padding string for the encoder. Defaults to a list
of "causal" paddings.
padding_decoder : str or list of str, default = "same"
Keras compatible Padding string for the decoder. Defaults to a list
of "same" paddings.
use_bias : bool or list of bool, default = True
Whether or not ot use bias in convolution.
n_epochs : int, default = 2000
The number of epochs to train the model.
batch_size : int, default = 16
The number of samples per gradient update.
use_mini_batch_size : bool, default = True,
Whether or not to use the mini batch size formula.
random_state : int, RandomState instance or None, default=None
If `int`, random_state is the seed used by the random number generator;
If `RandomState` instance, random_state is the random number generator;
If `None`, the random number generator is the `RandomState` instance used
by `np.random`.
Seeded random number generation can only be guaranteed on CPU processing,
GPU processing will be non-deterministic.
verbose : boolean, default = False
Whether to output extra information.
loss : string, default="mean_squared_error"
Fit parameter for the keras model.
metrics : List[str], default=["mean_squared_error"]
Metrics to evaluate the performance of the deep learning network.
optimizer : keras.optimizers object, default = Adam(lr=0.01)
Specify the optimizer and the learning rate to be used.
file_path : str, default = "./"
File path to save best model.
save_best_model : bool, default = False
Whether or not to save the best model, if the
modelcheckpoint callback is used by default,
this condition, if True, will prevent the
automatic deletion of the best saved model from
file and the user can choose the file name.
save_last_model : bool, default = False
Whether or not to save the last model, last
epoch trained, using the base class method
save_last_model_to_file.
best_file_name : str, default = "best_model"
The name of the file of the best model, if
save_best_model is set to False, this parameter
is discarded.
last_file_name : str, default = "last_model"
The name of the file of the last model, if
save_last_model is set to False, this parameter
is discarded.
callbacks : keras.callbacks, default = None
List of keras callbacks.

Examples
--------
>>> from aeon.clustering.deep_learning import AEDCNNClusterer
>>> from aeon.datasets import load_unit_test
>>> from aeon.clustering import DummyClusterer
>>> X_train, y_train = load_unit_test(split="train")
>>> X_test, y_test = load_unit_test(split="test")
>>> _clst = DummyClusterer(n_clusters=2)
>>> aedcnn=AEDCNNClusterer(estimator=_clst, n_epochs=20,
... batch_size=4) # doctest: +SKIP
>>> aedcnn.fit(X_train) # doctest: +SKIP
AEDCNNClusterer(...)
"""

def __init__(
self,
n_clusters=None,
estimator=None,
clustering_algorithm="deprecated",
clustering_params=None,
latent_space_dim=128,
temporal_latent_space=False,
n_layers=3,
n_filters=None,
kernel_size=3,
dilation_rate=1,
activation="relu",
padding_encoder="same",
padding_decoder="same",
n_epochs=2000,
batch_size=32,
use_mini_batch_size=False,
random_state=None,
verbose=False,
loss="mse",
metrics=None,
optimizer="Adam",
file_path="./",
save_best_model=False,
save_last_model=False,
best_file_name="best_model",
last_file_name="last_file",
callbacks=None,
):
self.latent_space_dim = latent_space_dim
self.temporal_latent_space = temporal_latent_space
self.n_layers = n_layers
self.n_filters = n_filters
self.kernel_size = kernel_size
self.activation = activation
self.padding_encoder = padding_encoder
self.padding_decoder = padding_decoder
self.dilation_rate = dilation_rate
self.optimizer = optimizer
self.loss = loss
self.metrics = metrics
self.verbose = verbose
self.use_mini_batch_size = use_mini_batch_size
self.callbacks = callbacks
self.file_path = file_path
self.n_epochs = n_epochs
self.save_best_model = save_best_model
self.save_last_model = save_last_model
self.best_file_name = best_file_name
self.random_state = random_state

super().__init__(
n_clusters=n_clusters,
clustering_params=clustering_params,
clustering_algorithm=clustering_algorithm,
estimator=estimator,
batch_size=batch_size,
last_file_name=last_file_name,
)

self._network = AEDCNNNetwork(
latent_space_dim=self.latent_space_dim,
temporal_latent_space=self.temporal_latent_space,
n_layers=self.n_layers,
n_filters=self.n_filters,
kernel_size=self.kernel_size,
dilation_rate=self.dilation_rate,
activation=self.activation,
padding_encoder=self.padding_encoder,
padding_decoder=self.padding_decoder,
)

def build_model(self, input_shape, **kwargs):
"""Construct a compiled, un-trained, keras model that is ready for training.

In aeon, time series are stored in numpy arrays of shape
(n_channels,n_timepoints). Keras/tensorflow assume
data is in shape (n_timepoints,n_channels). This method also assumes
(n_timepoints,n_channels). Transpose should happen in fit.

Parameters
----------
input_shape : tuple
The shape of the data fed into the input layer, should be
(n_timepoints,n_channels).

Returns
-------
output : a compiled Keras Model.
"""
import numpy as np
import tensorflow as tf

rng = check_random_state(self.random_state)
self.random_state_ = rng.randint(0, np.iinfo(np.int32).max)
tf.keras.utils.set_random_seed(self.random_state_)
encoder, decoder = self._network.build_network(input_shape, **kwargs)

input_layer = tf.keras.layers.Input(input_shape, name="input layer")
encoder_output = encoder(input_layer)
decoder_output = decoder(encoder_output)
output_layer = tf.keras.layers.Reshape(
target_shape=input_shape, name="outputlayer"
)(decoder_output)

model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)

self.optimizer_ = (
tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer
)

if self.metrics is None:
self._metrics = ["mean_squared_error"]
elif isinstance(self.metrics, list):
self._metrics = self.metrics
elif isinstance(self.metrics, str):
self._metrics = [self.metrics]
else:
raise ValueError("Metrics should be a list, string, or None.")

model.compile(optimizer=self.optimizer_, loss=self.loss, metrics=self._metrics)

return model

def _fit(self, X):
"""Fit the classifier on the training set (X, y).

Parameters
----------
X : np.ndarray of shape = (n_cases (n), n_channels (d), n_timepoints (m))
The training input samples.

Returns
-------
self : object
"""
import tensorflow as tf

# Transpose to conform to Keras input style.
X = X.transpose(0, 2, 1)

self.input_shape = X.shape[1:]
self.training_model_ = self.build_model(self.input_shape)

if self.verbose:
self.training_model_.summary()

if self.use_mini_batch_size:
mini_batch_size = min(self.batch_size, X.shape[0] // 10)
else:
mini_batch_size = self.batch_size

self.file_name_ = (
self.best_file_name if self.save_best_model else str(time.time_ns())
)

if self.callbacks is None:
self.callbacks_ = [
tf.keras.callbacks.ReduceLROnPlateau(
monitor="loss", factor=0.5, patience=50, min_lr=0.0001
),
tf.keras.callbacks.ModelCheckpoint(
filepath=self.file_path + self.file_name_ + ".keras",
monitor="loss",
save_best_only=True,
),
]
else:
self.callbacks_ = self._get_model_checkpoint_callback(
callbacks=self.callbacks,
file_path=self.file_path,
file_name=self.file_name_,
)

self.history = self.training_model_.fit(
X,
X,
batch_size=mini_batch_size,
epochs=self.n_epochs,
verbose=self.verbose,
callbacks=self.callbacks_,
)

try:
self.model_ = tf.keras.models.load_model(
self.file_path + self.file_name_ + ".keras", compile=False
)
if not self.save_best_model:
os.remove(self.file_path + self.file_name_ + ".keras")
except FileNotFoundError:
self.model_ = deepcopy(self.training_model_)

self._fit_clustering(X=X)

gc.collect()

return self

def _score(self, X, y=None):
# Transpose to conform to Keras input style.
X = X.transpose(0, 2, 1)
latent_space = self.model_.layers[1].predict(X)
return self._estimator.score(latent_space)

@classmethod
def _get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.

Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
For classifiers, a "default" set of parameters should be provided for
general testing, and a "results_comparison" set for comparing against
previously recorded results if the general set does not produce suitable
probabilities to compare against.

Returns
-------
params : dict or list of dict, default={}
Parameters to create testing instances of the class.
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`.
"""
param1 = {
"estimator": DummyClusterer(n_clusters=2),
"n_epochs": 1,
"batch_size": 4,
"n_layers": 1,
"n_filters": 1,
"kernel_size": None,
}

return [param1]