diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 0af5dfe3d..4dfb212c3 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1121,7 +1121,8 @@ function SciMLBase._concrete_solve_adjoint( J = du[i] if Δ isa AbstractVector v = Δ[i] - elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray + elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray || + Δ isa Tangent v = Δ.u[i] elseif Δ isa AbstractMatrix v = @view Δ[:, i] @@ -1149,6 +1150,69 @@ function SciMLBase._concrete_solve_adjoint( return out, forward_sensitivity_backpass end +# Mooncake-specific `ForwardSensitivity` path. The main method builds an +# `ODEForwardSensitivityProblem` whose `f` is an +# `ODEForwardSensitivityFunction` carrying ForwardDiff internals +# (`ForwardDiff.JacobianConfig`, `Dual` caches, …) in its type parameters. +# The returned `sensitivity_solution(augmented_sol, u, ts)` inherits those +# types in `sol.prob.f`, which confuses Mooncake's `@from_rrule` tangent +# recursion the same way the tracked types in `ReverseDiffAdjoint` / +# `TrackerAdjoint` do. Delegate to the `ChainRulesOriginator` path for the +# sensitivity tape and re-solve the plain problem for the primal. +# +# A tempting alternative is to walk the returned `primal` and strip +# tracked / augmented types via `SciMLBase.value` recursively. That +# approach *almost* works but fails on one specific slot: `solve()` wraps +# `prob.f` in a `FunctionWrappersWrappers.FunctionWrappersWrapper` during +# `get_concrete_problem`, and the resulting `FunctionWrapper` has no +# public positional constructor for `ConstructionBase.setproperties`, so a +# generic walker can't rebuild it. Mooncake's `@from_rrule` type +# inference nonetheless expects the `FunctionWrapper` version (because +# that's what `solve()` returns), so anything less than actually invoking +# `solve()` produces a type mismatch on the `DerivedRule` assertion. +# A truly general `strip_values(sol)` in SciMLBase would need the same +# `solve()` round-trip internally, so the cost is unavoidable here. +# +# Additionally, the main `forward_sensitivity_backpass` returns `du0 = +# @not_implemented(...)` because `ForwardSensitivity` can't differentiate +# w.r.t. `u0`. Mooncake's `@from_rrule` plumbing then tries to convert that +# `ChainRulesCore.NotImplemented` tangent back through +# `increment_and_get_rdata!` against the `Vector{Float64}` fdata of `u0`, +# and Mooncake doesn't have a method for that combination (only scalar +# `IEEEFloat` + `NotImplemented` is handled). Since Mooncake will dutifully +# thread the cotangent of *every* argument through `increment_and_get_rdata!` +# regardless of whether the caller is actually differentiating `u0`, we +# replace the `du0` slot in the delegated ChainRules pullback with +# `NoTangent()` so the Mooncake conversion has a shape it understands. Any +# caller that genuinely differentiates `u0` while using `ForwardSensitivity` +# is already using the wrong sensealg (the main method's error message says +# as much). +function SciMLBase._concrete_solve_adjoint( + prob::SciMLBase.AbstractODEProblem, alg, + sensealg::ForwardSensitivity, + u0, p, originator::SciMLBase.MooncakeOriginator, + args...; kwargs... + ) + _, backpass = SciMLBase._concrete_solve_adjoint( + prob, alg, sensealg, u0, p, + SciMLBase.ChainRulesOriginator(), args...; kwargs... + ) + # ChainRules branch of `forward_sensitivity_backpass` returns + # `(NoTangent(), NoTangent(), NoTangent(), du0, adj, NoTangent(), rest...)`. + # Replace position 4 (`du0`) with `NoTangent()`. + function mooncake_forward_sensitivity_backpass(Δ) + cr = backpass(Δ) + return (cr[1], cr[2], cr[3], NoTangent(), cr[5:end]...) + end + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) + primal = solve( + remake(prob; u0, p), alg, args...; + sensealg = DiffEqBase.SensitivityADPassThrough(), + kwargs_filtered... + ) + return primal, mooncake_forward_sensitivity_backpass +end + function SciMLBase._concrete_solve_forward( prob::SciMLBase.AbstractODEProblem, alg, sensealg::AbstractForwardSensitivityAlgorithm, @@ -1846,21 +1910,6 @@ function Base.showerror(io::IO, e::EnzymeTrackedRealError) return println(io, ENZYME_TRACKED_REAL_ERROR_MESSAGE) end -const MOONCAKE_TRACKED_REAL_ERROR_MESSAGE = """ -`Mooncake` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`. -Either choose a different adjoint method like `GaussAdjoint`, -or use a different AD system like `ReverseDiff`. -For more details, on these methods see -https://docs.sciml.ai/SciMLSensitivity/stable/. -""" - -struct MooncakeTrackedRealError <: Exception -end - -function Base.showerror(io::IO, e::MooncakeTrackedRealError) - return println(io, MOONCAKE_TRACKED_REAL_ERROR_MESSAGE) -end - function SciMLBase._concrete_solve_adjoint( prob::Union{ SciMLBase.AbstractDiscreteProblem, @@ -1881,10 +1930,6 @@ function SciMLBase._concrete_solve_adjoint( throw(EnzymeTrackedRealError()) end - if originator isa SciMLBase.MooncakeOriginator - throw(MooncakeTrackedRealError()) - end - if !(p === nothing || p isa SciMLBase.NullParameters) if !isscimlstructure(p) throw(SciMLStructuresCompatibilityError()) @@ -2093,6 +2138,41 @@ function SciMLBase._concrete_solve_adjoint( tracker_adjoint_backpass end +# Mooncake-specific `TrackerAdjoint` path. Same reasoning as the +# `ReverseDiffAdjoint` + `MooncakeOriginator` method below: the main method +# returns `sensitivity_solution(tracked_sol, …)` with `Tracker.TrackedReal` / +# `TrackedArray` type parameters embedded in `tracked_sol.interp` / `.prob` +# / `.alg`, and Mooncake's `@from_rrule` plumbing chokes when recursively +# computing `tangent_type` on those fields. Delegate the tape to the +# `ChainRulesOriginator` path and re-solve with `SensitivityADPassThrough` +# for the primal. +function SciMLBase._concrete_solve_adjoint( + prob::Union{ + SciMLBase.AbstractDiscreteProblem, + SciMLBase.AbstractODEProblem, + SciMLBase.AbstractDAEProblem, + SciMLBase.AbstractDDEProblem, + SciMLBase.AbstractSDEProblem, + SciMLBase.AbstractSDDEProblem, + SciMLBase.AbstractRODEProblem, + }, + alg, sensealg::TrackerAdjoint, + u0, p, originator::SciMLBase.MooncakeOriginator, + args...; kwargs... + ) + _, backpass = SciMLBase._concrete_solve_adjoint( + prob, alg, sensealg, u0, p, + SciMLBase.ChainRulesOriginator(), args...; kwargs... + ) + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) + primal = solve( + remake(prob; u0, p), alg, args...; + sensealg = DiffEqBase.SensitivityADPassThrough(), + kwargs_filtered... + ) + return primal, backpass +end + const REVERSEDIFF_ADJOINT_GPU_COMPATIBILITY_MESSAGE = """ ReverseDiffAdjoint is not compatible GPU-based array types. Use a different sensitivity analysis method, like InterpolatingAdjoint or TrackerAdjoint, @@ -2148,10 +2228,6 @@ function SciMLBase._concrete_solve_adjoint( throw(EnzymeTrackedRealError()) end - if originator isa SciMLBase.MooncakeOriginator - throw(MooncakeTrackedRealError()) - end - t = eltype(prob.tspan)[] u = typeof(u0)[] @@ -2277,6 +2353,66 @@ function SciMLBase._concrete_solve_adjoint( reversediff_adjoint_backpass end +# Mooncake-specific `ReverseDiffAdjoint` path. The main `ReverseDiffAdjoint` +# method above returns `SciMLBase.sensitivity_solution(sol, …)` where `sol` +# still carries `ReverseDiff.TrackedReal` / `TrackedArray` type parameters in +# nested fields (`interp`, `prob`, `alg`, …). ChainRules / Zygote don't +# inspect the primal's type parameters, so they don't care. Mooncake's +# `@from_rrule` plumbing, on the other hand, calls +# `zero_tangent(y_primal)` and therefore recursively computes `tangent_type` +# for every nested field of the returned solution; that recursion fails on +# the tracked type parameters with either a `TypeError` or an unhelpful +# tangent-type error, which is what the `hybrid_diffeq` tutorial and +# PR #1419 ran into. +# +# This method delegates the tape construction (and hence the whole backward +# pass) to the `ChainRulesOriginator` path, then replaces the primal with +# a fresh plain-arithmetic solve of the same problem. The obvious +# alternative — walking the returned primal and stripping tracked scalars +# via `SciMLBase.value` recursively — *almost* works but fails on one +# specific slot: `solve()` wraps `prob.f` in a +# `FunctionWrappersWrappers.FunctionWrappersWrapper` during +# `get_concrete_problem`, and the resulting `FunctionWrapper` has no +# public positional constructor for `ConstructionBase.setproperties`, so a +# generic walker can't rebuild it. Mooncake's `@from_rrule` type +# inference nonetheless expects the `FunctionWrapper` version (because +# that's what `solve()` normally returns from a plain-arithmetic solve), +# so anything short of actually invoking `solve()` produces a type +# mismatch on the `DerivedRule` assertion. A truly general +# `strip_values(sol)` in SciMLBase would need the same `solve()` +# round-trip internally, so the cost is unavoidable here. +# +# Keeping this in a dedicated method dispatched on `MooncakeOriginator` +# stops Julia type inference from joining two different return shapes +# into a `Union{ODESolution{tracked…}, ODESolution{plain…}}`, which would +# otherwise trip Mooncake's `DerivedRule` type assertion. +function SciMLBase._concrete_solve_adjoint( + prob::Union{ + SciMLBase.AbstractDiscreteProblem, + SciMLBase.AbstractODEProblem, + SciMLBase.AbstractDAEProblem, + SciMLBase.AbstractDDEProblem, + SciMLBase.AbstractSDEProblem, + SciMLBase.AbstractSDDEProblem, + SciMLBase.AbstractRODEProblem, + }, + alg, sensealg::ReverseDiffAdjoint, + u0, p, originator::SciMLBase.MooncakeOriginator, + args...; kwargs... + ) + _, backpass = SciMLBase._concrete_solve_adjoint( + prob, alg, sensealg, u0, p, + SciMLBase.ChainRulesOriginator(), args...; kwargs... + ) + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) + primal = solve( + remake(prob; u0, p), alg, args...; + sensealg = DiffEqBase.SensitivityADPassThrough(), + kwargs_filtered... + ) + return primal, backpass +end + function SciMLBase._concrete_solve_adjoint( prob::SciMLBase.AbstractODEProblem, alg, sensealg::AbstractShadowingSensitivityAlgorithm, diff --git a/test/concrete_solve_derivatives.jl b/test/concrete_solve_derivatives.jl index aff403475..442ea2a17 100644 --- a/test/concrete_solve_derivatives.jl +++ b/test/concrete_solve_derivatives.jl @@ -442,6 +442,20 @@ Tests callable structs with different AD backends end end + # Mooncake is not in `REVERSE_BACKENDS` because it doesn't yet compose + # with every sensealg, but it does compose with `ReverseDiffAdjoint` and + # `TrackerAdjoint` via the dedicated `MooncakeOriginator` dispatches + # added in #1420 (the hybrid_diffeq.md pattern from #1419). + @testset "Mooncake with ReverseDiffAdjoint" begin + result = gradient_mooncake(senseloss(ReverseDiffAdjoint()), u0p) + @test result ≈ ref_grad_senseloss + end + + @testset "Mooncake with TrackerAdjoint" begin + result = gradient_mooncake(senseloss(TrackerAdjoint()), u0p) + @test result ≈ ref_grad_senseloss + end + # Test with p-only differentiation (senseloss3 and senseloss4 from alternative_ad_frontend.jl) struct senseloss_p{T} sense::T @@ -470,6 +484,16 @@ Tests callable structs with different AD backends @test result ≈ ref_grad_p end end + + # Mooncake + `ForwardSensitivity` via the dedicated `MooncakeOriginator` + # dispatch added in #1420. p-only because `ForwardSensitivity` can't + # differentiate `u0`, and the Mooncake dispatch rewrites the `du0` + # slot to `NoTangent()` so Mooncake's cotangent threading doesn't trip + # on the main method's `@not_implemented` stub for `du0`. + @testset "Mooncake with ForwardSensitivity (p-only)" begin + result = gradient_mooncake(senseloss_p(ForwardSensitivity()), p_only) + @test result ≈ ref_grad_p + end end #=