Skip to content

Commit

Permalink
Support DM to learn residuals
Browse files Browse the repository at this point in the history
  • Loading branch information
Mario Santa Cruz Lopez committed Oct 24, 2023
1 parent f335b0c commit 1e67b17
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deepr/model/conditional_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
scheduler,
obs_model=None,
baseline_interpolation_method: Optional[str] = "bicubic",
learn_residuals: Optional[bool] = False,
hour_embed_type: [Optional] = "class",
hour_embed_dim: Optional[int] = 64,
instance_norm: Optional[bool] = False,
Expand All @@ -41,6 +42,7 @@ def __init__(
self.hour_embed_type = hour_embed_type
self.hour_embed_dim = hour_embed_dim
self.instance_norm = instance_norm
self.learn_residuals = learn_residuals
self.register_modules(unet=unet, scheduler=scheduler, obs_model=obs_model)

@torch.no_grad()
Expand Down Expand Up @@ -143,6 +145,9 @@ def __call__(
intermediate_images.append(latents.cpu())
intermediate_images = torch.cat(intermediate_images, dim=1)

if self.learn_residuals:
latents = latents + up_images

if self.instance_norm:
latents = latents * s + m

Expand Down
1 change: 1 addition & 0 deletions deepr/model/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TrainingConfig:
save_image_epochs: Optional[int] = None
save_model_epochs: Optional[int] = None
instance_norm: Optional[bool] = False
learn_residuals: Optional[bool] = False
hour_embed_type: str = "none"
hour_embed_size: int = 64
device: str = "cuda"
Expand Down
7 changes: 7 additions & 0 deletions deepr/model/diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def train_diffusion(
era5 = (era5 - m) / s
cerra = (cerra - m) / s

if config.learn_residuals:
cerra = cerra - era5

# Add noise to the clean images according to the noise magnitude at each t
noisy_images = noise_scheduler.add_noise(cerra, noise, timesteps)

Expand Down Expand Up @@ -252,6 +255,9 @@ def train_diffusion(
era5 = (era5 - m) / s
cerra = (cerra - m) / s

if config.learn_residuals:
cerra = cerra - era5

noisy_images = noise_scheduler.add_noise(cerra, noise, timesteps)

# Predict the noise residual
Expand Down Expand Up @@ -299,6 +305,7 @@ def train_diffusion(
output_dir=config.output_dir,
epoch=epoch + 1,
obs_model=obs_model,
learn_residuals=config.learn_residuals,
instance_norm=config.instance_norm,
hour_embed_type=config.hour_embed_type,
hour_embed_dim=config.hour_embed_dim,
Expand Down

0 comments on commit 1e67b17

Please sign in to comment.