Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dependency DistributionsAD #30

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ version = "0.2.3"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"
Expand All @@ -22,7 +22,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ChainRulesCore = "0.9.44"
Distributions = "0.23, 0.24, 0.25"
DistributionsAD = "0.6"
ForwardDiff = "0.10"
IterativeSolvers = "0.8, 0.9"
LinearMaps = "3"
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

*This is an implementation of the [Metric Gaussian Variational Inference](https://arxiv.org/abs/1901.11033) (MGVI) algorithm in julia*


MGVI is an iterative method that performs a series of Gaussian approximations to the posterior. It alternates between approximating the covariance with the inverse Fisher information metric evaluated at an intermediate mean estimate and optimizing the KL-divergence for the given covariance with respect to the mean. This procedure is iterated until the uncertainty estimate is self-consistent with the mean parameter. We achieve linear scaling by avoiding to store the covariance explicitly at any time. Instead we draw samples from the approximating distribution relying on an implicit representation and numerical schemes to approximately solve linear equations. Those samples are used to approximate the KL-divergence and its gradient. The usage of natural gradient descent allows for rapid convergence. Formulating the Bayesian model in standardized coordinates makes MGVI applicable to any inference problem with continuous parameters.

Depending on the distributions used in your application, you may need to use the package [DistributionsAD](https://github.com/TuringLang/DistributionsAD.jl).


## Documentation
* [Documentation for stable version](https://bat.github.io/MGVI.jl/stable)
* [Documentation for development version](https://bat.github.io/MGVI.jl/dev)
Expand Down
2 changes: 2 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

MGVI is an iterative method that performs a series of Gaussian approximations to the posterior. We alternate between approximating the covariance with the inverse Fisher information metric evaluated at an intermediate mean estimate and optimizing the KL-divergence for the given covariance with respect to the mean. This procedure is iterated until the uncertainty estimate is self-consistent with the mean parameter. We achieve linear scaling by avoiding to store the covariance explicitly at any time. Instead we draw samples from the approximating distribution relying on an implicit representation and numerical schemes to approximately solve linear equations. Those samples are used to approximate the KL-divergence and its gradient. The usage of natural gradient descent allows for rapid convergence. Formulating the Bayesian model in standardized coordinates makes MGVI applicable to any inference problem with continuous parameters.

Depending on the distributions used in your application, you may need to use the package [DistributionsAD](https://github.com/TuringLang/DistributionsAD.jl).


## Citing MGVI.jl

Expand Down
8 changes: 7 additions & 1 deletion src/MGVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,29 @@ using Random
using SparseArrays
using Base.Iterators
using Distributions
using DistributionsAD
import ForwardDiff
using LinearMaps
using IterativeSolvers
using Optim
using PositiveFactorizations
using Requires
import SparseArrays: blockdiag
using SparseArrays
using StaticArrays
using ValueShapes
import Zygote

import Requires

include("custom_linear_maps.jl")
include("shapes.jl")
include("jacobian_maps.jl")
include("information.jl")
include("residual_samplers.jl")
include("mgvi_impl.jl")

function __init__()
@require DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" include("distributionsad_support.jl")
end

end # module
8 changes: 8 additions & 0 deletions src/distributionsad_support.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# This file is a part of MGVI.jl, licensed under the MIT License (MIT).


function unshaped_params(d::DistributionsAD.TuringDenseMvNormal)
μ = d.m
σ = convert(AbstractMatrix, d.C)
vcat(μ, _uppertriang_to_vec(σ))
end
6 changes: 0 additions & 6 deletions src/shapes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,3 @@ function unshaped_params(d::MvNormal)
μ, σ = params(d)
vcat(μ, _uppertriang_to_vec(σ))
end

function unshaped_params(d::TuringDenseMvNormal)
μ = d.m
σ = d.C.L*d.C.U
vcat(μ, _uppertriang_to_vec(σ))
end
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Expand Down
1 change: 0 additions & 1 deletion test/information/information_utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file is a part of MGVI.jl, licensed under the MIT License (MIT).

using Distributions
using DistributionsAD
using LinearAlgebra
import Zygote

Expand Down