Skip to content

Commit

Permalink
fix reduce computational cost of tests, use more sophisticated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Jun 13, 2024
1 parent 47aba44 commit b927e80
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 57 deletions.
19 changes: 13 additions & 6 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
T = 1000
η = 1e-3
opt = Optimisers.Descent(realtype(η))

# For small enough η, the error of SGD, Δλ, is bounded as
# Δλ ≤ ρ^T Δλ0 + O(η),
# where ρ = 1 - ημ, μ is the strong convexity constant.
contraction_rate = 1 - η*strong_convexity

μ0 = Zeros(realtype, n_dims)
L0 = Diagonal(Ones(realtype, n_dims))
Expand All @@ -33,7 +40,7 @@
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
optimizer = opt,
show_progress = PROGRESS,
adtype = adtype,
)
Expand All @@ -42,7 +49,7 @@
L = sqrt(cov(q))
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ0/T^(1/4)
@test Δλ contraction_rate^(T/2)*Δλ0
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand All @@ -51,7 +58,7 @@
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
optimizer = opt,
show_progress = PROGRESS,
adtype = adtype,
)
Expand All @@ -61,7 +68,7 @@
rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
optimizer = opt,
show_progress = PROGRESS,
adtype = adtype,
)
Expand Down
19 changes: 13 additions & 6 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,16 @@
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
T = 1000
η = 1e-3
opt = Optimisers.Descent(realtype(η))

# For small enough η, the error of SGD, Δλ, is bounded as
# Δλ ≤ ρ^T Δλ0 + O(η),
# where ρ = 1 - ημ, μ is the strong convexity constant.
contraction_rate = 1 - η*strong_convexity

q0 = if is_meanfield
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
Expand All @@ -37,7 +44,7 @@
Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
optimizer = opt,
show_progress = PROGRESS,
adtype = adtype,
)
Expand All @@ -46,7 +53,7 @@
L = q.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ0/T^(1/4)
@test Δλ contraction_rate^(T/2)*Δλ0
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand All @@ -55,7 +62,7 @@
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
optimizer = opt,
show_progress = PROGRESS,
adtype = adtype,
)
Expand All @@ -65,7 +72,7 @@
rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
optimizer = opt,
show_progress = PROGRESS,
adtype = adtype,
)
Expand Down
19 changes: 13 additions & 6 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
T = 1000
η = 1e-3
opt = Optimisers.Descent(realtype(η))

b = Bijectors.bijector(model)
b⁻¹ = inverse(b)
Expand All @@ -38,11 +40,16 @@
end
q0_z = Bijectors.transformed(q0_η, b⁻¹)

# For small enough η, the error of SGD, Δλ, is bounded as
# Δλ ≤ ρ^T Δλ0 + O(η),
# where ρ = 1 - ημ, μ is the strong convexity constant.
contraction_rate = 1 - η*strong_convexity

@testset "convergence" begin
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
optimizer = opt,
show_progress = PROGRESS,
adtype = adtype,
)
Expand All @@ -51,7 +58,7 @@
L = q.dist.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ0/T^(1/4)
@test Δλ contraction_rate^(T/2)*Δλ0
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand All @@ -60,7 +67,7 @@
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
optimizer = opt,
show_progress = PROGRESS,
adtype = adtype,
)
Expand All @@ -70,7 +77,7 @@
rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
optimizer = opt
show_progress = PROGRESS,
adtype = adtype,
)
Expand Down
18 changes: 11 additions & 7 deletions test/models/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,28 @@ end
function normal_fullrank(rng::Random.AbstractRNG, realtype::Type)
n_dims = 5

μ = randn(rng, realtype, n_dims)
L = tril(I + ones(realtype, n_dims, n_dims))/2
Σ = L*L' |> Hermitian
σ0 = realtype(0.3)
μ = Fill(realtype(5), n_dims)
L = Matrix(σ0*I, n_dims, n_dims)
Σ = L*L' |> Hermitian

model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)))

TestModel(model, μ, L, n_dims, false)
TestModel(model, μ, LowerTriangular(L), n_dims, 1/σ0^2, false)
end

function normal_meanfield(rng::Random.AbstractRNG, realtype::Type)
n_dims = 5

μ = randn(rng, realtype, n_dims)
σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
σ0 = realtype(0.3)
μ = Fill(realtype(5), n_dims)
#randn(rng, realtype, n_dims)
σ = Fill(σ0, n_dims)
#log.(exp.(randn(rng, realtype, n_dims)) .+ 1)

model = TestNormal(μ, Diagonal.^2))

L = σ |> Diagonal

TestModel(model, μ, L, n_dims, true)
TestModel(model, μ, L, n_dims, 1/σ0^2, true)
end
51 changes: 20 additions & 31 deletions test/models/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,29 @@ function Bijectors.bijector(model::NormalLogNormal)
[1:1, 2:1+length(μ_y)])
end

function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type)
n_dims = 5

μ_x = randn(rng, realtype)
σ_x =
μ_y = randn(rng, realtype, n_dims)
L_y = tril(I + ones(realtype, n_dims, n_dims))/2
Σ_y = L_y*L_y' |> Hermitian

model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0)))

Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)
Σ[1,1] = σ_x^2
Σ[2:end,2:end] = Σ_y
Σ = Σ |> Hermitian

μ = vcat(μ_x, μ_y)
L = cholesky(Σ).L

TestModel(model, μ, L, n_dims+1, false)
function normallognormal_fullrank(::Random.AbstractRNG, realtype::Type)
n_y_dims = 5

σ0 = realtype(0.3)
μ = Fill(realtype(5.0), n_y_dims+1)
L = Matrix(σ0*I, n_y_dims+1, n_y_dims+1)
Σ = L*L' |> Hermitian

model = NormalLogNormal(
μ[1], L[1,1], μ[2:end], PDMat(Σ[2:end,2:end], Cholesky(L[2:end,2:end], 'L', 0))
)
TestModel(model, μ, LowerTriangular(L), n_y_dims+1, 1/σ0^2, false)
end

function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type)
n_dims = 5

μ_x = randn(rng, realtype)
σ_x =
μ_y = randn(rng, realtype, n_dims)
σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
function normallognormal_meanfield(::Random.AbstractRNG, realtype::Type)
n_y_dims = 5

model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
σ0 = realtype(0.3)
μ = Fill(realtype(5), n_y_dims + 1)
σ = Fill(σ0, n_y_dims + 1)
L = Diagonal(σ)

μ = vcat(μ_x, μ_y)
L = vcat(σ_x, σ_y) |> Diagonal
model = NormalLogNormal(μ[1], σ[1], μ[2:end], Diagonal(σ[2:end].^2))

TestModel(model, μ, L, n_dims+1, true)
TestModel(model, μ, L, n_y_dims+1, 1/σ0^2, true)
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ using AdvancedVI
const GROUP = get(ENV, "GROUP", "All")

# Models for Inference Tests
struct TestModel{M,L,S}
struct TestModel{M,L,S,SC}
model::M
μ_true::L
L_true::S
n_dims::Int
strong_convexity::SC
is_meanfield::Bool
end
include("models/normal.jl")
Expand Down

1 comment on commit b927e80

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: b927e80 Previous: d44e36d Ratio
normal + bijector/meanfield/ForwardDiff 558006515.5 ns 558668726 ns 1.00
normal + bijector/meanfield/ReverseDiff 189274695 ns 190422991 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.