Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor Touches for ScoreGradELBO #99

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open

Conversation

Red-Portal
Copy link
Member

No description provided.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 43e8581 Previous: 045dd6d Ratio
normal/RepGradELBO + STL/meanfield/Zygote 15189643204 ns 15232742752 ns 1.00
normal/RepGradELBO + STL/meanfield/ForwardDiff 3680697154 ns 3521333144 ns 1.05
normal/RepGradELBO + STL/meanfield/ReverseDiff 3402309335 ns 3351996550 ns 1.02
normal/RepGradELBO + STL/fullrank/Zygote 15122164918 ns 14890644152 ns 1.02
normal/RepGradELBO + STL/fullrank/ForwardDiff 4025298069 ns 4021148178 ns 1.00
normal/RepGradELBO + STL/fullrank/ReverseDiff 6064492144 ns 5984971083 ns 1.01
normal/RepGradELBO/meanfield/Zygote 7366785984 ns 7307252380 ns 1.01
normal/RepGradELBO/meanfield/ForwardDiff 2717171989 ns 2713184833 ns 1.00
normal/RepGradELBO/meanfield/ReverseDiff 1576556708 ns 1543720475 ns 1.02
normal/RepGradELBO/fullrank/Zygote 7402608183 ns 7281289173 ns 1.02
normal/RepGradELBO/fullrank/ForwardDiff 2903606866 ns 2793902508 ns 1.04
normal/RepGradELBO/fullrank/ReverseDiff 2706291324 ns 2670070163 ns 1.01
normal + bijector/RepGradELBO + STL/meanfield/Zygote 23628310873 ns 23687476333 ns 1.00
normal + bijector/RepGradELBO + STL/meanfield/ForwardDiff 9036872886 ns 9547772508 ns 0.95
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff 4842175602 ns 4832381410 ns 1.00
normal + bijector/RepGradELBO + STL/fullrank/Zygote 23114862714 ns 23159256116 ns 1.00
normal + bijector/RepGradELBO + STL/fullrank/ForwardDiff 10118363781 ns 9758670064 ns 1.04
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff 7937031363 ns 8015117017 ns 0.99
normal + bijector/RepGradELBO/meanfield/Zygote 14708504769 ns 14583422835 ns 1.01
normal + bijector/RepGradELBO/meanfield/ForwardDiff 8356263940 ns 7683963206 ns 1.09
normal + bijector/RepGradELBO/meanfield/ReverseDiff 2697208733 ns 2844739654 ns 0.95
normal + bijector/RepGradELBO/fullrank/Zygote 14836373708 ns 14727679032 ns 1.01
normal + bijector/RepGradELBO/fullrank/ForwardDiff 7963593412 ns 8213197170 ns 0.97
normal + bijector/RepGradELBO/fullrank/ReverseDiff 4278527667 ns 4256870182 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

Copy link

codecov bot commented Oct 1, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.04%. Comparing base (d0efe02) to head (7370248).

Additional details and impacted files
@@            Coverage Diff             @@
##           master      #99      +/-   ##
==========================================
- Coverage   93.26%   92.04%   -1.22%     
==========================================
  Files          12       12              
  Lines         371      377       +6     
==========================================
+ Hits          346      347       +1     
- Misses         25       30       +5     
Flag Coverage Δ
?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Red-Portal
Copy link
Member Author

@willtebbutt Hi, it seems only Mooncake is failing. Could you look into the error messages? If they don't quite make sense, I'll try to pack up a MWE.

@willtebbutt
Copy link
Member

Aha! I've been waiting for an example where this happens -- I've been aware of this PartialTypeVar thing for a while, but haven't had a chance to dig into it properly. I'll figure out what's going on :)

@willtebbutt
Copy link
Member

This being said, if you're able to make a MWE that I can run easily on my machine, that would be great. Just whatever the function is that you're differentiating, because I'm not exactly how your tests are structured, and therefore where to find the correct function.

Red-Portal and others added 2 commits October 3, 2024 20:34
Sampling is now done out of the AD path.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@Red-Portal
Copy link
Member Author

