diff --git a/src/anemoi/utils/dates.py b/src/anemoi/utils/dates.py index 4614c4c..07c2871 100644 --- a/src/anemoi/utils/dates.py +++ b/src/anemoi/utils/dates.py @@ -162,11 +162,15 @@ def __init__(self, reference_dates, years=20): """ self.reference_dates = reference_dates - self.years = (1, years + 1) + + if isinstance(years, list): + self.years = years + else: + self.years = range(1, years + 1) def __iter__(self): for reference_date in self.reference_dates: - for year in range(*self.years): + for year in self.years: if reference_date.month == 2 and reference_date.day == 29: date = datetime.datetime(reference_date.year - year, 2, 28) else: @@ -246,3 +250,47 @@ def __init__(self, year, **kwargs): _description_ """ super().__init__(datetime.datetime(year, 9, 1), datetime.datetime(year, 11, 30), **kwargs) + + +class ConcatDateTimes: + def __init__(self, *dates): + if len(dates) == 1 and isinstance(dates[0], list): + dates = dates[0] + + self.dates = dates + + def __iter__(self): + for date in self.dates: + yield from date + + +class EnumDateTimes: + def __init__(self, dates): + self.dates = dates + + def __iter__(self): + for date in self.dates: + yield as_datetime(date) + + +def datetimes_factory(args): + if isinstance(args, dict): + name = args.get("name") + + if name == "hindcast": + reference_dates = args["reference_dates"] + reference_dates = datetimes_factory(reference_dates) + years = args["years"] + return HindcastDatesTimes(reference_dates=reference_dates, years=years) + + args = args.copy() + frequency = args.pop("frequency", 24) + args["increment"] = frequency + return DateTimes(**args) + + if isinstance(args, list): + if all(isinstance(arg, dict) for arg in args): + return ConcatDateTimes(*[datetimes_factory(arg) for arg in args]) + else: + return EnumDateTimes(args) + raise ValueError(f"Invalid dates provided : {args}") diff --git a/tests/test_dates.py b/tests/test_dates.py new file mode 100644 index 0000000..61331f7 --- /dev/null +++ b/tests/test_dates.py @@ -0,0 +1,92 @@ +from textwrap import dedent + +import yaml + +from anemoi.utils.dates import datetimes_factory + + +def _(txt): + txt = dedent(txt) + config = yaml.safe_load(txt) + config = config["dates"] + return datetimes_factory(config) + + +def test_date_1(): + d = _( + """ + dates: + - 2023-01-01 + - 2023-01-02 + - 2023-01-03 + """ + ) + assert len(list(d)) == 3 + + +def test_date_2(): + d = _( + """ + dates: + start: 2023-01-01 + end: 2023-01-07 + frequency: 12 + day_of_week: [monday, friday] + """ + ) + assert len(list(d)) == 4 + + +def test_date_3(): + d = _( + """ + dates: + - start: 2023-01-01 + end: 2023-01-03 + frequency: 24 + - start: 2024-01-01T06:00:00 + end: 2024-01-03T18:00:00 + frequency: 6 + """ + ) + assert len(list(d)) == 14 + + +def test_date_hindcast_1(): + d = _( + """ + dates: + - name: hindcast + reference_dates: + start: 2023-01-01 + end: 2023-01-03 + frequency: 24 + years: 20 + """ + ) + assert len(list(d)) == 60 + + +def test_date_hindcast_2(): + d = _( + """ + dates: + - name: hindcast + reference_dates: + start: 2023-01-01 + end: 2023-01-03 + frequency: 24 + years: [2018, 2019, 2020, 2021] + """ + ) + assert len(list(d)) == 12 + + +if __name__ == "__main__": + test_functions = [ + obj for name, obj in globals().items() if name.startswith("test_") and isinstance(obj, type(lambda: 0)) + ] + for test_func in test_functions: + print(f"Running test: {test_func.__name__}") + test_func() + print("All tests passed!")