Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Tapir support #71

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.7'
#- '1.7'
- '1.10'
Copy link
Member

@yebai yebai Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt, can you adapt the Bijectors setup so we don't need to comment out 1.7?

os:
- ubuntu-latest
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIBijectorsExt = "Bijectors"
AdvancedVIEnzymeExt = "Enzyme"
AdvancedVIForwardDiffExt = "ForwardDiff"
AdvancedVIReverseDiffExt = "ReverseDiff"
AdvancedVITapirExt = "Tapir"
AdvancedVIZygoteExt = "Zygote"

[compat]
Expand All @@ -55,6 +57,7 @@ Requires = "1.0"
ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
Tapir = "0.2.34"
Zygote = "0.6.63"
julia = "1.7"

Expand All @@ -64,6 +67,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
6 changes: 4 additions & 2 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
::Any,
f,
x::AbstractVector{<:Real},
x::AbstractVector,
out::DiffResults.MutableDiffResult,
)
chunk_size = getchunksize(ad)
Expand All @@ -31,12 +32,13 @@ end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
st_ad,
f,
x::AbstractVector,
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out)
end

end
6 changes: 4 additions & 2 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ end

# ReverseDiff without compiled tape
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
::ADTypes.AutoReverseDiff,
::Any,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
Expand All @@ -25,12 +26,13 @@ end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
st_ad,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out)
end

end
47 changes: 47 additions & 0 deletions ext/AdvancedVITapirExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

module AdvancedVITapirExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using Tapir
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..Tapir
end

AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x) = Tapir.build_rrule(f, x)

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoTapir,
st_ad,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
rule = st_ad
y, g = Tapir.value_and_gradient!!(rule, f, x)
DiffResults.value!(out, y)
DiffResults.gradient!(out, last(g))
yebai marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DiffResults.gradient!(out, last(g))
DiffResults.gradient!(out, g[2])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt, to clarify, we don't need this change. Is that correct?

return out
end

AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x, aux) = Tapir.build_rrule(f, x, aux)

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoTapir,
st_ad,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
rule = st_ad
y, g = Tapir.value_and_gradient!!(rule, f, x, aux)
DiffResults.value!(out, y)
DiffResults.gradient!(out, g[2])
return out
end

end
9 changes: 7 additions & 2 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ else
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
::ADTypes.AutoZygote,
::Any,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
y, back = Zygote.pullback(f, x)
∇x = back(one(y))
Expand All @@ -25,12 +29,13 @@ end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoZygote,
st_ad,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out)
end

end
34 changes: 27 additions & 7 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,37 @@ using StatsBase

# derivatives
"""
value_and_gradient!(ad, f, x, out)
value_and_gradient!(ad, f, x, aux, out)
value_and_gradient!(adtype, ad_st, f, x, out)
value_and_gradient!(adtype, ad_st, f, x, aux, out)

Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation (AD) backend `ad` and store the result in `out`.
`f` may receive auxiliary input as `f(x,aux)`.

# Arguments
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
- `adtype::ADTypes.AbstractADType`: AD backend.
- `ad_st`: State used by the AD backend. (This will often be pre-compiled tapes/caches.)
- `f`: Function subject to differentiation.
- `x`: The point to evaluate the gradient.
- `aux`: Auxiliary input passed to `f`.
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
"""
function value_and_gradient! end

"""
init_adbackend(adtype, f, x)
init_adbackend(adtype, f, x, aux)
Initialize the AD backend and setup necessary states.
# Arguments
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
- `f`: Function subject to differentiation.
- `x`: The point to evaluate the gradient.
- `aux`: Auxiliary input passed to `f`.
# Returns
- `ad_st`: State of the AD backend. (This will often be pre-compiled tapes/caches.)
"""
init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any) = nothing
init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any, ::Any) = nothing

"""
restructure_ad_forward(adtype, restructure, params)

Expand Down Expand Up @@ -95,18 +111,22 @@ If the estimator is stateful, it can implement `init` to initialize the state.
abstract type AbstractVariationalObjective end

"""
init(rng, obj, prob, params, restructure)
init(rng, obj, adtype, prob, params, restructure)

Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
Initialize a state of the variational objective `obj`.
This function needs to be implemented only if `obj` is stateful.
The state of the AD backend `adtype` shall also be initialized here.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `obj::AbstractVariationalObjective`: Variational objective.
- `adtype::ADTypes.ADType`:Automatic differentiation backend.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `params`: Initial variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
"""
init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any) = nothing
init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any, ::Any) =
nothing

"""
estimate_objective([rng,] obj, q, prob; kwargs...)
Expand Down
16 changes: 15 additions & 1 deletion src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ function estimate_repgradelbo_ad_forward(params′, aux)
return -elbo
end

