Skip to content

Commit

Permalink
perturbations
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 6, 2024
1 parent 7cc1eea commit 7d49ac1
Showing 1 changed file with 73 additions and 28 deletions.
101 changes: 73 additions & 28 deletions src/anemoi/datasets/compute/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,95 @@
# nor does it submit to any jurisdiction.
#

import warnings
import logging

import numpy as np
from climetlab.core.temporary import temp_file
from climetlab.readers.grib.output import new_grib_output

from anemoi.datasets.create.check import check_data_values
from anemoi.datasets.create.functions import assert_is_fieldset

LOG = logging.getLogger(__name__)

CLIP_VARIABLES = (
"q",
"cp",
"lsp",
"tp",
"sf",
"swl4",
"swl3",
"swl2",
"swl1",
)

SKIP = ("class", "stream", "type", "number", "expver", "_leg_number")


def check_compatible(f1, f2, center_field_as_mars, ensemble_field_as_mars):
assert f1.mars_grid == f2.mars_grid, (f1.mars_grid, f2.mars_grid)
assert f1.mars_area == f2.mars_area, (f1.mars_area, f2.mars_area)
assert f1.shape == f2.shape, (f1.shape, f2.shape)

# Not in *_as_mars
assert f1.metadata("valid_datetime") == f2.metadata("valid_datetime"), (
f1.metadata("valid_datetime"),
f2.metadata("valid_datetime"),
)

for k in set(center_field_as_mars.keys()) | set(ensemble_field_as_mars.keys()):
if k in SKIP:
continue
assert center_field_as_mars[k] == ensemble_field_as_mars[k], (
k,
center_field_as_mars[k],
ensemble_field_as_mars[k],
)


def perturbations(
*,
members,
center,
positive_clipping_variables=[
"q",
"cp",
"lsp",
"tp",
], # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ?
clip_variables=CLIP_VARIABLES,
output=None,
):

keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"]

def check_compatible(f1, f2, ignore=["number"]):
for k in keys + ["grid", "shape"]:
if k in ignore:
continue
assert f1.metadata(k) == f2.metadata(k), (k, f1.metadata(k), f2.metadata(k))
number_list = members.unique_values("number")["number"]
n_numbers = len(number_list)

print(f"Retrieving ensemble data with {members}")
print(f"Retrieving center data with {center}")
assert None not in number_list

LOG.info("Ordering fields")
members = members.order_by(*keys)
center = center.order_by(*keys)

number_list = members.unique_values("number")["number"]
n_numbers = len(number_list)
LOG.info("Done")

if len(center) * n_numbers != len(members):
print(len(center), n_numbers, len(members))
LOG.error("%s %s %s", len(center), n_numbers, len(members))
for f in members:
print("Member: ", f)
LOG.error("Member: %r", f)
for f in center:
print("Center: ", f)
LOG.error("Center: %r", f)
raise ValueError(f"Inconsistent number of fields: {len(center)} * {n_numbers} != {len(members)}")

# prepare output tmp file so we can read it back
tmp = temp_file()
path = tmp.path
if output is None:
# prepare output tmp file so we can read it back
tmp = temp_file()
path = tmp.path
else:
tmp = None
path = output

out = new_grib_output(path)

seen = set()

for i, center_field in enumerate(center):
param = center_field.metadata("param")
center_field_as_mars = center_field.as_mars()

# load the center field
center_np = center_field.to_numpy()
Expand All @@ -69,9 +105,14 @@ def check_compatible(f1, f2, ignore=["number"]):

for j in range(n_numbers):
ensemble_field = members[i * n_numbers + j]
check_compatible(center_field, ensemble_field)
ensemble_field_as_mars = ensemble_field.as_mars()
check_compatible(center_field, ensemble_field, center_field_as_mars, ensemble_field_as_mars)
members_np[j] = ensemble_field.to_numpy()

ensemble_field_as_mars = tuple(sorted(ensemble_field_as_mars.items()))
assert ensemble_field_as_mars not in seen, ensemble_field_as_mars
seen.add(ensemble_field_as_mars)

mean_np = members_np.mean(axis=0)

for j in range(n_numbers):
Expand All @@ -84,18 +125,22 @@ def check_compatible(f1, f2, ignore=["number"]):

x = c - m + e

if param in positive_clipping_variables:
warnings.warn(f"Clipping {param} to be positive")
if param in clip_variables:
# LOG.warning(f"Clipping {param} to be positive")
x = np.maximum(x, 0)

assert x.shape == e.shape, (x.shape, e.shape)

check_data_values(x, name=param)
out.write(x, template=template)
template = None

assert len(seen) == len(members), (len(seen), len(members))

out.close()

if output is not None:
return path

from climetlab import load_source

ds = load_source("file", path)
Expand Down

0 comments on commit 7d49ac1

Please sign in to comment.