From 9a524d3681aceece0f01955ddf171d2de7e0e113 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 8 Nov 2024 19:32:16 +0100 Subject: [PATCH] Mixed-mode sparse Jacobians (#554) --- DifferentiationInterface/Project.toml | 6 +- DifferentiationInterface/docs/src/api.md | 1 + ...iationInterfaceSparseMatrixColoringsExt.jl | 21 ++ .../hessian.jl | 1 + .../jacobian.jl | 16 -- .../jacobian_mixed.jl | 231 ++++++++++++++++++ .../src/DifferentiationInterface.jl | 3 +- .../src/first_order/mixed_mode.jl | 27 ++ .../src/utils/batchsize.jl | 6 + DifferentiationInterface/src/utils/check.jl | 5 + DifferentiationInterface/src/utils/traits.jl | 9 + .../test/Misc/FromPrimitive/test.jl | 6 +- .../test/Misc/Internals/backends.jl | 25 +- .../test/Misc/Internals/batchsize.jl | 3 + DifferentiationInterfaceTest/Project.toml | 4 +- .../src/DifferentiationInterfaceTest.jl | 1 + .../src/scenarios/scenario.jl | 15 +- 17 files changed, 352 insertions(+), 28 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl create mode 100644 DifferentiationInterface/src/first_order/mixed_mode.jl diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index b2c9b89ea..9dc73a4b8 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.21" +version = "0.6.22" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -61,7 +61,7 @@ ReverseDiff = "1.15.1" SparseArrays = "<0.0.1,1" SparseConnectivityTracer = "0.5.0,0.6" StaticArrays = "1.9.7" -SparseMatrixColorings = "0.4.5" +SparseMatrixColorings = "0.4.9" Symbolics = "5.27.1, 6" Tracker = "0.2.33" Zygote = "0.6.69" @@ -99,4 +99,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"] +test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"] diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 33f1ce8f3..7a76d54c5 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -68,6 +68,7 @@ jacobian jacobian! value_and_jacobian value_and_jacobian! +MixedMode ``` ## Second order diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl index 089c4c7aa..b8523c897 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl @@ -21,6 +21,8 @@ using DifferentiationInterface: PushforwardPerformance, inner, outer, + forward_backend, + reverse_backend, multibasis, pick_batchsize, pushforward_performance, @@ -33,13 +35,32 @@ using SparseMatrixColorings: coloring, column_colors, row_colors, + ncolors, column_groups, row_groups, sparsity_pattern, decompress! import SparseMatrixColorings as SMC +function fy_with_contexts(f, contexts::Vararg{Context,C}) where {C} + return (with_contexts(f, contexts...),) +end + +function fy_with_contexts(f!, y, contexts::Vararg{Context,C}) where {C} + return (with_contexts(f!, contexts...), y) +end + +abstract type SparseJacobianPrep <: JacobianPrep end + +SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result) +SMC.column_colors(prep::SparseJacobianPrep) = column_colors(prep.coloring_result) +SMC.column_groups(prep::SparseJacobianPrep) = column_groups(prep.coloring_result) +SMC.row_colors(prep::SparseJacobianPrep) = row_colors(prep.coloring_result) +SMC.row_groups(prep::SparseJacobianPrep) = row_groups(prep.coloring_result) +SMC.ncolors(prep::SparseJacobianPrep) = ncolors(prep.coloring_result) + include("jacobian.jl") +include("jacobian_mixed.jl") include("hessian.jl") end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index f76544066..4ba0e373c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -19,6 +19,7 @@ end SMC.sparsity_pattern(prep::SparseHessianPrep) = sparsity_pattern(prep.coloring_result) SMC.column_colors(prep::SparseHessianPrep) = column_colors(prep.coloring_result) SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result) +SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result) ## Hessian, one argument diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index 8e642fe73..80aba9651 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -1,21 +1,5 @@ -function fy_with_contexts(f, contexts::Vararg{Context,C}) where {C} - return (with_contexts(f, contexts...),) -end - -function fy_with_contexts(f!, y, contexts::Vararg{Context,C}) where {C} - return (with_contexts(f!, contexts...), y) -end - ## Preparation -abstract type SparseJacobianPrep <: JacobianPrep end - -SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result) -SMC.column_colors(prep::SparseJacobianPrep) = column_colors(prep.coloring_result) -SMC.column_groups(prep::SparseJacobianPrep) = column_groups(prep.coloring_result) -SMC.row_colors(prep::SparseJacobianPrep) = row_colors(prep.coloring_result) -SMC.row_groups(prep::SparseJacobianPrep) = row_groups(prep.coloring_result) - struct PushforwardSparseJacobianPrep{ BS<:BatchSizeSettings, C<:AbstractColoringResult{:nonsymmetric,:column}, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl new file mode 100644 index 000000000..24cd33b21 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl @@ -0,0 +1,231 @@ +## Preparation + +struct MixedModeSparseJacobianPrep{ + BSf<:BatchSizeSettings, + BSr<:BatchSizeSettings, + C<:AbstractColoringResult{:nonsymmetric,:bidirectional}, + M<:AbstractMatrix{<:Real}, + Sf<:Vector{<:NTuple}, + Sr<:Vector{<:NTuple}, + Rf<:Vector{<:NTuple}, + Rr<:Vector{<:NTuple}, + Ef<:PushforwardPrep, + Er<:PullbackPrep, +} <: SparseJacobianPrep + batch_size_settings_forward::BSf + batch_size_settings_reverse::BSr + coloring_result::C + compressed_matrix_forward::M + compressed_matrix_reverse::M + batched_seeds_forward::Sf + batched_seeds_reverse::Sr + batched_results_forward::Rf + batched_results_reverse::Rr + pushforward_prep::Ef + pullback_prep::Er +end + +function DI.prepare_jacobian( + f::F, backend::AutoSparse{<:MixedMode}, x, contexts::Vararg{Context,C} +) where {F,C} + y = f(x, map(unwrap, contexts)...) + return _prepare_mixed_sparse_jacobian_aux(y, (f,), backend, x, contexts...) +end + +function DI.prepare_jacobian( + f!::F, y, backend::AutoSparse{<:MixedMode}, x, contexts::Vararg{Context,C} +) where {F,C} + return _prepare_mixed_sparse_jacobian_aux(y, (f!, y), backend, x, contexts...) +end + +function _prepare_mixed_sparse_jacobian_aux( + y, f_or_f!y::FY, backend::AutoSparse{<:MixedMode}, x, contexts::Vararg{Context,C} +) where {FY,C} + dense_backend = dense_ad(backend) + sparsity = jacobian_sparsity( + fy_with_contexts(f_or_f!y..., contexts...)..., x, sparsity_detector(backend) + ) + problem = ColoringProblem{:nonsymmetric,:bidirectional}() + coloring_result = coloring( + sparsity, + problem, + coloring_algorithm(backend); + decompression_eltype=promote_type(eltype(x), eltype(y)), + ) + + Nf = length(column_groups(coloring_result)) + Nr = length(row_groups(coloring_result)) + batch_size_settings_forward = pick_batchsize(forward_backend(dense_backend), Nf) + batch_size_settings_reverse = pick_batchsize(reverse_backend(dense_backend), Nr) + + return _prepare_mixed_sparse_jacobian_aux_aux( + batch_size_settings_forward, + batch_size_settings_reverse, + coloring_result, + y, + f_or_f!y, + backend, + x, + contexts..., + ) +end + +function _prepare_mixed_sparse_jacobian_aux_aux( + batch_size_settings_forward::BatchSizeSettings{Bf}, + batch_size_settings_reverse::BatchSizeSettings{Br}, + coloring_result::AbstractColoringResult{:nonsymmetric,:bidirectional}, + y, + f_or_f!y::FY, + backend::AutoSparse{<:MixedMode}, + x, + contexts::Vararg{Context,C}, +) where {Bf,Br,FY,C} + Nf, Af = batch_size_settings_forward.N, batch_size_settings_forward.A + Nr, Ar = batch_size_settings_reverse.N, batch_size_settings_reverse.A + + dense_backend = dense_ad(backend) + + groups_forward = column_groups(coloring_result) + groups_reverse = row_groups(coloring_result) + + seeds_forward = [ + multibasis(backend, x, eachindex(x)[group]) for group in groups_forward + ] + seeds_reverse = [ + multibasis(backend, y, eachindex(y)[group]) for group in groups_reverse + ] + + compressed_matrix_forward = stack(_ -> vec(similar(y)), groups_forward; dims=2) + compressed_matrix_reverse = stack(_ -> vec(similar(x)), groups_reverse; dims=1) + + batched_seeds_forward = [ + ntuple(b -> seeds_forward[1 + ((a - 1) * Bf + (b - 1)) % Nf], Val(Bf)) for a in 1:Af + ] + batched_seeds_reverse = [ + ntuple(b -> seeds_reverse[1 + ((a - 1) * Br + (b - 1)) % Nr], Val(Br)) for a in 1:Ar + ] + + batched_results_forward = [ + ntuple(b -> similar(y), Val(Bf)) for _ in batched_seeds_forward + ] + batched_results_reverse = [ + ntuple(b -> similar(x), Val(Br)) for _ in batched_seeds_reverse + ] + + pushforward_prep = prepare_pushforward( + f_or_f!y..., + forward_backend(dense_backend), + x, + batched_seeds_forward[1], + contexts..., + ) + pullback_prep = prepare_pullback( + f_or_f!y..., + reverse_backend(dense_backend), + x, + batched_seeds_reverse[1], + contexts..., + ) + + return MixedModeSparseJacobianPrep( + batch_size_settings_forward, + batch_size_settings_reverse, + coloring_result, + compressed_matrix_forward, + compressed_matrix_reverse, + batched_seeds_forward, + batched_seeds_reverse, + batched_results_forward, + batched_results_reverse, + pushforward_prep, + pullback_prep, + ) +end + +## Common auxiliaries + +function _sparse_jacobian_aux!( + f_or_f!y::FY, + jac, + prep::MixedModeSparseJacobianPrep{<:BatchSizeSettings{Bf},<:BatchSizeSettings{Br}}, + backend::AutoSparse, + x, + contexts::Vararg{Context,C}, +) where {FY,Bf,Br,C} + (; + batch_size_settings_forward, + batch_size_settings_reverse, + coloring_result, + compressed_matrix_forward, + compressed_matrix_reverse, + batched_seeds_forward, + batched_seeds_reverse, + batched_results_forward, + batched_results_reverse, + pushforward_prep, + pullback_prep, + ) = prep + + dense_backend = dense_ad(backend) + Nf = batch_size_settings_forward.N + Nr = batch_size_settings_reverse.N + + pushforward_prep_same = prepare_pushforward_same_point( + f_or_f!y..., + pushforward_prep, + forward_backend(dense_backend), + x, + batched_seeds_forward[1], + contexts..., + ) + pullback_prep_same = prepare_pullback_same_point( + f_or_f!y..., + pullback_prep, + reverse_backend(dense_backend), + x, + batched_seeds_reverse[1], + contexts..., + ) + + for a in eachindex(batched_seeds_forward, batched_results_forward) + pushforward!( + f_or_f!y..., + batched_results_forward[a], + pushforward_prep_same, + forward_backend(dense_backend), + x, + batched_seeds_forward[a], + contexts..., + ) + + for b in eachindex(batched_results_forward[a]) + copyto!( + view(compressed_matrix_forward, :, 1 + ((a - 1) * Bf + (b - 1)) % Nf), + vec(batched_results_forward[a][b]), + ) + end + end + + for a in eachindex(batched_seeds_reverse, batched_results_reverse) + pullback!( + f_or_f!y..., + batched_results_reverse[a], + pullback_prep_same, + reverse_backend(dense_backend), + x, + batched_seeds_reverse[a], + contexts..., + ) + + for b in eachindex(batched_results_reverse[a]) + copyto!( + view(compressed_matrix_reverse, 1 + ((a - 1) * Br + (b - 1)) % Nr, :), + vec(batched_results_reverse[a][b]), + ) + end + end + + decompress!(jac, compressed_matrix_reverse, compressed_matrix_forward, coloring_result) + + return jac +end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index a059428bf..fcfc25dd0 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -33,6 +33,7 @@ using LinearAlgebra: dot include("compat.jl") +include("first_order/mixed_mode.jl") include("second_order/second_order.jl") include("utils/prep.jl") @@ -66,7 +67,7 @@ include("misc/zero_backends.jl") ## Exported export Context, Constant, Cache -export SecondOrder +export MixedMode, SecondOrder export value_and_pushforward!, value_and_pushforward export value_and_pullback!, value_and_pullback diff --git a/DifferentiationInterface/src/first_order/mixed_mode.jl b/DifferentiationInterface/src/first_order/mixed_mode.jl new file mode 100644 index 000000000..20dcf3504 --- /dev/null +++ b/DifferentiationInterface/src/first_order/mixed_mode.jl @@ -0,0 +1,27 @@ +""" + MixedMode + +Combination of a forward and a reverse mode backend for mixed-mode Jacobian computation. + +!!! danger + `MixedMode` backends only support [`jacobian`](@ref) and its variants. + +# Constructor + + MixedMode(forward_backend, reverse_backend) +""" +struct MixedMode{F<:AbstractADType,R<:AbstractADType} <: AbstractADType + forward::F + reverse::R + function MixedMode(forward::AbstractADType, reverse::AbstractADType) + @assert pushforward_performance(forward) isa PushforwardFast + @assert pullback_performance(reverse) isa PullbackFast + return new{typeof(forward),typeof(reverse)}(forward, reverse) + end +end + +forward_backend(m::MixedMode) = m.forward +reverse_backend(m::MixedMode) = m.reverse + +struct ForwardAndReverseMode <: ADTypes.AbstractMode end +ADTypes.mode(::MixedMode) = ForwardAndReverseMode() diff --git a/DifferentiationInterface/src/utils/batchsize.jl b/DifferentiationInterface/src/utils/batchsize.jl index 2e652b774..91639bc2c 100644 --- a/DifferentiationInterface/src/utils/batchsize.jl +++ b/DifferentiationInterface/src/utils/batchsize.jl @@ -52,6 +52,12 @@ function pick_batchsize(backend::AbstractADType, x_or_N::Union{AbstractArray,Int "You should select the batch size for the dense backend of $backend" ), ) + elseif backend isa MixedMode + throw( + ArgumentError( + "You should select the batch size for the forward or reverse backend of $backend", + ), + ) else return BatchSizeSettings(backend, x_or_N) end diff --git a/DifferentiationInterface/src/utils/check.jl b/DifferentiationInterface/src/utils/check.jl index 54bac8b1e..6e2d947a3 100644 --- a/DifferentiationInterface/src/utils/check.jl +++ b/DifferentiationInterface/src/utils/check.jl @@ -11,6 +11,11 @@ end check_available(backend::AutoSparse) = check_available(dense_ad(backend)) +function check_available(backend::MixedMode) + return check_available(forward_backend(backend)) && + check_available(reverse_backend(backend)) +end + """ check_inplace(backend) diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index ff920ec6f..055d32fd8 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -34,6 +34,15 @@ end inplace_support(backend::AutoSparse) = inplace_support(dense_ad(backend)) +function inplace_support(backend::MixedMode) + if Bool(inplace_support(forward_backend(backend))) && + Bool(inplace_support(reverse_backend(backend))) + return InPlaceSupported() + else + return InPlaceNotSupported() + end +end + ## Pushforward abstract type PushforwardPerformance end diff --git a/DifferentiationInterface/test/Misc/FromPrimitive/test.jl b/DifferentiationInterface/test/Misc/FromPrimitive/test.jl index 54a169286..26907aa9c 100644 --- a/DifferentiationInterface/test/Misc/FromPrimitive/test.jl +++ b/DifferentiationInterface/test/Misc/FromPrimitive/test.jl @@ -57,7 +57,9 @@ test_differentiation( ); test_differentiation( - MyAutoSparse.(adaptive_backends), + MyAutoSparse.( + vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2])) + ), sparse_scenarios(; include_constantified=true); sparsity=true, logging=LOGGING, @@ -75,6 +77,8 @@ test_differentiation( @test all(==(1), column_colors(jac_for_prep)) @test all(==(1), row_colors(jac_rev_prep)) @test all(==(1), column_colors(hess_prep)) + @test ncolors(jac_for_prep) == 1 + @test ncolors(hess_prep) == 1 @test only(column_groups(jac_for_prep)) == 1:10 @test only(row_groups(jac_rev_prep)) == 1:10 @test only(column_groups(hess_prep)) == 1:10 diff --git a/DifferentiationInterface/test/Misc/Internals/backends.jl b/DifferentiationInterface/test/Misc/Internals/backends.jl index 156155882..7e334c15b 100644 --- a/DifferentiationInterface/test/Misc/Internals/backends.jl +++ b/DifferentiationInterface/test/Misc/Internals/backends.jl @@ -2,9 +2,17 @@ using ADTypes using ADTypes: mode using DifferentiationInterface using DifferentiationInterface: - inner, outer, inplace_support, pushforward_performance, pullback_performance, hvp_mode + inner, + outer, + forward_backend, + reverse_backend, + inplace_support, + pushforward_performance, + pullback_performance, + hvp_mode import DifferentiationInterface as DI using ForwardDiff: ForwardDiff +using Zygote: Zygote using Test @testset "SecondOrder" begin @@ -13,10 +21,21 @@ using Test @test outer(backend) isa AutoForwardDiff @test inner(backend) isa AutoZygote @test mode(backend) isa ADTypes.ForwardMode - @test Bool(inplace_support(backend)) == - (Bool(inplace_support(inner(backend))) && Bool(inplace_support(outer(backend)))) + @test !Bool(inplace_support(backend)) @test_throws ArgumentError pushforward_performance(backend) @test_throws ArgumentError pullback_performance(backend) + @test check_available(backend) +end + +@testset "MixedMode" begin + backend = MixedMode(AutoForwardDiff(), AutoZygote()) + @test ADTypes.mode(backend) isa DifferentiationInterface.ForwardAndReverseMode + @test forward_backend(backend) isa AutoForwardDiff + @test reverse_backend(backend) isa AutoZygote + @test !Bool(inplace_support(backend)) + @test_throws MethodError pushforward_performance(backend) + @test_throws MethodError pullback_performance(backend) + @test check_available(backend) end @testset "Sparse" begin diff --git a/DifferentiationInterface/test/Misc/Internals/batchsize.jl b/DifferentiationInterface/test/Misc/Internals/batchsize.jl index 71929f3a3..16bd0d221 100644 --- a/DifferentiationInterface/test/Misc/Internals/batchsize.jl +++ b/DifferentiationInterface/test/Misc/Internals/batchsize.jl @@ -16,6 +16,9 @@ BSS = BatchSizeSettings @test_throws ArgumentError pick_batchsize( SecondOrder(AutoZygote(), AutoZygote()), zeros(2) ) + @test_throws ArgumentError pick_batchsize( + MixedMode(AutoForwardDiff(), AutoZygote()), zeros(2) + ) end @testset "ForwardDiff (adaptive)" begin diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 4bd8966de..1a19140c5 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterfaceTest" uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.8.4" +version = "0.8.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -59,7 +59,7 @@ ProgressMeter = "1" Random = "<0.0.1,1" SparseArrays = "<0.0.1,1" SparseConnectivityTracer = "0.5.0,0.6" -SparseMatrixColorings = "0.4.4" +SparseMatrixColorings = "0.4.9" StaticArrays = "1.9" Test = "<0.0.1,1" Zygote = "0.6" diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index c914de603..50f56c69a 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -48,6 +48,7 @@ using DifferentiationInterface: PullbackPrep, PushforwardPrep, SecondDerivativePrep, + MixedMode, SecondOrder, Rewrap import DifferentiationInterface as DI diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index 8c32d06f0..4ea5da6bb 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -100,7 +100,12 @@ function compatible(backend::AbstractADType, scen::Scenario) sparse_compatible = operator(scen) in (:jacobian, :hessian) || !isa(backend, AutoSparse) secondorder_compatible = order(scen) == 2 || !isa(backend, Union{SecondOrder,AutoSparse{<:SecondOrder}}) - return place_compatible && secondorder_compatible && sparse_compatible + mixedmode_compatible = + operator(scen) == :jacobian || !isa(backend, AutoSparse{<:MixedMode}) + return place_compatible && + secondorder_compatible && + sparse_compatible && + mixedmode_compatible end function group_by_operator(scenarios::AbstractVector{<:Scenario}) @@ -127,8 +132,14 @@ function adapt_batchsize(backend::AbstractADType, scen::Scenario) if operator(scen) == :jacobian if ADTypes.mode(backend) isa Union{ADTypes.ForwardMode,ADTypes.ForwardOrReverseMode} return DI.threshold_batchsize(backend, length(scen.x)) - else + elseif ADTypes.mode(backend) isa ADTypes.ReverseMode return DI.threshold_batchsize(backend, length(scen.y)) + elseif ADTypes.mode(backend) isa DifferentiationInterface.ForwardAndReverseMode + return DI.threshold_batchsize(backend, min(length(scen.x), length(scen.y))) + elseif ADTypes.mode(backend) isa ADTypes.SymbolicMode + return backend + else + error("Unknown mode") end elseif operator(scen) == :hessian return DI.threshold_batchsize(backend, length(scen.x))