diff --git a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl index 8bb0cbe90f9..bca939c332e 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl @@ -36,6 +36,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici OrdinaryDiffEqImplicitAlgorithm, CompositeAlgorithm, OrdinaryDiffEqExponentialAlgorithm, OrdinaryDiffEqAdaptiveExponentialAlgorithm, + OrdinaryDiffEqLinearExponentialAlgorithm, StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm, StochasticDiffEqJumpNewtonAdaptiveAlgorithm, StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm, diff --git a/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl index 75b724ab495..073e88b4c5c 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl @@ -32,6 +32,12 @@ function _alg_autodiff( ) where {CS, AD, FDT, ST, CJ, Controller} return Val{AD}() end +# OrdinaryDiffEqLinearExponentialAlgorithm subtypes (Magnus integrators, LieEuler, +# CG methods, etc.) have NO autodiff field β€” their only fields are krylov, m, iop. +# They must be excluded before the generic ExponentialAlgorithm dispatch below. +function _alg_autodiff(::OrdinaryDiffEqLinearExponentialAlgorithm) + return Val{false}() +end function _alg_autodiff( alg::Union{ OrdinaryDiffEqExponentialAlgorithm{CS, AD}, diff --git a/lib/OrdinaryDiffEqDifferentiation/test/differentiation_traits_tests.jl b/lib/OrdinaryDiffEqDifferentiation/test/differentiation_traits_tests.jl index b00c3c8e3bb..c5a9b72d714 100644 --- a/lib/OrdinaryDiffEqDifferentiation/test/differentiation_traits_tests.jl +++ b/lib/OrdinaryDiffEqDifferentiation/test/differentiation_traits_tests.jl @@ -43,3 +43,26 @@ sol = solve(prob2, Rosenbrock23(autodiff = AutoForwardDiff(chunksize = 1))) sol = solve(prob2, Rosenbrock23(autodiff = AutoFiniteDiff())) @test β‰ˆ(good_sol[:, end], sol[:, end], rtol = 1.0e-2) + +# Regression test for issue #3232: +# MagnusGL6 (and all OrdinaryDiffEqLinearExponentialAlgorithm subtypes) +# have no `autodiff` field. When OrdinaryDiffEqDifferentiation is loaded, +# _alg_autodiff must not crash by trying to access alg.autodiff. +using OrdinaryDiffEqLinear +using SciMLOperators: MatrixOperator + +@testset "MagnusGL6 solve with Differentiation loaded (issue #3232)" begin + function update_func!(A, u, p, t) + A[1, 1] = cos(t) + A[2, 1] = sin(t) + A[1, 2] = -sin(t) + A[2, 2] = cos(t) + end + A = MatrixOperator(ones(2, 2), update_func! = update_func!) + prob = ODEProblem(A, ones(2), (1.0, 6.0)) + + # This would crash with FieldError before the fix + sol = solve(prob, MagnusGL6(), dt = 1 / 10) + @test sol.retcode == ReturnCode.Success + @test length(sol.t) > 1 +end diff --git a/lib/OrdinaryDiffEqLinear/test/linear_method_tests.jl b/lib/OrdinaryDiffEqLinear/test/linear_method_tests.jl index b49d19b0c1e..c5a62556ce9 100644 --- a/lib/OrdinaryDiffEqLinear/test/linear_method_tests.jl +++ b/lib/OrdinaryDiffEqLinear/test/linear_method_tests.jl @@ -267,3 +267,21 @@ test_setup = Dict(:alg => Vern9(), :reltol => 1.0e-14, :abstol => 1.0e-14) sim = analyticless_test_convergence(dts, prob, CayleyEuler(), test_setup) @test sim.π’ͺest[:l2] β‰ˆ 1 atol = 0.2 + +# Regression test for https://github.com/SciML/OrdinaryDiffEq.jl/issues/3232 +# Magnus/Linear integrators must not FieldError when OrdinaryDiffEqDifferentiation +# is loaded (which happens transitively via DifferentialEquations.jl). +@testset "Regression #3232: non-autonomous Magnus solve does not FieldError" begin + function update_func_3232!(A, u, p, t) + A[1, 1] = cos(t) + A[2, 1] = sin(t) + A[1, 2] = -sin(t) + A[2, 2] = cos(t) + end + A_3232 = MatrixOperator(ones(2, 2), update_func! = update_func_3232!) + prob_3232 = ODEProblem(A_3232, ones(2), (1.0, 6.0)) + for alg in (MagnusGL6(), MagnusGL4(), MagnusGL8(), LieEuler(), CG2(), CG3(), CG4a()) + sol = solve(prob_3232, alg, dt = 1 / 10) + @test sol.retcode == ReturnCode.Success + end +end