Skip to content

Commit

Permalink
Add "low-rank" variational families (#76)
Browse files Browse the repository at this point in the history
* add feature complete `MvLocationScaleLowRank` with tests
* fix bugs and improve comments in `MvLocationScale` and lowrank
* promote families.md into a higher category
* add test for `MVLocationScale` with non-Gaussian
* tighten compat bound for `Distributions`
* fix base distribution standardization bug in `LocationScale` and `LocationScaleLowRank`
* fix `LocationScale` interfaces to only allow univariate base dist
* fix test comparison operator for families
* fix scale lower bound to `1e-4`
---------

Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2024
1 parent 9e4c1cc commit 57c9e58
Show file tree
Hide file tree
Showing 12 changed files with 745 additions and 170 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ concurrency:
permissions:
contents: write
pull-requests: write
issues: write

jobs:
benchmark:
Expand Down Expand Up @@ -47,10 +48,10 @@ jobs:
name: Benchmark Results
tool: 'julia'
output-file-path: bench/benchmark_results.json
summary-always: true
summary-always: ${{ !github.event.pull_request.head.repo.fork }} # Disable summary for PRs from forks
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-always: true
alert-threshold: "200%"
fail-on-alert: true
benchmark-data-dir-path: benchmarks
comment-always: ${{ !github.event.pull_request.head.repo.fork }} # Disable comments for PRs from forks
auto-push: ${{ !github.event.pull_request.head.repo.fork }} # Disable push for PRs from forks
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Accessors = "0.1"
Bijectors = "0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
Distributions = "0.25.87"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.12.32"
FillArrays = "1.3"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ makedocs(;
"ELBO Maximization" => [
"Overview" => "elbo/overview.md",
"Reparameterization Gradient Estimator" => "elbo/repgradelbo.md",
"Location-Scale Variational Family" => "locscale.md",
],
"Variational Families" => "families.md",
"Optimization" => "optimization.md",
],
)
Expand Down
267 changes: 267 additions & 0 deletions docs/src/families.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# [Reparameterizable Variational Families](@id families)

The [RepGradELBO](@ref repgradelbo) objective assumes that the members of the variational family have a differentiable sampling path.
We provide multiple pre-packaged variational families that can be readily used.

## [The `LocationScale` Family](@id locscale)

The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as

```math
z \sim q_{\lambda} \qquad\Leftrightarrow\qquad
z \stackrel{d}{=} C u + m;\quad u \sim \varphi
```

where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*.
``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``.
The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``.

The probability density is given by

```math
q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)),
```

the covariance is given as

```math
\mathrm{Var}\left(q_{\lambda}\right) = C \mathrm{Var}(q_{\lambda}) C^{\top}
```

and the entropy is given as

```math
\mathbb{H}(q_{\lambda}) = \mathbb{H}(\varphi) + \log |C|,
```

where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution.
Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``.
The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution.

### API

!!! note

For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned.
Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities.

```@docs
MvLocationScale
```

The following are specialized constructors for convenience:

```@docs
FullRankGaussian
MeanFieldGaussian
```

### Gaussian Variational Families

```julia
using AdvancedVI, LinearAlgebra, Distributions;
μ = zeros(2);

L = LowerTriangular(diagm(ones(2)));
q = FullRankGaussian(μ, L)

L = Diagonal(ones(2));
q = MeanFieldGaussian(μ, L)
```

### Student-$$t$$ Variational Families

```julia
using AdvancedVI, LinearAlgebra, Distributions;
μ = zeros(2);
ν = 3;

# Full-Rank
L = LowerTriangular(diagm(ones(2)));
q = MvLocationScale(μ, L, TDist(ν))

# Mean-Field
L = Diagonal(ones(2));
q = MvLocationScale(μ, L, TDist(ν))
```

### Laplace Variational families

```julia
using AdvancedVI, LinearAlgebra, Distributions;
μ = zeros(2);

# Full-Rank
L = LowerTriangular(diagm(ones(2)));
q = MvLocationScale(μ, L, Laplace())

# Mean-Field
L = Diagonal(ones(2));
q = MvLocationScale(μ, L, Laplace())
```

## The `LocationScaleLowRank` Family

In practice, `LocationScale` families with full-rank scale matrices are known to converge slowly as they require a small SGD stepsize.
Low-rank variational families can be an effective alternative[^ONS2018].
`LocationScaleLowRank` generally represent any ``d``-dimensional distribution which its sampling path can be represented as

```math
z \sim q_{\lambda} \qquad\Leftrightarrow\qquad
z \stackrel{d}{=} D u_1 + U u_2 + m;\quad u_1, u_2 \sim \varphi
```

where ``D \in \mathbb{R}^{d \times d}`` is a diagonal matrix, ``U \in \mathbb{R}^{d \times r}`` is a dense low-rank matrix for the rank ``r > 0``, ``m \in \mathbb{R}^d`` is the location, and ``\varphi`` is the *base distribution*.
``m``, ``D``, and ``U`` form the variational parameters ``\lambda = (m, D, U)``.

