From 2e98d33d1f01e28bba139a1b5bb8a85a8a7117c4 Mon Sep 17 00:00:00 2001 From: Ferdinand Rewicki <3592978+ferewi@users.noreply.github.com> Date: Thu, 14 Nov 2024 22:38:48 +0100 Subject: [PATCH] [BUG] Fixed subsampling in highly imbalances datasets giving subsamples with only a single class (#2305) * added argument 'max_subsamples' to subsample multiple times in case of unbalanced datasets giving subsamples with only one class. * removed parameter max_subsamples and resample until valid subsample is found --------- Co-authored-by: Ferdinand Rewicki --- aeon/classification/dictionary_based/_tde.py | 11 ++++++++--- .../dictionary_based/tests/test_tde.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/aeon/classification/dictionary_based/_tde.py b/aeon/classification/dictionary_based/_tde.py index f8af016b54..ba3eae07a5 100644 --- a/aeon/classification/dictionary_based/_tde.py +++ b/aeon/classification/dictionary_based/_tde.py @@ -320,9 +320,14 @@ def _fit(self, X, y, keep_train_preds=False): rng.choice(np.flatnonzero(preds == preds.max())) ) - subsample = rng.choice(self.n_cases_, size=subsample_size, replace=False) - X_subsample = X[subsample] - y_subsample = y[subsample] + while True: + subsample = rng.choice( + self.n_cases_, size=subsample_size, replace=False + ) + X_subsample = X[subsample] + y_subsample = y[subsample] + if len(np.unique(y_subsample)) > 1: + break tde = IndividualTDE( *parameters, diff --git a/aeon/classification/dictionary_based/tests/test_tde.py b/aeon/classification/dictionary_based/tests/test_tde.py index e49f2d9015..a069e1ba45 100644 --- a/aeon/classification/dictionary_based/tests/test_tde.py +++ b/aeon/classification/dictionary_based/tests/test_tde.py @@ -103,3 +103,18 @@ def test_histogram_intersection(): res = histogram_intersection(numba_first, numba_second) assert res == 2 + + +def test_subsampling_in_highly_imbalanced_datasets(): + """Test the subsampling during fit for highly imbalanced datasets. + + This test case tests the fix for bug #1726. + https://github.com/aeon-toolkit/aeon/issues/1726 + """ + X = np.random.rand(10, 1, 20) + y_sc = np.array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0]) + + tde = TemporalDictionaryEnsemble(random_state=42) + tde.fit(X, y_sc) + + assert tde.is_fitted