Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor Touches for ScoreGradELBO #99

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.7'
- '1.10'
- '1'
os:
- ubuntu-latest
- macOS-latest
Expand Down
7 changes: 0 additions & 7 deletions src/objectives/elbo/entropy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,3 @@ function estimate_entropy(
-logpdf(q, mc_sample)
end
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end
7 changes: 7 additions & 0 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ function Base.show(io::IO, obj::RepGradELBO)
return print(io, ")")
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end

function estimate_energy_with_samples(prob, samples)
return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
end
Expand Down
108 changes: 22 additions & 86 deletions src/objectives/elbo/scoregradelbo.jl
Original file line number Diff line number Diff line change
@@ -1,113 +1,47 @@

"""
ScoreGradELBO(n_samples; kwargs...)

Evidence lower-bound objective computed with score function gradients.
```math
\\begin{aligned}
\\nabla_{\\lambda} \\mathrm{ELBO}\\left(\\lambda\\right)
&\\=
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
\\log \\pi\\left(z\\right) \\nabla_{\\lambda} \\log q_{\\lambda}(z)
\\right]
+ \\mathbb{H}\\left(q_{\\lambda}\\right),
\\end{aligned}
```

To reduce the variance of the gradient estimator, we use a baseline computed from a running average of the previous ELBO values and subtract it from the objective.

```math
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
\\nabla_{\\lambda} \\log q_{\\lambda}(z) \\left(\\pi\\left(z\\right) - \\beta\\right)
\\right]
```
Evidence lower-bound objective computed with score function gradient with the VarGrad objective, also known as the leave-one-out control variate.

# Arguments
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.

# Keyword Arguments
- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: `ClosedFormEntropy()`)
- `baseline_window_size::Int`: The window size to use to compute the baseline. (Default: `10`)
- `baseline_history::Vector{Float64}`: The history of the baseline. (Default: `Float64[]`)
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the VarGrad objective.

# Requirements
- The variational approximation ``q_{\\lambda}`` implements `rand` and `logpdf`.
- `logpdf(q, x)` must be differentiable with respect to `q` by the selected AD backend.
- The target distribution and the variational approximation have the same support.

Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
"""
struct ScoreGradELBO{EntropyEst<:AbstractEntropyEstimator} <:
AdvancedVI.AbstractVariationalObjective
entropy::EntropyEst
struct ScoreGradELBO <: AbstractVariationalObjective
n_samples::Int
baseline_window_size::Int
baseline_history::Vector{Float64}
end

function ScoreGradELBO(
n_samples::Int;
entropy::AbstractEntropyEstimator=ClosedFormEntropy(),
baseline_window_size::Int=10,
baseline_history::Vector{Float64}=Float64[],
)
return ScoreGradELBO(entropy, n_samples, baseline_window_size, baseline_history)
end

function Base.show(io::IO, obj::ScoreGradELBO)
print(io, "ScoreGradELBO(entropy=")
print(io, obj.entropy)
print(io, ", n_samples=")
print(io, "ScoreGradELBO(n_samples=")
print(io, obj.n_samples)
print(io, ", baseline_window_size=")
print(io, obj.baseline_window_size)
return print(io, ")")
end

function compute_control_variate_baseline(history, window_size)
if length(history) == 0
return 1.0
end
min_index = max(1, length(history) - window_size)
return mean(history[min_index:end])
end

function estimate_energy_with_samples(
prob, samples_stop, samples_logprob, samples_logprob_stop, baseline
)
fv = Base.Fix1(LogDensityProblems.logdensity, prob).(eachsample(samples_stop))
fv_mean = mean(fv)
score_grad = mean(@. samples_logprob * (fv - baseline))
score_grad_stop = mean(@. samples_logprob_stop * (fv - baseline))
return fv_mean + (score_grad - score_grad_stop)
end

function estimate_objective(
rng::Random.AbstractRNG, obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples
)
samples, entropy = reparam_with_entropy(rng, q, q, obj.n_samples, obj.entropy)
energy = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
return mean(energy) + entropy
samples = rand(rng, q, n_samples)
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
return mean(ℓπ - ℓq)
end

function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_samples)
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
end