function init(
rng::Random.AbstractRNG,
obj::RepGradELBO,
adtype::ADTypes.AbstractADType,
prob,
params,
restructure,
)
q_stop = restructure(params)
aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop)
ad_st = init_adbackend(adtype, estimate_repgradelbo_ad_forward, params, aux)
return (ad_st=ad_st,)
end

function estimate_gradient!(
rng::Random.AbstractRNG,
obj::RepGradELBO,
Expand All @@ -123,5 +137,5 @@ function estimate_gradient!(
value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out)
nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)
return out, nothing, stat
return out, state, stat
end
4 changes: 3 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ function optimize(
)
params, restructure = Optimisers.destructure(deepcopy(q_init))
opt_st = maybe_init_optimizer(state_init, optimizer, params)
obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure)
obj_st = maybe_init_objective(
state_init, rng, adtype, objective, problem, params, restructure
)
avg_st = maybe_init_averager(state_init, averager, params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
stats = NamedTuple[]
Expand Down
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ end
function maybe_init_objective(
state_init::NamedTuple,
rng::Random.AbstractRNG,
adtype::ADTypes.AbstractADType,
objective::AbstractVariationalObjective,
problem,
params,
Expand All @@ -32,7 +33,7 @@ function maybe_init_objective(
if haskey(state_init, :objective)
state_init.objective
else
init(rng, objective, problem, params, restructure)
init(rng, objective, adtype, problem, params, restructure)
end
end

Expand Down
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -42,7 +43,8 @@ ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StableRNGs = "1.0.0"
Statistics = "1"
Tapir = "0.2.23"
Test = "1"
Tracker = "0.2.20"
Zygote = "0.6.63"
julia = "1.6"
julia = "1.7"
1 change: 1 addition & 0 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ AD_distributionsad = if VERSION >= v"1.10"
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
:Tapir => AutoTapir(false),
)
else
Dict(
Expand Down
1 change: 1 addition & 0 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ AD_locationscale = if VERSION >= v"1.10"
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
:Tapir => AutoTapir(false),
)
else
Dict(
Expand Down
1 change: 1 addition & 0 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ AD_locationscale_bijectors = if VERSION >= v"1.10"
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
:Tapir => AutoTapir(false),
)
else
Dict(
Expand Down
11 changes: 7 additions & 4 deletions test/interface/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
using Test

@testset "ad" begin
@testset "$(adname)" for (adname, adsymbol) in Dict(
@testset "$(adname)" for (adname, adtype) in Dict(
:ForwardDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
:Tapir => AutoTapir(; safe_mode=false),
#:Enzyme => AutoEnzyme()
)
D = 10
A = randn(D, D)
λ = randn(D)
grad_buf = DiffResults.GradientResult(λ)
f(λ′) = λ′' * A * λ′ / 2
AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf)

ad_st = AdvancedVI.init_adbackend(adtype, f, λ)
grad_buf = DiffResults.GradientResult(λ)
AdvancedVI.value_and_gradient!(adtype, ad_st, f, λ, grad_buf)
∇ = DiffResults.gradient(grad_buf)
f = DiffResults.value(grad_buf)
@test ∇ ≈ (A + A') * λ / 2
Expand Down
3 changes: 2 additions & 1 deletion test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ end
ADTypes.AutoReverseDiff(),
ADTypes.AutoZygote(),
ADTypes.AutoEnzyme(),
ADTypes.AutoTapir(false),
]
q_true = MeanFieldGaussian(
Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true)))
Expand All @@ -49,7 +50,7 @@ end

aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad)
AdvancedVI.value_and_gradient!(
ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out
adtype, ad_st, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out
)
grad = DiffResults.gradient(out)
@test norm(grad) ≈ 0 atol = 1e-5
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using DistributionsAD
using LogDensityProblems
using Optimisers
using ADTypes
using ForwardDiff, ReverseDiff, Zygote, Enzyme
using ForwardDiff, ReverseDiff, Zygote, Enzyme, Tapir

using AdvancedVI

Expand Down
Loading