@arnauqb Hi, I shifted things around a bit. Could you take a look and see if you have any comments? Also, the default value for the baseline, which you set as 1, is somewhat curious. Shouldn't this be 0 for the control variate to have no effect? Or maybe I'm misunderstanding something.

@arnauqb
Copy link
Contributor

arnauqb commented Oct 4, 2024

Hi @Red-Portal thanks for tidying this up, looks good to me. Yes, you are right that baseline should be 0 to have no effect, not sure what I set it to 1 😅

@Red-Portal Red-Portal added this to the v0.3.0 milestone Oct 4, 2024
@arnauqb
Copy link
Contributor

arnauqb commented Oct 9, 2024

@Red-Portal After a bit of testing, I run into an error when sampling from two priors like this:

@model make_model(data)
    p1 ~ Dirichlet(3, 1.0)
    p2 ~ Dirichlet(2, 1.0)
    data ~ CustomDistribution(vcat(p1, p2))
end

and then transforming q using Bijectors.

When using the score estimator I get an input missmatch (3 !=5) error when it tries to differentiate through the scoregradelbo objective. I believe that there is something happening inside the Bijectors extension where the priors get flatten somehow for q but when passed q_stop something is different.

To fix this, we can sample samples_stop outside the gradient calculation. That is, do samples_stop = rand(rng, q_stop, obj.n_samples) and then pass samples_stop to aux, which probably makes more sense anyways.

@Red-Portal
Copy link
Member Author

@arnauqb I think that's how it is done in this PR now. However, that error is quite peculiar, and it appears that there might be something wrong with Bijectors. If you have time, could you try to come up with a MWE and take it to Bijectors?

@arnauqb
Copy link
Contributor

arnauqb commented Oct 9, 2024

The problem may be that I am using this: https://github.com/TuringLang/Turing.jl/blob/40a0d84b76e8e262e32618f83e6b895b34177d95/src/variational/advi.jl#L23
to do the automatic transformation:

using AdvancedVI
using ADTypes
using DynamicPPL
using DistributionsAD
using Distributions
using Bijectors
using Optimisers
using LinearAlgebra
using Zygote

function wrap_in_vec_reshape(f, in_size)
    vec_in_length = prod(in_size)
    reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
    out_size = Bijectors.output_size(f, in_size)
    vec_out_length = prod(out_size)
    reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
    return reshape_outer  f  reshape_inner
end

function Bijectors.bijector(
        model::DynamicPPL.Model, ::Val{sym2ranges} = Val(false);
        varinfo = DynamicPPL.VarInfo(model)
) where {sym2ranges}
    num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata)])

    dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...)

    num_ranges = sum([length(varinfo.metadata[sym].ranges)
                      for sym in keys(varinfo.metadata)])
    ranges = Vector{UnitRange{Int}}(undef, num_ranges)
    idx = 0
    range_idx = 1

    # ranges might be discontinuous => values are vectors of ranges rather than just ranges
    sym_lookup = Dict{Symbol, Vector{UnitRange{Int}}}()
    for sym in keys(varinfo.metadata)
        sym_lookup[sym] = Vector{UnitRange{Int}}()
        for r in varinfo.metadata[sym].ranges
            ranges[range_idx] = idx .+ r
            push!(sym_lookup[sym], ranges[range_idx])
            range_idx += 1
        end

        idx += varinfo.metadata[sym].ranges[end][end]
    end

    bs = map(tuple(dists...)) do d
        b = Bijectors.bijector(d)
        if d isa Distributions.UnivariateDistribution
            b
        else
            wrap_in_vec_reshape(b, size(d))
        end
    end

    if sym2ranges
        return (
            Bijectors.Stacked(bs, ranges),
            (; collect(zip(keys(sym_lookup), values(sym_lookup)))...)
        )
    else
        return Bijectors.Stacked(bs, ranges)
    end
end
##

function double_normal()
    return MvNormal([2.0, 3.0, 4.0], Diagonal(ones(3)))
end

