diff --git a/deepr/validation/generate_data.py b/deepr/validation/generate_data.py index 856ae8b..b635915 100644 --- a/deepr/validation/generate_data.py +++ b/deepr/validation/generate_data.py @@ -142,10 +142,9 @@ def predict_xr( "sample) at low-res & high-res orography.", } elif isinstance(model, cDDPMPipeline): - hour_emb = get_hour_embedding(times[:, :1], "class", 24).to(model.device) prediction = model( images=era5, - class_labels=hour_emb, + class_labels=times[:, :1].to(model.device), eta=config["eta"], num_inference_steps=config["inference_steps"], generator=torch.manual_seed(config.get("seed", 2023)),