diff --git a/src/anemoi/datasets/compute/perturbations.py b/src/anemoi/datasets/compute/perturbations.py index c09041f3..4fa964f9 100644 --- a/src/anemoi/datasets/compute/perturbations.py +++ b/src/anemoi/datasets/compute/perturbations.py @@ -7,59 +7,95 @@ # nor does it submit to any jurisdiction. # -import warnings +import logging import numpy as np from climetlab.core.temporary import temp_file from climetlab.readers.grib.output import new_grib_output -from anemoi.datasets.create.check import check_data_values from anemoi.datasets.create.functions import assert_is_fieldset +LOG = logging.getLogger(__name__) + +CLIP_VARIABLES = ( + "q", + "cp", + "lsp", + "tp", + "sf", + "swl4", + "swl3", + "swl2", + "swl1", +) + +SKIP = ("class", "stream", "type", "number", "expver", "_leg_number") + + +def check_compatible(f1, f2, center_field_as_mars, ensemble_field_as_mars): + assert f1.mars_grid == f2.mars_grid, (f1.mars_grid, f2.mars_grid) + assert f1.mars_area == f2.mars_area, (f1.mars_area, f2.mars_area) + assert f1.shape == f2.shape, (f1.shape, f2.shape) + + # Not in *_as_mars + assert f1.metadata("valid_datetime") == f2.metadata("valid_datetime"), ( + f1.metadata("valid_datetime"), + f2.metadata("valid_datetime"), + ) + + for k in set(center_field_as_mars.keys()) | set(ensemble_field_as_mars.keys()): + if k in SKIP: + continue + assert center_field_as_mars[k] == ensemble_field_as_mars[k], ( + k, + center_field_as_mars[k], + ensemble_field_as_mars[k], + ) + def perturbations( + *, members, center, - positive_clipping_variables=[ - "q", - "cp", - "lsp", - "tp", - ], # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ? + clip_variables=CLIP_VARIABLES, + output=None, ): keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"] - def check_compatible(f1, f2, ignore=["number"]): - for k in keys + ["grid", "shape"]: - if k in ignore: - continue - assert f1.metadata(k) == f2.metadata(k), (k, f1.metadata(k), f2.metadata(k)) + number_list = members.unique_values("number")["number"] + n_numbers = len(number_list) - print(f"Retrieving ensemble data with {members}") - print(f"Retrieving center data with {center}") + assert None not in number_list + LOG.info("Ordering fields") members = members.order_by(*keys) center = center.order_by(*keys) - - number_list = members.unique_values("number")["number"] - n_numbers = len(number_list) + LOG.info("Done") if len(center) * n_numbers != len(members): - print(len(center), n_numbers, len(members)) + LOG.error("%s %s %s", len(center), n_numbers, len(members)) for f in members: - print("Member: ", f) + LOG.error("Member: %r", f) for f in center: - print("Center: ", f) + LOG.error("Center: %r", f) raise ValueError(f"Inconsistent number of fields: {len(center)} * {n_numbers} != {len(members)}") - # prepare output tmp file so we can read it back - tmp = temp_file() - path = tmp.path + if output is None: + # prepare output tmp file so we can read it back + tmp = temp_file() + path = tmp.path + else: + tmp = None + path = output + out = new_grib_output(path) + seen = set() + for i, center_field in enumerate(center): param = center_field.metadata("param") + center_field_as_mars = center_field.as_mars() # load the center field center_np = center_field.to_numpy() @@ -69,9 +105,14 @@ def check_compatible(f1, f2, ignore=["number"]): for j in range(n_numbers): ensemble_field = members[i * n_numbers + j] - check_compatible(center_field, ensemble_field) + ensemble_field_as_mars = ensemble_field.as_mars() + check_compatible(center_field, ensemble_field, center_field_as_mars, ensemble_field_as_mars) members_np[j] = ensemble_field.to_numpy() + ensemble_field_as_mars = tuple(sorted(ensemble_field_as_mars.items())) + assert ensemble_field_as_mars not in seen, ensemble_field_as_mars + seen.add(ensemble_field_as_mars) + mean_np = members_np.mean(axis=0) for j in range(n_numbers): @@ -84,18 +125,22 @@ def check_compatible(f1, f2, ignore=["number"]): x = c - m + e - if param in positive_clipping_variables: - warnings.warn(f"Clipping {param} to be positive") + if param in clip_variables: + # LOG.warning(f"Clipping {param} to be positive") x = np.maximum(x, 0) assert x.shape == e.shape, (x.shape, e.shape) - check_data_values(x, name=param) out.write(x, template=template) template = None + assert len(seen) == len(members), (len(seen), len(members)) + out.close() + if output is not None: + return path + from climetlab import load_source ds = load_source("file", path)