@model function normal_model(data)
    p1 ~ filldist(Normal(0.0, 1.0), 2)
    p2 ~ Normal(0.0, 1.0)
    ps = vcat(p1, p2)
    for i in 1:size(data, 2)
        data[:, i] ~ MvNormal(ps, Diagonal(ones(3)))
    end
end

data = rand(double_normal(), 100)
model = normal_model(data)

##

d = 3
μ = zeros(d)
L = Diagonal(ones(d));
q = AdvancedVI.MeanFieldGaussian(μ, L)
optimizer = Optimisers.Adam(1e-3)

bijector_transf = inverse(bijector(model))
q_transformed = transformed(q, bijector_transf)
ℓπ = DynamicPPL.LogDensityFunction(model)
elbo = AdvancedVI.ScoreGradELBO(10, entropy = StickingTheLandingEntropy()) # this doesn't
#elbo = AdvancedVI.RepGradELBO(10, entropy = StickingTheLandingEntropy()) # This works

q, _, stats, _ = AdvancedVI.optimize(
	ℓπ,
	elbo,
	q_transformed,
	10;
	adtype = AutoZygote(),
	optimizer = optimizer,
)

and stacktrace:

ERROR: output length mismatch (3 != 2)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] rrule(::typeof(error), 729::String)
    @ ChainRules ~/.julia/packages/ChainRules/vdf7M/src/rulesets/Base/nondiff.jl:155
  [3] rrule(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::String)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6Pucz/src/rules.jl:138
  [4] chain_rrule
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:224 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
  [6] _pullback(ctx::Zygote.Context{false}, f::typeof(error), args::String)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
  [7] transform
    @ ~/.julia/packages/Bijectors/I8eRc/src/bijectors/stacked.jl:162 [inlined]
  [8] _pullback(::Zygote.Context{…}, ::typeof(transform), ::Stacked{…}, ::SubArray{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [9] Transform
    @ ~/.julia/packages/Bijectors/I8eRc/src/interface.jl:82 [inlined]
 [10] #84
    @ ~/.julia/packages/Bijectors/I8eRc/src/transformed_distribution.jl:221 [inlined]
 [11] #665
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/array.jl:188 [inlined]
 [12] iterate
    @ ./generator.jl:47 [inlined]
 [13] _collect(c::Base.OneTo{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:854
 [14] collect_similar
    @ ./array.jl:763 [inlined]
 [15] map
    @ ./abstractarray.jl:3285 [inlined]
 [16] ∇map(cx::Zygote.Context{…}, f::Bijectors.var"#84#85"{…}, args::Base.OneTo{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/array.jl:188
 [17] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/array.jl:214 [inlined]
 [18] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [19] rand
    @ ~/.julia/packages/Bijectors/I8eRc/src/transformed_distribution.jl:218 [inlined]
 [20] _pullback(::Zygote.Context{…}, ::typeof(rand), ::Random.TaskLocalRNG, ::MultivariateTransformed{…}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [21] estimate_scoregradelbo_ad_forward
    @ ~/code/AdvancedVI.jl/src/objectives/elbo/scoregradelbo.jl:102 [inlined]
 [22] _pullback(::Zygote.Context{…}, ::typeof(AdvancedVI.estimate_scoregradelbo_ad_forward), ::Vector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [23] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90
 [24] pullback(::Function, ::Vector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88
 [25] withgradient(::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:205
 [26] value_and_gradient
    @ ~/.julia/packages/DifferentiationInterface/qcr6f/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:92 [inlined]
 [27] value_and_gradient!(f::Function, grad::Vector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::Vector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/qcr6f/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:105
 [28] value_and_gradient!
    @ ~/.julia/packages/DifferentiationInterface/qcr6f/src/fallbacks/no_prep.jl:67 [inlined]
 [29] value_and_gradient!(ad::AutoZygote, f::Function, x::Vector{…}, aux::@NamedTuple{…}, out::DiffResults.MutableDiffResult{…})
    @ AdvancedVI ~/code/AdvancedVI.jl/src/AdvancedVI.jl:46
 [30] estimate_gradient!(rng::Random.TaskLocalRNG, obj::ScoreGradELBO{…}, adtype::AutoZygote, out::DiffResults.MutableDiffResult{…}, prob::DynamicPPL.LogDensityFunction{…}, params::Vector{…}, restructure::Optimisers.Restructure{…}, state::Nothing)
    @ AdvancedVI ~/code/AdvancedVI.jl/src/objectives/elbo/scoregradelbo.jl:134
 [31] optimize(::Random.TaskLocalRNG, ::DynamicPPL.LogDensityFunction{…}, ::ScoreGradELBO{…}, ::MultivariateTransformed{…}, ::Int64; adtype::AutoZygote, optimizer::Adam, averager::NoAveraging, show_progress::Bool, state_init::@NamedTuple{}, callback::Nothing, prog::ProgressMeter.Progress)
    @ AdvancedVI ~/code/AdvancedVI.jl/src/optimize.jl:76
 [32] optimize
    @ ~/code/AdvancedVI.jl/src/optimize.jl:50 [inlined]
 [33] #optimize#26
    @ ~/code/AdvancedVI.jl/src/optimize.jl:127 [inlined]
 [34] top-level scope
    @ ~/code/AdvancedVI.jl/score_bug.jl:97

@Red-Portal
Copy link
Member Author

Red-Portal commented Oct 14, 2024

Hi @arnauqb , sorry for the late reply. Seems like this happens only on Zygote right?

Edit: Aha! This is because Zygote attempts to differentiate through rand. This should be addressed by this PR since it doesn't call rand in the ad path anymore.

#if @isdefined(Tapir)
# AD_scoregradelbo_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false)
#end
if @isdefined(Mooncake)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Red-Portal, these conditional loading of Mooncake is no longer needed now that we have dropped support for Julia < 1.10

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yebai ! Yes that part of the code is outdated. But it will be fixed by merging #129.

@yebai
Copy link
Member

yebai commented Oct 23, 2024

Everything else is passing other than Enzyme tests.

@arnauqb
Copy link
Contributor

arnauqb commented Oct 28, 2024

@Red-Portal I think this line

https://github.com/TuringLang/AdvancedVI.jl/blob/45b37c111ce05f2fb01f195b8f58237bdb3a66c5/src/objectives/elbo/scoregradelbo.jl#L88C5-L88C85

should be moved outside the gradient calculation (and passed through aux). Otherwise, the gradient estimator is differentiating through the "simulator". If the simulator is not differentiable then this is just the standard score, but if it is, I think it would also be using pathwise gradients.

@Red-Portal
Copy link
Member Author

Red-Portal commented Oct 28, 2024

@arnauqb

Hi! Hmm it wouldn't be using the path gradient here since samples is calculated outside. My only concern here would be the case where Zygote tries to differentiate through logpi anyways and throws an error if it can't just like when it did with rand(q). Did you experience this happen?

@arnauqb
Copy link
Contributor

arnauqb commented Oct 28, 2024

Something similar. I have a custom Distribution for which I implement a logpdf that is calculated by sampling from the distribution and measuring a loss with respect to the value passed (think it as doing generalized variational inference, if you are familiar with that). I have a custom Zygote rule for differentiating the rand of the distribution and so when calling logdensity it goes through there.

Now you are right that samples don't carry a gradient so ultimately this won't affect the results, I think, but may be a performance boost to take it out.

@Red-Portal
Copy link
Member Author

I see. Uhghh it is annoying that Zygote can't be smarter. Anyways thanks for mentioning this, I'll try to improve it.

@Red-Portal
Copy link
Member Author

Red-Portal commented Nov 5, 2024

I modified the implementation so that the ScoreGradELBO objective targets the VarGrad objective1. It seems to do much better already on the README example:

vargrad

This should also vastly simplify the interface since the entropy estimator is part of the VarGrad objective.

Footnotes

  1. Richter, L., Boustati, A., Nüsken, N., Ruiz, F., & Akyildiz, O. D. (2020). VarGrad: a low-variance gradient estimator for variational inference. Advances in Neural Information Processing Systems, 33, 13481-13492.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants