From 03563eab6792647f17d36d8961b1e7f4463f4212 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 3 Aug 2024 00:33:33 -0400 Subject: [PATCH 01/37] rename location scale source file --- src/families/{location_scale.jl => locscale.jl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/families/{location_scale.jl => locscale.jl} (100%) diff --git a/src/families/location_scale.jl b/src/families/locscale.jl similarity index 100% rename from src/families/location_scale.jl rename to src/families/locscale.jl From 5ab7286531ac8d551a8f8b96b9bd3b918c140e97 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 3 Aug 2024 00:49:18 -0400 Subject: [PATCH 02/37] revert renaming of location_scale file --- src/AdvancedVI.jl | 1 + src/families/{locscale.jl => location_scale.jl} | 0 2 files changed, 1 insertion(+) rename src/families/{locscale.jl => location_scale.jl} (100%) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 7a09030b..c8811163 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -194,6 +194,7 @@ export FullRankGaussian include("families/location_scale.jl") +include("families/location_lowrank_scale.jl") # Optimization Routine diff --git a/src/families/locscale.jl b/src/families/location_scale.jl similarity index 100% rename from src/families/locscale.jl rename to src/families/location_scale.jl From 3e0bf3d6f7c2f9ef50a0df82a27bf13a73274cbc Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 3 Aug 2024 00:49:31 -0400 Subject: [PATCH 03/37] add location-low-rank-scale family (except `entropy` and `logpdf`) --- src/families/location_lowrank_scale.jl | 124 +++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 src/families/location_lowrank_scale.jl diff --git a/src/families/location_lowrank_scale.jl b/src/families/location_lowrank_scale.jl new file mode 100644 index 00000000..fc0f19d5 --- /dev/null +++ b/src/families/location_lowrank_scale.jl @@ -0,0 +1,124 @@ + +""" + MvLocationLowRankScale(location, scale_diag, scale_factors, dist) <: ContinuousMultivariateDistribution + +Variational family with a covariance in the form of a diagonal matrix plus a squared low-rank matrix. + +It generally represents any distribution for which the sampling path can be +represented as follows: +```julia + d = length(location) + r = size(scale_factors, 2) + u_d = rand(dist, d) + u_f = rand(dist, r) + z = scale_diag.*u_d + scale_factors*u_f + location +``` +""" +struct MvLowRankLocationScale{ + L, + SD <: AbstractVector, + SF <: AbstractMatrix, + D <: ContinuousDistribution, + E <: Real +} <: ContinuousMultivariateDistribution + location ::L + scale_diag ::SD + scale_factors::SF + dist ::D + scale_eps ::E +end + +function MvLowRankLocationScale( + location ::AbstractVector{T}, + scale_diag ::AbstractVector{T}, + scale_factors::AbstractMatrix{T}, + dist ::ContinuousDistribution; + scale_eps ::T = sqrt(eps(T)) +) where {T <: Real} + MvLowRankLocationScale(location, scale_diag, scale_factors, dist, scale_eps) +end + +Functors.@functor MvLowRankLocationScale (location, scale_diag, scale_factors) + +Base.length(q::MvLowRankLocationScale) = length(q.location) + +Base.size(q::MvLowRankLocationScale) = size(q.location) + +Base.eltype(::Type{<:MvLowRankLocationScale{S, D, L}}) where {S, D, L} = eltype(D) + +function StatsBase.entropy(q::MvLowRankLocationScale) + #@unpack location, scale, dist = q + #n_dims = length(location) + # `convert` is necessary because `entropy` is not type stable upstream + #n_dims*convert(eltype(location), entropy(dist)) + logdet(scale) +end + +function Distributions.logpdf(q::MvLowRankLocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + #sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) +end + +function Distributions.rand(q::MvLowRankLocationScale) + @unpack location, scale_diag, scale_factors, dist = q + n_dims = length(location) + n_factors = size(scale_factors, 2) + u_diag = rand(dist, n_dims) + u_fact = rand(dist, n_factors) + scale_diag.*u_diag + scale_factors*u_fact + location +end + +function Distributions.rand( + rng::AbstractRNG, q::MvLowRankLocationScale{S, D, L}, num_samples::Int +) where {S, D, L} + @unpack location, scale_diag, scale_factors, dist = q + n_dims = length(location) + n_factors = size(scale_factors, 2) + u_diag = rand(rng, dist, n_dims, num_samples) + u_fact = rand(rng, dist, n_factors, num_samples) + scale_diag.*u_diag + scale_factors*u_fact .+ location +end + +function Distributions._rand!( + rng::AbstractRNG, + q ::MvLowRankLocationScale, + x ::AbstractVecOrMat{<:Real} +) + @unpack location, scale_diag, scale_factors, dist = q + + n_factors = size(scale_factors, 2) + + rand!(rng, dist, x) + x[:] = scale_diag.*x + + u_fact = rand(dist, n_factors, size(x,2)) + x .+= scale_factors*u_fact + + return x .+= location +end + +Distributions.mean(q::MvLowRankLocationScale) = q.location + +function Distributions.var(q::MvLowRankLocationScale) + @unpack scale_diag, scale_factors = q + Diagonal(scale_diag + diag(scale_factors*scale_factors')) +end + +function Distributions.cov(q::MvLowRankLocationScale) + @unpack scale_diag, scale_factors = q + Diagonal(scale_diag) + scale_factors*scale_factors' +end + +function update_variational_params!( + ::Type{<:MvLowRankLocationScale}, opt_st, params, restructure, grad +) + opt_st, params = Optimisers.update!(opt_st, params, grad) + q = restructure(params) + ϵ = q.scale_eps + + # Project the scale matrix to the set of positive definite triangular matrices + @. q.scale_diag = max(q.scale_diag, ϵ) + + params, _ = Optimisers.destructure(q) + + opt_st, params +end From 0bd6e5cad103e28fe14039df4fc7fb47dcd7bfd4 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 5 Aug 2024 00:51:02 -0400 Subject: [PATCH 04/37] add feature complete `MvLocationScaleLowRank` with tests --- src/AdvancedVI.jl | 5 +- ...nk_scale.jl => location_scale_low_rank.jl} | 62 ++++--- test/interface/location_scale.jl | 166 ------------------ test/runtests.jl | 5 +- 4 files changed, 45 insertions(+), 193 deletions(-) rename src/families/{location_lowrank_scale.jl => location_scale_low_rank.jl} (57%) delete mode 100644 test/interface/location_scale.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index c8811163..c96032e6 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -191,10 +191,11 @@ include("objectives/elbo/repgradelbo.jl") export MvLocationScale, MeanFieldGaussian, - FullRankGaussian + FullRankGaussian, + MvLocationScaleLowRank include("families/location_scale.jl") -include("families/location_lowrank_scale.jl") +include("families/location_scale_low_rank.jl") # Optimization Routine diff --git a/src/families/location_lowrank_scale.jl b/src/families/location_scale_low_rank.jl similarity index 57% rename from src/families/location_lowrank_scale.jl rename to src/families/location_scale_low_rank.jl index fc0f19d5..a8cb3df4 100644 --- a/src/families/location_lowrank_scale.jl +++ b/src/families/location_scale_low_rank.jl @@ -14,7 +14,7 @@ represented as follows: z = scale_diag.*u_d + scale_factors*u_f + location ``` """ -struct MvLowRankLocationScale{ +struct MvLocationScaleLowRank{ L, SD <: AbstractVector, SF <: AbstractMatrix, @@ -28,37 +28,51 @@ struct MvLowRankLocationScale{ scale_eps ::E end -function MvLowRankLocationScale( +function MvLocationScaleLowRank( location ::AbstractVector{T}, scale_diag ::AbstractVector{T}, scale_factors::AbstractMatrix{T}, dist ::ContinuousDistribution; scale_eps ::T = sqrt(eps(T)) ) where {T <: Real} - MvLowRankLocationScale(location, scale_diag, scale_factors, dist, scale_eps) + @assert size(scale_factors,1) == length(scale_diag) + MvLocationScaleLowRank(location, scale_diag, scale_factors, dist, scale_eps) end -Functors.@functor MvLowRankLocationScale (location, scale_diag, scale_factors) +Functors.@functor MvLocationScaleLowRank (location, scale_diag, scale_factors) -Base.length(q::MvLowRankLocationScale) = length(q.location) +Base.length(q::MvLocationScaleLowRank) = length(q.location) -Base.size(q::MvLowRankLocationScale) = size(q.location) +Base.size(q::MvLocationScaleLowRank) = size(q.location) -Base.eltype(::Type{<:MvLowRankLocationScale{S, D, L}}) where {S, D, L} = eltype(D) +Base.eltype(::Type{<:MvLocationScaleLowRank{S, D, L}}) where {S, D, L} = eltype(D) -function StatsBase.entropy(q::MvLowRankLocationScale) - #@unpack location, scale, dist = q - #n_dims = length(location) - # `convert` is necessary because `entropy` is not type stable upstream - #n_dims*convert(eltype(location), entropy(dist)) + logdet(scale) +function StatsBase.entropy(q::MvLocationScaleLowRank) + @unpack location, scale_diag, scale_factors, dist = q + n_dims = length(location) + UtDinvU = Hermitian(scale_factors'*(scale_factors./scale_diag)) + logdetΣ = (sum(log.(scale_diag)) + logdet(I + UtDinvU))/2 + n_dims*convert(eltype(location), entropy(dist)) + logdetΣ end -function Distributions.logpdf(q::MvLowRankLocationScale, z::AbstractVector{<:Real}) - @unpack location, scale, dist = q - #sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale) +function Distributions.logpdf(q::MvLocationScaleLowRank, z::AbstractVector{<:Real}) + @unpack location, scale_diag, scale_factors, dist = q + # + ## More efficient O(kd^2) but non-differentiable version: + # + # Σchol = Cholesky(LowerTriangular(diagm(sqrt.(scale_diag)))) + # n_factors = size(scale_factors, 2) + # for k in 1:n_factors + # factor = scale_factors[:,k] + # lowrankupdate!(Σchol, factor) + # end + + Σ = Diagonal(scale_diag) + scale_factors*scale_factors' + Σchol = cholesky(Σ) + sum(Base.Fix1(logpdf, dist), Σchol.L \ (z - location)) - logdet(Σchol.L) end -function Distributions.rand(q::MvLowRankLocationScale) +function Distributions.rand(q::MvLocationScaleLowRank) @unpack location, scale_diag, scale_factors, dist = q n_dims = length(location) n_factors = size(scale_factors, 2) @@ -68,7 +82,7 @@ function Distributions.rand(q::MvLowRankLocationScale) end function Distributions.rand( - rng::AbstractRNG, q::MvLowRankLocationScale{S, D, L}, num_samples::Int + rng::AbstractRNG, q::MvLocationScaleLowRank{S, D, L}, num_samples::Int ) where {S, D, L} @unpack location, scale_diag, scale_factors, dist = q n_dims = length(location) @@ -80,7 +94,7 @@ end function Distributions._rand!( rng::AbstractRNG, - q ::MvLowRankLocationScale, + q ::MvLocationScaleLowRank, x ::AbstractVecOrMat{<:Real} ) @unpack location, scale_diag, scale_factors, dist = q @@ -90,26 +104,26 @@ function Distributions._rand!( rand!(rng, dist, x) x[:] = scale_diag.*x - u_fact = rand(dist, n_factors, size(x,2)) + u_fact = rand(rng, dist, n_factors, size(x,2)) x .+= scale_factors*u_fact return x .+= location end -Distributions.mean(q::MvLowRankLocationScale) = q.location +Distributions.mean(q::MvLocationScaleLowRank) = q.location -function Distributions.var(q::MvLowRankLocationScale) +function Distributions.var(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q - Diagonal(scale_diag + diag(scale_factors*scale_factors')) + Diagonal(scale_diag + sum(scale_factors.^2, dims=2)[:,1]) end -function Distributions.cov(q::MvLowRankLocationScale) +function Distributions.cov(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q Diagonal(scale_diag) + scale_factors*scale_factors' end function update_variational_params!( - ::Type{<:MvLowRankLocationScale}, opt_st, params, restructure, grad + ::Type{<:MvLocationScaleLowRank}, opt_st, params, restructure, grad ) opt_st, params = Optimisers.update!(opt_st, params, grad) q = restructure(params) diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl deleted file mode 100644 index 7a129018..00000000 --- a/test/interface/location_scale.jl +++ /dev/null @@ -1,166 +0,0 @@ - -@testset "interface LocationScale" begin - @testset "$(string(covtype)) $(basedist) $(realtype)" for - basedist = [:gaussian], - covtype = [:meanfield, :fullrank], - realtype = [Float32, Float64] - - n_dims = 10 - n_montecarlo = 1000_000 - - μ = randn(realtype, n_dims) - L = if covtype == :fullrank - tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular - else - Diagonal(ones(realtype, n_dims)) - end - Σ = L*L' - - q = if covtype == :fullrank && basedist == :gaussian - FullRankGaussian(μ, L) - elseif covtype == :meanfield && basedist == :gaussian - MeanFieldGaussian(μ, L) - end - q_true = if basedist == :gaussian - MvNormal(μ, Σ) - end - - @testset "eltype" begin - @test eltype(q) == realtype - end - - @testset "logpdf" begin - z = rand(q) - @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) - @test eltype(logpdf(q, z)) == realtype - end - - @testset "entropy" begin - @test eltype(entropy(q)) == realtype - @test entropy(q) ≈ entropy(q_true) - end - - @testset "length" begin - @test length(q) == n_dims - end - - @testset "statistics" begin - @testset "mean" begin - @test eltype(mean(q)) == realtype - @test mean(q) == μ - end - @testset "var" begin - @test eltype(var(q)) == realtype - @test var(q) ≈ Diagonal(Σ) - end - @testset "cov" begin - @test eltype(cov(q)) == realtype - @test cov(q) ≈ Σ - end - end - - @testset "sampling" begin - @testset "rand" begin - z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) - @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) - - z_sample_ref = rand(StableRNG(1), q) - @test z_sample_ref == rand(StableRNG(1), q) - end - - @testset "rand batch" begin - z_samples = rand(q, n_montecarlo) - @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) - - samples_ref = rand(StableRNG(1), q, n_montecarlo) - @test samples_ref == rand(StableRNG(1), q, n_montecarlo) - end - - @testset "rand! AbstractVector" begin - res = map(1:n_montecarlo) do _ - z_sample = Array{realtype}(undef, n_dims) - z_sample_ret = rand!(q, z_sample) - (z_sample, z_sample_ret) - end - z_samples = mapreduce(first, hcat, res) - z_samples_ret = mapreduce(last, hcat, res) - @test z_samples == z_samples_ret - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) - - z_sample_ref = Array{realtype}(undef, n_dims) - rand!(StableRNG(1), q, z_sample_ref) - - z_sample = Array{realtype}(undef, n_dims) - rand!(StableRNG(1), q, z_sample) - @test z_sample_ref == z_sample - end - - @testset "rand! AbstractMatrix" begin - z_samples = Array{realtype}(undef, n_dims, n_montecarlo) - z_samples_ret = rand!(q, z_samples) - @test z_samples == z_samples_ret - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) - - z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) - rand!(StableRNG(1), q, z_samples_ref) - - z_samples = Array{realtype}(undef, n_dims, n_montecarlo) - rand!(StableRNG(1), q, z_samples) - @test z_samples_ref == z_samples - end - end - end - - @testset "Diagonal destructure" begin - n_dims = 10 - μ = zeros(n_dims) - L = ones(n_dims) - q = MeanFieldGaussian(μ, L |> Diagonal) - λ, re = Optimisers.destructure(q) - - @test length(λ) == 2*n_dims - @test q == re(λ) - end -end - -@testset "scale positive definite projection" begin - @testset "$(string(covtype)) $(realtype) $(bijector)" for - covtype = [:meanfield, :fullrank], - realtype = [Float32, Float64], - bijector = [nothing, :identity] - - d = 5 - μ = zeros(realtype, d) - ϵ = sqrt(realtype(0.5)) - q = if covtype == :fullrank - L = LowerTriangular(Matrix{realtype}(I,d,d)) - FullRankGaussian(μ, L; scale_eps=ϵ) - elseif covtype == :meanfield - L = Diagonal(ones(realtype, d)) - MeanFieldGaussian(μ, L; scale_eps=ϵ) - end - q_trans = if isnothing(bijector) - q - else - Bijectors.TransformedDistribution(q, identity) - end - g = deepcopy(q) - - λ, re = Optimisers.destructure(q) - grad, _ = Optimisers.destructure(g) - opt_st = Optimisers.setup(Descent(one(realtype)), λ) - _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) - q′ = re(λ′) - @test all(diag(var(q′)) .≥ ϵ^2) - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 3bd13144..4523aaf4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,9 +41,12 @@ if GROUP == "All" || GROUP == "Interface" include("interface/ad.jl") include("interface/optimize.jl") include("interface/repgradelbo.jl") - include("interface/location_scale.jl") end +if GROUP == "All" || GROUP == "Families" + include("families/location_scale.jl") + include("families/location_scale_low_rank.jl") +end const PROGRESS = haskey(ENV, "PROGRESS") From 34546e148c8cc4c5294dcc8c48024abf5aacb5dc Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 5 Aug 2024 00:54:18 -0400 Subject: [PATCH 05/37] fix remove misleading comment --- src/families/location_scale_low_rank.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index a8cb3df4..6ab732dc 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -129,7 +129,7 @@ function update_variational_params!( q = restructure(params) ϵ = q.scale_eps - # Project the scale matrix to the set of positive definite triangular matrices + # Clip diagonal to guarantee positive definite covariance @. q.scale_diag = max(q.scale_diag, ϵ) params, _ = Optimisers.destructure(q) From e030f2d0f03128be3b66b65027b8017302d1343a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 5 Aug 2024 01:08:06 -0400 Subject: [PATCH 06/37] fix add missing test files --- test/families/location_scale.jl | 167 +++++++++++++++++++++++ test/families/location_scale_low_rank.jl | 150 ++++++++++++++++++++ 2 files changed, 317 insertions(+) create mode 100644 test/families/location_scale.jl create mode 100644 test/families/location_scale_low_rank.jl diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl new file mode 100644 index 00000000..22cffd89 --- /dev/null +++ b/test/families/location_scale.jl @@ -0,0 +1,167 @@ + +@testset "interface LocationScale" begin + @testset "$(string(covtype)) $(basedist) $(realtype)" for + basedist = [:gaussian], + covtype = [:meanfield, :fullrank], + realtype = [Float32, Float64] + + n_dims = 10 + n_montecarlo = 1000_000 + + μ = randn(realtype, n_dims) + L = if covtype == :fullrank + tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular + else + Diagonal(ones(realtype, n_dims)) + end + Σ = L*L' + + q = if covtype == :fullrank && basedist == :gaussian + FullRankGaussian(μ, L) + elseif covtype == :meanfield && basedist == :gaussian + MeanFieldGaussian(μ, L) + end + q_true = if basedist == :gaussian + MvNormal(μ, Σ) + end + + @testset "eltype" begin + @test eltype(q) == realtype + end + + @testset "logpdf" begin + z = rand(q) + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype + end + + @testset "entropy" begin + @test eltype(entropy(q)) == realtype + @test entropy(q) ≈ entropy(q_true) + end + + @testset "length" begin + @test length(q) == n_dims + end + + @testset "statistics" begin + @testset "mean" begin + @test eltype(mean(q)) == realtype + @test mean(q) == μ + end + @testset "var" begin + @test eltype(var(q)) == realtype + @test var(q) ≈ Diagonal(Σ) + end + @testset "cov" begin + @test eltype(cov(q)) == realtype + @test cov(q) ≈ Σ + end + end + + @testset "sampling" begin + @testset "rand" begin + z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_sample_ref = rand(StableRNG(1), q) + @test z_sample_ref == rand(StableRNG(1), q) + end + + @testset "rand batch" begin + z_samples = rand(q, n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + samples_ref = rand(StableRNG(1), q, n_montecarlo) + @test samples_ref == rand(StableRNG(1), q, n_montecarlo) + end + + @testset "rand! AbstractVector" begin + res = map(1:n_montecarlo) do _ + z_sample = Array{realtype}(undef, n_dims) + z_sample_ret = rand!(q, z_sample) + (z_sample, z_sample_ret) + end + z_samples = mapreduce(first, hcat, res) + z_samples_ret = mapreduce(last, hcat, res) + @test z_samples == z_samples_ret + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_sample_ref = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample_ref) + + z_sample = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample) + @test z_sample_ref == z_sample + end + + @testset "rand! AbstractMatrix" begin + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + z_samples_ret = rand!(q, z_samples) + @test z_samples == z_samples_ret + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples_ref) + + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples) + @test z_samples_ref == z_samples + end + end + end + + @testset "scale positive definite projection" begin + @testset "$(string(covtype)) $(realtype) $(bijector)" for + covtype = [:meanfield, :fullrank], + realtype = [Float32, Float64], + bijector = [nothing, :identity] + + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + q = if covtype == :fullrank + L = LowerTriangular(Matrix{realtype}(I,d,d)) + FullRankGaussian(μ, L; scale_eps=ϵ) + elseif covtype == :meanfield + L = Diagonal(ones(realtype, d)) + MeanFieldGaussian(μ, L; scale_eps=ϵ) + end + q_trans = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, identity) + end + g = deepcopy(q) + + λ, re = Optimisers.destructure(q) + grad, _ = Optimisers.destructure(g) + opt_st = Optimisers.setup(Descent(one(realtype)), λ) + _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + q′ = re(λ′) + @test all(diag(var(q′)) .≥ ϵ^2) + end + end + + @testset "Diagonal destructure" begin + n_dims = 10 + μ = zeros(n_dims) + L = ones(n_dims) + q = MeanFieldGaussian(μ, L |> Diagonal) + λ, re = Optimisers.destructure(q) + + @test length(λ) == 2*n_dims + @test q == re(λ) + end +end + diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl new file mode 100644 index 00000000..0e12e315 --- /dev/null +++ b/test/families/location_scale_low_rank.jl @@ -0,0 +1,150 @@ + +@testset "interface LocationScaleLowRank" begin + @testset "$(basedist) rank=$(rank) $(realtype)" for + basedist = [:gaussian], + rank = [1, 2], + realtype = [Float32, Float64] + + n_dims = 10 + n_montecarlo = 1000_000 + + μ = randn(realtype, n_dims) + D = ones(realtype, n_dims) + U = randn(realtype, n_dims, rank) + Σ = Diagonal(D) + U*U' + + q = if basedist == :gaussian + MvLocationScaleLowRank( + μ, D, U, Normal{realtype}(zero(realtype), one(realtype)) + ) + end + q_true = if basedist == :gaussian + MvNormal(μ, Σ) + end + + @testset "eltype" begin + @test eltype(q) == realtype + end + + @testset "logpdf" begin + z = rand(q) + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype + end + + @testset "entropy" begin + @test eltype(entropy(q)) == realtype + @test entropy(q) ≈ entropy(q_true) + end + + @testset "length" begin + @test length(q) == n_dims + end + + @testset "statistics" begin + @testset "mean" begin + @test eltype(mean(q)) == realtype + @test mean(q) == μ + end + @testset "var" begin + @test eltype(var(q)) == realtype + @test var(q) ≈ Diagonal(Σ) + end + @testset "cov" begin + @test eltype(cov(q)) == realtype + @test cov(q) ≈ Σ + end + end + + @testset "sampling" begin + @testset "rand" begin + z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_sample_ref = rand(StableRNG(1), q) + @test z_sample_ref == rand(StableRNG(1), q) + end + + @testset "rand batch" begin + z_samples = rand(q, n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + samples_ref = rand(StableRNG(1), q, n_montecarlo) + @test samples_ref == rand(StableRNG(1), q, n_montecarlo) + end + + @testset "rand! AbstractVector" begin + res = map(1:n_montecarlo) do _ + z_sample = Array{realtype}(undef, n_dims) + z_sample_ret = rand!(q, z_sample) + (z_sample, z_sample_ret) + end + z_samples = mapreduce(first, hcat, res) + z_samples_ret = mapreduce(last, hcat, res) + @test z_samples == z_samples_ret + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_sample_ref = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample_ref) + + z_sample = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample) + @test z_sample_ref == z_sample + end + + @testset "rand! AbstractMatrix" begin + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + z_samples_ret = rand!(q, z_samples) + @test z_samples == z_samples_ret + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples_ref) + + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples) + @test z_samples_ref == z_samples + end + end + end + + @testset "diagonal positive definite projection" begin + @testset "$(realtype) $(bijector)" for + realtype = [Float32, Float64], + bijector = [nothing, :identity] + + rank = 2 + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + D = ones(realtype, d) + U = randn(realtype, d, rank) + q = MvLocationScaleLowRank( + μ, D, U, Normal{realtype}(zero(realtype), one(realtype)); scale_eps=ϵ + ) + q_trans = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, identity) + end + g = deepcopy(q) + + λ, re = Optimisers.destructure(q) + grad, _ = Optimisers.destructure(g) + opt_st = Optimisers.setup(Descent(one(realtype)), λ) + _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + q′ = re(λ′) + @test all(diag(var(q′)) .≥ ϵ^2) + end + end +end From c7f36d65329d2c509b11bea690e12618bd8256b2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 5 Aug 2024 01:59:29 -0400 Subject: [PATCH 07/37] fix broadcasting error on Julia 1.6 --- src/families/location_scale_low_rank.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index 6ab732dc..81a28154 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -99,13 +99,11 @@ function Distributions._rand!( ) @unpack location, scale_diag, scale_factors, dist = q - n_factors = size(scale_factors, 2) - rand!(rng, dist, x) x[:] = scale_diag.*x - u_fact = rand(rng, dist, n_factors, size(x,2)) - x .+= scale_factors*u_fact + u_fact = rand(rng, dist, size(scale_factors, 2), size(x,2)) + x[:,:] += scale_factors*u_fact return x .+= location end From 1bb3e3eae30ab917692cc7d32a0c31ae3b3d679f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 7 Aug 2024 02:19:51 -0400 Subject: [PATCH 08/37] fix bug in sampling from `LocationScaleLowRank` --- src/families/location_scale_low_rank.jl | 13 +++++++------ test/families/location_scale_low_rank.jl | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index 81a28154..93192f74 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -49,10 +49,11 @@ Base.eltype(::Type{<:MvLocationScaleLowRank{S, D, L}}) where {S, D, L} = eltype( function StatsBase.entropy(q::MvLocationScaleLowRank) @unpack location, scale_diag, scale_factors, dist = q - n_dims = length(location) - UtDinvU = Hermitian(scale_factors'*(scale_factors./scale_diag)) - logdetΣ = (sum(log.(scale_diag)) + logdet(I + UtDinvU))/2 - n_dims*convert(eltype(location), entropy(dist)) + logdetΣ + n_dims = length(location) + scale_diag2 = scale_diag.*scale_diag + UtDinvU = Hermitian(scale_factors'*(scale_factors./scale_diag2)) + logdetΣ = 2*sum(log.(scale_diag)) + logdet(I + UtDinvU) + n_dims*convert(eltype(location), entropy(dist)) + logdetΣ/2 end function Distributions.logpdf(q::MvLocationScaleLowRank, z::AbstractVector{<:Real}) @@ -112,12 +113,12 @@ Distributions.mean(q::MvLocationScaleLowRank) = q.location function Distributions.var(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q - Diagonal(scale_diag + sum(scale_factors.^2, dims=2)[:,1]) + Diagonal(scale_diag.^2 + sum(scale_factors.^2, dims=2)[:,1]) end function Distributions.cov(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q - Diagonal(scale_diag) + scale_factors*scale_factors' + Diagonal(scale_diag.^2) + scale_factors*scale_factors' end function update_variational_params!( diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 0e12e315..591524e1 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -11,7 +11,7 @@ μ = randn(realtype, n_dims) D = ones(realtype, n_dims) U = randn(realtype, n_dims, rank) - Σ = Diagonal(D) + U*U' + Σ = Diagonal(D.^2) + U*U' q = if basedist == :gaussian MvLocationScaleLowRank( From ddd212268edfb4babf30cb3c3800f54bd7080eda Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 7 Aug 2024 02:24:31 -0400 Subject: [PATCH 09/37] fix missing squared bug in `LocationScaleLowRank` --- src/families/location_scale_low_rank.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index 93192f74..e60a419b 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -68,7 +68,7 @@ function Distributions.logpdf(q::MvLocationScaleLowRank, z::AbstractVector{<:Rea # lowrankupdate!(Σchol, factor) # end - Σ = Diagonal(scale_diag) + scale_factors*scale_factors' + Σ = Diagonal(scale_diag.*scale_diag) + scale_factors*scale_factors' Σchol = cholesky(Σ) sum(Base.Fix1(logpdf, dist), Σchol.L \ (z - location)) - logdet(Σchol.L) end @@ -113,12 +113,12 @@ Distributions.mean(q::MvLocationScaleLowRank) = q.location function Distributions.var(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q - Diagonal(scale_diag.^2 + sum(scale_factors.^2, dims=2)[:,1]) + Diagonal(scale_diag.*scale_diag + sum(scale_factors.*scale_factors, dims=2)[:,1]) end function Distributions.cov(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q - Diagonal(scale_diag.^2) + scale_factors*scale_factors' + Diagonal(scale_diag.*scale_diag) + scale_factors*scale_factors' end function update_variational_params!( From b24737f92bed37ed2b97c4f79aad5c659a5b550f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 9 Aug 2024 01:09:40 -0400 Subject: [PATCH 10/37] add documentation for low-rank families --- docs/make.jl | 6 +- docs/src/elbo/families.md | 122 ++++++++++++++++++++++++++++++++++++++ docs/src/locscale.md | 80 ------------------------- 3 files changed, 125 insertions(+), 83 deletions(-) create mode 100644 docs/src/elbo/families.md delete mode 100644 docs/src/locscale.md diff --git a/docs/make.jl b/docs/make.jl index 7ae3bc62..1abe8f22 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -15,9 +15,9 @@ makedocs(; "General Usage" => "general.md", "Examples" => "examples.md", "ELBO Maximization" => [ - "Overview" => "elbo/overview.md", - "Reparameterization Gradient Estimator" => "elbo/repgradelbo.md", - "Location-Scale Variational Family" => "locscale.md", + "Overview" => "elbo/overview.md", + "Reparameterization Gradient Estimator" => "elbo/repgradelbo.md", + "Variational Families" => "elbo/families.md", ]], ) diff --git a/docs/src/elbo/families.md b/docs/src/elbo/families.md new file mode 100644 index 00000000..e827a0d0 --- /dev/null +++ b/docs/src/elbo/families.md @@ -0,0 +1,122 @@ + +# [Reparameterizable Variational Families](@id families) +The [RepGradELBO](@ref repgradelbo) objective assumes that the members of the variational family have a differentiable sampling path. +We provide multiple pre-packaged variational families that can be readily used. + +## The `LocationScale` Family +The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as +```math +z \sim q_{\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} C u + m;\quad u \sim \varphi +``` +where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. +``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. +The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. + +The probability density is given by +```math + q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)), +``` +the covariance is given as +```math + \mathrm{Var}\left(q_{\lambda}\right) = C \mathrm{Var}(q_{\lambda}) C^{\top} +``` +and the entropy is given as +```math + \mathbb{H}(q_{\lambda}) = \mathbb{H}(\varphi) + \log |C|, +``` +where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. +Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``. +The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. + +!!! note + For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. + Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities. + +```@docs +MvLocationScale +``` + +The following are specialized constructors for convenience: +```@docs +FullRankGaussian +MeanFieldGaussian +``` + +### Gaussian Variational Families +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); + +L = diagm(ones(2)) |> LowerTriangular; +q = FullRankGaussian(μ, L) + +L = ones(2) |> Diagonal; +q = MeanFieldGaussian(μ, L) +``` + +### Sudent-$$t$$ Variational Families +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); +ν = 3; + +# Full-Rank +L = diagm(ones(2)) |> LowerTriangular; +q = MvLocationScale(μ, L, TDist(ν)) + +# Mean-Field +L = ones(2) |> Diagonal; +q = MvLocationScale(μ, L, TDist(ν)) +``` + +### Laplace Variational families +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); + +# Full-Rank +L = diagm(ones(2)) |> LowerTriangular; +q = MvLocationScale(μ, L, Laplace()) + +# Mean-Field +L = ones(2) |> Diagonal; +q = MvLocationScale(μ, L, Laplace()) +``` + +## The `LocationScaleLowRank` Family +In practice, `LocationScale` families with full-rank scale matrices are known to converge slowly as they require a small SGD stepsize. +Low-rank variational families can be an effective alternative[^ONS2018]. +`LocationScaleLowRank` generally represent any ``d``-dimensional distribution which its sampling path can be represented as +```math +z \sim q_{\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} D u_1 + U u_2 + m;\quad u_1, u_2 \sim \varphi +``` +where ``D \in \mathbb{R}^{d \times d}`` is a diagonal matrix, ``U \in \mathbb{R}^{d \times r}`` is a dense low-rank matrix for the rank ``r > 0``, ``m \in \mathbb{R}^d`` is the location, and ``\varphi`` is the *base distribution*. +``m``, ``D``, and ``U`` form the variational parameters ``\lambda = (m, D, U)``. + +The covariance of this distribution is given as +```math + \mathrm{Var}\left(q_{\lambda}\right) = D \mathrm{Var}(\varphi) D + U \mathrm{Var}(\varphi) U^{\top} +``` +and the entropy is given by the matrix determinant lemma as +```math + \mathbb{H}(q_{\lambda}) + = \mathbb{H}(\varphi) + \log |\Sigma| + = \mathbb{H}(\varphi) + 2 \log |D| + \log |I + U^{\top} D^{-2} U|, +``` +where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. + +!!! note + `logpdf` for `LocationScaleLowRank` is unfortunately not computationally efficient and has the same time complexity as `LocationScale` with a full-rank scale. + +```@docs +MvLocationScaleLowRank +``` + +The following is a specialized constructor for convenience: +```@docs +LowRankGaussian +``` + +[^ONS2018]: Ong, V. M. H., Nott, D. J., & Smith, M. S. (2018). Gaussian variational approximation with a factor covariance structure. Journal of Computational and Graphical Statistics, 27(3), 465-478. diff --git a/docs/src/locscale.md b/docs/src/locscale.md deleted file mode 100644 index 643c3a98..00000000 --- a/docs/src/locscale.md +++ /dev/null @@ -1,80 +0,0 @@ - -# [Location-Scale Variational Family](@id locscale) - -## Introduction -The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as -```math -z \sim q_{\lambda} \qquad\Leftrightarrow\qquad -z \stackrel{d}{=} C u + m;\quad u \sim \varphi -``` -where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. -``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. -The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. - -The probability density is given by -```math - q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)) -``` -and the entropy is given as -```math - \mathbb{H}(q_{\lambda}) = \mathbb{H}(\varphi) + \log |C|, -``` -where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. -Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``. -The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. - -## Constructors - -!!! note - For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. - Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities. - -```@docs -MvLocationScale -``` - -```@docs -FullRankGaussian -MeanFieldGaussian -``` - -## Gaussian Variational Families -```julia -using AdvancedVI, LinearAlgebra, Distributions; -μ = zeros(2); - -L = diagm(ones(2)) |> LowerTriangular; -q = FullRankGaussian(μ, L) - -L = ones(2) |> Diagonal; -q = MeanFieldGaussian(μ, L) -``` - -## Sudent-$$t$$ Variational Families -```julia -using AdvancedVI, LinearAlgebra, Distributions; -μ = zeros(2); -ν = 3; - -# Full-Rank -L = diagm(ones(2)) |> LowerTriangular; -q = MvLocationScale(μ, L, TDist(ν)) - -# Mean-Field -L = ones(2) |> Diagonal; -q = MvLocationScale(μ, L, TDist(ν)) -``` - -## Laplace Variational families -```julia -using AdvancedVI, LinearAlgebra, Distributions; -μ = zeros(2); - -# Full-Rank -L = diagm(ones(2)) |> LowerTriangular; -q = MvLocationScale(μ, L, Laplace()) - -# Mean-Field -L = ones(2) |> Diagonal; -q = MvLocationScale(μ, L, Laplace()) -``` From 1d56953b1d7523d7504516701de167a3b0915265 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 9 Aug 2024 01:09:51 -0400 Subject: [PATCH 11/37] add convenience constructors for `LocationScaleLowRank` --- src/AdvancedVI.jl | 8 ++++++-- src/families/location_scale.jl | 12 ++++++------ src/families/location_scale_low_rank.jl | 25 ++++++++++++++++++++++++ test/families/location_scale_low_rank.jl | 4 +--- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index c96032e6..36cc9410 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -191,10 +191,14 @@ include("objectives/elbo/repgradelbo.jl") export MvLocationScale, MeanFieldGaussian, - FullRankGaussian, - MvLocationScaleLowRank + FullRankGaussian include("families/location_scale.jl") + +export + MvLocationScaleLowRank, + LowRankGaussian + include("families/location_scale_low_rank.jl") diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 66ea5cdb..92dd2bf6 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -120,13 +120,13 @@ function Distributions.cov(q::MvLocationScale) end """ - FullRankGaussian(location, scale; check_args = true) + FullRankGaussian(μ, L; check_args = true) Construct a Gaussian variational approximation with a dense covariance matrix. # Arguments -- `location::AbstractVector{T}`: Mean of the Gaussian. -- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian. +- `μ::AbstractVector{T}`: Mean of the Gaussian. +- `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian. # Keyword Arguments - `check_args`: Check the conditioning of the initial scale (default: `true`). @@ -142,13 +142,13 @@ function FullRankGaussian( end """ - MeanFieldGaussian(location, scale; check_args = true) + MeanFieldGaussian(μ, L; check_args = true) Construct a Gaussian variational approximation with a diagonal covariance matrix. # Arguments -- `location::AbstractVector{T}`: Mean of the Gaussian. -- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian. +- `μ::AbstractVector{T}`: Mean of the Gaussian. +- `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian. # Keyword Arguments - `check_args`: Check the conditioning of the initial scale (default: `true`). diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index e60a419b..08c5aa25 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -3,6 +3,7 @@ MvLocationLowRankScale(location, scale_diag, scale_factors, dist) <: ContinuousMultivariateDistribution Variational family with a covariance in the form of a diagonal matrix plus a squared low-rank matrix. +The rank is given by `size(scale_factors, 2)`. It generally represents any distribution for which the sampling path can be represented as follows: @@ -135,3 +136,27 @@ function update_variational_params!( opt_st, params end + +""" + LowRankGaussian(location, scale_diag, scale_factors; check_args = true) + +Construct a Gaussian variational approximation with a diagonal plus low-rank covariance matrix. + +# Arguments +- `μ::AbstractVector{T}`: Mean of the Gaussian. +- `D::Vector{T}`: Diagonal of the scale. +- `U::Matrix{T}`: Low-rank factors of the scale, where `size(U,2)` is the rank. + +# Keyword Arguments +- `check_args`: Check the conditioning of the initial scale (default: `true`). +""" +function LowRankGaussian( + μ::AbstractVector{T}, + D::Vector{T}, + U::Matrix{T}; + scale_eps::T = sqrt(eps(T)) +) where {T <: Real} + @assert minimum(D) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal scale value is $(minimum(D)). This might result in unstable optimization behavior." + q_base = Normal{T}(zero(T), one(T)) + MvLocationScaleLowRank(μ, D, U, q_base, scale_eps) +end diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 591524e1..dc9fc2d1 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -14,9 +14,7 @@ Σ = Diagonal(D.^2) + U*U' q = if basedist == :gaussian - MvLocationScaleLowRank( - μ, D, U, Normal{realtype}(zero(realtype), one(realtype)) - ) + LowRankGaussian(μ, D, U) end q_true = if basedist == :gaussian MvNormal(μ, Σ) From 52568b58e43cc8e1697bbced54ce2ae78461d99a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 10 Aug 2024 02:41:38 -0400 Subject: [PATCH 12/37] fix mhauru's suggestions and run formatter --- src/families/location_scale.jl | 19 +++-- src/families/location_scale_low_rank.jl | 93 +++++++++++------------- test/families/location_scale_low_rank.jl | 8 +- 3 files changed, 56 insertions(+), 64 deletions(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 5a5767e1..d73aee9f 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -27,6 +27,7 @@ function MvLocationScale( dist::ContinuousDistribution; scale_eps::T=sqrt(eps(T)), ) where {T<:Real} + @assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." return MvLocationScale(location, scale, dist, scale_eps) end @@ -37,8 +38,8 @@ Functors.@functor MvLocationScale (location, scale) # `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD # is very inefficient. # begin -struct RestructureMeanField{S<:Diagonal,D,L} - q::MvLocationScale{S,D,L} +struct RestructureMeanField{S<:Diagonal,D,L,E} + q::MvLocationScale{S,D,L,E} end function (re::RestructureMeanField)(flat::AbstractVector) @@ -48,7 +49,7 @@ function (re::RestructureMeanField)(flat::AbstractVector) return MvLocationScale(location, scale, re.q.dist, re.q.scale_eps) end -function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L} +function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E} @unpack location, scale, dist = q flat = vcat(location, diag(scale)) return flat, RestructureMeanField(q) @@ -59,7 +60,7 @@ Base.length(q::MvLocationScale) = length(q.location) Base.size(q::MvLocationScale) = size(q.location) -Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D) +Base.eltype(::Type{<:MvLocationScale{S,D,L,E}}) where {S,D,L,E} = eltype(D) function StatsBase.entropy(q::MvLocationScale) @unpack location, scale, dist = q @@ -119,7 +120,7 @@ function Distributions.cov(q::MvLocationScale) end """ - FullRankGaussian(μ, L; check_args = true) + FullRankGaussian(μ, L; scale_eps) Construct a Gaussian variational approximation with a dense covariance matrix. @@ -128,18 +129,17 @@ Construct a Gaussian variational approximation with a dense covariance matrix. - `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian. # Keyword Arguments -- `check_args`: Check the conditioning of the initial scale (default: `true`). +- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). """ function FullRankGaussian( μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=sqrt(eps(T)) ) where {T<:Real} - @assert minimum(diag(L)) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) return MvLocationScale(μ, L, q_base, scale_eps) end """ - MeanFieldGaussian(μ, L; check_args = true) + MeanFieldGaussian(μ, L; scale_eps) Construct a Gaussian variational approximation with a diagonal covariance matrix. @@ -148,12 +148,11 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix - `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian. # Keyword Arguments -- `check_args`: Check the conditioning of the initial scale (default: `true`). +- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). """ function MeanFieldGaussian( μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=sqrt(eps(T)) ) where {T<:Real} - @assert minimum(diag(L)) ≥ sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) return MvLocationScale(μ, L, q_base, scale_eps) end diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index 08c5aa25..d72a4e2c 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -16,28 +16,24 @@ represented as follows: ``` """ struct MvLocationScaleLowRank{ - L, - SD <: AbstractVector, - SF <: AbstractMatrix, - D <: ContinuousDistribution, - E <: Real + L,SD<:AbstractVector,SF<:AbstractMatrix,D<:ContinuousDistribution,E<:Real } <: ContinuousMultivariateDistribution - location ::L - scale_diag ::SD + location::L + scale_diag::SD scale_factors::SF - dist ::D - scale_eps ::E + dist::D + scale_eps::E end function MvLocationScaleLowRank( - location ::AbstractVector{T}, - scale_diag ::AbstractVector{T}, + location::AbstractVector{T}, + scale_diag::AbstractVector{T}, scale_factors::AbstractMatrix{T}, - dist ::ContinuousDistribution; - scale_eps ::T = sqrt(eps(T)) -) where {T <: Real} - @assert size(scale_factors,1) == length(scale_diag) - MvLocationScaleLowRank(location, scale_diag, scale_factors, dist, scale_eps) + dist::ContinuousDistribution; + scale_eps::T=sqrt(eps(T)), +) where {T<:Real} + @assert size(scale_factors, 1) == length(scale_diag) + return MvLocationScaleLowRank(location, scale_diag, scale_factors, dist, scale_eps) end Functors.@functor MvLocationScaleLowRank (location, scale_diag, scale_factors) @@ -46,15 +42,15 @@ Base.length(q::MvLocationScaleLowRank) = length(q.location) Base.size(q::MvLocationScaleLowRank) = size(q.location) -Base.eltype(::Type{<:MvLocationScaleLowRank{S, D, L}}) where {S, D, L} = eltype(D) +Base.eltype(::Type{<:MvLocationScaleLowRank{L,SD,SF,D,E}}) where {L,SD,SF,D,E} = eltype(L) function StatsBase.entropy(q::MvLocationScaleLowRank) @unpack location, scale_diag, scale_factors, dist = q - n_dims = length(location) - scale_diag2 = scale_diag.*scale_diag - UtDinvU = Hermitian(scale_factors'*(scale_factors./scale_diag2)) - logdetΣ = 2*sum(log.(scale_diag)) + logdet(I + UtDinvU) - n_dims*convert(eltype(location), entropy(dist)) + logdetΣ/2 + n_dims = length(location) + scale_diag2 = scale_diag .* scale_diag + UtDinvU = Hermitian(scale_factors' * (scale_factors ./ scale_diag2)) + logdetΣ = 2 * sum(log.(scale_diag)) + logdet(I + UtDinvU) + return n_dims * convert(eltype(location), entropy(dist)) + logdetΣ / 2 end function Distributions.logpdf(q::MvLocationScaleLowRank, z::AbstractVector{<:Real}) @@ -69,57 +65,57 @@ function Distributions.logpdf(q::MvLocationScaleLowRank, z::AbstractVector{<:Rea # lowrankupdate!(Σchol, factor) # end - Σ = Diagonal(scale_diag.*scale_diag) + scale_factors*scale_factors' + Σ = Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors' Σchol = cholesky(Σ) - sum(Base.Fix1(logpdf, dist), Σchol.L \ (z - location)) - logdet(Σchol.L) + return sum(Base.Fix1(logpdf, dist), Σchol.L \ (z - location)) - logdet(Σchol.L) end function Distributions.rand(q::MvLocationScaleLowRank) @unpack location, scale_diag, scale_factors, dist = q - n_dims = length(location) + n_dims = length(location) n_factors = size(scale_factors, 2) - u_diag = rand(dist, n_dims) - u_fact = rand(dist, n_factors) - scale_diag.*u_diag + scale_factors*u_fact + location + u_diag = rand(dist, n_dims) + u_fact = rand(dist, n_factors) + return scale_diag .* u_diag + scale_factors * u_fact + location end function Distributions.rand( - rng::AbstractRNG, q::MvLocationScaleLowRank{S, D, L}, num_samples::Int -) where {S, D, L} + rng::AbstractRNG, q::MvLocationScaleLowRank{S,D,L}, num_samples::Int +) where {S,D,L} @unpack location, scale_diag, scale_factors, dist = q - n_dims = length(location) + n_dims = length(location) n_factors = size(scale_factors, 2) - u_diag = rand(rng, dist, n_dims, num_samples) - u_fact = rand(rng, dist, n_factors, num_samples) - scale_diag.*u_diag + scale_factors*u_fact .+ location + u_diag = rand(rng, dist, n_dims, num_samples) + u_fact = rand(rng, dist, n_factors, num_samples) + return scale_diag .* u_diag + scale_factors * u_fact .+ location end function Distributions._rand!( - rng::AbstractRNG, - q ::MvLocationScaleLowRank, - x ::AbstractVecOrMat{<:Real} + rng::AbstractRNG, q::MvLocationScaleLowRank, x::AbstractVecOrMat{<:Real} ) @unpack location, scale_diag, scale_factors, dist = q rand!(rng, dist, x) - x[:] = scale_diag.*x + x[:] = scale_diag .* x - u_fact = rand(rng, dist, size(scale_factors, 2), size(x,2)) - x[:,:] += scale_factors*u_fact + u_fact = rand(rng, dist, size(scale_factors, 2), size(x, 2)) + x[:, :] += scale_factors * u_fact return x .+= location end Distributions.mean(q::MvLocationScaleLowRank) = q.location -function Distributions.var(q::MvLocationScaleLowRank) +function Distributions.var(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q - Diagonal(scale_diag.*scale_diag + sum(scale_factors.*scale_factors, dims=2)[:,1]) + return Diagonal( + scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1] + ) end function Distributions.cov(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q - Diagonal(scale_diag.*scale_diag) + scale_factors*scale_factors' + return Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors' end function update_variational_params!( @@ -134,7 +130,7 @@ function update_variational_params!( params, _ = Optimisers.destructure(q) - opt_st, params + return opt_st, params end """ @@ -151,12 +147,9 @@ Construct a Gaussian variational approximation with a diagonal plus low-rank cov - `check_args`: Check the conditioning of the initial scale (default: `true`). """ function LowRankGaussian( - μ::AbstractVector{T}, - D::Vector{T}, - U::Matrix{T}; - scale_eps::T = sqrt(eps(T)) -) where {T <: Real} + μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=sqrt(eps(T)) +) where {T<:Real} @assert minimum(D) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal scale value is $(minimum(D)). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - MvLocationScaleLowRank(μ, D, U, q_base, scale_eps) + return MvLocationScaleLowRank(μ, D, U, q_base, scale_eps) end diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index dc9fc2d1..057bfe18 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -1,9 +1,9 @@ @testset "interface LocationScaleLowRank" begin @testset "$(basedist) rank=$(rank) $(realtype)" for - basedist = [:gaussian], - rank = [1, 2], - realtype = [Float32, Float64] + basedist in [:gaussian], + rank in [1, 2], + realtype in [Float32, Float64] n_dims = 10 n_montecarlo = 1000_000 @@ -133,7 +133,7 @@ q_trans = if isnothing(bijector) q else - Bijectors.TransformedDistribution(q, identity) + Bijectors.TransformedDistribution(q, bijector) end g = deepcopy(q) From 96eae86df0128105907caaa797d4d13a4e292cb5 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 10 Aug 2024 02:44:10 -0400 Subject: [PATCH 13/37] run formatter --- docs/make.jl | 30 ++++----- test/families/location_scale_low_rank.jl | 84 +++++++++++++----------- 2 files changed, 60 insertions(+), 54 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 1abe8f22..de755d13 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,23 +2,23 @@ using AdvancedVI using Documenter -DocMeta.setdocmeta!( - AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true -) +DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true) makedocs(; - modules = [AdvancedVI], - sitename = "AdvancedVI.jl", - repo = "https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}", - format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), - pages = ["AdvancedVI" => "index.md", - "General Usage" => "general.md", - "Examples" => "examples.md", - "ELBO Maximization" => [ - "Overview" => "elbo/overview.md", - "Reparameterization Gradient Estimator" => "elbo/repgradelbo.md", - "Variational Families" => "elbo/families.md", - ]], + modules=[AdvancedVI], + sitename="AdvancedVI.jl", + repo="https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}", + format=Documenter.HTML(; prettyurls=get(ENV, "CI", nothing) == "true"), + pages=[ + "AdvancedVI" => "index.md", + "General Usage" => "general.md", + "Examples" => "examples.md", + "ELBO Maximization" => [ + "Overview" => "elbo/overview.md", + "Reparameterization Gradient Estimator" => "elbo/repgradelbo.md", + "Variational Families" => "elbo/families.md", + ], + ], ) deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 057bfe18..094963be 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -1,17 +1,16 @@ @testset "interface LocationScaleLowRank" begin - @testset "$(basedist) rank=$(rank) $(realtype)" for - basedist in [:gaussian], - rank in [1, 2], + @testset "$(basedist) rank=$(rank) $(realtype)" for basedist in [:gaussian], + rank in [1, 2], realtype in [Float32, Float64] - n_dims = 10 + n_dims = 10 n_montecarlo = 1000_000 μ = randn(realtype, n_dims) D = ones(realtype, n_dims) U = randn(realtype, n_dims, rank) - Σ = Diagonal(D.^2) + U*U' + Σ = Diagonal(D .^ 2) + U * U' q = if basedist == :gaussian LowRankGaussian(μ, D, U) @@ -26,13 +25,13 @@ @testset "logpdf" begin z = rand(q) - @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) - @test eltype(logpdf(q, z)) == realtype + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol = realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype end @testset "entropy" begin @test eltype(entropy(q)) == realtype - @test entropy(q) ≈ entropy(q_true) + @test entropy(q) ≈ entropy(q_true) end @testset "length" begin @@ -41,37 +40,41 @@ @testset "statistics" begin @testset "mean" begin - @test eltype(mean(q)) == realtype - @test mean(q) == μ + @test eltype(mean(q)) == realtype + @test mean(q) == μ end @testset "var" begin - @test eltype(var(q)) == realtype - @test var(q) ≈ Diagonal(Σ) + @test eltype(var(q)) == realtype + @test var(q) ≈ Diagonal(Σ) end @testset "cov" begin - @test eltype(cov(q)) == realtype - @test cov(q) ≈ Σ + @test eltype(cov(q)) == realtype + @test cov(q) ≈ Σ end end @testset "sampling" begin @testset "rand" begin - z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) + z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) z_sample_ref = rand(StableRNG(1), q) @test z_sample_ref == rand(StableRNG(1), q) end @testset "rand batch" begin - z_samples = rand(q, n_montecarlo) + z_samples = rand(q, n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) samples_ref = rand(StableRNG(1), q, n_montecarlo) @test samples_ref == rand(StableRNG(1), q, n_montecarlo) @@ -79,16 +82,18 @@ @testset "rand! AbstractVector" begin res = map(1:n_montecarlo) do _ - z_sample = Array{realtype}(undef, n_dims) + z_sample = Array{realtype}(undef, n_dims) z_sample_ret = rand!(q, z_sample) (z_sample, z_sample_ret) end - z_samples = mapreduce(first, hcat, res) + z_samples = mapreduce(first, hcat, res) z_samples_ret = mapreduce(last, hcat, res) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) z_sample_ref = Array{realtype}(undef, n_dims) rand!(StableRNG(1), q, z_sample_ref) @@ -99,12 +104,14 @@ end @testset "rand! AbstractMatrix" begin - z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) z_samples_ret = rand!(q, z_samples) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) rand!(StableRNG(1), q, z_samples_ref) @@ -117,9 +124,8 @@ end @testset "diagonal positive definite projection" begin - @testset "$(realtype) $(bijector)" for - realtype = [Float32, Float64], - bijector = [nothing, :identity] + @testset "$(realtype) $(bijector)" for realtype in [Float32, Float64], + bijector in [nothing, :identity] rank = 2 d = 5 @@ -130,18 +136,18 @@ q = MvLocationScaleLowRank( μ, D, U, Normal{realtype}(zero(realtype), one(realtype)); scale_eps=ϵ ) - q_trans = if isnothing(bijector) + q_trans = if isnothing(bijector) q else Bijectors.TransformedDistribution(q, bijector) end g = deepcopy(q) - λ, re = Optimisers.destructure(q) + λ, re = Optimisers.destructure(q) grad, _ = Optimisers.destructure(g) - opt_st = Optimisers.setup(Descent(one(realtype)), λ) - _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) - q′ = re(λ′) + opt_st = Optimisers.setup(Descent(one(realtype)), λ) + _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + q′ = re(λ′) @test all(diag(var(q′)) .≥ ϵ^2) end end From 15556da234235ca8966fd571d3b0f058b8815314 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 10 Aug 2024 02:46:27 -0400 Subject: [PATCH 14/37] run formatter --- docs/src/elbo/families.md | 43 +++++++++++++++++++++++++-------- src/AdvancedVI.jl | 5 +--- test/families/location_scale.jl | 5 ++-- 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/docs/src/elbo/families.md b/docs/src/elbo/families.md index e827a0d0..26dd828a 100644 --- a/docs/src/elbo/families.md +++ b/docs/src/elbo/families.md @@ -1,36 +1,46 @@ - # [Reparameterizable Variational Families](@id families) + The [RepGradELBO](@ref repgradelbo) objective assumes that the members of the variational family have a differentiable sampling path. We provide multiple pre-packaged variational families that can be readily used. ## The `LocationScale` Family + The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as + ```math z \sim q_{\lambda} \qquad\Leftrightarrow\qquad z \stackrel{d}{=} C u + m;\quad u \sim \varphi ``` + where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. -``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. +``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. The probability density is given by + ```math q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)), ``` + the covariance is given as + ```math \mathrm{Var}\left(q_{\lambda}\right) = C \mathrm{Var}(q_{\lambda}) C^{\top} ``` + and the entropy is given as + ```math \mathbb{H}(q_{\lambda}) = \mathbb{H}(\varphi) + \log |C|, ``` + where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``. The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. !!! note - For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. + + For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities. ```@docs @@ -38,76 +48,88 @@ MvLocationScale ``` The following are specialized constructors for convenience: + ```@docs FullRankGaussian MeanFieldGaussian ``` ### Gaussian Variational Families + ```julia using AdvancedVI, LinearAlgebra, Distributions; μ = zeros(2); -L = diagm(ones(2)) |> LowerTriangular; +L = LowerTriangular(diagm(ones(2))); q = FullRankGaussian(μ, L) -L = ones(2) |> Diagonal; +L = Diagonal(ones(2)); q = MeanFieldGaussian(μ, L) ``` ### Sudent-$$t$$ Variational Families + ```julia using AdvancedVI, LinearAlgebra, Distributions; μ = zeros(2); ν = 3; # Full-Rank -L = diagm(ones(2)) |> LowerTriangular; +L = LowerTriangular(diagm(ones(2))); q = MvLocationScale(μ, L, TDist(ν)) # Mean-Field -L = ones(2) |> Diagonal; +L = Diagonal(ones(2)); q = MvLocationScale(μ, L, TDist(ν)) ``` ### Laplace Variational families + ```julia using AdvancedVI, LinearAlgebra, Distributions; μ = zeros(2); # Full-Rank -L = diagm(ones(2)) |> LowerTriangular; +L = LowerTriangular(diagm(ones(2))); q = MvLocationScale(μ, L, Laplace()) # Mean-Field -L = ones(2) |> Diagonal; +L = Diagonal(ones(2)); q = MvLocationScale(μ, L, Laplace()) ``` ## The `LocationScaleLowRank` Family + In practice, `LocationScale` families with full-rank scale matrices are known to converge slowly as they require a small SGD stepsize. Low-rank variational families can be an effective alternative[^ONS2018]. `LocationScaleLowRank` generally represent any ``d``-dimensional distribution which its sampling path can be represented as + ```math z \sim q_{\lambda} \qquad\Leftrightarrow\qquad z \stackrel{d}{=} D u_1 + U u_2 + m;\quad u_1, u_2 \sim \varphi ``` + where ``D \in \mathbb{R}^{d \times d}`` is a diagonal matrix, ``U \in \mathbb{R}^{d \times r}`` is a dense low-rank matrix for the rank ``r > 0``, ``m \in \mathbb{R}^d`` is the location, and ``\varphi`` is the *base distribution*. -``m``, ``D``, and ``U`` form the variational parameters ``\lambda = (m, D, U)``. +``m``, ``D``, and ``U`` form the variational parameters ``\lambda = (m, D, U)``. The covariance of this distribution is given as + ```math \mathrm{Var}\left(q_{\lambda}\right) = D \mathrm{Var}(\varphi) D + U \mathrm{Var}(\varphi) U^{\top} ``` + and the entropy is given by the matrix determinant lemma as + ```math \mathbb{H}(q_{\lambda}) = \mathbb{H}(\varphi) + \log |\Sigma| = \mathbb{H}(\varphi) + 2 \log |D| + \log |I + U^{\top} D^{-2} U|, ``` + where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. !!! note + `logpdf` for `LocationScaleLowRank` is unfortunately not computationally efficient and has the same time complexity as `LocationScale` with a full-rank scale. ```@docs @@ -115,6 +137,7 @@ MvLocationScaleLowRank ``` The following is a specialized constructor for convenience: + ```@docs LowRankGaussian ``` diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 0de9a068..adf8d6eb 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -181,13 +181,10 @@ export MvLocationScale, MeanFieldGaussian, FullRankGaussian include("families/location_scale.jl") -export - MvLocationScaleLowRank, - LowRankGaussian +export MvLocationScaleLowRank, LowRankGaussian include("families/location_scale_low_rank.jl") - # Optimization Routine function optimize end diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl index 2c0c3dfb..3f384d3f 100644 --- a/test/families/location_scale.jl +++ b/test/families/location_scale.jl @@ -129,9 +129,8 @@ end @testset "scale positive definite projection" begin - @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in [ - :meanfield, :fullrank - ], + @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in + [:meanfield, :fullrank], realtype in [Float32, Float64], bijector in [nothing, :identity] From f79615405536639090cd6df57927d43558a8f670 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 11 Aug 2024 15:12:35 -0400 Subject: [PATCH 15/37] fix bugs and improve comments in `MvLocationScale` and lowrank --- src/families/location_scale.jl | 26 +++++++++++-------- src/families/location_scale_low_rank.jl | 34 +++++++++++++++---------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index d73aee9f..051bae27 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -1,6 +1,14 @@ +struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <: + ContinuousMultivariateDistribution + location::L + scale::S + dist::D + scale_eps::E +end + """ - MvLocationScale(location, scale, dist) <: ContinuousMultivariateDistribution + MvLocationScale(location, scale, dist; scale_eps) The location scale variational family broadly represents various variational families using `location` and `scale` variational parameters. @@ -12,22 +20,20 @@ represented as follows: u = rand(dist, d) z = scale*u + location ``` -""" -struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <: - ContinuousMultivariateDistribution - location::L - scale::S - dist::D - scale_eps::E -end +`scale_eps` sets a constraint on the smallest value of `scale` to be enforced during optimization. +This is necessary to guarantee stable convergence. + +# Keyword Arguments +- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `sqrt(eps(T))`). +""" function MvLocationScale( location::AbstractVector{T}, scale::AbstractMatrix{T}, dist::ContinuousDistribution; scale_eps::T=sqrt(eps(T)), ) where {T<:Real} - @assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." + @assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior." return MvLocationScale(location, scale, dist, scale_eps) end diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index d72a4e2c..54f16563 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -1,6 +1,16 @@ +struct MvLocationScaleLowRank{ + L,SD<:AbstractVector,SF<:AbstractMatrix,D<:ContinuousDistribution,E<:Real +} <: ContinuousMultivariateDistribution + location::L + scale_diag::SD + scale_factors::SF + dist::D + scale_eps::E +end + """ - MvLocationLowRankScale(location, scale_diag, scale_factors, dist) <: ContinuousMultivariateDistribution + MvLocationLowRankScale(location, scale_diag, scale_factors, dist; scale_eps) Variational family with a covariance in the form of a diagonal matrix plus a squared low-rank matrix. The rank is given by `size(scale_factors, 2)`. @@ -14,17 +24,13 @@ represented as follows: u_f = rand(dist, r) z = scale_diag.*u_d + scale_factors*u_f + location ``` -""" -struct MvLocationScaleLowRank{ - L,SD<:AbstractVector,SF<:AbstractMatrix,D<:ContinuousDistribution,E<:Real -} <: ContinuousMultivariateDistribution - location::L - scale_diag::SD - scale_factors::SF - dist::D - scale_eps::E -end +`scale_eps` sets a constraint on the smallest value of `scale_diag` to be enforced during optimization. +This is necessary to guarantee stable convergence. + +# Keyword Arguments +- `scale_eps`: Lower bound constraint for the values of scale_diag. (default: `sqrt(eps(T))`). +""" function MvLocationScaleLowRank( location::AbstractVector{T}, scale_diag::AbstractVector{T}, @@ -32,6 +38,7 @@ function MvLocationScaleLowRank( dist::ContinuousDistribution; scale_eps::T=sqrt(eps(T)), ) where {T<:Real} + @assert minimum(scale_diag) ≥ scale_eps "Initial scale is too small (smallest diagonal scale value is $(minimum(scale_diag)). This might result in unstable optimization behavior." @assert size(scale_factors, 1) == length(scale_diag) return MvLocationScaleLowRank(location, scale_diag, scale_factors, dist, scale_eps) end @@ -144,12 +151,11 @@ Construct a Gaussian variational approximation with a diagonal plus low-rank cov - `U::Matrix{T}`: Low-rank factors of the scale, where `size(U,2)` is the rank. # Keyword Arguments -- `check_args`: Check the conditioning of the initial scale (default: `true`). +- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). """ function LowRankGaussian( μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=sqrt(eps(T)) ) where {T<:Real} - @assert minimum(D) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal scale value is $(minimum(D)). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) - return MvLocationScaleLowRank(μ, D, U, q_base, scale_eps) + return MvLocationScaleLowRank(μ, D, U, q_base; scale_eps) end From 6b1699cd1b3b09394e630bd2d8692e46787da480 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 11 Aug 2024 15:16:18 -0400 Subject: [PATCH 16/37] promote families.md into a higher category --- docs/make.jl | 2 +- docs/src/{elbo => }/families.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename docs/src/{elbo => }/families.md (99%) diff --git a/docs/make.jl b/docs/make.jl index de755d13..825ed6a2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,8 +16,8 @@ makedocs(; "ELBO Maximization" => [ "Overview" => "elbo/overview.md", "Reparameterization Gradient Estimator" => "elbo/repgradelbo.md", - "Variational Families" => "elbo/families.md", ], + "Variational Families" => "families.md", ], ) diff --git a/docs/src/elbo/families.md b/docs/src/families.md similarity index 99% rename from docs/src/elbo/families.md rename to docs/src/families.md index 26dd828a..7fcd400a 100644 --- a/docs/src/elbo/families.md +++ b/docs/src/families.md @@ -67,7 +67,7 @@ L = Diagonal(ones(2)); q = MeanFieldGaussian(μ, L) ``` -### Sudent-$$t$$ Variational Families +### Student-$$t$$ Variational Families ```julia using AdvancedVI, LinearAlgebra, Distributions; From 5187d76f4fc801272e63249433edf34612deda93 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 14 Aug 2024 01:32:38 -0400 Subject: [PATCH 17/37] add test for `MVLocationScale` with non-Gaussian --- test/families/location_scale.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl index 3f384d3f..3e651a82 100644 --- a/test/families/location_scale.jl +++ b/test/families/location_scale.jl @@ -1,6 +1,6 @@ @testset "interface LocationScale" begin - @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian], + @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian, :studentt], covtype in [:meanfield, :fullrank], realtype in [Float32, Float64] @@ -19,11 +19,19 @@ FullRankGaussian(μ, L) elseif covtype == :meanfield && basedist == :gaussian MeanFieldGaussian(μ, L) + elseif covtype == :fullrank && basedist == :studentt + MvLocationScale(μ, L, TDist(realtype(10.0))) + elseif covtype == :meanfield && basedist == :studentt + MvLocationScale(μ, L, TDist(realtype(10.0))) end + q_true = if basedist == :gaussian MvNormal(μ, Σ) + elseif basedist == :studentt + MvTDist(realtype(10.0), μ, Matrix(Σ)) end + println(q) @testset "eltype" begin @test eltype(q) == realtype end From 6dfc9192c716fdd6beaa6be595ba26e60f98f4a7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 26 Aug 2024 23:47:55 -0400 Subject: [PATCH 18/37] tighten compat bound for `Distributions` --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fff721f2..29564aaf 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ Accessors = "0.1" Bijectors = "0.13" ChainRulesCore = "1.16" DiffResults = "1" -Distributions = "0.25.87" +Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" Enzyme = "0.12.32" FillArrays = "1.3" diff --git a/test/Project.toml b/test/Project.toml index 251869e7..018198d1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -27,7 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ADTypes = "0.2.1, 1" Bijectors = "0.13" DiffResults = "1.0" -Distributions = "0.25.100" +Distributions = "0.25.111" DistributionsAD = "0.6.45" Enzyme = "0.12.32" FillArrays = "1.6.1" From ba293e52a56895144d5f4471e1f8add909ee4b5a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 5 Sep 2024 12:46:37 -0700 Subject: [PATCH 19/37] fix base distribution standardization bug in `LocationScale` --- src/families/location_scale.jl | 11 ++++-- test/families/location_scale.jl | 67 ++++++++++++++++++--------------- 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index f4fa9d03..a8f3ef2c 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -113,16 +113,21 @@ function Distributions._rand!( return x .+= location end -Distributions.mean(q::MvLocationScale) = q.location +function Distributions.mean(q::MvLocationScale) + @unpack location, scale = q + return location + scale * Fill(mean(q.dist), length(location)) +end function Distributions.var(q::MvLocationScale) C = q.scale - return Diagonal(C * C') + σ2 = var(q.dist) + return σ2 * diag(C * C') end function Distributions.cov(q::MvLocationScale) C = q.scale - return Hermitian(C * C') + σ2 = var(q.dist) + return σ2 * Hermitian(C * C') end """ diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl index 3e651a82..bd45458d 100644 --- a/test/families/location_scale.jl +++ b/test/families/location_scale.jl @@ -1,37 +1,36 @@ @testset "interface LocationScale" begin - @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian, :studentt], + @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in + [:gaussian, :gaussian_nonstd], covtype in [:meanfield, :fullrank], realtype in [Float32, Float64] n_dims = 10 n_montecarlo = 1000_000 - μ = randn(realtype, n_dims) - L = if covtype == :fullrank + location = randn(realtype, n_dims) + scale = if covtype == :fullrank LowerTriangular(tril(I + ones(realtype, n_dims, n_dims) / 2)) else Diagonal(ones(realtype, n_dims)) end - Σ = L * L' q = if covtype == :fullrank && basedist == :gaussian - FullRankGaussian(μ, L) + FullRankGaussian(location, scale) elseif covtype == :meanfield && basedist == :gaussian - MeanFieldGaussian(μ, L) - elseif covtype == :fullrank && basedist == :studentt - MvLocationScale(μ, L, TDist(realtype(10.0))) - elseif covtype == :meanfield && basedist == :studentt - MvLocationScale(μ, L, TDist(realtype(10.0))) + MeanFieldGaussian(location, scale) + elseif covtype == :fullrank && basedist == :gaussian_nonstd + MvLocationScale(location, scale, Normal(realtype(3), realtype(3))) + elseif covtype == :meanfield && basedist == :gaussian_nonstd + MvLocationScale(location, scale, Normal(realtype(3), realtype(3))) end q_true = if basedist == :gaussian - MvNormal(μ, Σ) - elseif basedist == :studentt - MvTDist(realtype(10.0), μ, Matrix(Σ)) + MvNormal(location, scale * scale') + elseif basedist == :gaussian_nonstd + MvNormal(location + scale * fill(3, n_dims), 9 * scale * scale') end - println(q) @testset "eltype" begin @test eltype(q) == realtype end @@ -54,15 +53,15 @@ @testset "statistics" begin @testset "mean" begin @test eltype(mean(q)) == realtype - @test mean(q) == μ + @test mean(q) ≈ mean(q_true) end @testset "var" begin @test eltype(var(q)) == realtype - @test var(q) ≈ Diagonal(Σ) + @test var(q) ≈ var(q_true) end @testset "cov" begin @test eltype(cov(q)) == realtype - @test cov(q) ≈ Σ + @test cov(q) ≈ cov(q_true) end end @@ -70,11 +69,13 @@ @testset "rand" begin z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_sample_ref = rand(StableRNG(1), q) @test z_sample_ref == rand(StableRNG(1), q) @@ -83,11 +84,13 @@ @testset "rand batch" begin z_samples = rand(q, n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) samples_ref = rand(StableRNG(1), q, n_montecarlo) @test samples_ref == rand(StableRNG(1), q, n_montecarlo) @@ -102,11 +105,13 @@ z_samples = mapreduce(first, hcat, res) z_samples_ret = mapreduce(last, hcat, res) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_sample_ref = Array{realtype}(undef, n_dims) rand!(StableRNG(1), q, z_sample_ref) @@ -120,11 +125,13 @@ z_samples = Array{realtype}(undef, n_dims, n_montecarlo) z_samples_ret = rand!(q, z_samples) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) rand!(StableRNG(1), q, z_samples_ref) @@ -164,7 +171,7 @@ opt_st = Optimisers.setup(Descent(one(realtype)), λ) _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) q′ = re(λ′) - @test all(diag(var(q′)) .≥ ϵ^2) + @test all(var(q′) .≥ ϵ^2) end end From 426d94371cfc8d0ddab6e074299c302ea740500a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 5 Sep 2024 13:42:25 -0700 Subject: [PATCH 20/37] fix base distribution standardization bug in `LocationScaleLowRank` --- src/families/location_scale_low_rank.jl | 53 +++++++++++------ test/families/location_scale_low_rank.jl | 74 ++++++++++++++++-------- 2 files changed, 83 insertions(+), 44 deletions(-) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index 54f16563..1359dc33 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -60,21 +60,29 @@ function StatsBase.entropy(q::MvLocationScaleLowRank) return n_dims * convert(eltype(location), entropy(dist)) + logdetΣ / 2 end -function Distributions.logpdf(q::MvLocationScaleLowRank, z::AbstractVector{<:Real}) +function Distributions.logpdf( + q::MvLocationScaleLowRank, z::AbstractVector{<:Real}; non_differntiable::Bool=false +) @unpack location, scale_diag, scale_factors, dist = q - # - ## More efficient O(kd^2) but non-differentiable version: - # - # Σchol = Cholesky(LowerTriangular(diagm(sqrt.(scale_diag)))) - # n_factors = size(scale_factors, 2) - # for k in 1:n_factors - # factor = scale_factors[:,k] - # lowrankupdate!(Σchol, factor) - # end - - Σ = Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors' - Σchol = cholesky(Σ) - return sum(Base.Fix1(logpdf, dist), Σchol.L \ (z - location)) - logdet(Σchol.L) + μ_base = mean(dist) + n_dims = length(location) + + scale2chol = if non_differntiable + # Fast O(kd^2) path (not supported by most current AD frameworks): + scale2chol = Cholesky(LowerTriangular(diagm(sqrt.(scale_diag)))) + n_factors = size(scale_factors, 2) + for k in 1:n_factors + factor = scale_factors[:, k] # copy necessary due to in-place mutation + lowrankupdate!(scale2chol, factor) + end + scale2chol + else + # Slow but differentiable O(d^3) path + scale2 = Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors' + cholesky(scale2) + end + z_std = z - mean(q) + scale2chol.L * Fill(μ_base, n_dims) + return sum(Base.Fix1(logpdf, dist), scale2chol.L \ z_std) - logdet(scale2chol.L) end function Distributions.rand(q::MvLocationScaleLowRank) @@ -111,18 +119,25 @@ function Distributions._rand!( return x .+= location end -Distributions.mean(q::MvLocationScaleLowRank) = q.location +function Distributions.mean(q::MvLocationScaleLowRank) + @unpack location, scale_diag, scale_factors = q + μ = mean(q.dist) + return location + + scale_diag .* Fill(μ, length(scale_diag)) + + scale_factors * Fill(μ, size(scale_factors, 2)) +end function Distributions.var(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q - return Diagonal( - scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1] - ) + σ2 = var(q.dist) + return σ2 * + (scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1]) end function Distributions.cov(q::MvLocationScaleLowRank) @unpack scale_diag, scale_factors = q - return Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors' + σ2 = var(q.dist) + return σ2 * (Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors') end function update_variational_params!( diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 094963be..5b0b53cb 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -1,21 +1,32 @@ @testset "interface LocationScaleLowRank" begin - @testset "$(basedist) rank=$(rank) $(realtype)" for basedist in [:gaussian], - rank in [1, 2], + @testset "$(basedist) rank=$(rank) $(realtype)" for basedist in + [:gaussian, :gaussian_nonstd], + n_rank in [1, 2], realtype in [Float32, Float64] n_dims = 10 n_montecarlo = 1000_000 - μ = randn(realtype, n_dims) - D = ones(realtype, n_dims) - U = randn(realtype, n_dims, rank) - Σ = Diagonal(D .^ 2) + U * U' + location = randn(realtype, n_dims) + scale_diag = ones(realtype, n_dims) + scale_factors = randn(realtype, n_dims, n_rank) q = if basedist == :gaussian - LowRankGaussian(μ, D, U) + LowRankGaussian(location, scale_diag, scale_factors) + elseif basedist == :gaussian_nonstd + MvLocationScaleLowRank( + location, scale_diag, scale_factors, Normal(realtype(3), realtype(3)) + ) end + q_true = if basedist == :gaussian + μ = location + Σ = Diagonal(scale_diag .^ 2) + scale_factors * scale_factors' + MvNormal(location, Σ) + elseif basedist == :gaussian_nonstd + μ = location + scale_diag .* fill(3, n_dims) + scale_factors * fill(3, n_rank) + Σ = 3^2 * (Diagonal(scale_diag .^ 2) + scale_factors * scale_factors') MvNormal(μ, Σ) end @@ -27,6 +38,11 @@ z = rand(q) @test logpdf(q, z) ≈ logpdf(q_true, z) rtol = realtype(1e-2) @test eltype(logpdf(q, z)) == realtype + + @test logpdf(q, z; non_differntiable=true) ≈ logpdf(q_true, z) rtol = realtype( + 1e-2 + ) + @test eltype(logpdf(q, z; non_differntiable=true)) == realtype end @testset "entropy" begin @@ -41,15 +57,15 @@ @testset "statistics" begin @testset "mean" begin @test eltype(mean(q)) == realtype - @test mean(q) == μ + @test mean(q) == mean(q_true) end @testset "var" begin @test eltype(var(q)) == realtype - @test var(q) ≈ Diagonal(Σ) + @test var(q) ≈ var(q_true) end @testset "cov" begin @test eltype(cov(q)) == realtype - @test cov(q) ≈ Σ + @test cov(q) ≈ cov(q_true) end end @@ -57,11 +73,13 @@ @testset "rand" begin z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_sample_ref = rand(StableRNG(1), q) @test z_sample_ref == rand(StableRNG(1), q) @@ -70,11 +88,13 @@ @testset "rand batch" begin z_samples = rand(q, n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) samples_ref = rand(StableRNG(1), q, n_montecarlo) @test samples_ref == rand(StableRNG(1), q, n_montecarlo) @@ -89,11 +109,13 @@ z_samples = mapreduce(first, hcat, res) z_samples_ret = mapreduce(last, hcat, res) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_sample_ref = Array{realtype}(undef, n_dims) rand!(StableRNG(1), q, z_sample_ref) @@ -107,11 +129,13 @@ z_samples = Array{realtype}(undef, n_dims, n_montecarlo) z_samples_ret = rand!(q, z_samples) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) rand!(StableRNG(1), q, z_samples_ref) @@ -127,12 +151,12 @@ @testset "$(realtype) $(bijector)" for realtype in [Float32, Float64], bijector in [nothing, :identity] - rank = 2 + n_rank = 2 d = 5 μ = zeros(realtype, d) ϵ = sqrt(realtype(0.5)) D = ones(realtype, d) - U = randn(realtype, d, rank) + U = randn(realtype, d, n_rank) q = MvLocationScaleLowRank( μ, D, U, Normal{realtype}(zero(realtype), one(realtype)); scale_eps=ϵ ) @@ -148,7 +172,7 @@ opt_st = Optimisers.setup(Descent(one(realtype)), λ) _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) q′ = re(λ′) - @test all(diag(var(q′)) .≥ ϵ^2) + @test all(var(q′) .≥ ϵ^2) end end end From 3cc9e8013958464a25e9b7a5cf9eba5cfff37cf1 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 5 Sep 2024 13:44:34 -0700 Subject: [PATCH 21/37] format weird indentation in test `for` loops --- test/families/location_scale.jl | 6 ++---- test/families/location_scale_low_rank.jl | 3 +-- test/inference/repgradelbo_distributionsad.jl | 3 +-- test/inference/repgradelbo_locationscale.jl | 3 +-- test/inference/repgradelbo_locationscale_bijectors.jl | 3 +-- 5 files changed, 6 insertions(+), 12 deletions(-) diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl index bd45458d..6e2992a4 100644 --- a/test/families/location_scale.jl +++ b/test/families/location_scale.jl @@ -1,7 +1,6 @@ @testset "interface LocationScale" begin - @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in - [:gaussian, :gaussian_nonstd], + @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian, :gaussian_nonstd], covtype in [:meanfield, :fullrank], realtype in [Float32, Float64] @@ -144,8 +143,7 @@ end @testset "scale positive definite projection" begin - @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in - [:meanfield, :fullrank], + @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in [:meanfield, :fullrank], realtype in [Float32, Float64], bijector in [nothing, :identity] diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 5b0b53cb..44f542c4 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -1,7 +1,6 @@ @testset "interface LocationScaleLowRank" begin - @testset "$(basedist) rank=$(rank) $(realtype)" for basedist in - [:gaussian, :gaussian_nonstd], + @testset "$(basedist) rank=$(rank) $(realtype)" for basedist in [:gaussian, :gaussian_nonstd], n_rank in [1, 2], realtype in [Float32, Float64] diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 815981c7..bfce495b 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -19,8 +19,7 @@ if @isdefined(Tapir) end @testset "inference RepGradELBO DistributionsAD" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in - [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield), n_montecarlo in [1, 10], (objname, objective) in Dict( diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index dc643c74..d0e7b6d4 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -20,8 +20,7 @@ if @isdefined(Tapir) end @testset "inference RepGradELBO VILocationScale" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in - [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), n_montecarlo in [1, 10], diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 35355478..ff37b82a 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -20,8 +20,7 @@ if @isdefined(Tapir) end @testset "inference RepGradELBO VILocationScale Bijectors" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in - [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), n_montecarlo in [1, 10], From 0481ddac0947d28a6943c9d848454b31add20c46 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 5 Sep 2024 16:27:01 -0700 Subject: [PATCH 22/37] update docs add example for `LocationScaleLowRank` --- docs/src/families.md | 124 +++++++++++++++++++++++- src/families/location_scale_low_rank.jl | 6 +- 2 files changed, 123 insertions(+), 7 deletions(-) diff --git a/docs/src/families.md b/docs/src/families.md index 7fcd400a..dbd71fa4 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -3,7 +3,7 @@ The [RepGradELBO](@ref repgradelbo) objective assumes that the members of the variational family have a differentiable sampling path. We provide multiple pre-packaged variational families that can be readily used. -## The `LocationScale` Family +## [The `LocationScale` Family](@id locscale) The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as @@ -38,6 +38,8 @@ where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``. The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. +### API + !!! note For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. @@ -128,14 +130,128 @@ and the entropy is given by the matrix determinant lemma as where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. -!!! note - - `logpdf` for `LocationScaleLowRank` is unfortunately not computationally efficient and has the same time complexity as `LocationScale` with a full-rank scale. +```@setup lowrank +using ADTypes +using AdvancedVI +using Distributions +using ReverseDiff +using LinearAlgebra +using LogDensityProblems +using Plots + +struct Target{D} + dist::D +end + +function LogDensityProblems.logdensity(model::Target, θ) + logpdf(model.dist, θ) +end + +function LogDensityProblems.dimension(model::Target) + return length(model.dist) +end + +function LogDensityProblems.capabilities(::Type{<:Target}) + return LogDensityProblems.LogDensityOrder{0}() +end + +n_dims = 30 +U_true = randn(n_dims, 3) +D_true = Diagonal(log.(1 .+ exp.(randn(n_dims)))) +Σ_true = D_true + U_true*U_true' +Σsqrt_true = sqrt(Σ_true) +μ_true = randn(n_dims) +model = Target(MvNormal(μ_true, Σ_true)); + +d = LogDensityProblems.dimension(model); +μ = zeros(d); + +L = Diagonal(ones(d)); +q0_mf = MeanFieldGaussian(μ, L) + +L = LowerTriangular(diagm(ones(d))); +q0_fr = FullRankGaussian(μ, L) + +D = ones(n_dims) +U = zeros(n_dims, 3) +q0_lr = LowRankGaussian(μ, D, U) + +obj = RepGradELBO(1); + +max_iter = 10^4 + +function callback(; params, averaged_params, restructure, stat, kwargs...) + q = restructure(averaged_params) + μ, Σ = mean(q), cov(q) + (dist2 = sum(abs2, μ - μ_true) + tr(Σ + Σ_true - 2*sqrt(Σsqrt_true*Σ*Σsqrt_true)),) +end + +_, _, stats_fr, _ = AdvancedVI.optimize( + model, + obj, + q0_fr, + max_iter; + show_progress = false, + adtype = AutoReverseDiff(), + optimizer = DoG(), + averager = PolynomialAveraging(), + callback = callback, +); + +_, _, stats_mf, _ = AdvancedVI.optimize( + model, + obj, + q0_mf, + max_iter; + show_progress = false, + adtype = AutoReverseDiff(), + optimizer = DoG(), + averager = PolynomialAveraging(), + callback = callback, +); + +_, _, stats_lr, _ = AdvancedVI.optimize( + model, + obj, + q0_lr, + max_iter; + show_progress = false, + adtype = AutoReverseDiff(), + optimizer = DoG(), + averager = PolynomialAveraging(), + callback = callback, +); + +t = [stat.iteration for stat in stats_fr] +dist_fr = [sqrt(stat.dist2) for stat in stats_fr] +dist_mf = [sqrt(stat.dist2) for stat in stats_mf] +dist_lr = [sqrt(stat.dist2) for stat in stats_lr] +plot( t, dist_mf , label="Mean-Field Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance") +plot!(t, dist_fr, label="Full-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance") +plot!(t, dist_lr, label="Low-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance") +savefig("lowrank_family_wasserstein.svg") +nothing +``` + +Consider a 30-dimensional Gaussian with a diagonal plus low-rank covariance structure, where the true rank is 3. +Then, we can compare the convergence speed of `LowRankGaussian` versus `FullRankGaussian`: + +![](lowrank_family_wasserstein.svg) + +As we can see, `LowRankGaussian` converges faster than `FullRankGaussian`. +While `FullRankGaussian` can converge to the true solution since it is a more expressive variational family, `LowRankGaussian` gets there faster. + +### API ```@docs MvLocationScaleLowRank ``` +The `logpdf` of `MvLocationScaleLowRank` has an optional argument `non_differentiable::Bool` (default: `false`). +If set as `true`, a more efficient ``O\left(r d^2\right)`` implementation is used to evaluate the density. +This, however, is not differentiable under most AD frameworks due to the use of Cholesky `lowrankupdate`. +The default value is `false`, which uses a ``O\left(d^3\right)`` implementation, is differentiable and therefore compatible with the `StickingTheLandingEntropy` estimator. + The following is a specialized constructor for convenience: ```@docs diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index 1359dc33..f9b671d0 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -20,9 +20,9 @@ represented as follows: ```julia d = length(location) r = size(scale_factors, 2) - u_d = rand(dist, d) - u_f = rand(dist, r) - z = scale_diag.*u_d + scale_factors*u_f + location + u_diag = rand(dist, d) + u_factors = rand(dist, r) + z = scale_diag.*u_diag + scale_factors*u_factors + location ``` `scale_eps` sets a constraint on the smallest value of `scale_diag` to be enforced during optimization. From 8449402024cf76ea9d2ce20b16f669434017161e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 5 Sep 2024 19:30:08 -0700 Subject: [PATCH 23/37] fix docs warn about divergence when using `MvLocationScaleLowRank` --- docs/src/families.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/docs/src/families.md b/docs/src/families.md index dbd71fa4..a618a595 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -134,10 +134,11 @@ where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. using ADTypes using AdvancedVI using Distributions -using ReverseDiff using LinearAlgebra using LogDensityProblems +using Optimisers using Plots +using ReverseDiff struct Target{D} dist::D @@ -193,7 +194,7 @@ _, _, stats_fr, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoReverseDiff(), - optimizer = DoG(), + optimizer = Adam(0.01), averager = PolynomialAveraging(), callback = callback, ); @@ -205,7 +206,7 @@ _, _, stats_mf, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoReverseDiff(), - optimizer = DoG(), + optimizer = Adam(0.01), averager = PolynomialAveraging(), callback = callback, ); @@ -217,7 +218,7 @@ _, _, stats_lr, _ = AdvancedVI.optimize( max_iter; show_progress = false, adtype = AutoReverseDiff(), - optimizer = DoG(), + optimizer = Adam(0.01), averager = PolynomialAveraging(), callback = callback, ); @@ -241,6 +242,11 @@ Then, we can compare the convergence speed of `LowRankGaussian` versus `FullRank As we can see, `LowRankGaussian` converges faster than `FullRankGaussian`. While `FullRankGaussian` can converge to the true solution since it is a more expressive variational family, `LowRankGaussian` gets there faster. +!!! info + `MvLocationScaleLowRank` tend to work better with the `Optimisers.Adam` optimizer due to non-smoothness. + Other optimisers may experience divergences. + + ### API ```@docs From e196da6e6cd7651f9165420e3c841df9d1832f03 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:06:06 +0100 Subject: [PATCH 24/37] Update Benchmark.yml --- .github/workflows/Benchmark.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/Benchmark.yml b/.github/workflows/Benchmark.yml index 27f4091b..3af335de 100644 --- a/.github/workflows/Benchmark.yml +++ b/.github/workflows/Benchmark.yml @@ -13,6 +13,7 @@ concurrency: permissions: contents: write pull-requests: write + issues: write jobs: benchmark: From e4bff6723ca1fc07a15a83f7881461589e47ec84 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:21:26 +0100 Subject: [PATCH 25/37] disable more features for PRs from forks --- .github/workflows/Benchmark.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/Benchmark.yml b/.github/workflows/Benchmark.yml index 3af335de..b161782d 100644 --- a/.github/workflows/Benchmark.yml +++ b/.github/workflows/Benchmark.yml @@ -48,10 +48,10 @@ jobs: name: Benchmark Results tool: 'julia' output-file-path: bench/benchmark_results.json - summary-always: true + summary-always: ${{ !github.event.pull_request.head.repo.fork }} # Disable summary for PRs from forks github-token: ${{ secrets.GITHUB_TOKEN }} - comment-always: true alert-threshold: "200%" fail-on-alert: true benchmark-data-dir-path: benchmarks + comment-always: ${{ !github.event.pull_request.head.repo.fork }} # Disable comments for PRs from forks auto-push: ${{ !github.event.pull_request.head.repo.fork }} # Disable push for PRs from forks From 894a84959f712c2d7f731528278c23a0ec6cb28c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 10:04:27 -0700 Subject: [PATCH 26/37] fix `LocationScale` interfaces to only allow univariate base dist --- src/families/location_scale.jl | 8 ++++---- src/families/location_scale_low_rank.jl | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index a8f3ef2c..f125032d 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -30,8 +30,8 @@ This is necessary to guarantee stable convergence. function MvLocationScale( location::AbstractVector{T}, scale::AbstractMatrix{T}, - dist::ContinuousDistribution; - scale_eps::T=sqrt(eps(T)), + dist::ContinuousUnivariateDistribution; + scale_eps::T=eps(T)^(1//4), ) where {T<:Real} @assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior." return MvLocationScale(location, scale, dist, scale_eps) @@ -143,7 +143,7 @@ Construct a Gaussian variational approximation with a dense covariance matrix. - `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). """ function FullRankGaussian( - μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=sqrt(eps(T)) + μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=eps(T)^(1//4) ) where {T<:Real} q_base = Normal{T}(zero(T), one(T)) return MvLocationScale(μ, L, q_base, scale_eps) @@ -162,7 +162,7 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix - `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). """ function MeanFieldGaussian( - μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=sqrt(eps(T)) + μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=eps(T)^(1//4) ) where {T<:Real} q_base = Normal{T}(zero(T), one(T)) return MvLocationScale(μ, L, q_base, scale_eps) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index f9b671d0..3eb810ed 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -35,8 +35,8 @@ function MvLocationScaleLowRank( location::AbstractVector{T}, scale_diag::AbstractVector{T}, scale_factors::AbstractMatrix{T}, - dist::ContinuousDistribution; - scale_eps::T=sqrt(eps(T)), + dist::ContinuousUnivariateDistribution; + scale_eps::T=eps(T)^(1//4), ) where {T<:Real} @assert minimum(scale_diag) ≥ scale_eps "Initial scale is too small (smallest diagonal scale value is $(minimum(scale_diag)). This might result in unstable optimization behavior." @assert size(scale_factors, 1) == length(scale_diag) @@ -156,7 +156,7 @@ function update_variational_params!( end """ - LowRankGaussian(location, scale_diag, scale_factors; check_args = true) + LowRankGaussian(μ, D, U; scale_eps) Construct a Gaussian variational approximation with a diagonal plus low-rank covariance matrix. @@ -169,7 +169,7 @@ Construct a Gaussian variational approximation with a diagonal plus low-rank cov - `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). """ function LowRankGaussian( - μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=sqrt(eps(T)) + μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=eps(T)^(1//4) ) where {T<:Real} q_base = Normal{T}(zero(T), one(T)) return MvLocationScaleLowRank(μ, D, U, q_base; scale_eps) From ce6793c400937ebb1ba94e18101236648226f32b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 10:34:11 -0700 Subject: [PATCH 27/37] fix test comparison operator for families Co-authored-by: Markus Hauru --- test/families/location_scale_low_rank.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 44f542c4..3b6afdeb 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -56,7 +56,7 @@ @testset "statistics" begin @testset "mean" begin @test eltype(mean(q)) == realtype - @test mean(q) == mean(q_true) + @test mean(q) ≈ mean(q_true) end @testset "var" begin @test eltype(var(q)) == realtype From 71aeb5a49dd01c2d46485003f60045ff8aaa566c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 10:34:25 -0700 Subject: [PATCH 28/37] fix test comparison operator for families Co-authored-by: Markus Hauru --- test/families/location_scale_low_rank.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 3b6afdeb..538854cf 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -81,7 +81,7 @@ @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_sample_ref = rand(StableRNG(1), q) - @test z_sample_ref == rand(StableRNG(1), q) + @test z_sample_ref ≈ rand(StableRNG(1), q) end @testset "rand batch" begin From 77ace2bc1dda64694485aa223ca56713220d2617 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 10:34:40 -0700 Subject: [PATCH 29/37] fix test comparison operator for families Co-authored-by: Markus Hauru --- test/families/location_scale_low_rank.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 538854cf..07605e01 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -107,7 +107,7 @@ end z_samples = mapreduce(first, hcat, res) z_samples_ret = mapreduce(last, hcat, res) - @test z_samples == z_samples_ret + @test z_samples ≈ z_samples_ret @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( 1e-2 ) From 641de3996e3b9666f3a1791d7207e8f916281576 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 10:34:53 -0700 Subject: [PATCH 30/37] fix test comparison operator for families Co-authored-by: Markus Hauru --- test/families/location_scale_low_rank.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 07605e01..3d144148 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -96,7 +96,7 @@ @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) samples_ref = rand(StableRNG(1), q, n_montecarlo) - @test samples_ref == rand(StableRNG(1), q, n_montecarlo) + @test samples_ref ≈ rand(StableRNG(1), q, n_montecarlo) end @testset "rand! AbstractVector" begin From a58f209d14a17780d16778dff7ec1b285c6d3a4c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 10:35:05 -0700 Subject: [PATCH 31/37] fix test comparison operator for families Co-authored-by: Markus Hauru --- test/families/location_scale_low_rank.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 3d144148..2d1e4940 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -121,7 +121,7 @@ z_sample = Array{realtype}(undef, n_dims) rand!(StableRNG(1), q, z_sample) - @test z_sample_ref == z_sample + @test z_sample_ref ≈ z_sample end @testset "rand! AbstractMatrix" begin From 846b259948cb5bc70ed540aea73eb8e17fcbb62d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 10:35:16 -0700 Subject: [PATCH 32/37] fix test comparison operator for families Co-authored-by: Markus Hauru --- test/families/location_scale_low_rank.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 2d1e4940..3ed33efd 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -127,7 +127,7 @@ @testset "rand! AbstractMatrix" begin z_samples = Array{realtype}(undef, n_dims, n_montecarlo) z_samples_ret = rand!(q, z_samples) - @test z_samples == z_samples_ret + @test z_samples ≈ z_samples_ret @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( 1e-2 ) From 1116f68f51078133fc1688550ec0faef0253a6be Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 10:35:27 -0700 Subject: [PATCH 33/37] fix test comparison operator for families Co-authored-by: Markus Hauru --- test/families/location_scale_low_rank.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index 3ed33efd..e98192ac 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -141,7 +141,7 @@ z_samples = Array{realtype}(undef, n_dims, n_montecarlo) rand!(StableRNG(1), q, z_samples) - @test z_samples_ref == z_samples + @test z_samples_ref ≈ z_samples end end end From 42d730d367565ab2e4d09dea425ae5e1b11e89fb Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 11:23:34 -0700 Subject: [PATCH 34/37] fix formatting --- docs/src/families.md | 2 +- test/inference/repgradelbo_distributionsad.jl | 3 ++- test/inference/repgradelbo_locationscale.jl | 3 ++- test/inference/repgradelbo_locationscale_bijectors.jl | 3 ++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/src/families.md b/docs/src/families.md index a618a595..e270acad 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -243,9 +243,9 @@ As we can see, `LowRankGaussian` converges faster than `FullRankGaussian`. While `FullRankGaussian` can converge to the true solution since it is a more expressive variational family, `LowRankGaussian` gets there faster. !!! info + `MvLocationScaleLowRank` tend to work better with the `Optimisers.Adam` optimizer due to non-smoothness. Other optimisers may experience divergences. - ### API diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index e8bebf9d..94da09bc 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -14,7 +14,8 @@ if @isdefined(Enzyme) end @testset "inference RepGradELBO DistributionsAD" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield), n_montecarlo in [1, 10], (objname, objective) in Dict( diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 5ce92809..9e254b6a 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -14,7 +14,8 @@ if @isdefined(Enzyme) end @testset "inference RepGradELBO VILocationScale" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], (modelname, modelconstr) in Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), n_montecarlo in [1, 10], diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 0fbe5ab7..731326f3 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -14,7 +14,8 @@ if @isdefined(Enzyme) end @testset "inference RepGradELBO VILocationScale Bijectors" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], (modelname, modelconstr) in Dict(:NormalLogNormalMeanField => normallognormal_meanfield), n_montecarlo in [1, 10], From 99d08c5f0ca01b32c6ae17d0baac1b33d884599e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 11 Sep 2024 13:33:07 -0700 Subject: [PATCH 35/37] fix formatting --- test/families/location_scale.jl | 6 ++++-- test/families/location_scale_low_rank.jl | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl index 6e2992a4..bd45458d 100644 --- a/test/families/location_scale.jl +++ b/test/families/location_scale.jl @@ -1,6 +1,7 @@ @testset "interface LocationScale" begin - @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian, :gaussian_nonstd], + @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in + [:gaussian, :gaussian_nonstd], covtype in [:meanfield, :fullrank], realtype in [Float32, Float64] @@ -143,7 +144,8 @@ end @testset "scale positive definite projection" begin - @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in [:meanfield, :fullrank], + @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in + [:meanfield, :fullrank], realtype in [Float32, Float64], bijector in [nothing, :identity] diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl index e98192ac..2accb971 100644 --- a/test/families/location_scale_low_rank.jl +++ b/test/families/location_scale_low_rank.jl @@ -1,6 +1,7 @@ @testset "interface LocationScaleLowRank" begin - @testset "$(basedist) rank=$(rank) $(realtype)" for basedist in [:gaussian, :gaussian_nonstd], + @testset "$(basedist) rank=$(rank) $(realtype)" for basedist in + [:gaussian, :gaussian_nonstd], n_rank in [1, 2], realtype in [Float32, Float64] From 4a90c5d37ac377a089255f427f2eca64887fbc11 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 12 Sep 2024 12:14:51 -0700 Subject: [PATCH 36/37] fix scale lower bound to `1e-4` --- src/families/location_scale.jl | 12 ++++++------ src/families/location_scale_low_rank.jl | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index f125032d..22af4b4a 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -25,13 +25,13 @@ represented as follows: This is necessary to guarantee stable convergence. # Keyword Arguments -- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `sqrt(eps(T))`). +- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `1e-4`). """ function MvLocationScale( location::AbstractVector{T}, scale::AbstractMatrix{T}, dist::ContinuousUnivariateDistribution; - scale_eps::T=eps(T)^(1//4), + scale_eps::T=T(1e-4), ) where {T<:Real} @assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior." return MvLocationScale(location, scale, dist, scale_eps) @@ -140,10 +140,10 @@ Construct a Gaussian variational approximation with a dense covariance matrix. - `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian. # Keyword Arguments -- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). +- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`). """ function FullRankGaussian( - μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=eps(T)^(1//4) + μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=T(1e-4) ) where {T<:Real} q_base = Normal{T}(zero(T), one(T)) return MvLocationScale(μ, L, q_base, scale_eps) @@ -159,10 +159,10 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix - `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian. # Keyword Arguments -- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). +- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`). """ function MeanFieldGaussian( - μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=eps(T)^(1//4) + μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=T(1e-4) ) where {T<:Real} q_base = Normal{T}(zero(T), one(T)) return MvLocationScale(μ, L, q_base, scale_eps) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index 3eb810ed..760b38fe 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -36,7 +36,7 @@ function MvLocationScaleLowRank( scale_diag::AbstractVector{T}, scale_factors::AbstractMatrix{T}, dist::ContinuousUnivariateDistribution; - scale_eps::T=eps(T)^(1//4), + scale_eps::T=T(1e-4), ) where {T<:Real} @assert minimum(scale_diag) ≥ scale_eps "Initial scale is too small (smallest diagonal scale value is $(minimum(scale_diag)). This might result in unstable optimization behavior." @assert size(scale_factors, 1) == length(scale_diag) @@ -169,7 +169,7 @@ Construct a Gaussian variational approximation with a diagonal plus low-rank cov - `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). """ function LowRankGaussian( - μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=eps(T)^(1//4) + μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=T(1e-4) ) where {T<:Real} q_base = Normal{T}(zero(T), one(T)) return MvLocationScaleLowRank(μ, D, U, q_base; scale_eps) From c41709befa5e32d0b2f66ada8210ed8510e168fc Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 12 Sep 2024 12:15:14 -0700 Subject: [PATCH 37/37] fix docstring for `LowRankGaussian` --- src/families/location_scale_low_rank.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl index 760b38fe..e2044142 100644 --- a/src/families/location_scale_low_rank.jl +++ b/src/families/location_scale_low_rank.jl @@ -166,7 +166,7 @@ Construct a Gaussian variational approximation with a diagonal plus low-rank cov - `U::Matrix{T}`: Low-rank factors of the scale, where `size(U,2)` is the rank. # Keyword Arguments -- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `sqrt(eps(T))`). +- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`). """ function LowRankGaussian( μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=T(1e-4)