diff --git a/Project.toml b/Project.toml index 9d30e740..04647727 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.3.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -39,7 +38,6 @@ ADTypes = "0.1, 0.2, 1" Accessors = "0.1" Bijectors = "0.13" ChainRulesCore = "1.16" -DiffResults = "1" Distributions = "0.25.87" DocStringExtensions = "0.8, 0.9" Enzyme = "0.12" diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl index 8333299f..02da1c28 100644 --- a/ext/AdvancedVIEnzymeExt.jl +++ b/ext/AdvancedVIEnzymeExt.jl @@ -4,23 +4,22 @@ module AdvancedVIEnzymeExt if isdefined(Base, :get_extension) using Enzyme using AdvancedVI - using AdvancedVI: ADTypes, DiffResults + using AdvancedVI: ADTypes else using ..Enzyme using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults + using ..AdvancedVI: ADTypes end # Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +function AdvancedVI.value_and_gradient( + ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T} ) where {T<:Real} y = f(θ) - DiffResults.value!(out, y) - ∇θ = DiffResults.gradient(out) + ∇θ = similar(θ) fill!(∇θ, zero(T)) Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) - return out + ∇θ, y end end diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index 5949bdf8..80be06cc 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -4,26 +4,26 @@ module AdvancedVIForwardDiffExt if isdefined(Base, :get_extension) using ForwardDiff using AdvancedVI - using AdvancedVI: ADTypes, DiffResults + using AdvancedVI: ADTypes else using ..ForwardDiff using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults + using ..AdvancedVI: ADTypes end getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} +function AdvancedVI.value_and_gradient( + ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{<:Real} +) chunk_size = getchunksize(ad) config = if isnothing(chunk_size) ForwardDiff.GradientConfig(f, θ) else ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) end - ForwardDiff.gradient!(out, f, θ, config) - return out + g = ForwardDiff.gradient(f, θ, config) + g, f(θ) end end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 520cd9ff..a899d36d 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -3,21 +3,21 @@ module AdvancedVIReverseDiffExt if isdefined(Base, :get_extension) using AdvancedVI - using AdvancedVI: ADTypes, DiffResults + using AdvancedVI: ADTypes using ReverseDiff else using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults + using ..AdvancedVI: ADTypes using ..ReverseDiff end # ReverseDiff without compiled tape -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +function AdvancedVI.value_and_gradient( + ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real} ) tp = ReverseDiff.GradientTape(f, θ) - ReverseDiff.gradient!(out, tp, θ) - return out + g = ReverseDiff.gradient!(tp, θ) + g, f(θ) end end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index 7b8f8817..c94c983f 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -3,22 +3,18 @@ module AdvancedVIZygoteExt if isdefined(Base, :get_extension) using AdvancedVI - using AdvancedVI: ADTypes, DiffResults + using AdvancedVI: ADTypes using Zygote else using ..AdvancedVI - using ..AdvancedVI: ADTypes, DiffResults + using ..AdvancedVI: ADTypes using ..Zygote end -function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult -) +function AdvancedVI.value_and_gradient(ad::ADTypes.AutoZygote, f, θ) y, back = Zygote.pullback(f, θ) ∇θ = back(one(y)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, only(∇θ)) - return out + only(∇θ), y end end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9fc986d3..4f804a33 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -16,7 +16,7 @@ using LinearAlgebra using LogDensityProblems -using ADTypes, DiffResults +using ADTypes using ChainRulesCore using FillArrays @@ -25,17 +25,43 @@ using StatsBase # derivatives """ - value_and_gradient!(ad, f, θ, out) + value_and_gradient(ad, f, θ) -Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`. +Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`. # Arguments - `ad::ADTypes.AbstractADType`: Automatic differentiation backend. - `f`: Function subject to differentiation. - `θ`: The point to evaluate the gradient. - `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. + +# Returns +- `grad`: Gradient of `f` evaluated on `θ` +- `fval`: Function value `f` evaluated on `θ` +""" +function value_and_gradient end + +maybe_destructure(::ADTypes.AutoZygote, q) = (q, identity) + +maybe_destructure(::ADTypes.AbstractADType, q) = Optimisers.destructure(q) + """ -function value_and_gradient! end + update_variational_params!(family_type, opt_st, params, re, grad) + +Update variational family according to +Essentially an indirection for `Optimisers.update!`. + +# Arguments +- `family_type::Type`: + +# Returns +- `opt_st`: Updated optimizer state. +- `params`: Updated params. +""" +function update_variational_params! end + +update_variational_params!(::Type, opt_st, params, re, grad) = + Optimisers.update!(opt_st, params, grad) # estimators """ @@ -51,7 +77,7 @@ If the estimator is stateful, it can implement `init` to initialize the state. abstract type AbstractVariationalObjective end """ - init(rng, obj, λ, restructure) + init(rng, obj, params, q_init; kwargs...) Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. This function needs to be implemented only if `obj` is stateful. @@ -59,15 +85,10 @@ This function needs to be implemented only if `obj` is stateful. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. -- `λ`: Initial variational parameters. -- `restructure`: Function that reconstructs the variational approximation from `λ`. +- `params`: Initial variational parameters. +- `q_init`: Initial variational distribution. """ -init( - ::Random.AbstractRNG, - ::AbstractVariationalObjective, - ::AbstractVector, - ::Any -) = nothing +init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any; kwargs...) = nothing """ estimate_objective([rng,] obj, q, prob; kwargs...) @@ -91,9 +112,8 @@ function estimate_objective end export estimate_objective - """ - estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state) + estimate_gradient(rng, obj, adtype, prob, params, restructure, obj_state; kwargs...) Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` @@ -101,18 +121,17 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. - `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. -- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. -- `λ`: Variational parameters to evaluate the gradient on. +- `params`: Variational parameters to evaluate the gradient on. - `restructure`: Function that reconstructs the variational approximation from `λ`. - `obj_state`: Previous state of the objective. # Returns -- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `grad`: Gradient estimate. - `obj_state`: The updated state of the objective. - `stat::NamedTuple`: Statistics and logs generated during estimation. """ -function estimate_gradient! end +function estimate_gradient end # ELBO-specific interfaces abstract type AbstractEntropyEstimator end diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 152cd15d..e60538a1 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -57,17 +57,17 @@ Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D) function StatsBase.entropy(q::MvLocationScale) @unpack location, scale, dist = q n_dims = length(location) - n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale)) + n_dims*convert(eltype(location), entropy(dist)) + first(logdet(scale)) end function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale)) end function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale)) end function Distributions.rand(q::MvLocationScale) diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 2d95d076..b1be3148 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -91,31 +91,27 @@ function estimate_objective( energy + entropy end -estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = +estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples, kwargs...) = estimate_objective(Random.default_rng(), obj, q, prob; n_samples) -function estimate_gradient!( +function estimate_gradient( rng ::Random.AbstractRNG, obj ::RepGradELBO, adtype::ADTypes.AbstractADType, - out ::DiffResults.MutableDiffResult, prob, - λ, + params, restructure, - state, + state; + kwargs... ) - q_stop = restructure(λ) - function f(λ′) - q = restructure(λ′) + q_stop = restructure(params) + function f(params′) + q = restructure(params′) samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) elbo = energy + entropy -elbo end - value_and_gradient!(adtype, f, λ, out) - - nelbo = DiffResults.value(out) - stat = (elbo=-nelbo,) - - out, nothing, stat + grad, nelbo = value_and_gradient(adtype, f, params) + grad, nothing, (elbo=-nelbo,) end diff --git a/src/optimize.jl b/src/optimize.jl index acb455d2..8d3bd2a7 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -1,19 +1,13 @@ """ - optimize(problem, objective, restructure, param_init, max_iter, objargs...; kwargs...) - optimize(problem, objective, variational_dist_init, max_iter, objargs...; kwargs...) + optimize(problem, objective, variational_dist_init, max_iter; kwargs...) Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients. -The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`. - # Arguments - `objective::AbstractVariationalObjective`: Variational Objective. -- `param_init`: Initial value of the variational parameters. -- `restruct`: Function that reconstructs the variational approximation from the flattened parameters. - `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`. - `max_iter::Int`: Maximum number of iterations. -- `objargs...`: Arguments to be passed to `objective`. # Keyword Arguments - `adtype::ADtypes.AbstractADType`: Automatic differentiation backend. @@ -24,8 +18,10 @@ The variational approximation can be constructed by passing the variational para - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) - `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) +Additional keyword arguments may apply depending on `objective`. + # Returns -- `params`: Variational parameters optimizing the variational objective. +- `variational_dist`: Variational distribution optimizing the variational objective. - `stats`: Statistics gathered during optimization. - `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. @@ -45,15 +41,20 @@ The arguments are as follows: This will be appended to the statistic of the current corresponding iteration. Otherwise, just return `nothing`. +!!! info + Some AD backends may only operator on "flattened" vectors. + In this case, `AdvancedVI` will leverage `Optimisers.destructure` to flatten the variational distribution. + (This is determined according to the value of `adtype`.) + For this to automatically work however, `variational_dist_init` must be marked as a functor through `Functors.@functor`. + Variational families provided by `AdvancedVI` will all be marked as functors already. + Otherwise, one can simply use an AD backend that supported structured gradients such as `Zygote`. """ function optimize( rng ::Random.AbstractRNG, problem, objective ::AbstractVariationalObjective, - restructure, - params_init, - max_iter ::Int, - objargs...; + q_init, + max_iter ::Int; adtype ::ADTypes.AbstractADType, optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), show_progress::Bool = true, @@ -65,29 +66,37 @@ function optimize( barlen = 31, showspeed = true, enabled = show_progress - ) + ), + kwargs... ) - params = copy(params_init) - opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective(state_init, rng, objective, params, restructure) - grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - stats = NamedTuple[] + q = deepcopy(q_init) + params, re = maybe_destructure(adtype, q) + opt_st = maybe_init_optimizer(state_init, optimizer, params) + obj_st = maybe_init_objective(state_init, rng, objective, params, q; kwargs...) + stats = NamedTuple[] for t = 1:max_iter stat = (iteration=t,) - grad_buf, obj_st, stat′ = estimate_gradient!( - rng, objective, adtype, grad_buf, problem, - params, restructure, obj_st, objargs... + grad, obj_st, stat′ = estimate_gradient( + rng, + objective, + adtype, + problem, + params, + re, + obj_st; + kwargs... ) stat = merge(stat, stat′) - grad = DiffResults.gradient(grad_buf) - opt_st, params = Optimisers.update!(opt_st, params, grad) + opt_st, params = update_variational_params!( + typeof(q), opt_st, params, re, grad + ) if !isnothing(callback) stat′ = callback( - ; stat, restructure, params=params, gradient=grad, + ; stat, restructure=re, params=params, gradient=grad, state=(optimizer=opt_st, objective=obj_st) ) stat = !isnothing(stat′) ? merge(stat′, stat) : stat @@ -98,47 +107,11 @@ function optimize( pm_next!(prog, stat) push!(stats, stat) end - state = (optimizer=opt_st, objective=obj_st) - stats = map(identity, stats) - params, stats, state -end - -function optimize( - problem, - objective ::AbstractVariationalObjective, - restructure, - params_init, - max_iter ::Int, - objargs...; - kwargs... -) - optimize( - Random.default_rng(), - problem, - objective, - restructure, - params_init, - max_iter, - objargs...; - kwargs... - ) -end - -function optimize(rng ::Random.AbstractRNG, - problem, - objective ::AbstractVariationalObjective, - variational_dist_init, - n_max_iter ::Int, - objargs...; - kwargs...) - λ, restructure = Optimisers.destructure(variational_dist_init) - λ, logstats, state = optimize( - rng, problem, objective, restructure, λ, n_max_iter, objargs...; kwargs... - ) - restructure(λ), logstats, state + state = (optimizer=opt_st, objective=obj_st) + stats = map(identity, stats) + re(params), stats, state end - function optimize( problem, objective ::AbstractVariationalObjective, diff --git a/src/utils.jl b/src/utils.jl index 92b5686f..a5d021aa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,19 +6,24 @@ end function maybe_init_optimizer( state_init::NamedTuple, optimizer ::Optimisers.AbstractRule, - λ ::AbstractVector + params, ) - haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, λ) + haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, params) end function maybe_init_objective( state_init::NamedTuple, rng ::Random.AbstractRNG, objective ::AbstractVariationalObjective, - λ ::AbstractVector, - restructure + params, + q; + kwargs... ) - haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure) + if haskey(state_init, :objective) + state_init.objective + else + init(rng, objective, params, q; kwargs...) + end end eachsample(samples::AbstractMatrix) = eachcol(samples) diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index f4b94235..9b0d5248 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -29,8 +29,8 @@ using Test T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) - μ0 = Zeros(realtype, n_dims) - L0 = Diagonal(Ones(realtype, n_dims)) + μ0 = zeros(realtype, n_dims) + L0 = Diagonal(ones(realtype, n_dims)) q0 = TuringDiagMvNormal(μ0, diag(L0)) @testset "convergence" begin diff --git a/test/interface/ad.jl b/test/interface/ad.jl index be4ca34e..a7970b43 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -11,11 +11,8 @@ using Test D = 10 A = randn(D, D) λ = randn(D) - grad_buf = DiffResults.GradientResult(λ) f(λ′) = λ′'*A*λ′ / 2 - AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) - ∇ = DiffResults.gradient(grad_buf) - f = DiffResults.value(grad_buf) + ∇, f = AdvancedVI.value_and_gradient(adsymbol, f, λ) @test ∇ ≈ (A + A')*λ/2 @test f ≈ λ'*A*λ / 2 end diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index 9666893b..2606851c 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -33,28 +33,6 @@ using Test show_progress = false, adtype, ) - - λ₀, re = Optimisers.destructure(q0) - optimize( - model, obj, re, λ₀, T; - optimizer, - show_progress = false, - adtype, - ) - end - - @testset "restructure" begin - λ₀, re = Optimisers.destructure(q0) - - rng = StableRNG(seed) - λ, stats, _ = optimize( - rng, model, obj, re, λ₀, T; - optimizer, - show_progress = false, - adtype, - ) - @test λ == λ_ref - @test stats == stats_ref end @testset "callback" begin