Enable Mooncake + ReverseDiffAdjoint nested-AD path#1420
Merged
ChrisRackauckas merged 4 commits intoSciML:masterfrom Apr 11, 2026
Merged
Conversation
Adding a MooncakeOriginator dispatch for `_concrete_solve_adjoint(…, ::ReverseDiffAdjoint, …)` so the hybrid_diffeq tutorial (and PR SciML#1419) no longer have to fall back to Zygote when the inner sensealg is `ReverseDiffAdjoint` and the outer AD is Mooncake. Before: the method threw `MooncakeTrackedRealError` on `MooncakeOriginator` because the return value `sensitivity_solution(tracked_sol, plain_u, plain_t)` still carries `ReverseDiff.TrackedReal` / `TrackedArray` type parameters in its nested fields (`interp`, `prob`, `alg`, …). Mooncake's `@from_rrule` plumbing calls `zero_tangent(y_primal)` and therefore recursively computes `tangent_type` for every nested field of the returned solution; that recursion fails on the tracked type parameters with either a `TypeError` or an unhelpful tangent-type error. ChainRules / Zygote don't inspect the primal's type parameters, so they are unaffected. The new method delegates the tape construction (and hence the whole backward pass) to the existing `ChainRulesOriginator` path and then replaces the primal with a fresh plain-arithmetic solve of the same problem via `SensitivityADPassThrough`. Reusing the main method for the tape keeps the outward-facing return type identical to what the non-sensitivity solve would return (`InterpolationData` / `DEStats`), which matters because Mooncake's `DerivedRule` specialises on the type Julia inference picks for `CRC.rrule(solve_up, …)` — and that inference does not narrow through the `originator` kwarg, so the compiled rule expects the *main* method's return shape even on the dedicated Mooncake dispatch. Adds `test/mooncake_reversediff_adjoint.jl` covering both a plain Lotka-Volterra ODE and a hybrid ODE with `PresetTimeCallback` (mirroring the `hybrid_diffeq.md` tutorial that was forced to stay on Zygote in SciML#1419). Gradients are checked against `ForwardDiff` and `Zygote` at `rtol = 1e-4` / `1e-3`; the looser tolerance on the hybrid case reflects the ~ULP arithmetic reordering between the tape's tracked forward and the primal's plain forward amplified by the callback-driven time grid. Also updates `MOONCAKE_TRACKED_REAL_ERROR_MESSAGE` so it no longer claims `ReverseDiffAdjoint` is incompatible with Mooncake (only `TrackerAdjoint` still is). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Delete the standalone `test/mooncake_reversediff_adjoint.jl` file and its `runtests.jl` entry; add a single `@testset "Mooncake with ReverseDiffAdjoint"` inside the existing "Struct-Based Loss Functions" block in `test/concrete_solve_derivatives.jl` that reuses the `gradient_mooncake` helper already defined at the top of that file. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Apply the same `MooncakeOriginator` delegation pattern to two more
sensealgs that were throwing `MooncakeTrackedRealError` / hitting the
same structural return-type issue as `ReverseDiffAdjoint`:
- `TrackerAdjoint`: returned `sensitivity_solution(tracked_sol, …)`
carries `Tracker.TrackedReal` / `TrackedArray` type parameters in
nested container fields.
- `ForwardSensitivity`: returned `sensitivity_solution(augmented_sol, …)`
carries `ODEForwardSensitivityFunction` / `ForwardDiff.JacobianConfig`
/ `Dual` caches in `sol.prob.f`'s type parameters.
Both are fixed by delegating the tape construction to the
`ChainRulesOriginator` path and re-solving the plain problem for the
primal, matching the `ReverseDiffAdjoint` approach.
The `ForwardSensitivity` method additionally wraps the ChainRules
pullback to replace its `du0 = @not_implemented(...)` slot with
`NoTangent()`. Mooncake's `@from_rrule` plumbing threads the cotangent
of every argument through `increment_and_get_rdata!` regardless of
whether the caller is differentiating `u0`, and Mooncake doesn't have
a method for `Vector{Float64}` fdata + `ChainRulesCore.NotImplemented`
tangent (only scalar `IEEEFloat` + `NotImplemented` is handled).
Any caller genuinely differentiating `u0` while using `ForwardSensitivity`
is already using the wrong sensealg.
Also adds `Δ isa Tangent` to the cotangent-shape check in the existing
`forward_sensitivity_backpass` so Mooncake's `CRC.Tangent{Any}(…)`
conversion takes the `Δ.u[i]` path rather than falling through to the
broken-since-forever `@view Δ[.., i]` branch (`..` from
`EllipsisNotation` is not imported in SciMLSensitivity).
Removes `MooncakeTrackedRealError` and its message string entirely,
since after this commit nothing throws it anymore.
Tests: three `@testset`s in
`test/concrete_solve_derivatives.jl`'s "Struct-Based Loss Functions"
block exercising `Mooncake(senseloss(ReverseDiffAdjoint()))`,
`Mooncake(senseloss(TrackerAdjoint()))`, and
`Mooncake(senseloss_p(ForwardSensitivity()))`, all reusing the existing
`gradient_mooncake` helper.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
…place
The previous commit attempted to replace the second `solve()` call in
each Mooncake dispatch with a recursive `SciMLBase.value`-based walker
(via `ConstructionBase.setproperties`) that would rebuild the returned
`ODESolution` with plain type parameters. The walker does successfully
strip `ReverseDiff.TrackedReal` / `Tracker.TrackedReal` / `Dual` types
from nested fields (`u`, `t`, `k`, `interp.timeseries`, `interp.ts`,
`interp.ks`, cache scratch arrays, …) and produces a solution on which
`Mooncake.tangent_type` succeeds.
It does not, however, satisfy Mooncake's `DerivedRule` type assertion:
`solve()` wraps `prob.f` in a
`FunctionWrappersWrappers.FunctionWrappersWrapper` during
`get_concrete_problem`, and the resulting `FunctionWrapper` has no
public positional constructor for `ConstructionBase.setproperties`, so
a generic walker cannot rebuild it. Mooncake's inference on
`CRC.rrule(solve_up, …)` nonetheless expects the `FunctionWrapper`-
wrapped `ODEFunction` (because that's what a plain-arithmetic
`solve()` normally returns), so anything short of actually invoking
`solve()` — including substituting `sol.prob` with
`remake(prob; u0, p)`, which keeps the raw `typeof(f)` instead of
wrapping — produces a `TypeError` mismatch on the
`ODEFunction{…, FunctionWrapper{…}, …}` slot.
This commit therefore keeps the re-solve approach, and adds a comment
next to each Mooncake dispatch explaining why the walker alternative
was tried and abandoned. A truly general `strip_values(sol)` helper in
SciMLBase would need the same `solve()` round-trip internally, so the
one-extra-solve cost is unavoidable in this rrule.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
36cfcc0 to
4c868c1
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a
MooncakeOriginatordispatch for_concrete_solve_adjoint(…, ::ReverseDiffAdjoint, …)so tutorials likedocs/src/examples/hybrid_jump/hybrid_diffeq.md(and PR #1419) no longer haveto fall back to Zygote when the inner sensealg is
ReverseDiffAdjointandthe outer AD is Mooncake.
The problem
Before this PR, the
ReverseDiffAdjointmethod threwMooncakeTrackedRealErroron
MooncakeOriginator. It wasn't just a pessimistic guard — removing thethrow naively doesn't work, because the primal the method returns is
`sensitivity_solution(tracked_sol, plain_u, plain_t)`, and while its
`.u` / `.t` fields are stripped via `ReverseDiff.value`, the tracked type
parameters still leak through the container (`tracked_sol.interp`,
`tracked_sol.prob`, `tracked_sol.alg`, `tracked_sol.k`, …).
Mooncake's `@from_rrule` plumbing calls `zero_tangent(y_primal)` and therefore
recursively computes `tangent_type` for every nested field of the returned
solution — that recursion fails on `ReverseDiff.TrackedReal` /
`ReverseDiff.TrackedArray` with either a `TypeError` or an unhelpful
tangent-type error. ChainRules / Zygote don't inspect the primal's type
parameters, so they're unaffected.
This is the same failure mode PR #1419's
`docs/src/examples/hybrid_jump/hybrid_diffeq.md` call-out was describing:
The fix
A new `_concrete_solve_adjoint` method dispatched on
`sensealg::ReverseDiffAdjoint` / `originator::MooncakeOriginator` that:
to the existing `ChainRulesOriginator` path — no tape duplication.
`SensitivityADPassThrough` solve of `remake(prob; u0, p)`.
Reusing the main method for the tape keeps the outward-facing return type
identical to what the non-sensitivity `solve` would return
(`InterpolationData` / `DEStats`, not the `LinearInterpolation` / `Nothing`
shape `build_solution(prob, alg, t, u)` would produce). That matters because
Mooncake's `DerivedRule` specialises on the type Julia inference picks for
`CRC.rrule(solve_up, …)`, and that inference does not narrow through the
`originator` kwarg — so the compiled rule expects the main method's return
shape even on the dedicated Mooncake dispatch. I tried the obvious
`build_solution(prob, alg, t, u)` shortcut first; it trips Mooncake's
`DerivedRule` type assertion with a `LinearInterpolation` / `InterpolationData`
mismatch. The two-solve approach avoids that entirely.
The tape's tracked forward and the primal's plain forward can differ by a
handful of ULPs (ReverseDiff operator overloading reorders some arithmetic),
which propagates into a ~`1e-5` (plain ODE) to ~`2e-4` (hybrid ODE with
callbacks) relative drift in the gradient compared to the Zygote path. That's
already below the inherent accuracy of `ReverseDiffAdjoint` vs. `ForwardDiff`,
so the test tolerances are `rtol = 1e-4` / `1e-3` against both `ForwardDiff`
and `Zygote`.
Also updates `MOONCAKE_TRACKED_REAL_ERROR_MESSAGE` so it no longer claims
`ReverseDiffAdjoint` is incompatible with Mooncake (only `TrackerAdjoint`
still is).
Files
`MooncakeTrackedRealError` throw from the `ReverseDiffAdjoint` main
method; updated the error message text.
Test plan
plain Lotka-Volterra ODE
hybrid ODE with `PresetTimeCallback` (the `hybrid_diffeq` tutorial
shape from PR [WIP] docs: prefer Mooncake over Zygote where it works end-to-end #1419)
`Mooncake VJP Prob Kwargs` should still pass; the new test runs in
the Core 1 group.
Unblocks the `hybrid_diffeq.md` migration in #1419 — once this merges, that
tutorial can drop its `!!! note` and switch the adtype back to
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.
Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com