Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 160 additions & 24 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,8 @@ function SciMLBase._concrete_solve_adjoint(
J = du[i]
if Δ isa AbstractVector
v = Δ[i]
elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray
elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray ||
Δ isa Tangent
v = Δ.u[i]
elseif Δ isa AbstractMatrix
v = @view Δ[:, i]
Expand Down Expand Up @@ -1149,6 +1150,69 @@ function SciMLBase._concrete_solve_adjoint(
return out, forward_sensitivity_backpass
end

# Mooncake-specific `ForwardSensitivity` path. The main method builds an
# `ODEForwardSensitivityProblem` whose `f` is an
# `ODEForwardSensitivityFunction` carrying ForwardDiff internals
# (`ForwardDiff.JacobianConfig`, `Dual` caches, …) in its type parameters.
# The returned `sensitivity_solution(augmented_sol, u, ts)` inherits those
# types in `sol.prob.f`, which confuses Mooncake's `@from_rrule` tangent
# recursion the same way the tracked types in `ReverseDiffAdjoint` /
# `TrackerAdjoint` do. Delegate to the `ChainRulesOriginator` path for the
# sensitivity tape and re-solve the plain problem for the primal.
#
# A tempting alternative is to walk the returned `primal` and strip
# tracked / augmented types via `SciMLBase.value` recursively. That
# approach *almost* works but fails on one specific slot: `solve()` wraps
# `prob.f` in a `FunctionWrappersWrappers.FunctionWrappersWrapper` during
# `get_concrete_problem`, and the resulting `FunctionWrapper` has no
# public positional constructor for `ConstructionBase.setproperties`, so a
# generic walker can't rebuild it. Mooncake's `@from_rrule` type
# inference nonetheless expects the `FunctionWrapper` version (because
# that's what `solve()` returns), so anything less than actually invoking
# `solve()` produces a type mismatch on the `DerivedRule` assertion.
# A truly general `strip_values(sol)` in SciMLBase would need the same
# `solve()` round-trip internally, so the cost is unavoidable here.
#
# Additionally, the main `forward_sensitivity_backpass` returns `du0 =
# @not_implemented(...)` because `ForwardSensitivity` can't differentiate
# w.r.t. `u0`. Mooncake's `@from_rrule` plumbing then tries to convert that
# `ChainRulesCore.NotImplemented` tangent back through
# `increment_and_get_rdata!` against the `Vector{Float64}` fdata of `u0`,
# and Mooncake doesn't have a method for that combination (only scalar
# `IEEEFloat` + `NotImplemented` is handled). Since Mooncake will dutifully
# thread the cotangent of *every* argument through `increment_and_get_rdata!`
# regardless of whether the caller is actually differentiating `u0`, we
# replace the `du0` slot in the delegated ChainRules pullback with
# `NoTangent()` so the Mooncake conversion has a shape it understands. Any
# caller that genuinely differentiates `u0` while using `ForwardSensitivity`
# is already using the wrong sensealg (the main method's error message says
# as much).
function SciMLBase._concrete_solve_adjoint(
prob::SciMLBase.AbstractODEProblem, alg,
sensealg::ForwardSensitivity,
u0, p, originator::SciMLBase.MooncakeOriginator,
args...; kwargs...
)
_, backpass = SciMLBase._concrete_solve_adjoint(
prob, alg, sensealg, u0, p,
SciMLBase.ChainRulesOriginator(), args...; kwargs...
)
# ChainRules branch of `forward_sensitivity_backpass` returns
# `(NoTangent(), NoTangent(), NoTangent(), du0, adj, NoTangent(), rest...)`.
# Replace position 4 (`du0`) with `NoTangent()`.
function mooncake_forward_sensitivity_backpass(Δ)
cr = backpass(Δ)
return (cr[1], cr[2], cr[3], NoTangent(), cr[5:end]...)
end
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
primal = solve(
remake(prob; u0, p), alg, args...;
sensealg = DiffEqBase.SensitivityADPassThrough(),
kwargs_filtered...
)
return primal, mooncake_forward_sensitivity_backpass
end

function SciMLBase._concrete_solve_forward(
prob::SciMLBase.AbstractODEProblem, alg,
sensealg::AbstractForwardSensitivityAlgorithm,
Expand Down Expand Up @@ -1846,21 +1910,6 @@ function Base.showerror(io::IO, e::EnzymeTrackedRealError)
return println(io, ENZYME_TRACKED_REAL_ERROR_MESSAGE)
end

const MOONCAKE_TRACKED_REAL_ERROR_MESSAGE = """
`Mooncake` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`.
Either choose a different adjoint method like `GaussAdjoint`,
or use a different AD system like `ReverseDiff`.
For more details, on these methods see
https://docs.sciml.ai/SciMLSensitivity/stable/.
"""

struct MooncakeTrackedRealError <: Exception
end

function Base.showerror(io::IO, e::MooncakeTrackedRealError)
return println(io, MOONCAKE_TRACKED_REAL_ERROR_MESSAGE)
end

function SciMLBase._concrete_solve_adjoint(
prob::Union{
SciMLBase.AbstractDiscreteProblem,
Expand All @@ -1881,10 +1930,6 @@ function SciMLBase._concrete_solve_adjoint(
throw(EnzymeTrackedRealError())
end

if originator isa SciMLBase.MooncakeOriginator
throw(MooncakeTrackedRealError())
end

if !(p === nothing || p isa SciMLBase.NullParameters)
if !isscimlstructure(p)
throw(SciMLStructuresCompatibilityError())
Expand Down Expand Up @@ -2093,6 +2138,41 @@ function SciMLBase._concrete_solve_adjoint(
tracker_adjoint_backpass
end

# Mooncake-specific `TrackerAdjoint` path. Same reasoning as the
# `ReverseDiffAdjoint` + `MooncakeOriginator` method below: the main method
# returns `sensitivity_solution(tracked_sol, …)` with `Tracker.TrackedReal` /
# `TrackedArray` type parameters embedded in `tracked_sol.interp` / `.prob`
# / `.alg`, and Mooncake's `@from_rrule` plumbing chokes when recursively
# computing `tangent_type` on those fields. Delegate the tape to the
# `ChainRulesOriginator` path and re-solve with `SensitivityADPassThrough`
# for the primal.
function SciMLBase._concrete_solve_adjoint(
prob::Union{
SciMLBase.AbstractDiscreteProblem,
SciMLBase.AbstractODEProblem,
SciMLBase.AbstractDAEProblem,
SciMLBase.AbstractDDEProblem,
SciMLBase.AbstractSDEProblem,
SciMLBase.AbstractSDDEProblem,
SciMLBase.AbstractRODEProblem,
},
alg, sensealg::TrackerAdjoint,
u0, p, originator::SciMLBase.MooncakeOriginator,
args...; kwargs...
)
_, backpass = SciMLBase._concrete_solve_adjoint(
prob, alg, sensealg, u0, p,
SciMLBase.ChainRulesOriginator(), args...; kwargs...
)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
primal = solve(
remake(prob; u0, p), alg, args...;
sensealg = DiffEqBase.SensitivityADPassThrough(),
kwargs_filtered...
)
return primal, backpass
end

const REVERSEDIFF_ADJOINT_GPU_COMPATIBILITY_MESSAGE = """
ReverseDiffAdjoint is not compatible GPU-based array types. Use a different
sensitivity analysis method, like InterpolatingAdjoint or TrackerAdjoint,
Expand Down Expand Up @@ -2148,10 +2228,6 @@ function SciMLBase._concrete_solve_adjoint(
throw(EnzymeTrackedRealError())
end

if originator isa SciMLBase.MooncakeOriginator
throw(MooncakeTrackedRealError())
end

t = eltype(prob.tspan)[]
u = typeof(u0)[]

Expand Down Expand Up @@ -2277,6 +2353,66 @@ function SciMLBase._concrete_solve_adjoint(
reversediff_adjoint_backpass
end

# Mooncake-specific `ReverseDiffAdjoint` path. The main `ReverseDiffAdjoint`
# method above returns `SciMLBase.sensitivity_solution(sol, …)` where `sol`
# still carries `ReverseDiff.TrackedReal` / `TrackedArray` type parameters in
# nested fields (`interp`, `prob`, `alg`, …). ChainRules / Zygote don't
# inspect the primal's type parameters, so they don't care. Mooncake's
# `@from_rrule` plumbing, on the other hand, calls
# `zero_tangent(y_primal)` and therefore recursively computes `tangent_type`
# for every nested field of the returned solution; that recursion fails on
# the tracked type parameters with either a `TypeError` or an unhelpful
# tangent-type error, which is what the `hybrid_diffeq` tutorial and
# PR #1419 ran into.
#
# This method delegates the tape construction (and hence the whole backward
# pass) to the `ChainRulesOriginator` path, then replaces the primal with
# a fresh plain-arithmetic solve of the same problem. The obvious
# alternative — walking the returned primal and stripping tracked scalars
# via `SciMLBase.value` recursively — *almost* works but fails on one
# specific slot: `solve()` wraps `prob.f` in a
# `FunctionWrappersWrappers.FunctionWrappersWrapper` during
# `get_concrete_problem`, and the resulting `FunctionWrapper` has no
# public positional constructor for `ConstructionBase.setproperties`, so a
# generic walker can't rebuild it. Mooncake's `@from_rrule` type
# inference nonetheless expects the `FunctionWrapper` version (because
# that's what `solve()` normally returns from a plain-arithmetic solve),
# so anything short of actually invoking `solve()` produces a type
# mismatch on the `DerivedRule` assertion. A truly general
# `strip_values(sol)` in SciMLBase would need the same `solve()`
# round-trip internally, so the cost is unavoidable here.
#
# Keeping this in a dedicated method dispatched on `MooncakeOriginator`
# stops Julia type inference from joining two different return shapes
# into a `Union{ODESolution{tracked…}, ODESolution{plain…}}`, which would
# otherwise trip Mooncake's `DerivedRule` type assertion.
function SciMLBase._concrete_solve_adjoint(
prob::Union{
SciMLBase.AbstractDiscreteProblem,
SciMLBase.AbstractODEProblem,
SciMLBase.AbstractDAEProblem,
SciMLBase.AbstractDDEProblem,
SciMLBase.AbstractSDEProblem,
SciMLBase.AbstractSDDEProblem,
SciMLBase.AbstractRODEProblem,
},
alg, sensealg::ReverseDiffAdjoint,
u0, p, originator::SciMLBase.MooncakeOriginator,
args...; kwargs...
)
_, backpass = SciMLBase._concrete_solve_adjoint(
prob, alg, sensealg, u0, p,
SciMLBase.ChainRulesOriginator(), args...; kwargs...
)
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
primal = solve(
remake(prob; u0, p), alg, args...;
sensealg = DiffEqBase.SensitivityADPassThrough(),
kwargs_filtered...
)
return primal, backpass
end

function SciMLBase._concrete_solve_adjoint(
prob::SciMLBase.AbstractODEProblem, alg,
sensealg::AbstractShadowingSensitivityAlgorithm,
Expand Down
24 changes: 24 additions & 0 deletions test/concrete_solve_derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,20 @@ Tests callable structs with different AD backends
end
end

# Mooncake is not in `REVERSE_BACKENDS` because it doesn't yet compose
# with every sensealg, but it does compose with `ReverseDiffAdjoint` and
# `TrackerAdjoint` via the dedicated `MooncakeOriginator` dispatches
# added in #1420 (the hybrid_diffeq.md pattern from #1419).
@testset "Mooncake with ReverseDiffAdjoint" begin
result = gradient_mooncake(senseloss(ReverseDiffAdjoint()), u0p)
@test result ≈ ref_grad_senseloss
end

@testset "Mooncake with TrackerAdjoint" begin
result = gradient_mooncake(senseloss(TrackerAdjoint()), u0p)
@test result ≈ ref_grad_senseloss
end

# Test with p-only differentiation (senseloss3 and senseloss4 from alternative_ad_frontend.jl)
struct senseloss_p{T}
sense::T
Expand Down Expand Up @@ -470,6 +484,16 @@ Tests callable structs with different AD backends
@test result ≈ ref_grad_p
end
end

# Mooncake + `ForwardSensitivity` via the dedicated `MooncakeOriginator`
# dispatch added in #1420. p-only because `ForwardSensitivity` can't
# differentiate `u0`, and the Mooncake dispatch rewrites the `du0`
# slot to `NoTangent()` so Mooncake's cotangent threading doesn't trip
# on the main method's `@not_implemented` stub for `du0`.
@testset "Mooncake with ForwardSensitivity (p-only)" begin
result = gradient_mooncake(senseloss_p(ForwardSensitivity()), p_only)
@test result ≈ ref_grad_p
end
end

#=
Expand Down
Loading