Skip to content

Enable Mooncake + ReverseDiffAdjoint nested-AD path#1420

Merged
ChrisRackauckas merged 4 commits intoSciML:masterfrom
ChrisRackauckas-Claude:fix-mooncake-reversediffadjoint
Apr 11, 2026
Merged

Enable Mooncake + ReverseDiffAdjoint nested-AD path#1420
ChrisRackauckas merged 4 commits intoSciML:masterfrom
ChrisRackauckas-Claude:fix-mooncake-reversediffadjoint

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Summary

Adds a MooncakeOriginator dispatch for
_concrete_solve_adjoint(…, ::ReverseDiffAdjoint, …) so tutorials like
docs/src/examples/hybrid_jump/hybrid_diffeq.md (and PR #1419) no longer have
to fall back to Zygote when the inner sensealg is ReverseDiffAdjoint and
the outer AD is Mooncake.

The problem

Before this PR, the ReverseDiffAdjoint method threw MooncakeTrackedRealError
on MooncakeOriginator. It wasn't just a pessimistic guard — removing the
throw naively doesn't work, because the primal the method returns is
`sensitivity_solution(tracked_sol, plain_u, plain_t)`, and while its
`.u` / `.t` fields are stripped via `ReverseDiff.value`, the tracked type
parameters still leak through the container (`tracked_sol.interp`,
`tracked_sol.prob`, `tracked_sol.alg`, `tracked_sol.k`, …).

Mooncake's `@from_rrule` plumbing calls `zero_tangent(y_primal)` and therefore
recursively computes `tangent_type` for every nested field of the returned
solution — that recursion fails on `ReverseDiff.TrackedReal` /
`ReverseDiff.TrackedArray` with either a `TypeError` or an unhelpful
tangent-type error. ChainRules / Zygote don't inspect the primal's type
parameters, so they're unaffected.

This is the same failure mode PR #1419's
`docs/src/examples/hybrid_jump/hybrid_diffeq.md` call-out was describing:

This example still uses Zygote because the inner `ReverseDiffAdjoint`
sensealg returns ReverseDiff-tracked types that confuse Mooncake's
`CoDual` type expectations (`TypeError` during pullback compilation).

The fix

A new `_concrete_solve_adjoint` method dispatched on
`sensealg::ReverseDiffAdjoint` / `originator::MooncakeOriginator` that:

  1. Delegates the tape construction (and therefore the whole backward pass)
    to the existing `ChainRulesOriginator` path — no tape duplication.
  2. Replaces the returned primal with a fresh plain-arithmetic
    `SensitivityADPassThrough` solve of `remake(prob; u0, p)`.

Reusing the main method for the tape keeps the outward-facing return type
identical to what the non-sensitivity `solve` would return
(`InterpolationData` / `DEStats`, not the `LinearInterpolation` / `Nothing`
shape `build_solution(prob, alg, t, u)` would produce). That matters because
Mooncake's `DerivedRule` specialises on the type Julia inference picks for
`CRC.rrule(solve_up, …)`, and that inference does not narrow through the
`originator` kwarg — so the compiled rule expects the main method's return
shape even on the dedicated Mooncake dispatch. I tried the obvious
`build_solution(prob, alg, t, u)` shortcut first; it trips Mooncake's
`DerivedRule` type assertion with a `LinearInterpolation` / `InterpolationData`
mismatch. The two-solve approach avoids that entirely.

The tape's tracked forward and the primal's plain forward can differ by a
handful of ULPs (ReverseDiff operator overloading reorders some arithmetic),
which propagates into a ~`1e-5` (plain ODE) to ~`2e-4` (hybrid ODE with
callbacks) relative drift in the gradient compared to the Zygote path. That's
already below the inherent accuracy of `ReverseDiffAdjoint` vs. `ForwardDiff`,
so the test tolerances are `rtol = 1e-4` / `1e-3` against both `ForwardDiff`
and `Zygote`.

Also updates `MOONCAKE_TRACKED_REAL_ERROR_MESSAGE` so it no longer claims
`ReverseDiffAdjoint` is incompatible with Mooncake (only `TrackerAdjoint`
still is).

Files

  • `src/concrete_solve.jl` — new dispatched method + removed the
    `MooncakeTrackedRealError` throw from the `ReverseDiffAdjoint` main
    method; updated the error message text.
  • `test/mooncake_reversediff_adjoint.jl` — new regression test.
  • `test/runtests.jl` — wires the new test into the Core 1 group.

Test plan

  • `test/mooncake_reversediff_adjoint.jl` — `Mooncake(ReverseDiffAdjoint)`
    plain Lotka-Volterra ODE
  • `test/mooncake_reversediff_adjoint.jl` — `Mooncake(ReverseDiffAdjoint)`
    hybrid ODE with `PresetTimeCallback` (the `hybrid_diffeq` tutorial
    shape from PR [WIP] docs: prefer Mooncake over Zygote where it works end-to-end #1419)
  • Full CI — `ReverseDiffAdjoint Output Type` and
    `Mooncake VJP Prob Kwargs` should still pass; the new test runs in
    the Core 1 group.

Unblocks the `hybrid_diffeq.md` migration in #1419 — once this merges, that
tutorial can drop its `!!! note` and switch the adtype back to
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.

Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com

Adding a MooncakeOriginator dispatch for `_concrete_solve_adjoint(…,
::ReverseDiffAdjoint, …)` so the hybrid_diffeq tutorial (and PR SciML#1419)
no longer have to fall back to Zygote when the inner sensealg is
`ReverseDiffAdjoint` and the outer AD is Mooncake.

Before: the method threw `MooncakeTrackedRealError` on
`MooncakeOriginator` because the return value
`sensitivity_solution(tracked_sol, plain_u, plain_t)` still carries
`ReverseDiff.TrackedReal` / `TrackedArray` type parameters in its
nested fields (`interp`, `prob`, `alg`, …). Mooncake's `@from_rrule`
plumbing calls `zero_tangent(y_primal)` and therefore recursively
computes `tangent_type` for every nested field of the returned
solution; that recursion fails on the tracked type parameters with
either a `TypeError` or an unhelpful tangent-type error. ChainRules /
Zygote don't inspect the primal's type parameters, so they are
unaffected.

The new method delegates the tape construction (and hence the whole
backward pass) to the existing `ChainRulesOriginator` path and then
replaces the primal with a fresh plain-arithmetic solve of the same
problem via `SensitivityADPassThrough`. Reusing the main method for
the tape keeps the outward-facing return type identical to what the
non-sensitivity solve would return (`InterpolationData` / `DEStats`),
which matters because Mooncake's `DerivedRule` specialises on the
type Julia inference picks for `CRC.rrule(solve_up, …)` — and that
inference does not narrow through the `originator` kwarg, so the
compiled rule expects the *main* method's return shape even on the
dedicated Mooncake dispatch.

Adds `test/mooncake_reversediff_adjoint.jl` covering both a plain
Lotka-Volterra ODE and a hybrid ODE with `PresetTimeCallback`
(mirroring the `hybrid_diffeq.md` tutorial that was forced to stay on
Zygote in SciML#1419). Gradients are checked against `ForwardDiff` and
`Zygote` at `rtol = 1e-4` / `1e-3`; the looser tolerance on the
hybrid case reflects the ~ULP arithmetic reordering between the
tape's tracked forward and the primal's plain forward amplified by
the callback-driven time grid.

Also updates `MOONCAKE_TRACKED_REAL_ERROR_MESSAGE` so it no longer
claims `ReverseDiffAdjoint` is incompatible with Mooncake (only
`TrackerAdjoint` still is).

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Delete the standalone `test/mooncake_reversediff_adjoint.jl` file and
its `runtests.jl` entry; add a single `@testset "Mooncake with
ReverseDiffAdjoint"` inside the existing "Struct-Based Loss Functions"
block in `test/concrete_solve_derivatives.jl` that reuses the
`gradient_mooncake` helper already defined at the top of that file.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Apply the same `MooncakeOriginator` delegation pattern to two more
sensealgs that were throwing `MooncakeTrackedRealError` / hitting the
same structural return-type issue as `ReverseDiffAdjoint`:

- `TrackerAdjoint`: returned `sensitivity_solution(tracked_sol, …)`
  carries `Tracker.TrackedReal` / `TrackedArray` type parameters in
  nested container fields.
- `ForwardSensitivity`: returned `sensitivity_solution(augmented_sol, …)`
  carries `ODEForwardSensitivityFunction` / `ForwardDiff.JacobianConfig`
  / `Dual` caches in `sol.prob.f`'s type parameters.

Both are fixed by delegating the tape construction to the
`ChainRulesOriginator` path and re-solving the plain problem for the
primal, matching the `ReverseDiffAdjoint` approach.

The `ForwardSensitivity` method additionally wraps the ChainRules
pullback to replace its `du0 = @not_implemented(...)` slot with
`NoTangent()`. Mooncake's `@from_rrule` plumbing threads the cotangent
of every argument through `increment_and_get_rdata!` regardless of
whether the caller is differentiating `u0`, and Mooncake doesn't have
a method for `Vector{Float64}` fdata + `ChainRulesCore.NotImplemented`
tangent (only scalar `IEEEFloat` + `NotImplemented` is handled).
Any caller genuinely differentiating `u0` while using `ForwardSensitivity`
is already using the wrong sensealg.

Also adds `Δ isa Tangent` to the cotangent-shape check in the existing
`forward_sensitivity_backpass` so Mooncake's `CRC.Tangent{Any}(…)`
conversion takes the `Δ.u[i]` path rather than falling through to the
broken-since-forever `@view Δ[.., i]` branch (`..` from
`EllipsisNotation` is not imported in SciMLSensitivity).

Removes `MooncakeTrackedRealError` and its message string entirely,
since after this commit nothing throws it anymore.

Tests: three `@testset`s in
`test/concrete_solve_derivatives.jl`'s "Struct-Based Loss Functions"
block exercising `Mooncake(senseloss(ReverseDiffAdjoint()))`,
`Mooncake(senseloss(TrackerAdjoint()))`, and
`Mooncake(senseloss_p(ForwardSensitivity()))`, all reusing the existing
`gradient_mooncake` helper.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
…place

The previous commit attempted to replace the second `solve()` call in
each Mooncake dispatch with a recursive `SciMLBase.value`-based walker
(via `ConstructionBase.setproperties`) that would rebuild the returned
`ODESolution` with plain type parameters. The walker does successfully
strip `ReverseDiff.TrackedReal` / `Tracker.TrackedReal` / `Dual` types
from nested fields (`u`, `t`, `k`, `interp.timeseries`, `interp.ts`,
`interp.ks`, cache scratch arrays, …) and produces a solution on which
`Mooncake.tangent_type` succeeds.

It does not, however, satisfy Mooncake's `DerivedRule` type assertion:
`solve()` wraps `prob.f` in a
`FunctionWrappersWrappers.FunctionWrappersWrapper` during
`get_concrete_problem`, and the resulting `FunctionWrapper` has no
public positional constructor for `ConstructionBase.setproperties`, so
a generic walker cannot rebuild it. Mooncake's inference on
`CRC.rrule(solve_up, …)` nonetheless expects the `FunctionWrapper`-
wrapped `ODEFunction` (because that's what a plain-arithmetic
`solve()` normally returns), so anything short of actually invoking
`solve()` — including substituting `sol.prob` with
`remake(prob; u0, p)`, which keeps the raw `typeof(f)` instead of
wrapping — produces a `TypeError` mismatch on the
`ODEFunction{…, FunctionWrapper{…}, …}` slot.

This commit therefore keeps the re-solve approach, and adds a comment
next to each Mooncake dispatch explaining why the walker alternative
was tried and abandoned. A truly general `strip_values(sol)` helper in
SciMLBase would need the same `solve()` round-trip internally, so the
one-extra-solve cost is unavoidable in this rrule.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@ChrisRackauckas-Claude ChrisRackauckas-Claude force-pushed the fix-mooncake-reversediffadjoint branch from 36cfcc0 to 4c868c1 Compare April 11, 2026 20:51
@ChrisRackauckas ChrisRackauckas merged commit c0a4647 into SciML:master Apr 11, 2026
36 of 100 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants