Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scale precision to gaussian family #218

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ExponentialFamily.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ include("distributions/normal_family/normal_weighted_mean_precision.jl")
include("distributions/normal_family/mv_normal_mean_covariance.jl")
include("distributions/normal_family/mv_normal_mean_precision.jl")
include("distributions/normal_family/mv_normal_weighted_mean_precision.jl")
include("distributions/normal_family/normal_family.jl")
include("distributions/normal_family/mv_normal_mean_scale_precision.jl")
include("distributions/normal_family/normal_family.jl")
include("distributions/gamma_inverse.jl")
include("distributions/geometric.jl")
include("distributions/matrix_dirichlet.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/matrix_dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end
isproper(::NaturalParametersSpace, ::Type{MatrixDirichlet}, η, conditioner) =
isnothing(conditioner) && length(η) > 1 && all(isless.(-1, η)) && all(!isinf, η) && all(!isnan, η)
isproper(::MeanParametersSpace, ::Type{MatrixDirichlet}, θ, conditioner) =
isnothing(conditioner) && length(θ) > 1 &&all(>(0), θ) && all(!isinf, θ) && all(!isnan, θ)
isnothing(conditioner) && length(θ) > 1 && all(>(0), θ) && all(!isinf, θ) && all(!isnan, θ)

function (::MeanToNatural{MatrixDirichlet})(tuple_of_θ::Tuple{Any})
(α,) = tuple_of_θ
Expand Down
47 changes: 16 additions & 31 deletions src/distributions/normal_family/mv_normal_mean_scale_precision.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export MvNormalMeanScalePrecision, MvGaussianMeanScalePrecision
export MvNormalMeanScalePrecision

import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal
import LinearAlgebra: diag, Diagonal, dot
Expand Down Expand Up @@ -26,8 +26,6 @@ struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: Abstract
γ::T
end

const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision

function MvNormalMeanScalePrecision(μ::AbstractVector{<:Real}, γ::Real)
T = promote_type(eltype(μ), eltype(γ))
return MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ))
Expand Down Expand Up @@ -66,7 +64,7 @@ end

function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any})
(μ, γ) = tuple_of_θ
return (γ * μ, - γ / 2)
return (γ * μ, -γ / 2)
end

function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, Any})
Expand Down Expand Up @@ -170,25 +168,13 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::MvNormalMeanScal
return MvNormalMeanScalePrecision(m, w)
end

BayesBase.default_prod_rule(::Type{<:MultivariateNormalDistributionsFamily}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution)

function BayesBase.prod(
::PreserveTypeProd{Distribution},
left::L,
right::R
) where {L <: MultivariateNormalDistributionsFamily, R <: MvNormalMeanScalePrecision}
wleft = convert(MvNormalWeightedMeanPrecision, left)
wright = convert(MvNormalWeightedMeanPrecision, right)
return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright)
end

function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T}
function BayesBase.rand(rng::AbstractRNG, dist::MvNormalMeanScalePrecision{T}) where {T}
μ, γ = params(dist)
d = length(μ)
return rand!(rng, dist, Vector{T}(undef, d))
end

function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}, size::Int64) where {T}
function BayesBase.rand(rng::AbstractRNG, dist::MvNormalMeanScalePrecision{T}, size::Int64) where {T}
container = Matrix{T}(undef, length(dist), size)
return rand!(rng, dist, container)
end
Expand All @@ -197,7 +183,7 @@ end
# it needs to work with scale method, not with std
function BayesBase.rand!(
rng::AbstractRNG,
dist::MvGaussianMeanScalePrecision,
dist::MvNormalMeanScalePrecision,
container::AbstractArray{T}
) where {T <: Real}
preallocated = similar(container)
Expand Down Expand Up @@ -227,42 +213,41 @@ getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
η2 = η[end]
k = length(η1)
Cinv = inv(η2)
return -dot(η1, 1/4*Cinv, η1) - (k / 2)*log(-2*η2)
return -dot(η1, 1 / 4 * Cinv, η1) - (k / 2) * log(-2 * η2)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(η) -> begin
η1 = @view η[1:end-1]
η2 = η[end]
inv2 = inv(η2)
k = length(η1)
return pack_parameters(MvNormalMeanCovariance, (-1/(2*η2) * η1, dot(η1,η1) / 4*inv2^2 - k/2 * inv2))
return pack_parameters(MvNormalMeanCovariance, (-1 / (2 * η2) * η1, dot(η1, η1) / 4 * inv2^2 - k / 2 * inv2))
end

getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(η) -> begin
η1 = @view η[1:end-1]
η2 = η[end]
k = length(η1)

η1_part = -inv(2*η2)* I(length(η1))
η1_part = -inv(2 * η2) * I(length(η1))
η1η2 = zeros(k, 1)
η1η2 .= η1*inv(2*η2^2)
η2_part = k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3)
η1η2 .= η1 * inv(2 * η2^2)

