Skip to content

Commit

Permalink
adding factory for dates
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed May 15, 2024
1 parent 233ad57 commit f6e8315
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 2 deletions.
52 changes: 50 additions & 2 deletions src/anemoi/utils/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,15 @@ def __init__(self, reference_dates, years=20):
"""

self.reference_dates = reference_dates
self.years = (1, years + 1)

if isinstance(years, list):
self.years = years
else:
self.years = range(1, years + 1)

def __iter__(self):
for reference_date in self.reference_dates:
for year in range(*self.years):
for year in self.years:
if reference_date.month == 2 and reference_date.day == 29:
date = datetime.datetime(reference_date.year - year, 2, 28)
else:
Expand Down Expand Up @@ -246,3 +250,47 @@ def __init__(self, year, **kwargs):
_description_
"""
super().__init__(datetime.datetime(year, 9, 1), datetime.datetime(year, 11, 30), **kwargs)


class ConcatDateTimes:
def __init__(self, *dates):
if len(dates) == 1 and isinstance(dates[0], list):
dates = dates[0]

self.dates = dates

def __iter__(self):
for date in self.dates:
yield from date


class EnumDateTimes:
def __init__(self, dates):
self.dates = dates

def __iter__(self):
for date in self.dates:
yield as_datetime(date)


def datetimes_factory(args):
if isinstance(args, dict):
name = args.get("name")

if name == "hindcast":
reference_dates = args["reference_dates"]
reference_dates = datetimes_factory(reference_dates)
years = args["years"]
return HindcastDatesTimes(reference_dates=reference_dates, years=years)

args = args.copy()
frequency = args.pop("frequency", 24)
args["increment"] = frequency
return DateTimes(**args)

if isinstance(args, list):
if all(isinstance(arg, dict) for arg in args):
return ConcatDateTimes(*[datetimes_factory(arg) for arg in args])
else:
return EnumDateTimes(args)
raise ValueError(f"Invalid dates provided : {args}")
92 changes: 92 additions & 0 deletions tests/test_dates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from textwrap import dedent

import yaml

from anemoi.utils.dates import datetimes_factory


def _(txt):
txt = dedent(txt)
config = yaml.safe_load(txt)
config = config["dates"]
return datetimes_factory(config)


def test_date_1():
d = _(
"""
dates:
- 2023-01-01
- 2023-01-02
- 2023-01-03
"""
)
assert len(list(d)) == 3


def test_date_2():
d = _(
"""
dates:
start: 2023-01-01
end: 2023-01-07
frequency: 12
day_of_week: [monday, friday]
"""
)
assert len(list(d)) == 4


def test_date_3():
d = _(
"""
dates:
- start: 2023-01-01
end: 2023-01-03
frequency: 24
- start: 2024-01-01T06:00:00
end: 2024-01-03T18:00:00
frequency: 6
"""
)
assert len(list(d)) == 14


def test_date_hindcast_1():
d = _(
"""
dates:
- name: hindcast
reference_dates:
start: 2023-01-01
end: 2023-01-03
frequency: 24
years: 20
"""
)
assert len(list(d)) == 60


def test_date_hindcast_2():
d = _(
"""
dates:
- name: hindcast
reference_dates:
start: 2023-01-01
end: 2023-01-03
frequency: 24
years: [2018, 2019, 2020, 2021]
"""
)
assert len(list(d)) == 12


if __name__ == "__main__":
test_functions = [
obj for name, obj in globals().items() if name.startswith("test_") and isinstance(obj, type(lambda: 0))
]
for test_func in test_functions:
print(f"Running test: {test_func.__name__}")
test_func()
print("All tests passed!")

0 comments on commit f6e8315

Please sign in to comment.