diff --git a/aeon/classification/dictionary_based/_tde.py b/aeon/classification/dictionary_based/_tde.py index f8af016b54..ba3eae07a5 100644 --- a/aeon/classification/dictionary_based/_tde.py +++ b/aeon/classification/dictionary_based/_tde.py @@ -320,9 +320,14 @@ def _fit(self, X, y, keep_train_preds=False): rng.choice(np.flatnonzero(preds == preds.max())) ) - subsample = rng.choice(self.n_cases_, size=subsample_size, replace=False) - X_subsample = X[subsample] - y_subsample = y[subsample] + while True: + subsample = rng.choice( + self.n_cases_, size=subsample_size, replace=False + ) + X_subsample = X[subsample] + y_subsample = y[subsample] + if len(np.unique(y_subsample)) > 1: + break tde = IndividualTDE( *parameters, diff --git a/aeon/classification/dictionary_based/tests/test_tde.py b/aeon/classification/dictionary_based/tests/test_tde.py index e49f2d9015..a069e1ba45 100644 --- a/aeon/classification/dictionary_based/tests/test_tde.py +++ b/aeon/classification/dictionary_based/tests/test_tde.py @@ -103,3 +103,18 @@ def test_histogram_intersection(): res = histogram_intersection(numba_first, numba_second) assert res == 2 + + +def test_subsampling_in_highly_imbalanced_datasets(): + """Test the subsampling during fit for highly imbalanced datasets. + + This test case tests the fix for bug #1726. + https://github.com/aeon-toolkit/aeon/issues/1726 + """ + X = np.random.rand(10, 1, 20) + y_sc = np.array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0]) + + tde = TemporalDictionaryEnsemble(random_state=42) + tde.fit(X, y_sc) + + assert tde.is_fitted