Skip to content

Commit

Permalink
Adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Dec 22, 2023
1 parent a611244 commit 349d282
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 133 deletions.
1 change: 1 addition & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ defmodule Bumblebee do

@diffusers_class_to_scheduler %{
"DDIMScheduler" => Bumblebee.Diffusion.DdimScheduler,
"LCMScheduler" => Bumblebee.Diffusion.LcmScheduler,
"PNDMScheduler" => Bumblebee.Diffusion.PndmScheduler
}

Expand Down
19 changes: 11 additions & 8 deletions lib/bumblebee/diffusion/ddim_scheduler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ defmodule Bumblebee.Diffusion.DdimScheduler do
default: true,
doc: """
whether to clip the predicted denoised sample ($x_0$ in Equation (12)) into $[-1, 1]$
for numerical stability.
for numerical stability
"""
],
rederive_noise: [
Expand Down Expand Up @@ -188,19 +188,22 @@ defmodule Bumblebee.Diffusion.DdimScheduler do
end
end

beta_bar_t = 1 - alpha_bar_t
beta_bar_t_prev = 1 - alpha_bar_t_prev

{pred_denoised_sample, noise} =
case scheduler.prediction_type do
:noise ->
pred_denoised_sample =
(sample - Nx.sqrt(1 - alpha_bar_t) * prediction) / Nx.sqrt(alpha_bar_t)
(sample - Nx.sqrt(beta_bar_t) * prediction) / Nx.sqrt(alpha_bar_t)

{pred_denoised_sample, prediction}

:angular_velocity ->
pred_denoised_sample =
Nx.sqrt(alpha_bar_t) * sample - Nx.sqrt(1 - alpha_bar_t) * prediction
Nx.sqrt(alpha_bar_t) * sample - Nx.sqrt(beta_bar_t) * prediction

noise = Nx.sqrt(alpha_bar_t) * prediction + Nx.sqrt(1 - alpha_bar_t) * sample
noise = Nx.sqrt(alpha_bar_t) * prediction + Nx.sqrt(beta_bar_t) * sample
{pred_denoised_sample, noise}
end

Expand All @@ -214,23 +217,23 @@ defmodule Bumblebee.Diffusion.DdimScheduler do
# See Equation (16)
sigma_t =
scheduler.eta *
Nx.sqrt((1 - alpha_bar_t_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_t_prev))
Nx.sqrt((beta_bar_t_prev) / (beta_bar_t) * (1 - alpha_bar_t / alpha_bar_t_prev))

noise =
if scheduler.rederive_noise do
# Re-derive the noise as in GLIDE
(sample - Nx.sqrt(alpha_bar_t) * pred_denoised_sample) / Nx.sqrt(1 - alpha_bar_t)
(sample - Nx.sqrt(alpha_bar_t) * pred_denoised_sample) / Nx.sqrt(beta_bar_t)
else
noise
end

pred_sample_direction = Nx.sqrt(1 - alpha_bar_t_prev - Nx.pow(sigma_t, 2)) * noise
pred_sample_direction = Nx.sqrt(beta_bar_t_prev - Nx.pow(sigma_t, 2)) * noise

prev_sample = Nx.sqrt(alpha_bar_t_prev) * pred_denoised_sample + pred_sample_direction

{prev_sample, next_key} =
if scheduler.eta > 0 do
{rand, next_key} = Nx.Random.normal(state.prng_key, prev_sample)
{rand, next_key} = Nx.Random.normal(state.prng_key, shape: Nx.shape(prev_sample))
out = prev_sample + sigma_t * rand
{out, next_key}
else
Expand Down
Loading

0 comments on commit 349d282

Please sign in to comment.