From 4bf1db9772ee1672cc181d1e4f2f624070a48cb8 Mon Sep 17 00:00:00 2001 From: Ajda Date: Fri, 2 Feb 2024 13:36:33 +0100 Subject: [PATCH] Timeseries: test Stack learner --- .../timeseries/tests/test_timeseries.py | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/orangecontrib/timeseries/tests/test_timeseries.py b/orangecontrib/timeseries/tests/test_timeseries.py index 926c9b68..d4905a6f 100644 --- a/orangecontrib/timeseries/tests/test_timeseries.py +++ b/orangecontrib/timeseries/tests/test_timeseries.py @@ -4,7 +4,11 @@ import platform from unittest.mock import patch -from Orange.data import Table +from Orange.data import Table, ContinuousVariable, Domain, TimeVariable, \ + DiscreteVariable +from Orange.classification import TreeLearner, KNNLearner +from Orange.ensembles import StackedFitter +from Orange.evaluation import CrossValidation, MSE from orangecontrib.timeseries import Timeseries from orangecontrib.timeseries.functions import timestamp, fromtimestamp @@ -19,7 +23,8 @@ def test_create_time_variable(self): self.assertNotEqual(id_1, id(time_series.attributes)) def test_make_timeseries_from_continuous_var(self): - table = Table.from_url("http://file.biolab.si/datasets/slovenian-national-assembly-eng.tab") + table = Table.from_url( + "http://file.biolab.si/datasets/slovenian-national-assembly-eng.tab") time_series = Timeseries.make_timeseries_from_continuous_var(table, 'date of birth') self.assertEqual(time_series.time_variable.name, 'date of birth') @@ -29,7 +34,7 @@ def test_time_var_removed(self): ts_with_tv = Timeseries.from_file('airpassengers') # select columns without time variable ts_without_tv = Timeseries.from_data_table(ts_with_tv[:, - ts_with_tv.domain.class_var]) + ts_with_tv.domain.class_var]) self.assertTrue(ts_with_tv.time_variable) # make sure the Timeseries without time variable in domain has # time_variable set to None @@ -66,6 +71,7 @@ def test_timestamp_windows(self): with hardcoded correct timestamps. It can be only tested with UTC since otherwise timestamp would be machine local time dependent """ + class T(datetime): def timestamp(self): nonlocal was_hit @@ -102,6 +108,28 @@ def fromtimestamp(cls, *args, **kwargs): self.assertEqual(fromtimestamp(TS, tz=timezone.utc), expected) self.assertTrue(was_hit) + def test_stacking(self): + domain = Domain([TimeVariable("a"), ContinuousVariable("b"), + ContinuousVariable("c")], + class_vars=DiscreteVariable("cls", values=["1", "0"])) + ts = Table.from_numpy(domain, [(585167768, 0.88224325, 0.87219962), + (1402820096, 0.30777631, 0.44907067), + (899898806, 0.79237373, 0.42574664), + (1270258083, 0.99060523, 0.05312487), + (393049941, 0.41741731, 0.36743904), + (701125886, 0.77428077, 0.12416262), + (18750041, 0.91150353, 0.80794697), + (815292377, 0.99979289, 0.82930241), + (161861359, 0.06679552, 0.70782449), + (408412451, 0.37535755, 0.94807882)], + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) + sf = StackedFitter([TreeLearner(), KNNLearner()]) + cv = CrossValidation(k=3, random_state=0) + results = cv(ts, [sf, KNNLearner(), TreeLearner()]) + mse = MSE()(results) + self.assertLess(mse[0], mse[1]) + self.assertLess(mse[0], mse[2]) + if __name__ == "__main__": unittest.main()