-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic rewrite of the package 2023 edition #45
Conversation
This is to avoid having to reconstruct transformed distributions all the time. The direct use of bijectors also avoids going through lots of abstraction layers that could break. Instead, transformed distributions could be constructed only once when returing the VI result.
I'll have a look at the PR itself later, but for now:
Maybe we should make this into a discussion. I feel like there are several different approaches we can take here.
For this one in particular we have an implementation in DynamicPPL that can potentially moved to its own package if we really want to: https://github.com/TuringLang/DynamicPPL.jl/blob/b23acff013a9111c8ce2c89dbf5339e76234d120/src/utils.jl#L434-L473 But this has a couple of issues:
(1) can be addressed by instead taking a closure-approach a la Functors.jl: function flatten(d::MvNormal{<:AbstractVector,<:Diagonal})
dim = length(d)
function MvNormal_unflatten(x)
return MvNormal(d[1:dim], Diagonal(d[dim+1:end]))
end
return vcat(d.μ, diag(d.Σ)), MvNormal_unflatten
end For (2), we have a couple of immediate options: For (a) we'd have something like: abstract type WrapperDistribution{D<:Distribution{V,F}} <: Distribution{V,F} end
# HACK: Probably shouldn't do this.
inner_dist(x::WrapperDistribution) = x.inner
# TODO: Specialize further on `x` to avoid hitting default implementations?
Distributions.logpdf(d::WrapperDistribution, x) = logpdf(d.dist, x)
# Etc.
struct MeanParameterized{D} <: WrapperDistribution{D}
inner::D
end
function flatten(d::MeanParameterized{<:MvNormal})
μ = mean(d.inner)
function MeanParameterized_MvNormal_unflatten(x)
return MeanParameterized(MvNormal(x, d.inner.Σ))
end
return μ, MeanParameterized_MvNormal_unflatten
end Pros:
For (b) we'd have something like struct MeanOnly end
function flatten(::MeanOnly, d::MvNormal)
μ = mean(d.inner)
function MvNormal_meanonly_unflatten(x)
return MeanParameterized(MvNormal(x, d.inner.Σ))
end
return μ, MvNormal_meanonly_unflatten
end Pros:
|
Hi @torfjelde
Should we proceed here or create a separate issue? Whatever approach we take, I think the key would be to avoid inverting or even computing the covariance matrix, provided that we operate with a Cholesky factor. None of the steps of ADVI require any of these, except for the STL estimator, where we do need to invert the Cholesky factor. |
Created a discussion: https://github.com/TuringLang/AdvancedVI.jl/discussions/46 |
…into rewriting_advancedvi
@torfjelde Hi, I have significantly changed the sketch for the project structure.
Any comments on the new structure? Also, do you approve the use of |
@devmotion what are your current thoughts on |
I've now added the pre-packaged location-scale family. Overall, to the user, the basic interface looks like the following: μ_y, σ_y = 1.0, 1.0
μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0]
Turing.@model function normallognormal()
y ~ LogNormal(μ_y, σ_y)
z ~ MvNormal(μ_z, Σ_z)
end
model = normallognormal()
b = Bijectors.bijector(model)
b⁻¹ = inverse(b)
prob = DynamicPPL.LogDensityFunction(model)
d = LogDensityProblems.dimension(prob)
μ = randn(d)
L = Diagonal(ones(d))
q = AVI.MeanFieldGaussian(μ, L)
λ₀, restructure = Flux.destructure(q)
function rebuild(λ′)
restructure(λ′)
end
λ = AVI.optimize(
AVI.ADVI(prob, b⁻¹, 10),
rebuild,
10000,
λ₀;
optimizer = Flux.ADAM(1e-3),
adbackend = AutoForwardDiff()
)
q = restructure(λ)
μ = q.transform.outer.a
L = q.transform.inner.a
Σ = L*L'
μ_true = vcat(μ_y, μ_z)
Σ_diag_true = vcat(σ_y, diag(Σ_z))
@info("VI Estimation Error",
norm(μ - μ_true),
norm(diag(Σ) - Σ_diag_true),) Some additional notes to the comments above,
|
@Red-Portal is there anything in this PR not yet merged by #49 and #50? |
@yebai Yes, we have the documentation still left. I'm currently working on it. |
Hi, this is the initial pull request for the rewrite of
AdvancedVI
as a successor to #25The following panel will be updated in real-time, reflecting the discussions happening below.
Roadmap
LogDensityProblems
interface.Migrate toNot mature enough yet.AbstractDifferentiations
.ADTypes
interface.Functor.jl
for flattening/unflattening variational parameters.optimize
. (see Missing API method #32 )Reduce memory usage of full-rank parameterization(Seems like there's a unfavorable compute-memory trade-off. See this thread)Optimisers.jl
.Implement minibatch subsampling (probably require changes upstream, e.g.,(separate issue)DynamicPPL
, too)callback
option (Callback function during training #5)Add BBVI (score gradient)not urgentSupport GPU computation (although(separate issue)Bijectors
will be a bottleneck for this.Topics to Discuss
Should we useNot now.AbstractDifferentiation
?Optimisers
? (probably yes)Should we call restructure inside of optimize such that the flattening/unflattening is completely abstracted out to the user? Then, in the current state of things,Flux
will have to be added as a dependency, otherwise we'll have to roll our own implementation of destructure.destructure
is now part ofOptimisers
, which is much more lightweight.Should we keepPlanning to deprecate.TruncatedADAGrad
,DecayedADAGrad
? I think these are quite outdated and would advise people from using these. So how about deprecating these?Demo