diff --git a/deepr/workflow.py b/deepr/workflow.py index 5d3ed99..390571e 100644 --- a/deepr/workflow.py +++ b/deepr/workflow.py @@ -421,6 +421,7 @@ def generate_predictions(self): hour_embed_type=self.inference_config.get("hour_embed_type", "class"), hour_embed_dim=self.inference_config.get("hour_embed_dim", 64), instance_norm=self.inference_config.get("instance_norm", False), + learn_residuals=self.inference_config.get("learn_residuals", False), ) pipe.to(self.inference_config["device"]) generate_data.generate_validation_dataset(