Skip to content

Commit

Permalink
Fix Enzyme extension (#79)
Browse files Browse the repository at this point in the history
* Fix Enzyme extension

* Enable Enzyme tests

* Fix format

* Fix format

* Do not test on Julia nightly
  • Loading branch information
devmotion authored Aug 7, 2024
1 parent 94cdd07 commit 1b36c6e
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 45 deletions.
33 changes: 0 additions & 33 deletions .github/workflows/JuliaNightly.yml

This file was deleted.

1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://turinglang.org/AdvancedVI.jl/stable/)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://turinglang.org/AdvancedVI.jl/dev/)
[![Build Status](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/CI.yml?query=branch%3Amaster)
[![JuliaNightly](https://github.com/TuringLang/AdvancedVI.jl/workflows/JuliaNightly/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedVI.jl/actions?query=workflow%3AJuliaNightly+branch%3Amaster)
[![Coverage](https://codecov.io/gh/TuringLang/AdvancedVI.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedVI.jl)

# AdvancedVI.jl
Expand Down
16 changes: 11 additions & 5 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ else
using ..AdvancedVI: ADTypes, DiffResults
end

# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916)
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
ad::ADTypes.AutoEnzyme,
f,
θ::AbstractVector{T},
out::DiffResults.MutableDiffResult,
) where {T<:Real}
y = f(θ)
DiffResults.value!(out, y)
∇θ = DiffResults.gradient(out)
fill!(∇θ, zero(T))
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
_, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal,
f,
Enzyme.Active,
Enzyme.Duplicated(θ, ∇θ),
)
DiffResults.value!(out, y)
return out
end

Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -26,6 +27,7 @@ ADTypes = "0.2.1, 1"
Bijectors = "0.13"
Distributions = "0.25.100"
DistributionsAD = "0.6.45"
Enzyme = "0.12"
FillArrays = "1.6.1"
ForwardDiff = "0.10.36"
Functors = "0.4.5"
Expand Down
10 changes: 5 additions & 5 deletions test/interface/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ using Test

@testset "ad" begin
@testset "$(adname)" for (adname, adsymbol) Dict(
:ForwardDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
# :Enzyme => AutoEnzyme() # Currently not tested against
)
:ForwardDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
)
D = 10
A = randn(D, D)
λ = randn(D)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using DistributionsAD
using LogDensityProblems
using Optimisers
using ADTypes
using ForwardDiff, ReverseDiff, Zygote
using Enzyme, ForwardDiff, ReverseDiff, Zygote

using AdvancedVI

Expand Down

0 comments on commit 1b36c6e

Please sign in to comment.