function estimate_scoregradelbo_ad_forward(params′, aux)
@unpack rng, obj, problem, adtype, restructure, q_stop = aux
baseline = compute_control_variate_baseline(
obj.baseline_history, obj.baseline_window_size
)
@unpack rng, obj, logprob, adtype, restructure, samples = aux
q = restructure_ad_forward(adtype, restructure, params′)
samples_stop = rand(rng, q_stop, obj.n_samples)
entropy = estimate_entropy_maybe_stl(obj.entropy, samples_stop, q, q_stop)
samples_logprob = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
samples_logprob_stop = logpdf.(Ref(q_stop), AdvancedVI.eachsample(samples_stop))
energy = estimate_energy_with_samples(
problem, samples_stop, samples_logprob, samples_logprob_stop, baseline
)
elbo = energy + entropy
return -elbo
ℓπ = logprob
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
f = ℓq - ℓπ
return var(f) / 2
end

function AdvancedVI.estimate_gradient!(
Expand All @@ -120,20 +54,22 @@ function AdvancedVI.estimate_gradient!(
restructure,
state,
)
q_stop = restructure(params)
q = restructure(params)
samples = rand(rng, q, obj.n_samples)
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
aux = (
rng=rng,
adtype=adtype,
obj=obj,
problem=prob,
logprob=ℓπ,
restructure=restructure,
q_stop=q_stop,
samples=samples,
)
AdvancedVI.value_and_gradient!(
adtype, estimate_scoregradelbo_ad_forward, params, aux, out
)
nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)
push!(obj.baseline_history, -nelbo)
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
elbo = mean(ℓπ - ℓq)
stat = (elbo=elbo,)
return out, nothing, stat
end
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if @isdefined(Enzyme)
)
end

@testset "inference ScoreGradELBO VILocationScale" begin
@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in
Expand Down
12 changes: 4 additions & 8 deletions test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@ if @isdefined(Mooncake)
AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

#if @isdefined(Enzyme)
# AD_scoregradelbo_distributionsad[:Enzyme] = AutoEnzyme()
#end
if @isdefined(Enzyme)
AD_scoregradelbo_distributionsad[:Enzyme] = AutoEnzyme()
end

@testset "inference ScoreGradELBO DistributionsAD" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
(modelname, modelconstr) in Dict(:Normal => normal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(n_montecarlo)),
(adbackname, adtype) in AD_scoregradelbo_distributionsad

seed = (0x38bef07cf9cc549d)
Expand Down
6 changes: 1 addition & 5 deletions test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ end
(modelname, modelconstr) in
Dict(:Normal => normal_meanfield, :Normal => normal_fullrank),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo),
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(n_montecarlo)),
(adbackname, adtype) in AD_locationscale

seed = (0x38bef07cf9cc549d)
Expand Down
12 changes: 4 additions & 8 deletions test/inference/scoregradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ AD_scoregradelbo_locationscale_bijectors = Dict(
#:Zygote => AutoZygote(),
)

#if @isdefined(Tapir)
# AD_scoregradelbo_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false)
#end
if @isdefined(Mooncake)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Red-Portal, these conditional loading of Mooncake is no longer needed now that we have dropped support for Julia < 1.10

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yebai ! Yes that part of the code is outdated. But it will be fixed by merging #129.

AD_scoregradelbo_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=nothing)
end

if @isdefined(Enzyme)
AD_scoregradelbo_locationscale_bijectors[:Enzyme] = AutoEnzyme()
Expand All @@ -19,11 +19,7 @@ end
(modelname, modelconstr) in
Dict(:NormalLogNormalMeanField => normallognormal_meanfield),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
#:ScoreGradELBOClosedFormEntropy => ScoreGradELBO(n_montecarlo), # not supported yet.
:ScoreGradELBOStickingTheLanding =>
ScoreGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
),
(objname, objective) in Dict(:ScoreGradELBO => ScoreGradELBO(n_montecarlo)),
(adbackname, adtype) in AD_scoregradelbo_locationscale_bijectors

seed = (0x38bef07cf9cc549d)
Expand Down
17 changes: 17 additions & 0 deletions test/interface/scoregradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,21 @@ using Test
elbo = estimate_objective(obj, q0, model; n_samples=10^4)
@test elbo ≈ elbo_ref rtol = 0.2
end

@testset "baseline_window" begin
T = 100
adtype = AutoForwardDiff()

obj = ScoreGradELBO(10)
_, _, stats, _ = optimize(rng, model, obj, q0, T; show_progress=false, adtype)
@test isfinite(last(stats).elbo)

obj = ScoreGradELBO(10; baseline_window_size=0)
_, _, stats, _ = optimize(rng, model, obj, q0, T; show_progress=false, adtype)
@test isfinite(last(stats).elbo)

obj = ScoreGradELBO(10; baseline_window_size=1)
_, _, stats, _ = optimize(rng, model, obj, q0, T; show_progress=false, adtype)
@test isfinite(last(stats).elbo)
end
end
Loading