The covariance of this distribution is given as

```math
\mathrm{Var}\left(q_{\lambda}\right) = D \mathrm{Var}(\varphi) D + U \mathrm{Var}(\varphi) U^{\top}
```

and the entropy is given by the matrix determinant lemma as

```math
\mathbb{H}(q_{\lambda})
= \mathbb{H}(\varphi) + \log |\Sigma|
= \mathbb{H}(\varphi) + 2 \log |D| + \log |I + U^{\top} D^{-2} U|,
```

where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution.

```@setup lowrank
using ADTypes
using AdvancedVI
using Distributions
using LinearAlgebra
using LogDensityProblems
using Optimisers
using Plots
using ReverseDiff
struct Target{D}
dist::D
end
function LogDensityProblems.logdensity(model::Target, θ)
logpdf(model.dist, θ)
end
function LogDensityProblems.dimension(model::Target)
return length(model.dist)
end
function LogDensityProblems.capabilities(::Type{<:Target})
return LogDensityProblems.LogDensityOrder{0}()
end
n_dims = 30
U_true = randn(n_dims, 3)
D_true = Diagonal(log.(1 .+ exp.(randn(n_dims))))
Σ_true = D_true + U_true*U_true'
Σsqrt_true = sqrt(Σ_true)
μ_true = randn(n_dims)
model = Target(MvNormal(μ_true, Σ_true));
d = LogDensityProblems.dimension(model);
μ = zeros(d);
L = Diagonal(ones(d));
q0_mf = MeanFieldGaussian(μ, L)
L = LowerTriangular(diagm(ones(d)));
q0_fr = FullRankGaussian(μ, L)
D = ones(n_dims)
U = zeros(n_dims, 3)
q0_lr = LowRankGaussian(μ, D, U)
obj = RepGradELBO(1);
max_iter = 10^4
function callback(; params, averaged_params, restructure, stat, kwargs...)
q = restructure(averaged_params)
μ, Σ = mean(q), cov(q)
(dist2 = sum(abs2, μ - μ_true) + tr(Σ + Σ_true - 2*sqrt(Σsqrt_true*Σ*Σsqrt_true)),)
end
_, _, stats_fr, _ = AdvancedVI.optimize(
model,
obj,
q0_fr,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);
_, _, stats_mf, _ = AdvancedVI.optimize(
model,
obj,
q0_mf,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);
_, _, stats_lr, _ = AdvancedVI.optimize(
model,
obj,
q0_lr,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);
t = [stat.iteration for stat in stats_fr]
dist_fr = [sqrt(stat.dist2) for stat in stats_fr]
dist_mf = [sqrt(stat.dist2) for stat in stats_mf]
dist_lr = [sqrt(stat.dist2) for stat in stats_lr]
plot( t, dist_mf , label="Mean-Field Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
plot!(t, dist_fr, label="Full-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
plot!(t, dist_lr, label="Low-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
savefig("lowrank_family_wasserstein.svg")
nothing
```

Consider a 30-dimensional Gaussian with a diagonal plus low-rank covariance structure, where the true rank is 3.
Then, we can compare the convergence speed of `LowRankGaussian` versus `FullRankGaussian`:

![](lowrank_family_wasserstein.svg)

As we can see, `LowRankGaussian` converges faster than `FullRankGaussian`.
While `FullRankGaussian` can converge to the true solution since it is a more expressive variational family, `LowRankGaussian` gets there faster.

!!! info

`MvLocationScaleLowRank` tend to work better with the `Optimisers.Adam` optimizer due to non-smoothness.
Other optimisers may experience divergences.

### API

```@docs
MvLocationScaleLowRank
```

The `logpdf` of `MvLocationScaleLowRank` has an optional argument `non_differentiable::Bool` (default: `false`).
If set as `true`, a more efficient ``O\left(r d^2\right)`` implementation is used to evaluate the density.
This, however, is not differentiable under most AD frameworks due to the use of Cholesky `lowrankupdate`.
The default value is `false`, which uses a ``O\left(d^3\right)`` implementation, is differentiable and therefore compatible with the `StickingTheLandingEntropy` estimator.

The following is a specialized constructor for convenience:

```@docs
LowRankGaussian
```

[^ONS2018]: Ong, V. M. H., Nott, D. J., & Smith, M. S. (2018). Gaussian variational approximation with a factor covariance structure. Journal of Computational and Graphical Statistics, 27(3), 465-478.
80 changes: 0 additions & 80 deletions docs/src/locscale.md

This file was deleted.

4 changes: 4 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ export MvLocationScale, MeanFieldGaussian, FullRankGaussian

include("families/location_scale.jl")

export MvLocationScaleLowRank, LowRankGaussian

include("families/location_scale_low_rank.jl")

# Optimization Rules

include("optimization/rules.jl")
Expand Down
Loading

0 comments on commit 57c9e58

Please sign in to comment.