From cf10ce9541eb2be52dfff615e61e6ab1b8599c00 Mon Sep 17 00:00:00 2001 From: Ali El Hadi ISMAIL FAWAZ <54309336+hadifawaz1999@users.noreply.github.com> Date: Mon, 11 Nov 2024 19:27:34 +0100 Subject: [PATCH] [ENH] Add DisjointCNN classifier and Regressor (#2316) * add network * add to init * test input list for kernel init * fix bug test network * fix bug test * adding deep classifier * update api * fix test * add regressor and refactor * no test for mpl * bug copying cls to rgs * add rs as self --- aeon/classification/deep_learning/__init__.py | 2 + aeon/classification/deep_learning/_cnn.py | 35 +- .../deep_learning/_disjoint_cnn.py | 416 +++++++++++++++++ aeon/classification/deep_learning/_encoder.py | 41 +- aeon/classification/deep_learning/_fcn.py | 39 +- .../deep_learning/_inception_time.py | 133 +++--- .../deep_learning/_lite_time.py | 54 ++- aeon/classification/deep_learning/_mlp.py | 46 +- aeon/classification/deep_learning/_resnet.py | 33 +- aeon/networks/__init__.py | 2 + aeon/networks/_disjoint_cnn.py | 305 +++++++++++++ aeon/networks/_mlp.py | 20 +- aeon/networks/tests/test_all_networks.py | 3 + aeon/networks/tests/test_disjoint_cnn.py | 22 + aeon/regression/deep_learning/__init__.py | 2 + aeon/regression/deep_learning/_cnn.py | 60 +-- .../regression/deep_learning/_disjoint_cnn.py | 419 ++++++++++++++++++ aeon/regression/deep_learning/_encoder.py | 27 +- aeon/regression/deep_learning/_fcn.py | 56 +-- .../deep_learning/_inception_time.py | 156 ++++--- aeon/regression/deep_learning/_lite_time.py | 54 ++- aeon/regression/deep_learning/_mlp.py | 39 +- aeon/regression/deep_learning/_resnet.py | 53 +-- docs/api_reference/classification.rst | 1 + docs/api_reference/networks.rst | 1 + docs/api_reference/regression.rst | 1 + 26 files changed, 1665 insertions(+), 355 deletions(-) create mode 100644 aeon/classification/deep_learning/_disjoint_cnn.py create mode 100644 aeon/networks/_disjoint_cnn.py create mode 100644 aeon/networks/tests/test_disjoint_cnn.py create mode 100644 aeon/regression/deep_learning/_disjoint_cnn.py diff --git a/aeon/classification/deep_learning/__init__.py b/aeon/classification/deep_learning/__init__.py index a7c1c73f5c..5fac750595 100644 --- a/aeon/classification/deep_learning/__init__.py +++ b/aeon/classification/deep_learning/__init__.py @@ -12,8 +12,10 @@ "TapNetClassifier", "LITETimeClassifier", "IndividualLITEClassifier", + "DisjointCNNClassifier", ] from aeon.classification.deep_learning._cnn import TimeCNNClassifier +from aeon.classification.deep_learning._disjoint_cnn import DisjointCNNClassifier from aeon.classification.deep_learning._encoder import EncoderClassifier from aeon.classification.deep_learning._fcn import FCNClassifier from aeon.classification.deep_learning._inception_time import ( diff --git a/aeon/classification/deep_learning/_cnn.py b/aeon/classification/deep_learning/_cnn.py index 5b87cab9b8..d099fb04fc 100644 --- a/aeon/classification/deep_learning/_cnn.py +++ b/aeon/classification/deep_learning/_cnn.py @@ -61,12 +61,19 @@ class TimeCNNClassifier(BaseDeepClassifier): The number of samples per gradient update. verbose : boolean, default = False Whether to output extra information. - loss : string, default = "mean_squared_error" - Fit parameter for the keras model. - optimizer : keras.optimizer, default = keras.optimizers.Adam() - metrics : list of strings, default = ["accuracy"] - callbacks : keras.callbacks, default = model_checkpoint - To save best model on training loss. + loss : str, default = "mean_squared_error" + The name of the keras training loss. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint. file_path : file_path for the best model Only used if checkpoint is used as callback. save_best_model : bool, default = False @@ -131,7 +138,7 @@ def __init__( init_file_name="init_model", verbose=False, loss="mean_squared_error", - metrics=None, + metrics="accuracy", random_state=None, use_bias=True, optimizer=None, @@ -201,18 +208,13 @@ def build_model(self, input_shape, n_classes, **kwargs): import numpy as np import tensorflow as tf - if self.metrics is None: - metrics = ["accuracy"] - else: - metrics = self.metrics - rng = check_random_state(self.random_state) self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) tf.keras.utils.set_random_seed(self.random_state_) input_layer, output_layer = self._network.build_network(input_shape, **kwargs) output_layer = tf.keras.layers.Dense( - units=n_classes, activation=self.activation, use_bias=self.use_bias + units=n_classes, activation=self.activation )(output_layer) self.optimizer_ = ( @@ -223,7 +225,7 @@ def build_model(self, input_shape, n_classes, **kwargs): model.compile( loss=self.loss, optimizer=self.optimizer_, - metrics=metrics, + metrics=self._metrics, ) return model @@ -249,6 +251,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) diff --git a/aeon/classification/deep_learning/_disjoint_cnn.py b/aeon/classification/deep_learning/_disjoint_cnn.py new file mode 100644 index 0000000000..dfe857031c --- /dev/null +++ b/aeon/classification/deep_learning/_disjoint_cnn.py @@ -0,0 +1,416 @@ +"""DisjointCNN classifier.""" + +__maintainer__ = ["hadifawaz1999"] +__all__ = ["DisjointCNNClassifier"] + +import gc +import os +import time +from copy import deepcopy + +from sklearn.utils import check_random_state + +from aeon.classification.deep_learning.base import BaseDeepClassifier +from aeon.networks import DisjointCNNNetwork + + +class DisjointCNNClassifier(BaseDeepClassifier): + """Disjoint Convolutional Neural Netowkr classifier. + + Adapted from the implementation used in [1]_. + + Parameters + ---------- + n_layers : int, default = 4 + Number of 1+1D Convolution layers. + n_filters : int or list of int, default = 64 + Number of filters used in convolution layers. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + kernel_size : int or list of int, default = [8, 5, 5, 3] + Size of convolution kernel. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + dilation_rate : int or list of int, default = 1 + The dilation rate for convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + strides : int or list of int, default = 1 + The strides of the convolution filter. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + padding : str or list of str, default = "same" + The type of padding used for convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + activation : str or list of str, default = "elu" + Activation used after the convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + use_bias : bool or list of bool, default = True + Whether or not ot use bias in convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + kernel_initializer: str or list of str, default = "he_uniform" + The initialization method of convolution layers. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + pool_size: int, default = 5 + The size of the one max pool layer at the end of + the model, default = 5. + pool_strides: int, default = None + The strides used for the one max pool layer at + the end of the model, default = None. + pool_padding: str, default = "valid" + The padding method for the one max pool layer at + the end of the model, default = "valid". + hidden_fc_units: int, default = 128 + The number of fully connected units. + activation_fc: str, default = "relu" + The activation of the fully connected layer. + n_epochs : int, default = 2000 + The number of epochs to train the model. + batch_size : int, default = 16 + The number of samples per gradient update. + use_mini_batch_size : bool, default = False + Whether or not to use the mini batch size formula. + random_state : int, RandomState instance or None, default=None + If `int`, random_state is the seed used by the random number generator; + If `RandomState` instance, random_state is the random number generator; + If `None`, the random number generator is the `RandomState` instance used + by `np.random`. + Seeded random number generation can only be guaranteed on CPU processing, + GPU processing will be non-deterministic. + verbose : boolean, default = False + Whether to output extra information. + loss : str, default = "categorical_crossentropy" + The name of the keras training loss. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + file_path : str, default = "./" + File path to save best model. + save_best_model : bool, default = False + Whether or not to save the best model, if the + modelcheckpoint callback is used by default, + this condition, if True, will prevent the + automatic deletion of the best saved model from + file and the user can choose the file name. + save_last_model : bool, default = False + Whether or not to save the last model, last + epoch trained, using the base class method + save_last_model_to_file. + save_init_model : bool, default = False + Whether to save the initialization of the model. + best_file_name : str, default = "best_model" + The name of the file of the best model, if + save_best_model is set to False, this parameter + is discarded. + last_file_name : str, default = "last_model" + The name of the file of the last model, if + save_last_model is set to False, this parameter + is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if + save_init_model is set to False, + this parameter is discarded. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. + + Notes + ----- + Adapted from the implementation from: + https://github.com/Navidfoumani/Disjoint-CNN + + References + ---------- + .. [1] Foumani, Seyed Navid Mohammadi, Chang Wei Tan, and Mahsa Salehi. + "Disjoint-cnn for multivariate time series classification." + 2021 International Conference on Data Mining Workshops + (ICDMW). IEEE, 2021. + + Examples + -------- + >>> from aeon.classification.deep_learning import DisjointCNNClassifier + >>> from aeon.datasets import load_unit_test + >>> X_train, y_train = load_unit_test(split="train") + >>> X_test, y_test = load_unit_test(split="test") + >>> disjoint_cnn = DisjointCNNClassifier(n_epochs=20, + ... batch_size=4) # doctest: +SKIP + >>> disjoint_cnn.fit(X_train, y_train) # doctest: +SKIP + DisjointCNNClassifier(...) + """ + + def __init__( + self, + n_layers=4, + n_filters=64, + kernel_size=None, + dilation_rate=1, + strides=1, + padding="same", + activation="elu", + use_bias=True, + kernel_initializer="he_uniform", + pool_size=5, + pool_strides=None, + pool_padding="valid", + hidden_fc_units=128, + activation_fc="relu", + n_epochs=2000, + batch_size=16, + use_mini_batch_size=False, + random_state=None, + verbose=False, + loss="categorical_crossentropy", + metrics="accuracy", + optimizer=None, + file_path="./", + save_best_model=False, + save_last_model=False, + save_init_model=False, + best_file_name="best_model", + last_file_name="last_model", + init_file_name="init_model", + callbacks=None, + ): + self.n_layers = n_layers + self.n_filters = n_filters + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.strides = strides + self.padding = padding + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.pool_size = pool_size + self.pool_strides = pool_strides + self.pool_padding = pool_padding + self.hidden_fc_units = hidden_fc_units + self.activation_fc = activation_fc + + self.callbacks = callbacks + self.n_epochs = n_epochs + self.use_mini_batch_size = use_mini_batch_size + self.verbose = verbose + self.loss = loss + self.metrics = metrics + self.optimizer = optimizer + + self.file_path = file_path + self.save_best_model = save_best_model + self.save_last_model = save_last_model + self.save_init_model = save_init_model + self.best_file_name = best_file_name + self.init_file_name = init_file_name + + self.history = None + + super().__init__( + batch_size=batch_size, + random_state=random_state, + last_file_name=last_file_name, + ) + + self._network = DisjointCNNNetwork( + n_layers=self.n_layers, + n_filters=self.n_filters, + kernel_size=self.kernel_size, + dilation_rate=self.dilation_rate, + strides=self.strides, + padding=self.padding, + activation=self.activation, + use_bias=self.use_bias, + kernel_initializer=self.kernel_initializer, + pool_size=self.pool_size, + pool_strides=self.pool_strides, + pool_padding=self.pool_padding, + hidden_fc_units=self.hidden_fc_units, + activation_fc=self.activation_fc, + ) + + def build_model(self, input_shape, n_classes, **kwargs): + """Construct a compiled, un-trained, keras model that is ready for training. + + In aeon, time series are stored in numpy arrays of shape (d,m), where d + is the number of dimensions, m is the series length. Keras/tensorflow assume + data is in shape (m,d). This method also assumes (m,d). Transpose should + happen in fit. + + Parameters + ---------- + input_shape : tuple + The shape of the data fed into the input layer, should be (m, d). + n_classes : int + The number of classes, which becomes the size of the output layer. + + Returns + ------- + output : a compiled Keras Model + """ + import numpy as np + import tensorflow as tf + + rng = check_random_state(self.random_state) + self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) + tf.keras.utils.set_random_seed(self.random_state_) + input_layer, output_layer = self._network.build_network(input_shape, **kwargs) + + output_layer = tf.keras.layers.Dense(units=n_classes, activation="softmax")( + output_layer + ) + + self.optimizer_ = ( + tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer + ) + + model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer) + model.compile( + loss=self.loss, + optimizer=self.optimizer_, + metrics=self._metrics, + ) + + return model + + def _fit(self, X, y): + """Fit the classifier on the training set (X, y). + + Parameters + ---------- + X : np.ndarray + The training input samples of shape (n_cases, n_channels, n_timepoints) + y : np.ndarray + The training data class labels of shape (n_cases,). + + Returns + ------- + self : object + """ + import tensorflow as tf + + y_onehot = self.convert_y_to_keras(y) + # Transpose to conform to Keras input style. + X = X.transpose(0, 2, 1) + + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + + self.input_shape = X.shape[1:] + self.training_model_ = self.build_model(self.input_shape, self.n_classes_) + + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + + if self.verbose: + self.training_model_.summary() + + if self.use_mini_batch_size: + mini_batch_size = min(self.batch_size, X.shape[0] // 10) + else: + mini_batch_size = self.batch_size + + self.file_name_ = ( + self.best_file_name if self.save_best_model else str(time.time_ns()) + ) + + if self.callbacks is None: + self.callbacks_ = [ + tf.keras.callbacks.ReduceLROnPlateau( + monitor="loss", factor=0.5, patience=50, min_lr=0.0001 + ), + tf.keras.callbacks.ModelCheckpoint( + filepath=self.file_path + self.file_name_ + ".keras", + monitor="loss", + save_best_only=True, + ), + ] + else: + self.callbacks_ = self._get_model_checkpoint_callback( + callbacks=self.callbacks, + file_path=self.file_path, + file_name=self.file_name_, + ) + + self.history = self.training_model_.fit( + X, + y_onehot, + batch_size=mini_batch_size, + epochs=self.n_epochs, + verbose=self.verbose, + callbacks=self.callbacks_, + ) + + try: + self.model_ = tf.keras.models.load_model( + self.file_path + self.file_name_ + ".keras", compile=False + ) + if not self.save_best_model: + os.remove(self.file_path + self.file_name_ + ".keras") + except FileNotFoundError: + self.model_ = deepcopy(self.training_model_) + + if self.save_last_model: + self.save_last_model_to_file(file_path=self.file_path) + + gc.collect() + return self + + @classmethod + def _get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default = "default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return "default" set. + For classifiers, a "default" set of parameters should be provided for + general testing, and a "results_comparison" set for comparing against + previously recorded results if the general set does not produce suitable + probabilities to compare against. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class. + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + """ + param1 = { + "n_epochs": 3, + "batch_size": 4, + "use_bias": False, + "n_layers": 2, + "n_filters": 2, + "kernel_size": [2, 2], + } + param2 = { + "n_epochs": 3, + "batch_size": 4, + "use_bias": False, + "n_layers": 2, + "n_filters": 2, + "kernel_size": [2, 2], + "verbose": True, + "metrics": ["accuracy"], + "use_mini_batch_size": True, + } + + return [param1, param2] diff --git a/aeon/classification/deep_learning/_encoder.py b/aeon/classification/deep_learning/_encoder.py index 1773d1ea1c..cb44688d6b 100644 --- a/aeon/classification/deep_learning/_encoder.py +++ b/aeon/classification/deep_learning/_encoder.py @@ -40,8 +40,29 @@ class EncoderClassifier(BaseDeepClassifier): fc_units : int, default = 256 Specifying the number of units in the hidden fully connected layer used in the EncoderNetwork. + verbose : boolean, default = False + Whether to output extra information. + loss : str, default = "categorical_crossentropy" + The name of the keras training loss. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint. file_path : str, default = "./" File path when saving model_Checkpoint callback. + n_epochs : int, default = 2000 + The number of epochs to train the model. + batch_size : int, default = 16 + The number of samples per gradient update. + use_mini_batch_size : bool, default = False + Whether or not to use the mini batch size formula. save_best_model : bool, default = False Whether or not to save the best model, if the modelcheckpoint callback is used by default, @@ -87,10 +108,6 @@ class EncoderClassifier(BaseDeepClassifier): """ - _tags = { - "python_dependencies": ["tensorflow"], - } - def __init__( self, n_epochs=100, @@ -113,7 +130,7 @@ def __init__( init_file_name="init_model", verbose=False, loss="categorical_crossentropy", - metrics=None, + metrics="accuracy", random_state=None, use_bias=True, optimizer=None, @@ -182,18 +199,13 @@ def build_model(self, input_shape, n_classes, **kwargs): import numpy as np import tensorflow as tf - if self.metrics is None: - metrics = ["accuracy"] - else: - metrics = self.metrics - rng = check_random_state(self.random_state) self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) tf.keras.utils.set_random_seed(self.random_state_) input_layer, output_layer = self._network.build_network(input_shape, **kwargs) output_layer = tf.keras.layers.Dense( - units=n_classes, activation=self.activation, use_bias=self.use_bias + units=n_classes, activation=self.activation )(output_layer) self.optimizer_ = ( @@ -206,7 +218,7 @@ def build_model(self, input_shape, n_classes, **kwargs): model.compile( loss=self.loss, optimizer=self.optimizer_, - metrics=metrics, + metrics=self._metrics, ) return model @@ -231,6 +243,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) diff --git a/aeon/classification/deep_learning/_fcn.py b/aeon/classification/deep_learning/_fcn.py index e79ea65270..ec1e4fc130 100644 --- a/aeon/classification/deep_learning/_fcn.py +++ b/aeon/classification/deep_learning/_fcn.py @@ -52,11 +52,15 @@ class FCNClassifier(BaseDeepClassifier): GPU processing will be non-deterministic. verbose : boolean, default = False Whether to output extra information. - loss : string, default = "mean_squared_error" - Fit parameter for the keras model. - metrics : list of strings, default = ["accuracy"] - optimizer : keras.optimizers object, default = Adam(lr=0.01) - Specify the optimizer and the learning rate to be used. + loss : str, default = "categorical_crossentropy" + The name of the keras training loss. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. file_path : str, default = "./" File path to save best model. save_best_model : bool, default = False @@ -83,7 +87,10 @@ class FCNClassifier(BaseDeepClassifier): The name of the file of the init model, if save_init_model is set to False, this parameter is discarded. - callbacks : keras.callbacks, default = None + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. Notes ----- @@ -128,7 +135,7 @@ def __init__( callbacks=None, verbose=False, loss="categorical_crossentropy", - metrics=None, + metrics="accuracy", random_state=None, use_bias=True, optimizer=None, @@ -198,19 +205,14 @@ def build_model(self, input_shape, n_classes, **kwargs): import numpy as np import tensorflow as tf - if self.metrics is None: - metrics = ["accuracy"] - else: - metrics = self.metrics - rng = check_random_state(self.random_state) self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) tf.keras.utils.set_random_seed(self.random_state_) input_layer, output_layer = self._network.build_network(input_shape, **kwargs) - output_layer = tf.keras.layers.Dense( - units=n_classes, activation="softmax", use_bias=self.use_bias - )(output_layer) + output_layer = tf.keras.layers.Dense(units=n_classes, activation="softmax")( + output_layer + ) self.optimizer_ = ( tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer @@ -220,7 +222,7 @@ def build_model(self, input_shape, n_classes, **kwargs): model.compile( loss=self.loss, optimizer=self.optimizer_, - metrics=metrics, + metrics=self._metrics, ) return model @@ -245,6 +247,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) diff --git a/aeon/classification/deep_learning/_inception_time.py b/aeon/classification/deep_learning/_inception_time.py index d56f3659a8..00c4e479f9 100644 --- a/aeon/classification/deep_learning/_inception_time.py +++ b/aeon/classification/deep_learning/_inception_time.py @@ -23,50 +23,50 @@ class InceptionTimeClassifier(BaseClassifier): Parameters ---------- - n_classifiers : int, default = 5, + n_classifiers : int, default = 5, the number of Inception models used for the Ensemble in order to create InceptionTime. - depth : int, default = 6, + depth : int, default = 6, the number of inception modules used - n_filters : int or list of int32, default = 32, + n_filters : int or list of int32, default = 32, the number of filters used in one inception module, if not a list, the same number of filters is used in all inception modules - n_conv_per_layer : int or list of int, default = 3, + n_conv_per_layer : int or list of int, default = 3, the number of convolution layers in each inception module, if not a list, the same number of convolution layers is used in all inception modules - kernel_size : int or list of int, default = 40, + kernel_size : int or list of int, default = 40, the head kernel size used for each inception module, if not a list, the same is used in all inception modules - use_max_pooling : bool or list of bool, default = True, + use_max_pooling : bool or list of bool, default = True, conditioning whether or not to use max pooling layer in inception modules,if not a list, the same is used in all inception modules - max_pool_size : int or list of int, default = 3, + max_pool_size : int or list of int, default = 3, the size of the max pooling layer, if not a list, the same is used in all inception modules - strides : int or list of int, default = 1, + strides : int or list of int, default = 1, the strides of kernels in convolution layers for each inception module, if not a list, the same is used in all inception modules - dilation_rate : int or list of int, default = 1, + dilation_rate : int or list of int, default = 1, the dilation rate of convolutions in each inception module, if not a list, the same is used in all inception modules - padding : str or list of str, default = "same", + padding : str or list of str, default = "same", the type of padding used for convoltuon for each inception module, if not a list, the same is used in all inception modules - activation : str or list of str, default = "relu", + activation : str or list of str, default = "relu", the activation function used in each inception module, if not a list, the same is used in all inception modules - use_bias : bool or list of bool, default = False, + use_bias : bool or list of bool, default = False, conditioning whether or not convolutions should use bias values in each inception module, if not a list, @@ -89,8 +89,10 @@ class InceptionTimeClassifier(BaseClassifier): formula Wang et al. n_epochs : int, default = 1500 the number of epochs to train the model. - callbacks : callable or None, default = ReduceOnPlateau and ModelCheckpoint - list of tf.keras.callbacks.Callback objects. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. file_path : str, default = "./" file_path when saving model_Checkpoint callback save_best_model : bool, default = False @@ -99,17 +101,17 @@ class InceptionTimeClassifier(BaseClassifier): this condition, if True, will prevent the automatic deletion of the best saved model from file and the user can choose the file name - save_last_model : bool, default = False + save_last_model : bool, default = False Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file save_init_model : bool, default = False Whether to save the initialization of the model. - best_file_name : str, default = "best_model" + best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter is discarded - last_file_name : str, default = "last_model" + last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter is discarded @@ -123,12 +125,17 @@ class InceptionTimeClassifier(BaseClassifier): by `np.random`. Seeded random number generation can only be guaranteed on CPU processing, GPU processing will be non-deterministic. - verbose : boolean, default = False + verbose : boolean, default = False whether to output extra information - optimizer : keras optimizer, default = Adam - loss : keras loss, default = categorical_crossentropy - metrics : keras metrics, default = None, - will be set to accuracy as default if None + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + loss : str, default = "categorical_crossentropy" + The name of the keras training loss. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. Notes ----- @@ -199,7 +206,7 @@ def __init__( random_state=None, verbose=False, loss="categorical_crossentropy", - metrics=None, + metrics="accuracy", optimizer=None, ): self.n_classifiers = n_classifiers @@ -382,80 +389,81 @@ class IndividualInceptionClassifier(BaseDeepClassifier): Parameters ---------- - depth : int, default = 6, + depth : int, default = 6, the number of inception modules used - n_filters : int or list of int32, default = 32, + n_filters : int or list of int32, default = 32, the number of filters used in one inception module, if not a list, the same number of filters is used in all inception modules - n_conv_per_layer : int or list of int, default = 3, + n_conv_per_layer : int or list of int, default = 3, the number of convolution layers in each inception module, if not a list, the same number of convolution layers is used in all inception modules - kernel_size : int or list of int, default = 40, + kernel_size : int or list of int, default = 40, the head kernel size used for each inception module, if not a list, the same is used in all inception modules - use_max_pooling : bool or list of bool, default = True, + use_max_pooling : bool or list of bool, default = True, conditioning whether or not to use max pooling layer in inception modules,if not a list, the same is used in all inception modules - max_pool_size : int or list of int, default = 3, + max_pool_size : int or list of int, default = 3, the size of the max pooling layer, if not a list, the same is used in all inception modules - strides : int or list of int, default = 1, + strides : int or list of int, default = 1, the strides of kernels in convolution layers for each inception module, if not a list, the same is used in all inception modules - dilation_rate : int or list of int, default = 1, + dilation_rate : int or list of int, default = 1, the dilation rate of convolutions in each inception module, if not a list, the same is used in all inception modules - padding : str or list of str, default = "same", + padding : str or list of str, default = "same", the type of padding used for convoltuon for each inception module, if not a list, the same is used in all inception modules - activation : str or list of str, default = "relu", + activation : str or list of str, default = "relu", the activation function used in each inception module, if not a list, the same is used in all inception modules - use_bias : bool or list of bool, default = False, + use_bias : bool or list of bool, default = False, conditioning whether or not convolutions should use bias values in each inception module, if not a list, the same is used in all inception modules - use_residual : bool, default = True, + use_residual : bool, default = True, condition whether or not to use residual connections all over Inception - use_bottleneck : bool, default = True, + use_bottleneck : bool, default = True, confition whether or not to use bottlenecks all over Inception - bottleneck_size : int, default = 32, + bottleneck_size : int, default = 32, the bottleneck size in case use_bottleneck = True - use_custom_filters : bool, default = False, + use_custom_filters : bool, default = False, condition on whether or not to use custom filters in the first inception module - batch_size : int, default = 64 + batch_size : int, default = 64 the number of samples per gradient update. use_mini_batch_size : bool, default = False condition on using the mini batch size formula Wang et al. - n_epochs : int, default = 1500 + n_epochs : int, default = 1500 the number of epochs to train the model. - callbacks : callable or None, default - ReduceOnPlateau and ModelCheckpoint - list of tf.keras.callbacks.Callback objects. - file_path : str, default = "./" + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. + file_path : str, default = "./" file_path when saving model_Checkpoint callback - save_best_model : bool, default = False + save_best_model : bool, default = False Whether or not to save the best model, if the modelcheckpoint callback is used by default, this condition, if True, will prevent the automatic deletion of the best saved model from file and the user can choose the file name - save_last_model : bool, default = False + save_last_model : bool, default = False Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file save_init_model : bool, default = False Whether to save the initialization of the model. - best_file_name : str, default = "best_model" + best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter is discarded. - last_file_name : str, default = "last_model" + last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter is discarded. @@ -469,12 +477,17 @@ class IndividualInceptionClassifier(BaseDeepClassifier): by `np.random`. Seeded random number generation can only be guaranteed on CPU processing, GPU processing will be non-deterministic. - verbose : boolean, default = False + verbose : boolean, default = False whether to output extra information - optimizer : keras optimizer, default = Adam - loss : keras loss, default = categorical_crossentropy - metrics : keras metrics, default = None, will be set - to accuracy as default if None + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + loss : str, default = "categorical_crossentropy" + The name of the keras training loss. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. Notes ----- @@ -534,7 +547,7 @@ def __init__( random_state=None, verbose=False, loss="categorical_crossentropy", - metrics=None, + metrics="accuracy", optimizer=None, ): # predefined @@ -624,11 +637,6 @@ def build_model(self, input_shape, n_classes, **kwargs): model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer) - if self.metrics is None: - metrics = ["accuracy"] - else: - metrics = self.metrics - self.optimizer_ = ( tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer ) @@ -636,7 +644,7 @@ def build_model(self, input_shape, n_classes, **kwargs): model.compile( loss=self.loss, optimizer=self.optimizer_, - metrics=metrics, + metrics=self._metrics, ) return model @@ -665,6 +673,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + # ignore the number of instances, X.shape[0], # just want the shape of each instance self.input_shape = X.shape[1:] diff --git a/aeon/classification/deep_learning/_lite_time.py b/aeon/classification/deep_learning/_lite_time.py index b53397939b..ac1594ecd7 100644 --- a/aeon/classification/deep_learning/_lite_time.py +++ b/aeon/classification/deep_learning/_lite_time.py @@ -54,8 +54,10 @@ class LITETimeClassifier(BaseClassifier): formula Wang et al. n_epochs : int, default = 1500 the number of epochs to train the model. - callbacks : callable or None, default = ReduceOnPlateau and ModelCheckpoint - list of tf.keras.callbacks.Callback objects. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. file_path : str, default = "./" file_path when saving model_Checkpoint callback save_best_model : bool, default = False @@ -90,10 +92,15 @@ class LITETimeClassifier(BaseClassifier): GPU processing will be non-deterministic. verbose : boolean, default = False whether to output extra information - optimizer : keras optimizer, default = Adam - loss : keras loss, default = categorical_crossentropy - metrics : keras metrics, default = None, - will be set to accuracy as default if None + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + loss : str, default = "categorical_crossentropy" + The name of the keras training loss. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. References ---------- @@ -150,7 +157,7 @@ def __init__( random_state=None, verbose=False, loss="categorical_crossentropy", - metrics=None, + metrics="accuracy", optimizer=None, ): self.n_classifiers = n_classifiers @@ -349,8 +356,10 @@ class IndividualLITEClassifier(BaseDeepClassifier): formula Wang et al. n_epochs : int, default = 1500 the number of epochs to train the model. - callbacks : callable or None, default = ReduceOnPlateau and ModelCheckpoint - list of tf.keras.callbacks.Callback objects. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. file_path : str, default = "./" file_path when saving model_Checkpoint callback save_best_model : bool, default = False @@ -385,10 +394,15 @@ class IndividualLITEClassifier(BaseDeepClassifier): GPU processing will be non-deterministic. verbose : boolean, default = False whether to output extra information - optimizer : keras optimizer, default = Adam - loss : keras loss, default = categorical_crossentropy - metrics : keras metrics, default = None, - will be set to accuracy as default if None + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + loss : str, default = "categorical_crossentropy" + The name of the keras training loss. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. References ---------- @@ -436,7 +450,7 @@ def __init__( random_state=None, verbose=False, loss="categorical_crossentropy", - metrics=None, + metrics="accuracy", optimizer=None, ): self.use_litemv = use_litemv @@ -506,11 +520,6 @@ def build_model(self, input_shape, n_classes, **kwargs): model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer) - if self.metrics is None: - metrics = ["accuracy"] - else: - metrics = self.metrics - self.optimizer_ = ( tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer ) @@ -518,7 +527,7 @@ def build_model(self, input_shape, n_classes, **kwargs): model.compile( loss=self.loss, optimizer=self.optimizer_, - metrics=metrics, + metrics=self._metrics, ) return model @@ -546,6 +555,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + # ignore the number of instances, X.shape[0], # just want the shape of each instance self.input_shape = X.shape[1:] diff --git a/aeon/classification/deep_learning/_mlp.py b/aeon/classification/deep_learning/_mlp.py index 0f0d538cb9..6768ba48bd 100644 --- a/aeon/classification/deep_learning/_mlp.py +++ b/aeon/classification/deep_learning/_mlp.py @@ -21,13 +21,18 @@ class MLPClassifier(BaseDeepClassifier): Parameters ---------- + use_bias : bool, default = True + Condition on whether or not to use bias values for dense layers. n_epochs : int, default = 2000 the number of epochs to train the model batch_size : int, default = 16 the number of samples per gradient update. use_mini_batch_size : boolean, default = False Condition on using the mini batch size formula - callbacks : callable or None, default + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -37,8 +42,8 @@ class MLPClassifier(BaseDeepClassifier): GPU processing will be non-deterministic. verbose : boolean, default = False whether to output extra information - loss : string, default="mean_squared_error" - fit parameter for the keras model + loss : str, default = "categorical_crossentropy" + The name of the keras training loss. file_path : str, default = "./" file_path when saving model_Checkpoint callback save_best_model : bool, default = False @@ -64,14 +69,17 @@ class MLPClassifier(BaseDeepClassifier): init_file_name : str, default = "init_model" The name of the file of the init model, if save_init_model is set to False, this parameter is discarded. - optimizer : keras.optimizer, default=keras.optimizers.Adadelta(), - metrics : list of strings, default=["accuracy"], + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. activation : string or a tf callable, default="sigmoid" Activation function used in the output linear layer. List of available activation functions: https://keras.io/api/layers/activations/ - use_bias : boolean, default = True - whether the layer uses a bias vector. Notes ----- @@ -96,13 +104,14 @@ class MLPClassifier(BaseDeepClassifier): def __init__( self, + use_bias=True, n_epochs=2000, batch_size=16, use_mini_batch_size=False, callbacks=None, verbose=False, loss="categorical_crossentropy", - metrics=None, + metrics="accuracy", file_path="./", save_best_model=False, save_last_model=False, @@ -112,7 +121,6 @@ def __init__( init_file_name="init_model", random_state=None, activation="sigmoid", - use_bias=True, optimizer=None, ): self.callbacks = callbacks @@ -139,7 +147,7 @@ def __init__( last_file_name=last_file_name, ) - self._network = MLPNetwork() + self._network = MLPNetwork(use_bias=self.use_bias) def build_model(self, input_shape, n_classes, **kwargs): """Construct a compiled, un-trained, keras model that is ready for training. @@ -164,19 +172,14 @@ def build_model(self, input_shape, n_classes, **kwargs): import tensorflow as tf from tensorflow import keras - if self.metrics is None: - metrics = ["accuracy"] - else: - metrics = self.metrics - rng = check_random_state(self.random_state) self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) tf.keras.utils.set_random_seed(self.random_state_) input_layer, output_layer = self._network.build_network(input_shape, **kwargs) - output_layer = keras.layers.Dense( - units=n_classes, activation="softmax", use_bias=self.use_bias - )(output_layer) + output_layer = keras.layers.Dense(units=n_classes, activation="softmax")( + output_layer + ) self.optimizer_ = ( keras.optimizers.Adadelta() if self.optimizer is None else self.optimizer @@ -186,7 +189,7 @@ def build_model(self, input_shape, n_classes, **kwargs): model.compile( loss=self.loss, optimizer=self.optimizer_, - metrics=metrics, + metrics=self._metrics, ) return model @@ -210,6 +213,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) diff --git a/aeon/classification/deep_learning/_resnet.py b/aeon/classification/deep_learning/_resnet.py index fe754548aa..a0edf757c9 100644 --- a/aeon/classification/deep_learning/_resnet.py +++ b/aeon/classification/deep_learning/_resnet.py @@ -88,10 +88,15 @@ class method save_last_model_to_file. this parameter is discarded. verbose : boolean, default = False whether to output extra information - loss : string, default = "mean_squared_error" - fit parameter for the keras model. - optimizer : keras.optimizer, default = keras.optimizers.Adam() - metrics : list of strings, default = ["accuracy"] + loss : str, default = "categorical_crossentropy" + The name of the keras training loss. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + metrics : str or list[str], default="accuracy" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. Notes ----- @@ -129,7 +134,7 @@ def __init__( callbacks=None, verbose=False, loss="categorical_crossentropy", - metrics=None, + metrics="accuracy", batch_size=64, use_mini_batch_size=False, random_state=None, @@ -213,25 +218,20 @@ def build_model(self, input_shape, n_classes, **kwargs): else self.optimizer ) - if self.metrics is None: - metrics = ["accuracy"] - else: - metrics = self.metrics - rng = check_random_state(self.random_state) self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) tf.keras.utils.set_random_seed(self.random_state_) input_layer, output_layer = self._network.build_network(input_shape, **kwargs) - output_layer = tf.keras.layers.Dense( - units=n_classes, activation="softmax", use_bias=self.use_bias - )(output_layer) + output_layer = tf.keras.layers.Dense(units=n_classes, activation="softmax")( + output_layer + ) model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer) model.compile( loss=self.loss, optimizer=self.optimizer_, - metrics=metrics, + metrics=self._metrics, ) return model @@ -256,6 +256,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape, self.n_classes_) diff --git a/aeon/networks/__init__.py b/aeon/networks/__init__.py index 72f2a2bf1f..209837178b 100644 --- a/aeon/networks/__init__.py +++ b/aeon/networks/__init__.py @@ -19,6 +19,7 @@ "AEBiGRUNetwork", "AEDRNNNetwork", "AEBiGRUNetwork", + "DisjointCNNNetwork", ] from aeon.networks._ae_abgru import AEAttentionBiGRUNetwork from aeon.networks._ae_bgru import AEBiGRUNetwork @@ -28,6 +29,7 @@ from aeon.networks._ae_resnet import AEResNetNetwork from aeon.networks._cnn import TimeCNNNetwork from aeon.networks._dcnn import DCNNNetwork +from aeon.networks._disjoint_cnn import DisjointCNNNetwork from aeon.networks._encoder import EncoderNetwork from aeon.networks._fcn import FCNNetwork from aeon.networks._inception import InceptionNetwork diff --git a/aeon/networks/_disjoint_cnn.py b/aeon/networks/_disjoint_cnn.py new file mode 100644 index 0000000000..126f0ad68c --- /dev/null +++ b/aeon/networks/_disjoint_cnn.py @@ -0,0 +1,305 @@ +"""Disjoint Convolutional Neural Network (DisjointCNNNetwork).""" + +__maintainer__ = ["hadifawaz1999"] + + +from aeon.networks.base import BaseDeepLearningNetwork + + +class DisjointCNNNetwork(BaseDeepLearningNetwork): + """Establish the network structure for a DisjointCNN Network. + + The model is proposed in [1]_ to apply convolutions + specifically for multivariate series, temporal-spatial + phases using 1+1D Convolution layers. + + Parameters + ---------- + n_layers : int, default = 4 + Number of 1+1D Convolution layers. + n_filters : int or list of int, default = 64 + Number of filters used in convolution layers. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + kernel_size : int or list of int, default = [8, 5, 5, 3] + Size of convolution kernel. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + dilation_rate : int or list of int, default = 1 + The dilation rate for convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + strides : int or list of int, default = 1 + The strides of the convolution filter. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + padding : str or list of str, default = "same" + The type of padding used for convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + activation : str or list of str, default = "elu" + Activation used after the convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + use_bias : bool or list of bool, default = True + Whether or not ot use bias in convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + kernel_initializer: str or list of str, default = "he_uniform" + The initialization method of convolution layers. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + pool_size: int, default = 5 + The size of the one max pool layer at the end of + the model, default = 5. + pool_strides: int, default = None + The strides used for the one max pool layer at + the end of the model, default = None. + pool_padding: str, default = "valid" + The padding method for the one max pool layer at + the end of the model, default = "valid". + hidden_fc_units: int, default = 128 + The number of fully connected units. + activation_fc: str, default = "relu" + The activation of the fully connected layer. + + Notes + ----- + The code is adapted from: + https://github.com/Navidfoumani/Disjoint-CNN + + References + ---------- + .. [1] Foumani, Seyed Navid Mohammadi, Chang Wei Tan, and Mahsa Salehi. + "Disjoint-cnn for multivariate time series classification." + 2021 International Conference on Data Mining Workshops + (ICDMW). IEEE, 2021. + """ + + def __init__( + self, + n_layers=4, + n_filters=64, + kernel_size=None, + dilation_rate=1, + strides=1, + padding="same", + activation="elu", + use_bias=True, + kernel_initializer="he_uniform", + pool_size=5, + pool_strides=None, + pool_padding="valid", + hidden_fc_units=128, + activation_fc="relu", + ): + self.n_layers = n_layers + self.n_filters = n_filters + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.strides = strides + self.padding = padding + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.pool_size = pool_size + self.pool_strides = pool_strides + self.pool_padding = pool_padding + self.hidden_fc_units = hidden_fc_units + self.activation_fc = activation_fc + + super().__init__() + + def build_network(self, input_shape, **kwargs): + """Construct a network and return its input and output layers. + + Parameters + ---------- + input_shape : tuple + shape = (n_timepoints (m), n_channels (d)), the shape of the data fed + into the input layer. + + Returns + ------- + input_layer : a keras layer + output_layer : a keras layer + """ + import tensorflow as tf + + self._kernel_size_ = ( + [8, 5, 5, 3] if self.kernel_size is None else self.kernel_size + ) + + if isinstance(self._kernel_size_, list): + if len(self._kernel_size_) != self.n_layers: + raise ValueError( + f"Kernel sizes {len(self._kernel_size_)} should be" + f" the same as number of layers but is" + f" not: {self.n_layers}" + ) + self._kernel_size = self._kernel_size_ + else: + self._kernel_size = [self._kernel_size_] * self.n_layers + + if isinstance(self.n_filters, list): + if len(self.n_filters) != self.n_layers: + raise ValueError( + f"Number of filters {len(self.n_filters)} should be" + f" the same as number of layers but is" + f" not: {self.n_layers}" + ) + self._n_filters = self.n_filters + else: + self._n_filters = [self.n_filters] * self.n_layers + + if isinstance(self.dilation_rate, list): + if len(self.dilation_rate) != self.n_layers: + raise ValueError( + f"Number of dilations {len(self.dilation_rate)} should be" + f" the same as number of layers but is" + f" not: {self.n_layers}" + ) + self._dilation_rate = self.dilation_rate + else: + self._dilation_rate = [self.dilation_rate] * self.n_layers + + if isinstance(self.strides, list): + if len(self.strides) != self.n_layers: + raise ValueError( + f"Number of strides {len(self.strides)} should be" + f" the same as number of layers but is" + f" not: {self.n_layers}" + ) + self._strides = self.strides + else: + self._strides = [self.strides] * self.n_layers + + if isinstance(self.padding, list): + if len(self.padding) != self.n_layers: + raise ValueError( + f"Number of paddings {len(self.padding)} should be" + f" the same as number of layers but is" + f" not: {self.n_layers}" + ) + self._padding = self.padding + else: + self._padding = [self.padding] * self.n_layers + + if isinstance(self.activation, list): + if len(self.activation) != self.n_layers: + raise ValueError( + f"Number of activations {len(self.activation)} should be" + f" the same as number of layers but is" + f" not: {self.n_layers}" + ) + self._activation = self.activation + else: + self._activation = [self.activation] * self.n_layers + + if isinstance(self.use_bias, list): + if len(self.use_bias) != self.n_layers: + raise ValueError( + f"Number of biases {len(self.use_bias)} should be" + f" the same as number of layers but is" + f" not: {self.n_layers}" + ) + self._use_bias = self.use_bias + else: + self._use_bias = [self.use_bias] * self.n_layers + + if isinstance(self.kernel_initializer, list): + if len(self.kernel_initializer) != self.n_layers: + raise ValueError( + f"Number of Kernel initializers {len(self.kernel_initializer)}" + f" should be" + f" the same as number of layers but is" + f" not: {self.n_layers}" + ) + self._kernel_initializer = self.kernel_initializer + else: + self._kernel_initializer = [self.kernel_initializer] * self.n_layers + + input_layer = tf.keras.layers.Input(input_shape) + reshape_layer = tf.keras.layers.Reshape( + target_shape=(input_shape[0], input_shape[1], 1) + )(input_layer) + + x = reshape_layer + + for i in range(self.n_layers): + x = self._one_plus_one_d_convolution_layer( + input_tensor=x, + n_filters=self._n_filters[i], + kernel_size=self._kernel_size[i], + dilation_rate=self._dilation_rate[i], + strides=self._strides[i], + padding=self._padding[i], + use_bias=self._use_bias[i], + activation=self._activation[i], + kernel_initializer=self._kernel_initializer[i], + ) + + max_pool_layer = tf.keras.layers.MaxPooling2D( + pool_size=(self.pool_size, 1), + strides=self.pool_strides, + padding=self.pool_padding, + )(x) + + gap = tf.keras.layers.GlobalAveragePooling2D()(max_pool_layer) + + projection_head = tf.keras.layers.Dense( + self.hidden_fc_units, activation=self.activation_fc + )(gap) + + return input_layer, projection_head + + def _one_plus_one_d_convolution_layer( + self, + input_tensor, + n_filters, + kernel_size, + dilation_rate, + strides, + padding, + use_bias, + activation, + kernel_initializer, + ): + import tensorflow as tf + + temporal_conv = tf.keras.layers.Conv2D( + n_filters, + (kernel_size, 1), + padding=padding, + kernel_initializer=kernel_initializer, + dilation_rate=dilation_rate, + use_bias=use_bias, + strides=strides, + )(input_tensor) + + temporal_conv = tf.keras.layers.BatchNormalization()(temporal_conv) + temporal_conv = tf.keras.layers.Activation(activation=activation)(temporal_conv) + + temporal_conv_output_channels = int(temporal_conv.shape[2]) + + spatial_conv = tf.keras.layers.Conv2D( + n_filters, + (1, temporal_conv_output_channels), + padding="valid", + kernel_initializer=kernel_initializer, + )(temporal_conv) + + spatial_conv = tf.keras.layers.BatchNormalization()(spatial_conv) + spatial_conv = tf.keras.layers.Activation(activation=activation)(spatial_conv) + + spatial_conv = tf.keras.layers.Permute((1, 3, 2))(spatial_conv) + + return spatial_conv diff --git a/aeon/networks/_mlp.py b/aeon/networks/_mlp.py index 84b1570ff7..bb1c090c7c 100644 --- a/aeon/networks/_mlp.py +++ b/aeon/networks/_mlp.py @@ -11,6 +11,11 @@ class MLPNetwork(BaseDeepLearningNetwork): Adapted from the implementation used in [1]_ + Parameters + ---------- + use_bias : bool, default = True + Condition on whether or not to use bias values for dense layers. + Notes ----- Adapted from the implementation from source code @@ -24,7 +29,10 @@ class MLPNetwork(BaseDeepLearningNetwork): def __init__( self, + use_bias=True, ): + self.use_bias = use_bias + super().__init__() def build_network(self, input_shape, **kwargs): @@ -47,13 +55,19 @@ def build_network(self, input_shape, **kwargs): input_layer_flattened = keras.layers.Flatten()(input_layer) layer_1 = keras.layers.Dropout(0.1)(input_layer_flattened) - layer_1 = keras.layers.Dense(500, activation="relu")(layer_1) + layer_1 = keras.layers.Dense(500, activation="relu", use_bias=self.use_bias)( + layer_1 + ) layer_2 = keras.layers.Dropout(0.2)(layer_1) - layer_2 = keras.layers.Dense(500, activation="relu")(layer_2) + layer_2 = keras.layers.Dense(500, activation="relu", use_bias=self.use_bias)( + layer_2 + ) layer_3 = keras.layers.Dropout(0.2)(layer_2) - layer_3 = keras.layers.Dense(500, activation="relu")(layer_3) + layer_3 = keras.layers.Dense(500, activation="relu", use_bias=self.use_bias)( + layer_3 + ) output_layer = keras.layers.Dropout(0.3)(layer_3) diff --git a/aeon/networks/tests/test_all_networks.py b/aeon/networks/tests/test_all_networks.py index 106a5b8b4f..eeb8c5d676 100644 --- a/aeon/networks/tests/test_all_networks.py +++ b/aeon/networks/tests/test_all_networks.py @@ -115,6 +115,9 @@ def test_all_networks_params(network): if network.__name__ == "LITENetwork": continue + if network.__name__ == "MLPNetwork": + continue + # Here we use 'None' string as default to differentiate with None values attr = getattr(my_network, attrname, "None") if attr != "None": diff --git a/aeon/networks/tests/test_disjoint_cnn.py b/aeon/networks/tests/test_disjoint_cnn.py new file mode 100644 index 0000000000..c40eb4becb --- /dev/null +++ b/aeon/networks/tests/test_disjoint_cnn.py @@ -0,0 +1,22 @@ +"""Tests for the DisjointCNN Network.""" + +import pytest + +from aeon.networks import DisjointCNNNetwork +from aeon.utils.validation._dependencies import _check_soft_dependencies + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_disjoint_cnn_netowkr_kernel_initializer(): + """Test DisjointCNN for different kernel_initializer per layer.""" + input_layer, output_layer = DisjointCNNNetwork( + n_layers=2, + kernel_initializer=["he_uniform", "glorot_uniform"], + kernel_size=[2, 2], + ).build_network(input_shape=((10, 2))) + + assert len(output_layer.shape) == 2 + assert len(input_layer.shape) == 3 diff --git a/aeon/regression/deep_learning/__init__.py b/aeon/regression/deep_learning/__init__.py index 0ea28dfecd..f6e7906179 100644 --- a/aeon/regression/deep_learning/__init__.py +++ b/aeon/regression/deep_learning/__init__.py @@ -12,9 +12,11 @@ "LITETimeRegressor", "EncoderRegressor", "MLPRegressor", + "DisjointCNNRegressor", ] from aeon.regression.deep_learning._cnn import TimeCNNRegressor +from aeon.regression.deep_learning._disjoint_cnn import DisjointCNNRegressor from aeon.regression.deep_learning._encoder import EncoderRegressor from aeon.regression.deep_learning._fcn import FCNRegressor from aeon.regression.deep_learning._inception_time import ( diff --git a/aeon/regression/deep_learning/_cnn.py b/aeon/regression/deep_learning/_cnn.py index c636c70087..351e3964d3 100644 --- a/aeon/regression/deep_learning/_cnn.py +++ b/aeon/regression/deep_learning/_cnn.py @@ -21,35 +21,35 @@ class TimeCNNRegressor(BaseDeepRegressor): Parameters ---------- - n_layers : int, default = 2, + n_layers : int, default = 2, the number of convolution layers in the network - kernel_size : int or list of int, default = 7, + kernel_size : int or list of int, default = 7, kernel size of convolution layers, if not a list, the same kernel size is used for all layer, len(list) should be n_layers - n_filters : int or list of int, default = [6, 12], + n_filters : int or list of int, default = [6, 12], number of filters for each convolution layer, if not a list, the same n_filters is used in all layers. - avg_pool_size : int or list of int, default = 3, + avg_pool_size : int or list of int, default = 3, the size of the average pooling layer, if not a list, the same max pooling size is used for all convolution layer - output_activation : str, default = "linear", + output_activation : str, default = "linear", the output activation for the regressor - activation : str or list of str, default = "sigmoid", + activation : str or list of str, default = "sigmoid", keras activation function used in the model for each layer, if not a list, the same activation is used for all layers - padding : str or list of str, default = 'valid', + padding : str or list of str, default = 'valid', the method of padding in convolution layers, if not a list, the same padding used for all convolution layers - strides : int or list of int, default = 1, + strides : int or list of int, default = 1, the strides of kernels in the convolution and max pooling layers, if not a list, the same strides are used for all layers - dilation_rate : int or list of int, default = 1, + dilation_rate : int or list of int, default = 1, the dilation rate of the convolution layers, if not a list, the same dilation rate is used all over the network - use_bias : bool or list of bool, default = True, + use_bias : bool or list of bool, default = True, condition on whether or not to use bias values for convolution layers, if not a list, the same condition is used for all layers random_state : int, RandomState instance or None, default=None @@ -59,40 +59,43 @@ class TimeCNNRegressor(BaseDeepRegressor): by `np.random`. Seeded random number generation can only be guaranteed on CPU processing, GPU processing will be non-deterministic. - n_epochs : int, default = 2000 + n_epochs : int, default = 2000 the number of epochs to train the model - batch_size : int, default = 16 + batch_size : int, default = 16 the number of samples per gradient update. - verbose : boolean, default = False + verbose : boolean, default = False whether to output extra information - loss : string, default="mean_squared_error" - fit parameter for the keras model - optimizer : keras.optimizer, default=keras.optimizers.Adam(), - metrics : str or list of str, default="mean_squared_error" + loss : str, default = "mean_squared_error" + The name of the keras training loss. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + metrics : str or list[str], default="mean_squared_error" The evaluation metrics to use during training. If a single string metric is provided, it will be used as the only metric. If a list of metrics are provided, all will be used for evaluation. - callbacks : keras.callbacks, default=model_checkpoint to save best - model on training loss - file_path : file_path for the best model (if checkpoint is used as callback) - save_best_model : bool, default = False + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint. + file_path : file_path for the best model (if checkpoint is used as callback) + save_best_model : bool, default = False Whether or not to save the best model, if the modelcheckpoint callback is used by default, this condition, if True, will prevent the automatic deletion of the best saved model from file and the user can choose the file name - save_last_model : bool, default = False + save_last_model : bool, default = False Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file save_init_model : bool, default = False Whether to save the initialization of the model. - best_file_name : str, default = "best_model" + best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter is discarded - last_file_name : str, default = "last_model" + last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter is discarded @@ -143,7 +146,7 @@ def __init__( last_file_name="last_model", init_file_name="init_model", verbose=False, - loss="mse", + loss="mean_squared_error", output_activation="linear", metrics="mean_squared_error", random_state=None, @@ -255,10 +258,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) - if isinstance(self.metrics, str): - self._metrics = [self.metrics] - else: + if isinstance(self.metrics, list): self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape) diff --git a/aeon/regression/deep_learning/_disjoint_cnn.py b/aeon/regression/deep_learning/_disjoint_cnn.py new file mode 100644 index 0000000000..cc2b0cb321 --- /dev/null +++ b/aeon/regression/deep_learning/_disjoint_cnn.py @@ -0,0 +1,419 @@ +"""DisjointCNN regressor.""" + +__maintainer__ = ["hadifawaz1999"] +__all__ = ["DisjointCNNRegressor"] + +import gc +import os +import time +from copy import deepcopy + +from sklearn.utils import check_random_state + +from aeon.networks import DisjointCNNNetwork +from aeon.regression.deep_learning.base import BaseDeepRegressor + + +class DisjointCNNRegressor(BaseDeepRegressor): + """Disjoint Convolutional Neural Netowkr regressor. + + Adapted from the implementation used in [1]_. + + Parameters + ---------- + n_layers : int, default = 4 + Number of 1+1D Convolution layers. + n_filters : int or list of int, default = 64 + Number of filters used in convolution layers. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + kernel_size : int or list of int, default = [8, 5, 5, 3] + Size of convolution kernel. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + dilation_rate : int or list of int, default = 1 + The dilation rate for convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + strides : int or list of int, default = 1 + The strides of the convolution filter. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + padding : str or list of str, default = "same" + The type of padding used for convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + activation : str or list of str, default = "elu" + Activation used after the convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + use_bias : bool or list of bool, default = True + Whether or not ot use bias in convolution. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + kernel_initializer: str or list of str, default = "he_uniform" + The initialization method of convolution layers. If + input is set to a list, the lenght should be the same + as `n_layers`, if input is int the a list of the same + element is created of length `n_layers`. + pool_size: int, default = 5 + The size of the one max pool layer at the end of + the model, default = 5. + pool_strides: int, default = None + The strides used for the one max pool layer at + the end of the model, default = None. + pool_padding: str, default = "valid" + The padding method for the one max pool layer at + the end of the model, default = "valid". + hidden_fc_units: int, default = 128 + The number of fully connected units. + activation_fc: str, default = "relu" + The activation of the fully connected layer. + n_epochs : int, default = 2000 + The number of epochs to train the model. + batch_size : int, default = 16 + The number of samples per gradient update. + use_mini_batch_size : bool, default = False + Whether or not to use the mini batch size formula. + random_state : int, RandomState instance or None, default=None + If `int`, random_state is the seed used by the random number generator; + If `RandomState` instance, random_state is the random number generator; + If `None`, the random number generator is the `RandomState` instance used + by `np.random`. + Seeded random number generation can only be guaranteed on CPU processing, + GPU processing will be non-deterministic. + verbose : boolean, default = False + Whether to output extra information. + output_activation : str, default = "linear", + the output activation of the regressor. + loss : str, default = "mean_squared_error" + The name of the keras training loss. + metrics : str or list[str], default="mean_squared_error" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + file_path : str, default = "./" + File path to save best model. + save_best_model : bool, default = False + Whether or not to save the best model, if the + modelcheckpoint callback is used by default, + this condition, if True, will prevent the + automatic deletion of the best saved model from + file and the user can choose the file name. + save_last_model : bool, default = False + Whether or not to save the last model, last + epoch trained, using the base class method + save_last_model_to_file. + save_init_model : bool, default = False + Whether to save the initialization of the model. + best_file_name : str, default = "best_model" + The name of the file of the best model, if + save_best_model is set to False, this parameter + is discarded. + last_file_name : str, default = "last_model" + The name of the file of the last model, if + save_last_model is set to False, this parameter + is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if + save_init_model is set to False, + this parameter is discarded. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. + + Notes + ----- + Adapted from the implementation from: + https://github.com/Navidfoumani/Disjoint-CNN + + References + ---------- + .. [1] Foumani, Seyed Navid Mohammadi, Chang Wei Tan, and Mahsa Salehi. + "Disjoint-cnn for multivariate time series classification." + 2021 International Conference on Data Mining Workshops + (ICDMW). IEEE, 2021. + + Examples + -------- + >>> from aeon.regression.deep_learning import DisjointCNNRegressor + >>> from aeon.datasets import load_unit_test + >>> X_train, y_train = load_unit_test(split="train") + >>> X_test, y_test = load_unit_test(split="test") + >>> disjoint_cnn = DisjointCNNRegressor(n_epochs=20, + ... batch_size=4) # doctest: +SKIP + >>> disjoint_cnn.fit(X_train, y_train) # doctest: +SKIP + DisjointCNNRegressor(...) + """ + + def __init__( + self, + n_layers=4, + n_filters=64, + kernel_size=None, + dilation_rate=1, + strides=1, + padding="same", + activation="elu", + use_bias=True, + kernel_initializer="he_uniform", + pool_size=5, + pool_strides=None, + pool_padding="valid", + hidden_fc_units=128, + activation_fc="relu", + n_epochs=2000, + batch_size=16, + use_mini_batch_size=False, + random_state=None, + verbose=False, + output_activation="linear", + loss="mean_squared_error", + metrics="mean_squared_error", + optimizer=None, + file_path="./", + save_best_model=False, + save_last_model=False, + save_init_model=False, + best_file_name="best_model", + last_file_name="last_model", + init_file_name="init_model", + callbacks=None, + ): + self.n_layers = n_layers + self.n_filters = n_filters + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.strides = strides + self.padding = padding + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.pool_size = pool_size + self.pool_strides = pool_strides + self.pool_padding = pool_padding + self.hidden_fc_units = hidden_fc_units + self.activation_fc = activation_fc + + self.random_state = random_state + self.callbacks = callbacks + self.n_epochs = n_epochs + self.use_mini_batch_size = use_mini_batch_size + self.verbose = verbose + self.output_activation = output_activation + self.loss = loss + self.metrics = metrics + self.optimizer = optimizer + + self.file_path = file_path + self.save_best_model = save_best_model + self.save_last_model = save_last_model + self.save_init_model = save_init_model + self.best_file_name = best_file_name + self.init_file_name = init_file_name + + self.history = None + + super().__init__( + batch_size=batch_size, + last_file_name=last_file_name, + ) + + self._network = DisjointCNNNetwork( + n_layers=self.n_layers, + n_filters=self.n_filters, + kernel_size=self.kernel_size, + dilation_rate=self.dilation_rate, + strides=self.strides, + padding=self.padding, + activation=self.activation, + use_bias=self.use_bias, + kernel_initializer=self.kernel_initializer, + pool_size=self.pool_size, + pool_strides=self.pool_strides, + pool_padding=self.pool_padding, + hidden_fc_units=self.hidden_fc_units, + activation_fc=self.activation_fc, + ) + + def build_model(self, input_shape, **kwargs): + """Construct a compiled, un-trained, keras model that is ready for training. + + In aeon, time series are stored in numpy arrays of shape (d,m), where d + is the number of dimensions, m is the series length. Keras/tensorflow assume + data is in shape (m,d). This method also assumes (m,d). Transpose should + happen in fit. + + Parameters + ---------- + input_shape : tuple + The shape of the data fed into the input layer, should be (m, d). + n_classes : int + The number of classes, which becomes the size of the output layer. + + Returns + ------- + output : a compiled Keras Model + """ + import numpy as np + import tensorflow as tf + + rng = check_random_state(self.random_state) + self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) + tf.keras.utils.set_random_seed(self.random_state_) + input_layer, output_layer = self._network.build_network(input_shape, **kwargs) + + output_layer = tf.keras.layers.Dense( + units=1, activation=self.output_activation + )(output_layer) + + self.optimizer_ = ( + tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer + ) + + model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer) + model.compile( + loss=self.loss, + optimizer=self.optimizer_, + metrics=self._metrics, + ) + + return model + + def _fit(self, X, y): + """Fit the regressor on the training set (X, y). + + Parameters + ---------- + X : np.ndarray + The training input samples of shape (n_cases, n_channels, n_timepoints) + y : np.ndarray + The training data class labels of shape (n_cases,). + + Returns + ------- + self : object + """ + import tensorflow as tf + + # Transpose to conform to Keras input style. + X = X.transpose(0, 2, 1) + + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + + self.input_shape = X.shape[1:] + self.training_model_ = self.build_model(self.input_shape) + + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + + if self.verbose: + self.training_model_.summary() + + if self.use_mini_batch_size: + mini_batch_size = min(self.batch_size, X.shape[0] // 10) + else: + mini_batch_size = self.batch_size + + self.file_name_ = ( + self.best_file_name if self.save_best_model else str(time.time_ns()) + ) + + if self.callbacks is None: + self.callbacks_ = [ + tf.keras.callbacks.ReduceLROnPlateau( + monitor="loss", factor=0.5, patience=50, min_lr=0.0001 + ), + tf.keras.callbacks.ModelCheckpoint( + filepath=self.file_path + self.file_name_ + ".keras", + monitor="loss", + save_best_only=True, + ), + ] + else: + self.callbacks_ = self._get_model_checkpoint_callback( + callbacks=self.callbacks, + file_path=self.file_path, + file_name=self.file_name_, + ) + + self.history = self.training_model_.fit( + X, + y, + batch_size=mini_batch_size, + epochs=self.n_epochs, + verbose=self.verbose, + callbacks=self.callbacks_, + ) + + try: + self.model_ = tf.keras.models.load_model( + self.file_path + self.file_name_ + ".keras", compile=False + ) + if not self.save_best_model: + os.remove(self.file_path + self.file_name_ + ".keras") + except FileNotFoundError: + self.model_ = deepcopy(self.training_model_) + + if self.save_last_model: + self.save_last_model_to_file(file_path=self.file_path) + + gc.collect() + return self + + @classmethod + def _get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default = "default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return "default" set. + For classifiers, a "default" set of parameters should be provided for + general testing, and a "results_comparison" set for comparing against + previously recorded results if the general set does not produce suitable + probabilities to compare against. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class. + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + """ + param1 = { + "n_epochs": 3, + "batch_size": 4, + "use_bias": False, + "n_layers": 2, + "n_filters": 2, + "kernel_size": [2, 2], + } + param2 = { + "n_epochs": 3, + "batch_size": 4, + "use_bias": False, + "n_layers": 2, + "n_filters": 2, + "kernel_size": [2, 2], + "verbose": True, + "metrics": ["mse"], + "use_mini_batch_size": True, + } + + return [param1, param2] diff --git a/aeon/regression/deep_learning/_encoder.py b/aeon/regression/deep_learning/_encoder.py index 2183d5ee8a..fd3bf855cb 100644 --- a/aeon/regression/deep_learning/_encoder.py +++ b/aeon/regression/deep_learning/_encoder.py @@ -78,19 +78,23 @@ class EncoderRegressor(BaseDeepRegressor): by `np.random`. Seeded random number generation can only be guaranteed on CPU processing, GPU processing will be non-deterministic. - loss: - The loss function to use for training. - metrics: str or list of str, default="mean_squared_error" + loss : str, default = "mean_squared_error" + The name of the keras training loss. + metrics : str or list[str], default="mean_squared_error" The evaluation metrics to use during training. If a single string metric is provided, it will be used as the only metric. If a list of metrics are provided, all will be used for evaluation. use_bias: Whether to use bias in the dense layers. - optimizer: - The optimizer to use for training. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. verbose: Whether to print progress messages during training. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint. Notes ----- @@ -105,10 +109,6 @@ class EncoderRegressor(BaseDeepRegressor): """ - _tags = { - "python_dependencies": ["tensorflow"], - } - def __init__( self, n_epochs=100, @@ -204,7 +204,7 @@ def build_model(self, input_shape, **kwargs): input_layer, output_layer = self._network.build_network(input_shape, **kwargs) output_layer = tf.keras.layers.Dense( - units=1, activation=self.output_activation, use_bias=self.use_bias + units=1, activation=self.output_activation )(output_layer) self.optimizer_ = ( @@ -241,10 +241,11 @@ def _fit(self, X, y): # Transpose X to conform to Keras input style X = X.transpose(0, 2, 1) - if isinstance(self.metrics, str): - self._metrics = [self.metrics] - else: + if isinstance(self.metrics, list): self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape) diff --git a/aeon/regression/deep_learning/_fcn.py b/aeon/regression/deep_learning/_fcn.py index 361bc2eef0..a6905580ac 100644 --- a/aeon/regression/deep_learning/_fcn.py +++ b/aeon/regression/deep_learning/_fcn.py @@ -21,25 +21,25 @@ class FCNRegressor(BaseDeepRegressor): Parameters ---------- - n_layers : int, default = 3 + n_layers : int, default = 3 number of convolution layers - n_filters : int or list of int, default = [128,256,128] + n_filters : int or list of int, default = [128,256,128] number of filters used in convolution layers - kernel_size : int or list of int, default = [8,5,3] + kernel_size : int or list of int, default = [8,5,3] size of convolution kernel - dilation_rate : int or list of int, default = 1 + dilation_rate : int or list of int, default = 1 the dilation rate for convolution - strides : int or list of int, default = 1 + strides : int or list of int, default = 1 the strides of the convolution filter - padding : str or list of str, default = "same" + padding : str or list of str, default = "same" the type of padding used for convolution - activation : str or list of str, default = "relu" + activation : str or list of str, default = "relu" activation used after the convolution - use_bias : bool or list of bool, default = True + use_bias : bool or list of bool, default = True whether or not ot use bias in convolution - n_epochs : int, default = 2000 + n_epochs : int, default = 2000 the number of epochs to train the model - batch_size : int, default = 16 + batch_size : int, default = 16 the number of samples per gradient update. use_mini_batch_size : bool, default = False, whether or not to use the mini batch size formula @@ -50,45 +50,48 @@ class FCNRegressor(BaseDeepRegressor): by `np.random`. Seeded random number generation can only be guaranteed on CPU processing, GPU processing will be non-deterministic. - verbose : boolean, default = False + verbose : boolean, default = False whether to output extra information output_activation : str, default = "linear", the output activation of the regressor - loss : string, default="mean_squared_error" - fit parameter for the keras model - metrics : list of strings, default="mean_squared_error", + loss : str, default = "mean_squared_error" + The name of the keras training loss. + metrics : str or list[str], default="mean_squared_error" The evaluation metrics to use during training. If a single string metric is provided, it will be used as the only metric. If a list of metrics are provided, all will be used for evaluation. - optimizer : keras.optimizers object, default = Adam(lr=0.01) - specify the optimizer and the learning rate to be used. - file_path : str, default = "./" + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + file_path : str, default = "./" file path to save best model - save_best_model : bool, default = False + save_best_model : bool, default = False Whether or not to save the best model, if the modelcheckpoint callback is used by default, this condition, if True, will prevent the automatic deletion of the best saved model from file and the user can choose the file name - save_last_model : bool, default = False + save_last_model : bool, default = False Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file save_init_model : bool, default = False Whether to save the initialization of the model. - best_file_name : str, default = "best_model" + best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter is discarded - last_file_name : str, default = "last_model" + last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter is discarded init_file_name : str, default = "init_model" The name of the file of the init model, if save_init_model is set to False, this parameter is discarded. - callbacks : keras.callbacks, default = None + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. Notes ----- @@ -134,7 +137,7 @@ def __init__( callbacks=None, verbose=False, output_activation="linear", - loss="mse", + loss="mean_squared_error", metrics="mean_squared_error", random_state=None, use_bias=True, @@ -240,10 +243,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) - if isinstance(self.metrics, str): - self._metrics = [self.metrics] - else: + + if isinstance(self.metrics, list): self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape) diff --git a/aeon/regression/deep_learning/_inception_time.py b/aeon/regression/deep_learning/_inception_time.py index 5e6c6a56e9..96e8a38362 100644 --- a/aeon/regression/deep_learning/_inception_time.py +++ b/aeon/regression/deep_learning/_inception_time.py @@ -24,96 +24,97 @@ class InceptionTimeRegressor(BaseRegressor): Parameters ---------- - n_regressors : int, default = 5, + n_regressors : int, default = 5, the number of Inception models used for the Ensemble in order to create InceptionTime. - depth : int, default = 6, + depth : int, default = 6, the number of inception modules used - n_filters : int or list of int32, default = 32, + n_filters : int or list of int32, default = 32, the number of filters used in one inception module, if not a list, the same number of filters is used in all inception modules - n_conv_per_layer : int or list of int, default = 3, + n_conv_per_layer : int or list of int, default = 3, the number of convolution layers in each inception module, if not a list, the same number of convolution layers is used in all inception modules - kernel_size : int or list of int, default = 40, + kernel_size : int or list of int, default = 40, the head kernel size used for each inception module, if not a list, the same is used in all inception modules - use_max_pooling : bool or list of bool, default = True, + use_max_pooling : bool or list of bool, default = True, conditioning whether or not to use max pooling layer in inception modules,if not a list, the same is used in all inception modules - max_pool_size : int or list of int, default = 3, + max_pool_size : int or list of int, default = 3, the size of the max pooling layer, if not a list, the same is used in all inception modules - strides : int or list of int, default = 1, + strides : int or list of int, default = 1, the strides of kernels in convolution layers for each inception module, if not a list, the same is used in all inception modules - dilation_rate : int or list of int, default = 1, + dilation_rate : int or list of int, default = 1, the dilation rate of convolutions in each inception module, if not a list, the same is used in all inception modules - padding : str or list of str, default = "same", + padding : str or list of str, default = "same", the type of padding used for convoltuon for each inception module, if not a list, the same is used in all inception modules - activation : str or list of str, default = "relu", + activation : str or list of str, default = "relu", the activation function used in each inception module, if not a list, the same is used in all inception modules - use_bias : bool or list of bool, default = False, + use_bias : bool or list of bool, default = False, condition whether or not convolutions should use bias values in each inception module, if not a list, the same is used in all inception modules - use_residual : bool, default = True, + use_residual : bool, default = True, condition whether or not to use residual connections all over Inception - use_bottleneck : bool, default = True, + use_bottleneck : bool, default = True, condition whether or not to use bottlenecks all over Inception - bottleneck_size : int, default = 32, + bottleneck_size : int, default = 32, the bottleneck size in case use_bottleneck = True - use_custom_filters : bool, default = False, + use_custom_filters : bool, default = False, condition on whether or not to use custom filters in the first inception module - output_activation : str, default = "linear", + output_activation : str, default = "linear", the output activation for the regressor - batch_size : int, default = 64 + batch_size : int, default = 64 the number of samples per gradient update. use_mini_batch_size : bool, default = False condition on using the mini batch size formula Wang et al. - n_epochs : int, default = 1500 + n_epochs : int, default = 1500 the number of epochs to train the model. - callbacks : callable or None, default - ReduceOnPlateau and ModelCheckpoint - list of tf.keras.callbacks.Callback objects. - file_path : str, default = './' + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. + file_path : str, default = './' file_path when saving model_Checkpoint callback - save_best_model : bool, default = False + save_best_model : bool, default = False Whether or not to save the best model, if the modelcheckpoint callback is used by default, this condition, if True, will prevent the automatic deletion of the best saved model from file and the user can choose the file name - save_last_model : bool, default = False + save_last_model : bool, default = False Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file save_init_model : bool, default = False Whether to save the initialization of the model. - best_file_name : str, default = "best_model" + best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter is discarded - last_file_name : str, default = "last_model" + last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter is discarded @@ -127,12 +128,17 @@ class InceptionTimeRegressor(BaseRegressor): by `np.random`. Seeded random number generation can only be guaranteed on CPU processing, GPU processing will be non-deterministic. - verbose : boolean, default = False + verbose : boolean, default = False whether to output extra information - optimizer : keras optimizer, default = Adam - loss : keras loss, - default = mean_squared_error - will be set to accuracy as default if None + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + loss : str, default = "mean_squared_error" + The name of the keras training loss. + metrics : str or list[str], default="mean_squared_error" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. Notes ----- @@ -203,7 +209,8 @@ def __init__( callbacks=None, random_state=None, verbose=False, - loss="mse", + loss="mean_squared_error", + metrics="mean_squared_error", optimizer=None, ): self.n_regressors = n_regressors @@ -241,6 +248,7 @@ def __init__( self.verbose = verbose self.use_mini_batch_size = use_mini_batch_size self.loss = loss + self.metrics = metrics self.optimizer = optimizer self.regressors_ = [] @@ -294,6 +302,7 @@ def _fit(self, X, y): n_epochs=self.n_epochs, callbacks=self.callbacks, loss=self.loss, + metrics=self.metrics, optimizer=self.optimizer, random_state=rng.randint(0, np.iinfo(np.int32).max), verbose=self.verbose, @@ -366,82 +375,83 @@ class IndividualInceptionRegressor(BaseDeepRegressor): Parameters ---------- - depth : int, default = 6, + depth : int, default = 6, the number of inception modules used - n_filters : int or list of int32, default = 32, + n_filters : int or list of int32, default = 32, the number of filters used in one inception module, if not a list, the same number of filters is used in all inception modules - n_conv_per_layer : int or list of int, default = 3, + n_conv_per_layer : int or list of int, default = 3, the number of convolution layers in each inception module, if not a list, the same number of convolution layers is used in all inception modules - kernel_size : int or list of int, default = 40, + kernel_size : int or list of int, default = 40, the head kernel size used for each inception module, if not a list, the same is used in all inception modules - use_max_pooling : bool or list of bool, default = True, + use_max_pooling : bool or list of bool, default = True, condition whether or not to use max pooling layer in inception modules,if not a list, the same is used in all inception modules - max_pool_size : int or list of int, default = 3, + max_pool_size : int or list of int, default = 3, the size of the max pooling layer, if not a list, the same is used in all inception modules - strides : int or list of int, default = 1, + strides : int or list of int, default = 1, the strides of kernels in convolution layers for each inception module, if not a list, the same is used in all inception modules - dilation_rate : int or list of int, default = 1, + dilation_rate : int or list of int, default = 1, the dilation rate of convolutions in each inception module, if not a list, the same is used in all inception modules - padding : str or list of str, default = "same", + padding : str or list of str, default = "same", the type of padding used for convoltuon for each inception module, if not a list, the same is used in all inception modules - activation : str or list of str, default = "relu", + activation : str or list of str, default = "relu", the activation function used in each inception module, if not a list, the same is used in all inception modules - use_bias : bool or list of bool, default = False, + use_bias : bool or list of bool, default = False, condition whether or not convolutions should use bias values in each inception module, if not a list, the same is used in all inception modules - use_residual : bool, default = True, + use_residual : bool, default = True, condition whether or not to use residual connections all over Inception - use_bottleneck : bool, default = True, + use_bottleneck : bool, default = True, condition whether or not to use bottlesnecks all over Inception - bottleneck_size : int, default = 32, + bottleneck_size : int, default = 32, the bottleneck size in case use_bottleneck = True - use_custom_filters : bool, default = False, + use_custom_filters : bool, default = False, condition on whether or not to use custom filters in the first inception module - output_activation : str, default = "linear", + output_activation : str, default = "linear", the output activation of the regressor - batch_size : int, default = 64 + batch_size : int, default = 64 the number of samples per gradient update. use_mini_batch_size : bool, default = False condition on using the mini batch size formula Wang et al. - n_epochs : int, default = 1500 + n_epochs : int, default = 1500 the number of epochs to train the model. - callbacks : callable or None, default - ReduceOnPlateau and ModelCheckpoint - list of tf.keras.callbacks.Callback objects. - file_path : str, default = './' + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. + file_path : str, default = './' file_path when saving model_Checkpoint callback - save_best_model : bool, default = False + save_best_model : bool, default = False Whether or not to save the best model, if the modelcheckpoint callback is used by default, this condition, if True, will prevent the automatic deletion of the best saved model from file and the user can choose the file name - save_last_model : bool, default = False + save_last_model : bool, default = False Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file save_init_model : bool, default = False Whether to save the initialization of the model. - best_file_name : str, default = "best_model" + best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter is discarded - last_file_name : str, default = "last_model" + last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter is discarded @@ -455,11 +465,17 @@ class IndividualInceptionRegressor(BaseDeepRegressor): by `np.random`. Seeded random number generation can only be guaranteed on CPU processing, GPU processing will be non-deterministic. - verbose : boolean, default = False + verbose : boolean, default = False whether to output extra information - optimizer : keras optimizer, default = Adam - loss : keras loss, default = mean_squared_error - to accuracy as default if None + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + loss : str, default = "mean_squared_error" + The name of the keras training loss. + metrics : str or list[str], default="mean_squared_error" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. Notes ----- @@ -520,7 +536,8 @@ def __init__( callbacks=None, random_state=None, verbose=False, - loss="mse", + loss="mean_squared_error", + metrics="mean_squared_error", optimizer=None, ): # predefined @@ -555,6 +572,7 @@ def __init__( self.verbose = verbose self.use_mini_batch_size = use_mini_batch_size self.loss = loss + self.metrics = metrics self.optimizer = optimizer super().__init__(batch_size=batch_size, last_file_name=last_file_name) @@ -609,10 +627,7 @@ def build_model(self, input_shape, **kwargs): tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer ) - model.compile( - loss=self.loss, - optimizer=self.optimizer_, - ) + model.compile(loss=self.loss, optimizer=self.optimizer_, metrics=self._metrics) return model @@ -638,6 +653,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + # ignore the number of instances, X.shape[0], # just want the shape of each instance self.input_shape_ = X.shape[1:] diff --git a/aeon/regression/deep_learning/_lite_time.py b/aeon/regression/deep_learning/_lite_time.py index 88d88ffcca..4be6bca86a 100644 --- a/aeon/regression/deep_learning/_lite_time.py +++ b/aeon/regression/deep_learning/_lite_time.py @@ -55,8 +55,10 @@ class LITETimeRegressor(BaseRegressor): formula Wang et al. n_epochs : int, default = 1500 the number of epochs to train the model. - callbacks : callable or None, default = ReduceOnPlateau and ModelCheckpoint - list of tf.keras.callbacks.Callback objects. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. file_path : str, default = "./" file_path when saving model_Checkpoint callback save_best_model : bool, default = False @@ -91,10 +93,15 @@ class LITETimeRegressor(BaseRegressor): GPU processing will be non-deterministic. verbose : boolean, default = False whether to output extra information - optimizer : keras optimizer, default = Adam - loss : keras loss, default = "mean_squared_error" - metrics : keras metrics, default = mean_squared_error, - will be set to mean_squared_error as default if None + loss : str, default = "mean_squared_error" + The name of the keras training loss. + metrics : str or list[str], default="mean_squared_error" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. Notes ----- @@ -152,7 +159,7 @@ def __init__( random_state=None, verbose=False, loss="mean_squared_error", - metrics=None, + metrics="mean_squared_error", optimizer=None, ): self.n_regressors = n_regressors @@ -333,8 +340,10 @@ class IndividualLITERegressor(BaseDeepRegressor): formula Wang et al. n_epochs : int, default = 1500 the number of epochs to train the model. - callbacks : callable or None, default = ReduceOnPlateau and ModelCheckpoint - list of tf.keras.callbacks.Callback objects. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. file_path : str, default = "./" file_path when saving model_Checkpoint callback save_best_model : bool, default = False @@ -369,10 +378,15 @@ class IndividualLITERegressor(BaseDeepRegressor): GPU processing will be non-deterministic. verbose : boolean, default = False whether to output extra information - optimizer : keras optimizer, default = Adam - loss : keras loss, default = 'mean_squared_error' - metrics : keras metrics, default = mean_squared_error, - will be set to mean_squared_error as default if None + loss : str, default = "mean_squared_error" + The name of the keras training loss. + metrics : str or list[str], default="mean_squared_error" + The evaluation metrics to use during training. If + a single string metric is provided, it will be + used as the only metric. If a list of metrics are + provided, all will be used for evaluation. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. Notes ----- @@ -421,7 +435,7 @@ def __init__( random_state=None, verbose=False, loss="mean_squared_error", - metrics=None, + metrics="mean_squared_error", optimizer=None, ): self.use_litemv = use_litemv @@ -489,11 +503,6 @@ def build_model(self, input_shape, **kwargs): model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer) - if self.metrics is None: - metrics = ["mean_squared_error"] - else: - metrics = self.metrics - self.optimizer_ = ( tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer ) @@ -501,7 +510,7 @@ def build_model(self, input_shape, **kwargs): model.compile( loss=self.loss, optimizer=self.optimizer_, - metrics=metrics, + metrics=self._metrics, ) return model @@ -527,6 +536,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) + if isinstance(self.metrics, list): + self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + # ignore the number of instances, X.shape[0], # just want the shape of each instance self.input_shape = X.shape[1:] diff --git a/aeon/regression/deep_learning/_mlp.py b/aeon/regression/deep_learning/_mlp.py index d593eb7b80..9a616238c6 100644 --- a/aeon/regression/deep_learning/_mlp.py +++ b/aeon/regression/deep_learning/_mlp.py @@ -21,16 +21,21 @@ class MLPRegressor(BaseDeepRegressor): Parameters ---------- + use_bias : bool, default = True + Condition on whether or not to use bias values for dense layers. n_epochs : int, default = 2000 the number of epochs to train the model batch_size : int, default = 16 the number of samples per gradient update. - callbacks : callable or None, default + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. verbose : boolean, default = False whether to output extra information - loss : string, default="mean_squared_error" - fit parameter for the keras model - metrics : list of strings, default="mean_squared_error" + loss : str, default = "mean_squared_error" + The name of the keras training loss. + metrics : str or list[str], default="mean_squared_error" The evaluation metrics to use during training. If a single string metric is provided, it will be used as the only metric. If a list of metrics are @@ -72,10 +77,9 @@ class MLPRegressor(BaseDeepRegressor): List of available activation functions: https://keras.io/api/layers/activations/ output_activation : str = "linear" - Activation for the last layer in a Regressor - use_bias : boolean, default = True - whether the layer uses a bias vector. - optimizer : keras.optimizer, default=keras.optimizers.Adadelta() + Activation for the last layer in a Regressor. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. References @@ -96,11 +100,12 @@ class MLPRegressor(BaseDeepRegressor): def __init__( self, + use_bias=True, n_epochs=2000, batch_size=16, callbacks=None, verbose=False, - loss="mse", + loss="mean_squared_error", metrics="mean_squared_error", file_path="./", save_best_model=False, @@ -112,7 +117,6 @@ def __init__( random_state=None, activation="relu", output_activation="linear", - use_bias=True, optimizer=None, ): self.callbacks = callbacks @@ -139,7 +143,7 @@ def __init__( last_file_name=last_file_name, ) - self._network = MLPNetwork() + self._network = MLPNetwork(use_bias=self.use_bias) def build_model(self, input_shape, **kwargs): """Construct a compiled, un-trained, keras model that is ready for training. @@ -168,9 +172,9 @@ def build_model(self, input_shape, **kwargs): input_layer, output_layer = self._network.build_network(input_shape, **kwargs) - output_layer = keras.layers.Dense( - units=1, activation=self.output_activation, use_bias=self.use_bias - )(output_layer) + output_layer = keras.layers.Dense(units=1, activation=self.output_activation)( + output_layer + ) self.optimizer_ = ( keras.optimizers.Adadelta() if self.optimizer is None else self.optimizer @@ -203,10 +207,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) - if isinstance(self.metrics, str): - self._metrics = [self.metrics] - else: + if isinstance(self.metrics, list): self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape) diff --git a/aeon/regression/deep_learning/_resnet.py b/aeon/regression/deep_learning/_resnet.py index 7592d4e683..7f89a18ade 100644 --- a/aeon/regression/deep_learning/_resnet.py +++ b/aeon/regression/deep_learning/_resnet.py @@ -24,31 +24,31 @@ class ResNetRegressor(BaseDeepRegressor): ---------- n_residual_blocks : int, default = 3 the number of residual blocks of ResNet's model - n_conv_per_residual_block : int, default = 3, + n_conv_per_residual_block : int, default = 3, the number of convolution blocks in each residual block - n_filters : int or list of int, default = [128, 64, 64], + n_filters : int or list of int, default = [128, 64, 64], the number of convolution filters for all the convolution layers in the same residual block, if not a list, the same number of filters is used in all convolutions of all residual blocks. - kernel_sizes : int or list of int, default = [8, 5, 3], + kernel_sizes : int or list of int, default = [8, 5, 3], the kernel size of all the convolution layers in one residual block, if not a list, the same kernel size is used in all convolution layers - strides : int or list of int, default = 1, + strides : int or list of int, default = 1, the strides of convolution kernels in each of the convolution layers in one residual block, if not a list, the same kernel size is used in all convolution layers - dilation_rate : int or list of int, default = 1, + dilation_rate : int or list of int, default = 1, the dilation rate of the convolution layers in one residual block, if not a list, the same kernel size is used in all convolution layers - padding : str or list of str, default = 'padding', + padding : str or list of str, default = 'padding', the type of padding used in the convolution layers in one residual block, if not a list, the same kernel size is used in all convolution layers - activation : str or list of str, default = 'relu', + activation : str or list of str, default = 'relu', keras activation used in the convolution layers in one residual block, if not a list, the same kernel size is used in all convolution layers - output_activation : str, default = "linear", + output_activation : str, default = "linear", the output activation for the regressor use_bias : bool or list of bool, default = True, condition on whether or not to use bias values in @@ -60,9 +60,10 @@ class ResNetRegressor(BaseDeepRegressor): the number of samples per gradient update. use_mini_batch_size : bool, default = False condition on using the mini batch size formula Wang et al. - callbacks : callable or None, default - ReduceOnPlateau and ModelCheckpoint - list of tf.keras.callbacks.Callback objects. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. random_state : int, RandomState instance or None, default=None If `int`, random_state is the seed used by the random number generator; If `RandomState` instance, random_state is the random number generator; @@ -70,37 +71,38 @@ class ResNetRegressor(BaseDeepRegressor): by `np.random`. Seeded random number generation can only be guaranteed on CPU processing, GPU processing will be non-deterministic. - file_path : str, default = './' + file_path : str, default = './' file_path when saving model_Checkpoint callback - save_best_model : bool, default = False + save_best_model : bool, default = False Whether or not to save the best model, if the modelcheckpoint callback is used by default, this condition, if True, will prevent the automatic deletion of the best saved model from file and the user can choose the file name - save_last_model : bool, default = False + save_last_model : bool, default = False Whether or not to save the last model, last epoch trained, using the base class method save_last_model_to_file save_init_model : bool, default = False Whether to save the initialization of the model. - best_file_name : str, default = "best_model" + best_file_name : str, default = "best_model" The name of the file of the best model, if save_best_model is set to False, this parameter is discarded - last_file_name : str, default = "last_model" + last_file_name : str, default = "last_model" The name of the file of the last model, if save_last_model is set to False, this parameter is discarded init_file_name : str, default = "init_model" The name of the file of the init model, if save_init_model is set to False, this parameter is discarded. - verbose : boolean, default = False + verbose : boolean, default = False whether to output extra information - loss : string, default="mean_squared_error" - fit parameter for the keras model - optimizer : keras.optimizer, default=keras.optimizers.Adam(), - metrics : list of strings, default="mean_squared_error", + loss : str, default = "mean_squared_error" + The name of the keras training loss. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + metrics : str or list[str], default="mean_squared_error" The evaluation metrics to use during training. If a single string metric is provided, it will be used as the only metric. If a list of metrics are @@ -143,7 +145,7 @@ def __init__( n_epochs=1500, callbacks=None, verbose=False, - loss="mse", + loss="mean_squared_error", output_activation="linear", metrics="mean_squared_error", batch_size=64, @@ -263,10 +265,11 @@ def _fit(self, X, y): # Transpose to conform to Keras input style. X = X.transpose(0, 2, 1) - if isinstance(self.metrics, str): - self._metrics = [self.metrics] - else: + if isinstance(self.metrics, list): self._metrics = self.metrics + elif isinstance(self.metrics, str): + self._metrics = [self.metrics] + self.input_shape = X.shape[1:] self.training_model_ = self.build_model(self.input_shape) diff --git a/docs/api_reference/classification.rst b/docs/api_reference/classification.rst index 9f5482fec0..55865112d6 100644 --- a/docs/api_reference/classification.rst +++ b/docs/api_reference/classification.rst @@ -44,6 +44,7 @@ Deep learning LITETimeClassifier MLPClassifier ResNetClassifier + DisjointCNNClassifier TapNetClassifier Dictionary-based diff --git a/docs/api_reference/networks.rst b/docs/api_reference/networks.rst index 66741bcb98..63ff3b6a13 100644 --- a/docs/api_reference/networks.rst +++ b/docs/api_reference/networks.rst @@ -25,3 +25,4 @@ Deep learning networks AEResNetNetwork LITENetwork AEBiGRUNetwork + DisjointCNNNetwork diff --git a/docs/api_reference/regression.rst b/docs/api_reference/regression.rst index 0f4b682347..4c211d9eee 100644 --- a/docs/api_reference/regression.rst +++ b/docs/api_reference/regression.rst @@ -64,6 +64,7 @@ Deep learning ResNetRegressor TapNetRegressor MLPRegressor + DisjointCNNRegressor Distance-based --------------