η2_part = k * inv(2abs2(η2)) - dot(η1, η1) / (2 * η2^3)

return ArrowheadMatrix(η2_part, η1η2, diag(η1_part))
end


getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(θ) -> begin
μ = @view θ[1:end-1]
γ = θ[end]
k = length(μ)

matrix = zeros(eltype(μ), (k+1))
matrix = zeros(eltype(μ), (k + 1))
matrix[1:k] .= γ
matrix[k+1] = k*inv(2abs2(γ))
matrix[k+1] = k * inv(2abs2(γ))
return Diagonal(matrix)
end
5 changes: 3 additions & 2 deletions src/distributions/normal_family/normal_family.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export GaussianMeanVariance, GaussianMeanPrecision, GaussianWeighteMeanPrecision
export MvGaussianMeanCovariance, MvGaussianMeanPrecision, MvGaussianWeightedMeanPrecision
export MvGaussianMeanCovariance, MvGaussianMeanPrecision, MvGaussianWeightedMeanPrecision, MvGaussianMeanScalePrecision
export UnivariateNormalDistributionsFamily, MultivariateNormalDistributionsFamily, NormalDistributionsFamily
export UnivariateGaussianDistributionsFamily, MultivariateGaussianDistributionsFamily, GaussianDistributionsFamily
export JointNormal, JointGaussian
Expand All @@ -10,9 +10,10 @@ const GaussianWeighteMeanPrecision = NormalWeightedMeanPrecision
const MvGaussianMeanCovariance = MvNormalMeanCovariance
const MvGaussianMeanPrecision = MvNormalMeanPrecision
const MvGaussianWeightedMeanPrecision = MvNormalWeightedMeanPrecision
const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision

const UnivariateNormalDistributionsFamily{T} = Union{NormalMeanPrecision{T}, NormalMeanVariance{T}, NormalWeightedMeanPrecision{T}, Normal{T}}
const MultivariateNormalDistributionsFamily{T} = Union{MvNormalMeanPrecision{T}, MvNormalMeanCovariance{T}, MvNormalWeightedMeanPrecision{T}, MvNormal{T}}
const MultivariateNormalDistributionsFamily{T} = Union{MvNormalMeanPrecision{T}, MvNormalMeanCovariance{T}, MvNormalWeightedMeanPrecision{T}, MvNormalMeanScalePrecision{T}, MvNormal{T}}
const NormalDistributionsFamily{T} = Union{UnivariateNormalDistributionsFamily{T}, MultivariateNormalDistributionsFamily{T}}

const UnivariateGaussianDistributionsFamily = UnivariateNormalDistributionsFamily
Expand Down
2 changes: 1 addition & 1 deletion test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,4 +557,4 @@ function test_generic_simple_exponentialfamily_product(
end

return true
end
end
4 changes: 2 additions & 2 deletions test/distributions/matrix_dirichlet_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ end
for space in (MeanParametersSpace(), NaturalParametersSpace())
@test !isproper(space, MatrixDirichlet, [Inf Inf; Inf 1.0])
@test !isproper(space, MatrixDirichlet, [1.0], Inf)
@test !isproper(space, MatrixDirichlet, [NaN],)
@test !isproper(space, MatrixDirichlet, [NaN])
@test !isproper(space, MatrixDirichlet, [1.0], NaN)
@test !isproper(space, MatrixDirichlet, [0.5, 0.5], 1.0)
@test isproper(space, MatrixDirichlet, [2.0, 3.0])
@test !isproper(space, MatrixDirichlet, [-1.0, -1.2])
end
end
@test_throws Exception convert(ExponentialFamilyDistribution, MatrixDirichlet([Inf Inf; 2 3]))
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,4 @@ end
@test cholinv_time_small < cholinv_time_full / (C * k)
@test cholinv_alloc_small < cholinv_alloc_full / (C * k)
end
end
end
4 changes: 2 additions & 2 deletions test/distributions/wishart_inverse_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,10 @@ end
# Test commutativity of the product
prod_result1 = prod(PreserveTypeProd(Distribution), left, right_fast)
prod_result2 = prod(PreserveTypeProd(Distribution), right_fast, left)

@test prod_result1.ν ≈ prod_result2.ν
@test prod_result1.S ≈ prod_result2.S

# Test that the product preserves type
@test prod_result1 isa InverseWishartFast
@test prod_result2 isa InverseWishartFast
Expand Down
5 changes: 2 additions & 3 deletions test/distributions/wishart_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ end
# Test commutativity of the product
prod_result1 = prod(PreserveTypeProd(Distribution), left, right)
prod_result2 = prod(PreserveTypeProd(Distribution), right, left)

@test prod_result1.ν ≈ prod_result2.ν
@test prod_result1.invS ≈ prod_result2.invS

# Test that the product preserves type
@test prod_result1 isa WishartFast
@test prod_result2 isa WishartFast
Expand All @@ -160,4 +160,3 @@ end
end
end
end

Loading