-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minor Touches for ScoreGradELBO
#99
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 43e8581 | Previous: 045dd6d | Ratio |
---|---|---|---|
normal/RepGradELBO + STL/meanfield/Zygote |
15189643204 ns |
15232742752 ns |
1.00 |
normal/RepGradELBO + STL/meanfield/ForwardDiff |
3680697154 ns |
3521333144 ns |
1.05 |
normal/RepGradELBO + STL/meanfield/ReverseDiff |
3402309335 ns |
3351996550 ns |
1.02 |
normal/RepGradELBO + STL/fullrank/Zygote |
15122164918 ns |
14890644152 ns |
1.02 |
normal/RepGradELBO + STL/fullrank/ForwardDiff |
4025298069 ns |
4021148178 ns |
1.00 |
normal/RepGradELBO + STL/fullrank/ReverseDiff |
6064492144 ns |
5984971083 ns |
1.01 |
normal/RepGradELBO/meanfield/Zygote |
7366785984 ns |
7307252380 ns |
1.01 |
normal/RepGradELBO/meanfield/ForwardDiff |
2717171989 ns |
2713184833 ns |
1.00 |
normal/RepGradELBO/meanfield/ReverseDiff |
1576556708 ns |
1543720475 ns |
1.02 |
normal/RepGradELBO/fullrank/Zygote |
7402608183 ns |
7281289173 ns |
1.02 |
normal/RepGradELBO/fullrank/ForwardDiff |
2903606866 ns |
2793902508 ns |
1.04 |
normal/RepGradELBO/fullrank/ReverseDiff |
2706291324 ns |
2670070163 ns |
1.01 |
normal + bijector/RepGradELBO + STL/meanfield/Zygote |
23628310873 ns |
23687476333 ns |
1.00 |
normal + bijector/RepGradELBO + STL/meanfield/ForwardDiff |
9036872886 ns |
9547772508 ns |
0.95 |
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff |
4842175602 ns |
4832381410 ns |
1.00 |
normal + bijector/RepGradELBO + STL/fullrank/Zygote |
23114862714 ns |
23159256116 ns |
1.00 |
normal + bijector/RepGradELBO + STL/fullrank/ForwardDiff |
10118363781 ns |
9758670064 ns |
1.04 |
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff |
7937031363 ns |
8015117017 ns |
0.99 |
normal + bijector/RepGradELBO/meanfield/Zygote |
14708504769 ns |
14583422835 ns |
1.01 |
normal + bijector/RepGradELBO/meanfield/ForwardDiff |
8356263940 ns |
7683963206 ns |
1.09 |
normal + bijector/RepGradELBO/meanfield/ReverseDiff |
2697208733 ns |
2844739654 ns |
0.95 |
normal + bijector/RepGradELBO/fullrank/Zygote |
14836373708 ns |
14727679032 ns |
1.01 |
normal + bijector/RepGradELBO/fullrank/ForwardDiff |
7963593412 ns |
8213197170 ns |
0.97 |
normal + bijector/RepGradELBO/fullrank/ReverseDiff |
4278527667 ns |
4256870182 ns |
1.01 |
This comment was automatically generated by workflow using github-action-benchmark.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #99 +/- ##
==========================================
- Coverage 93.26% 92.04% -1.22%
==========================================
Files 12 12
Lines 371 377 +6
==========================================
+ Hits 346 347 +1
- Misses 25 30 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@willtebbutt Hi, it seems only Mooncake is failing. Could you look into the error messages? If they don't quite make sense, I'll try to pack up a MWE. |
Aha! I've been waiting for an example where this happens -- I've been aware of this |
This being said, if you're able to make a MWE that I can run easily on my machine, that would be great. Just whatever the function is that you're differentiating, because I'm not exactly how your tests are structured, and therefore where to find the correct function. |
Sampling is now done out of the AD path.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@arnauqb Hi, I shifted things around a bit. Could you take a look and see if you have any comments? Also, the default value for the baseline, which you set as 1, is somewhat curious. Shouldn't this be 0 for the control variate to have no effect? Or maybe I'm misunderstanding something. |
Hi @Red-Portal thanks for tidying this up, looks good to me. Yes, you are right that baseline should be 0 to have no effect, not sure what I set it to 1 😅 |
@Red-Portal After a bit of testing, I run into an error when sampling from two priors like this: @model make_model(data)
p1 ~ Dirichlet(3, 1.0)
p2 ~ Dirichlet(2, 1.0)
data ~ CustomDistribution(vcat(p1, p2))
end and then transforming When using the score estimator I get an input missmatch To fix this, we can sample |
@arnauqb I think that's how it is done in this PR now. However, that error is quite peculiar, and it appears that there might be something wrong with Bijectors. If you have time, could you try to come up with a MWE and take it to |
The problem may be that I am using this: https://github.com/TuringLang/Turing.jl/blob/40a0d84b76e8e262e32618f83e6b895b34177d95/src/variational/advi.jl#L23 using AdvancedVI
using ADTypes
using DynamicPPL
using DistributionsAD
using Distributions
using Bijectors
using Optimisers
using LinearAlgebra
using Zygote
function wrap_in_vec_reshape(f, in_size)
vec_in_length = prod(in_size)
reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
out_size = Bijectors.output_size(f, in_size)
vec_out_length = prod(out_size)
reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
return reshape_outer ∘ f ∘ reshape_inner
end
function Bijectors.bijector(
model::DynamicPPL.Model, ::Val{sym2ranges} = Val(false);
varinfo = DynamicPPL.VarInfo(model)
) where {sym2ranges}
num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata)])
dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...)
num_ranges = sum([length(varinfo.metadata[sym].ranges)
for sym in keys(varinfo.metadata)])
ranges = Vector{UnitRange{Int}}(undef, num_ranges)
idx = 0
range_idx = 1
# ranges might be discontinuous => values are vectors of ranges rather than just ranges
sym_lookup = Dict{Symbol, Vector{UnitRange{Int}}}()
for sym in keys(varinfo.metadata)
sym_lookup[sym] = Vector{UnitRange{Int}}()
for r in varinfo.metadata[sym].ranges
ranges[range_idx] = idx .+ r
push!(sym_lookup[sym], ranges[range_idx])
range_idx += 1
end
idx += varinfo.metadata[sym].ranges[end][end]
end
bs = map(tuple(dists...)) do d
b = Bijectors.bijector(d)
if d isa Distributions.UnivariateDistribution
b
else
wrap_in_vec_reshape(b, size(d))
end
end
if sym2ranges
return (
Bijectors.Stacked(bs, ranges),
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...)
)
else
return Bijectors.Stacked(bs, ranges)
end
end
##
function double_normal()
return MvNormal([2.0, 3.0, 4.0], Diagonal(ones(3)))
end
@model function normal_model(data)
p1 ~ filldist(Normal(0.0, 1.0), 2)
p2 ~ Normal(0.0, 1.0)
ps = vcat(p1, p2)
for i in 1:size(data, 2)
data[:, i] ~ MvNormal(ps, Diagonal(ones(3)))
end
end
data = rand(double_normal(), 100)
model = normal_model(data)
##
d = 3
μ = zeros(d)
L = Diagonal(ones(d));
q = AdvancedVI.MeanFieldGaussian(μ, L)
optimizer = Optimisers.Adam(1e-3)
bijector_transf = inverse(bijector(model))
q_transformed = transformed(q, bijector_transf)
ℓπ = DynamicPPL.LogDensityFunction(model)
elbo = AdvancedVI.ScoreGradELBO(10, entropy = StickingTheLandingEntropy()) # this doesn't
#elbo = AdvancedVI.RepGradELBO(10, entropy = StickingTheLandingEntropy()) # This works
q, _, stats, _ = AdvancedVI.optimize(
ℓπ,
elbo,
q_transformed,
10;
adtype = AutoZygote(),
optimizer = optimizer,
) and stacktrace:
|
Hi @arnauqb , sorry for the late reply. Seems like this happens only on Zygote right? Edit: Aha! This is because Zygote attempts to differentiate through |
#if @isdefined(Tapir) | ||
# AD_scoregradelbo_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) | ||
#end | ||
if @isdefined(Mooncake) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Red-Portal, these conditional loading of Mooncake is no longer needed now that we have dropped support for Julia < 1.10
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything else is passing other than Enzyme tests. |
@Red-Portal I think this line should be moved outside the gradient calculation (and passed through |
Hi! Hmm it wouldn't be using the path gradient here since |
Something similar. I have a custom Distribution for which I implement a Now you are right that samples don't carry a gradient so ultimately this won't affect the results, I think, but may be a performance boost to take it out. |
I see. Uhghh it is annoying that Zygote can't be smarter. Anyways thanks for mentioning this, I'll try to improve it. |
….jl into tidy_scoregradelbo
I modified the implementation so that the This should also vastly simplify the interface since the entropy estimator is part of the VarGrad objective. Footnotes
|
No description provided.