From 7e260c3c7dcdde6248a375ead2ea59627f3fb0f0 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sat, 11 Apr 2026 07:54:08 -0400 Subject: [PATCH 1/4] Enable Mooncake + ReverseDiffAdjoint nested-AD path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adding a MooncakeOriginator dispatch for `_concrete_solve_adjoint(…, ::ReverseDiffAdjoint, …)` so the hybrid_diffeq tutorial (and PR #1419) no longer have to fall back to Zygote when the inner sensealg is `ReverseDiffAdjoint` and the outer AD is Mooncake. Before: the method threw `MooncakeTrackedRealError` on `MooncakeOriginator` because the return value `sensitivity_solution(tracked_sol, plain_u, plain_t)` still carries `ReverseDiff.TrackedReal` / `TrackedArray` type parameters in its nested fields (`interp`, `prob`, `alg`, …). Mooncake's `@from_rrule` plumbing calls `zero_tangent(y_primal)` and therefore recursively computes `tangent_type` for every nested field of the returned solution; that recursion fails on the tracked type parameters with either a `TypeError` or an unhelpful tangent-type error. ChainRules / Zygote don't inspect the primal's type parameters, so they are unaffected. The new method delegates the tape construction (and hence the whole backward pass) to the existing `ChainRulesOriginator` path and then replaces the primal with a fresh plain-arithmetic solve of the same problem via `SensitivityADPassThrough`. Reusing the main method for the tape keeps the outward-facing return type identical to what the non-sensitivity solve would return (`InterpolationData` / `DEStats`), which matters because Mooncake's `DerivedRule` specialises on the type Julia inference picks for `CRC.rrule(solve_up, …)` — and that inference does not narrow through the `originator` kwarg, so the compiled rule expects the *main* method's return shape even on the dedicated Mooncake dispatch. Adds `test/mooncake_reversediff_adjoint.jl` covering both a plain Lotka-Volterra ODE and a hybrid ODE with `PresetTimeCallback` (mirroring the `hybrid_diffeq.md` tutorial that was forced to stay on Zygote in #1419). Gradients are checked against `ForwardDiff` and `Zygote` at `rtol = 1e-4` / `1e-3`; the looser tolerance on the hybrid case reflects the ~ULP arithmetic reordering between the tape's tracked forward and the primal's plain forward amplified by the callback-driven time grid. Also updates `MOONCAKE_TRACKED_REAL_ERROR_MESSAGE` so it no longer claims `ReverseDiffAdjoint` is incompatible with Mooncake (only `TrackerAdjoint` still is). Co-Authored-By: Chris Rackauckas --- src/concrete_solve.jl | 67 +++++++++++++-- test/mooncake_reversediff_adjoint.jl | 122 +++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 183 insertions(+), 7 deletions(-) create mode 100644 test/mooncake_reversediff_adjoint.jl diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 0af5dfe3d..a6642d5bb 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1847,9 +1847,9 @@ function Base.showerror(io::IO, e::EnzymeTrackedRealError) end const MOONCAKE_TRACKED_REAL_ERROR_MESSAGE = """ -`Mooncake` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`. -Either choose a different adjoint method like `GaussAdjoint`, -or use a different AD system like `ReverseDiff`. +`Mooncake` is not compatible with `TrackerAdjoint`. +Either choose a different adjoint method like `GaussAdjoint` or +`ReverseDiffAdjoint`, or use a different AD system like `ReverseDiff`. For more details, on these methods see https://docs.sciml.ai/SciMLSensitivity/stable/. """ @@ -2148,10 +2148,6 @@ function SciMLBase._concrete_solve_adjoint( throw(EnzymeTrackedRealError()) end - if originator isa SciMLBase.MooncakeOriginator - throw(MooncakeTrackedRealError()) - end - t = eltype(prob.tspan)[] u = typeof(u0)[] @@ -2277,6 +2273,63 @@ function SciMLBase._concrete_solve_adjoint( reversediff_adjoint_backpass end +# Mooncake-specific `ReverseDiffAdjoint` path. The main `ReverseDiffAdjoint` +# method above returns `SciMLBase.sensitivity_solution(sol, …)` where `sol` +# still carries `ReverseDiff.TrackedReal` / `TrackedArray` type parameters in +# nested fields (`interp`, `prob`, `alg`, …). ChainRules / Zygote don't +# inspect the primal's type parameters, so they don't care. Mooncake's +# `@from_rrule` plumbing, on the other hand, calls +# `zero_tangent(y_primal)` and therefore recursively computes `tangent_type` +# for every nested field of the returned solution; that recursion fails on +# the tracked type parameters with either a `TypeError` or an unhelpful +# tangent-type error, which is what the `hybrid_diffeq` tutorial and +# PR #1419 ran into. +# +# This method delegates the tape construction (and hence the whole backward +# pass) to the `ChainRulesOriginator` path, then replaces the primal with +# a fresh plain-arithmetic solve of the same problem. This keeps the +# outward-facing return type identical to what the non-sensitivity solve +# would return (i.e. `InterpolationData` and `DEStats`, not the +# `LinearInterpolation` / `Nothing` shape `build_solution` would produce), +# which is important because Mooncake's `DerivedRule` specialises on the +# inferred return type of the underlying `solve_up` call — that inference +# does not narrow through the `originator` kwarg, so the compiled rule +# expects the *main* method's return shape even on the Mooncake dispatch. +# +# The tape's tracked forward pass and this plain forward pass can differ +# by a handful of ULPs (ReverseDiff operator overloading reorders some +# arithmetic), which propagates a ~1e-5 relative error into the pullback. +# That's below the inherent accuracy of `ReverseDiffAdjoint` vs. a +# first-principles gradient, so tests should compare against +# `ForwardDiff.gradient` at `rtol = 1e-4` rather than bitwise matching +# the Zygote path. +function SciMLBase._concrete_solve_adjoint( + prob::Union{ + SciMLBase.AbstractDiscreteProblem, + SciMLBase.AbstractODEProblem, + SciMLBase.AbstractDAEProblem, + SciMLBase.AbstractDDEProblem, + SciMLBase.AbstractSDEProblem, + SciMLBase.AbstractSDDEProblem, + SciMLBase.AbstractRODEProblem, + }, + alg, sensealg::ReverseDiffAdjoint, + u0, p, originator::SciMLBase.MooncakeOriginator, + args...; kwargs... + ) + _, backpass = SciMLBase._concrete_solve_adjoint( + prob, alg, sensealg, u0, p, + SciMLBase.ChainRulesOriginator(), args...; kwargs... + ) + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) + primal = solve( + remake(prob; u0, p), alg, args...; + sensealg = DiffEqBase.SensitivityADPassThrough(), + kwargs_filtered... + ) + return primal, backpass +end + function SciMLBase._concrete_solve_adjoint( prob::SciMLBase.AbstractODEProblem, alg, sensealg::AbstractShadowingSensitivityAlgorithm, diff --git a/test/mooncake_reversediff_adjoint.jl b/test/mooncake_reversediff_adjoint.jl new file mode 100644 index 000000000..44281358f --- /dev/null +++ b/test/mooncake_reversediff_adjoint.jl @@ -0,0 +1,122 @@ +# Regression test for Mooncake as the outer AD around a solve that uses +# `ReverseDiffAdjoint` as the inner sensealg. Before the fix in +# `_concrete_solve_adjoint(..., ::ReverseDiffAdjoint, ..., ::MooncakeOriginator)`, +# this path threw `MooncakeTrackedRealError` because the returned +# `sensitivity_solution(sol, …)` still carried `ReverseDiff.TrackedReal` / +# `TrackedArray` in nested type parameters (interp, prob, alg, …), which broke +# Mooncake's recursive `tangent_type` computation in `@from_rrule`. The fix +# returns a freshly-solved, plain-typed primal solution on the Mooncake path +# while keeping the tape-based backward pass. + +using OrdinaryDiffEq +using SciMLSensitivity +using DiffEqCallbacks +using Mooncake +using DifferentiationInterface +using ForwardDiff +using Zygote +using Test + +const backend = AutoMooncake(; config = nothing) + +# --------------------------------------------------------------------------- +# 1. Plain ODE: Mooncake(ReverseDiffAdjoint) vs. Zygote(ReverseDiffAdjoint) +# --------------------------------------------------------------------------- + +function lotka_volterra(u, p, t) + du1 = p[1] * u[1] - p[2] * u[1] * u[2] + du2 = -p[3] * u[2] + p[4] * u[1] * u[2] + return [du1, du2] +end + +const lv_u0 = [1.0, 1.0] +const lv_tspan = (0.0, 10.0) +const lv_p0 = [1.5, 1.0, 3.0, 1.0] +const lv_prob = ODEProblem(lotka_volterra, lv_u0, lv_tspan, lv_p0) + +function lv_loss(p) + sol = solve( + remake(lv_prob; p), Tsit5(); + reltol = 1.0e-10, abstol = 1.0e-10, + sensealg = ReverseDiffAdjoint() + ) + return sum(last(sol.u)) +end + +function lv_loss_plain(p) + sol = solve( + remake(lv_prob; p), Tsit5(); + reltol = 1.0e-10, abstol = 1.0e-10 + ) + return sum(last(sol.u)) +end + +@testset "Mooncake(ReverseDiffAdjoint) plain ODE" begin + prep = prepare_gradient(lv_loss, backend, lv_p0) + grad_moon = DifferentiationInterface.gradient(lv_loss, prep, backend, lv_p0) + grad_zyg = Zygote.gradient(lv_loss, lv_p0)[1] + grad_fd = ForwardDiff.gradient(lv_loss_plain, lv_p0) + # The Mooncake path re-runs the forward solve on plain inputs to obtain a + # cleanly-typed primal; the tape's tracked-arithmetic forward and the + # plain forward can differ by a few ULPs, which causes a correspondingly + # small (~1.0e-5 relative) difference in the cotangent fed into the + # ReverseDiff tape compared to the Zygote path. Check against ForwardDiff + # at the accuracy `ReverseDiffAdjoint` already has vs. a first-principles + # gradient. + @test grad_moon ≈ grad_fd rtol = 1.0e-4 + @test grad_moon ≈ grad_zyg rtol = 1.0e-4 +end + +# --------------------------------------------------------------------------- +# 2. Hybrid ODE with PresetTimeCallback — this is the hybrid_diffeq tutorial +# shape that PR #1419 was forced to keep on Zygote. +# --------------------------------------------------------------------------- + +function decay!(du, u, p, t) + du[1] = -p[1] * u[1] + du[2] = -p[2] * u[2] + return nothing +end + +const hyb_u0 = [2.0, 0.0] +const hyb_tspan = (0.0, 10.5) +const hyb_dosetimes = [1.0, 2.0, 4.0, 8.0] +const hyb_p0 = [1.0, 1.0] + +function hyb_affect!(integrator) + integrator.u .= integrator.u .+ 1 + return nothing +end + +const hyb_cb = PresetTimeCallback( + hyb_dosetimes, hyb_affect!; save_positions = (false, false) +) + +const hyb_prob = ODEProblem(decay!, hyb_u0, hyb_tspan, hyb_p0) + +function hyb_loss(p) + sol = solve( + hyb_prob, Tsit5(); p, callback = hyb_cb, + saveat = 0.5, sensealg = ReverseDiffAdjoint() + ) + return sum(abs2, last(sol.u)) +end + +function hyb_loss_plain(p) + sol = solve( + hyb_prob, Tsit5(); p, callback = hyb_cb, saveat = 0.5 + ) + return sum(abs2, last(sol.u)) +end + +@testset "Mooncake(ReverseDiffAdjoint) hybrid ODE with PresetTimeCallback" begin + prep = prepare_gradient(hyb_loss, backend, hyb_p0) + grad_moon = DifferentiationInterface.gradient(hyb_loss, prep, backend, hyb_p0) + grad_zyg = Zygote.gradient(hyb_loss, hyb_p0)[1] + grad_fd = ForwardDiff.gradient(hyb_loss_plain, hyb_p0) + # Looser tolerance than the plain-ODE case: the PresetTimeCallback + # amplifies the ~ULP difference between the tape's tracked forward and + # the primal's plain forward into ~2e-4 relative gradient drift. + @test grad_moon ≈ grad_fd rtol = 1.0e-3 + @test grad_moon ≈ grad_zyg rtol = 1.0e-3 +end diff --git a/test/runtests.jl b/test/runtests.jl index 4df30d35f..3b9e4ee7f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,7 @@ end @time @safetestset "Forward Remake" include("forward_remake.jl") @time @safetestset "Prob Kwargs" include("prob_kwargs.jl") @time @safetestset "Mooncake VJP Prob Kwargs" include("mooncake_vjp_prob_kwargs.jl") + @time @safetestset "Mooncake + ReverseDiffAdjoint" include("mooncake_reversediff_adjoint.jl") @time @safetestset "DiscreteProblem Adjoints" include("discrete.jl") @time @safetestset "Time Type Mixing Adjoints" include("time_type_mixing.jl") @time @safetestset "SciMLStructures Interface" include("scimlstructures_interface.jl") From eea0ced555b2f3d468cb0786b62b15cbb1a5ced6 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sat, 11 Apr 2026 09:22:06 -0400 Subject: [PATCH 2/4] Fold Mooncake+ReverseDiffAdjoint test into concrete_solve_derivatives Delete the standalone `test/mooncake_reversediff_adjoint.jl` file and its `runtests.jl` entry; add a single `@testset "Mooncake with ReverseDiffAdjoint"` inside the existing "Struct-Based Loss Functions" block in `test/concrete_solve_derivatives.jl` that reuses the `gradient_mooncake` helper already defined at the top of that file. Co-Authored-By: Chris Rackauckas --- src/concrete_solve.jl | 8 -- test/concrete_solve_derivatives.jl | 9 ++ test/mooncake_reversediff_adjoint.jl | 122 --------------------------- test/runtests.jl | 1 - 4 files changed, 9 insertions(+), 131 deletions(-) delete mode 100644 test/mooncake_reversediff_adjoint.jl diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index a6642d5bb..a7b66fbe6 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -2295,14 +2295,6 @@ end # inferred return type of the underlying `solve_up` call — that inference # does not narrow through the `originator` kwarg, so the compiled rule # expects the *main* method's return shape even on the Mooncake dispatch. -# -# The tape's tracked forward pass and this plain forward pass can differ -# by a handful of ULPs (ReverseDiff operator overloading reorders some -# arithmetic), which propagates a ~1e-5 relative error into the pullback. -# That's below the inherent accuracy of `ReverseDiffAdjoint` vs. a -# first-principles gradient, so tests should compare against -# `ForwardDiff.gradient` at `rtol = 1e-4` rather than bitwise matching -# the Zygote path. function SciMLBase._concrete_solve_adjoint( prob::Union{ SciMLBase.AbstractDiscreteProblem, diff --git a/test/concrete_solve_derivatives.jl b/test/concrete_solve_derivatives.jl index aff403475..dbd95501a 100644 --- a/test/concrete_solve_derivatives.jl +++ b/test/concrete_solve_derivatives.jl @@ -442,6 +442,15 @@ Tests callable structs with different AD backends end end + # Mooncake is not in `REVERSE_BACKENDS` because it doesn't yet compose + # with every sensealg, but it does compose with `ReverseDiffAdjoint` via + # the dedicated `MooncakeOriginator` dispatch added in #1420 (the + # hybrid_diffeq.md pattern from #1419). + @testset "Mooncake with ReverseDiffAdjoint" begin + result = gradient_mooncake(senseloss(ReverseDiffAdjoint()), u0p) + @test result ≈ ref_grad_senseloss + end + # Test with p-only differentiation (senseloss3 and senseloss4 from alternative_ad_frontend.jl) struct senseloss_p{T} sense::T diff --git a/test/mooncake_reversediff_adjoint.jl b/test/mooncake_reversediff_adjoint.jl deleted file mode 100644 index 44281358f..000000000 --- a/test/mooncake_reversediff_adjoint.jl +++ /dev/null @@ -1,122 +0,0 @@ -# Regression test for Mooncake as the outer AD around a solve that uses -# `ReverseDiffAdjoint` as the inner sensealg. Before the fix in -# `_concrete_solve_adjoint(..., ::ReverseDiffAdjoint, ..., ::MooncakeOriginator)`, -# this path threw `MooncakeTrackedRealError` because the returned -# `sensitivity_solution(sol, …)` still carried `ReverseDiff.TrackedReal` / -# `TrackedArray` in nested type parameters (interp, prob, alg, …), which broke -# Mooncake's recursive `tangent_type` computation in `@from_rrule`. The fix -# returns a freshly-solved, plain-typed primal solution on the Mooncake path -# while keeping the tape-based backward pass. - -using OrdinaryDiffEq -using SciMLSensitivity -using DiffEqCallbacks -using Mooncake -using DifferentiationInterface -using ForwardDiff -using Zygote -using Test - -const backend = AutoMooncake(; config = nothing) - -# --------------------------------------------------------------------------- -# 1. Plain ODE: Mooncake(ReverseDiffAdjoint) vs. Zygote(ReverseDiffAdjoint) -# --------------------------------------------------------------------------- - -function lotka_volterra(u, p, t) - du1 = p[1] * u[1] - p[2] * u[1] * u[2] - du2 = -p[3] * u[2] + p[4] * u[1] * u[2] - return [du1, du2] -end - -const lv_u0 = [1.0, 1.0] -const lv_tspan = (0.0, 10.0) -const lv_p0 = [1.5, 1.0, 3.0, 1.0] -const lv_prob = ODEProblem(lotka_volterra, lv_u0, lv_tspan, lv_p0) - -function lv_loss(p) - sol = solve( - remake(lv_prob; p), Tsit5(); - reltol = 1.0e-10, abstol = 1.0e-10, - sensealg = ReverseDiffAdjoint() - ) - return sum(last(sol.u)) -end - -function lv_loss_plain(p) - sol = solve( - remake(lv_prob; p), Tsit5(); - reltol = 1.0e-10, abstol = 1.0e-10 - ) - return sum(last(sol.u)) -end - -@testset "Mooncake(ReverseDiffAdjoint) plain ODE" begin - prep = prepare_gradient(lv_loss, backend, lv_p0) - grad_moon = DifferentiationInterface.gradient(lv_loss, prep, backend, lv_p0) - grad_zyg = Zygote.gradient(lv_loss, lv_p0)[1] - grad_fd = ForwardDiff.gradient(lv_loss_plain, lv_p0) - # The Mooncake path re-runs the forward solve on plain inputs to obtain a - # cleanly-typed primal; the tape's tracked-arithmetic forward and the - # plain forward can differ by a few ULPs, which causes a correspondingly - # small (~1.0e-5 relative) difference in the cotangent fed into the - # ReverseDiff tape compared to the Zygote path. Check against ForwardDiff - # at the accuracy `ReverseDiffAdjoint` already has vs. a first-principles - # gradient. - @test grad_moon ≈ grad_fd rtol = 1.0e-4 - @test grad_moon ≈ grad_zyg rtol = 1.0e-4 -end - -# --------------------------------------------------------------------------- -# 2. Hybrid ODE with PresetTimeCallback — this is the hybrid_diffeq tutorial -# shape that PR #1419 was forced to keep on Zygote. -# --------------------------------------------------------------------------- - -function decay!(du, u, p, t) - du[1] = -p[1] * u[1] - du[2] = -p[2] * u[2] - return nothing -end - -const hyb_u0 = [2.0, 0.0] -const hyb_tspan = (0.0, 10.5) -const hyb_dosetimes = [1.0, 2.0, 4.0, 8.0] -const hyb_p0 = [1.0, 1.0] - -function hyb_affect!(integrator) - integrator.u .= integrator.u .+ 1 - return nothing -end - -const hyb_cb = PresetTimeCallback( - hyb_dosetimes, hyb_affect!; save_positions = (false, false) -) - -const hyb_prob = ODEProblem(decay!, hyb_u0, hyb_tspan, hyb_p0) - -function hyb_loss(p) - sol = solve( - hyb_prob, Tsit5(); p, callback = hyb_cb, - saveat = 0.5, sensealg = ReverseDiffAdjoint() - ) - return sum(abs2, last(sol.u)) -end - -function hyb_loss_plain(p) - sol = solve( - hyb_prob, Tsit5(); p, callback = hyb_cb, saveat = 0.5 - ) - return sum(abs2, last(sol.u)) -end - -@testset "Mooncake(ReverseDiffAdjoint) hybrid ODE with PresetTimeCallback" begin - prep = prepare_gradient(hyb_loss, backend, hyb_p0) - grad_moon = DifferentiationInterface.gradient(hyb_loss, prep, backend, hyb_p0) - grad_zyg = Zygote.gradient(hyb_loss, hyb_p0)[1] - grad_fd = ForwardDiff.gradient(hyb_loss_plain, hyb_p0) - # Looser tolerance than the plain-ODE case: the PresetTimeCallback - # amplifies the ~ULP difference between the tape's tracked forward and - # the primal's plain forward into ~2e-4 relative gradient drift. - @test grad_moon ≈ grad_fd rtol = 1.0e-3 - @test grad_moon ≈ grad_zyg rtol = 1.0e-3 -end diff --git a/test/runtests.jl b/test/runtests.jl index 3b9e4ee7f..4df30d35f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,7 +27,6 @@ end @time @safetestset "Forward Remake" include("forward_remake.jl") @time @safetestset "Prob Kwargs" include("prob_kwargs.jl") @time @safetestset "Mooncake VJP Prob Kwargs" include("mooncake_vjp_prob_kwargs.jl") - @time @safetestset "Mooncake + ReverseDiffAdjoint" include("mooncake_reversediff_adjoint.jl") @time @safetestset "DiscreteProblem Adjoints" include("discrete.jl") @time @safetestset "Time Type Mixing Adjoints" include("time_type_mixing.jl") @time @safetestset "SciMLStructures Interface" include("scimlstructures_interface.jl") From 3469ee8060c357926d84cd9b707c40f9acf8e7dd Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sat, 11 Apr 2026 10:55:57 -0400 Subject: [PATCH 3/4] Extend Mooncake dispatch to TrackerAdjoint and ForwardSensitivity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply the same `MooncakeOriginator` delegation pattern to two more sensealgs that were throwing `MooncakeTrackedRealError` / hitting the same structural return-type issue as `ReverseDiffAdjoint`: - `TrackerAdjoint`: returned `sensitivity_solution(tracked_sol, …)` carries `Tracker.TrackedReal` / `TrackedArray` type parameters in nested container fields. - `ForwardSensitivity`: returned `sensitivity_solution(augmented_sol, …)` carries `ODEForwardSensitivityFunction` / `ForwardDiff.JacobianConfig` / `Dual` caches in `sol.prob.f`'s type parameters. Both are fixed by delegating the tape construction to the `ChainRulesOriginator` path and re-solving the plain problem for the primal, matching the `ReverseDiffAdjoint` approach. The `ForwardSensitivity` method additionally wraps the ChainRules pullback to replace its `du0 = @not_implemented(...)` slot with `NoTangent()`. Mooncake's `@from_rrule` plumbing threads the cotangent of every argument through `increment_and_get_rdata!` regardless of whether the caller is differentiating `u0`, and Mooncake doesn't have a method for `Vector{Float64}` fdata + `ChainRulesCore.NotImplemented` tangent (only scalar `IEEEFloat` + `NotImplemented` is handled). Any caller genuinely differentiating `u0` while using `ForwardSensitivity` is already using the wrong sensealg. Also adds `Δ isa Tangent` to the cotangent-shape check in the existing `forward_sensitivity_backpass` so Mooncake's `CRC.Tangent{Any}(…)` conversion takes the `Δ.u[i]` path rather than falling through to the broken-since-forever `@view Δ[.., i]` branch (`..` from `EllipsisNotation` is not imported in SciMLSensitivity). Removes `MooncakeTrackedRealError` and its message string entirely, since after this commit nothing throws it anymore. Tests: three `@testset`s in `test/concrete_solve_derivatives.jl`'s "Struct-Based Loss Functions" block exercising `Mooncake(senseloss(ReverseDiffAdjoint()))`, `Mooncake(senseloss(TrackerAdjoint()))`, and `Mooncake(senseloss_p(ForwardSensitivity()))`, all reusing the existing `gradient_mooncake` helper. Co-Authored-By: Chris Rackauckas --- src/concrete_solve.jl | 107 +++++++++++++++++++++++------ test/concrete_solve_derivatives.jl | 21 +++++- 2 files changed, 105 insertions(+), 23 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index a7b66fbe6..425fccc5d 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1121,7 +1121,8 @@ function SciMLBase._concrete_solve_adjoint( J = du[i] if Δ isa AbstractVector v = Δ[i] - elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray + elseif Δ isa AbstractTimeseriesSolution || Δ isa AbstractVectorOfArray || + Δ isa Tangent v = Δ.u[i] elseif Δ isa AbstractMatrix v = @view Δ[:, i] @@ -1149,6 +1150,56 @@ function SciMLBase._concrete_solve_adjoint( return out, forward_sensitivity_backpass end +# Mooncake-specific `ForwardSensitivity` path. The main method builds an +# `ODEForwardSensitivityProblem` whose `f` is an +# `ODEForwardSensitivityFunction` carrying ForwardDiff internals +# (`ForwardDiff.JacobianConfig`, `Dual` caches, …) in its type parameters. +# The returned `sensitivity_solution(augmented_sol, u, ts)` inherits those +# types in `sol.prob.f`, which confuses Mooncake's `@from_rrule` tangent +# recursion the same way the tracked types in `ReverseDiffAdjoint` / +# `TrackerAdjoint` do. Delegate to the `ChainRulesOriginator` path for the +# sensitivity tape and re-solve the plain problem for the primal. +# +# Additionally, the main `forward_sensitivity_backpass` returns `du0 = +# @not_implemented(...)` because `ForwardSensitivity` can't differentiate +# w.r.t. `u0`. Mooncake's `@from_rrule` plumbing then tries to convert that +# `ChainRulesCore.NotImplemented` tangent back through +# `increment_and_get_rdata!` against the `Vector{Float64}` fdata of `u0`, +# and Mooncake doesn't have a method for that combination (only scalar +# `IEEEFloat` + `NotImplemented` is handled). Since Mooncake will dutifully +# thread the cotangent of *every* argument through `increment_and_get_rdata!` +# regardless of whether the caller is actually differentiating `u0`, we +# replace the `du0` slot in the delegated ChainRules pullback with +# `NoTangent()` so the Mooncake conversion has a shape it understands. Any +# caller that genuinely differentiates `u0` while using `ForwardSensitivity` +# is already using the wrong sensealg (the main method's error message says +# as much). +function SciMLBase._concrete_solve_adjoint( + prob::SciMLBase.AbstractODEProblem, alg, + sensealg::ForwardSensitivity, + u0, p, originator::SciMLBase.MooncakeOriginator, + args...; kwargs... + ) + _, backpass = SciMLBase._concrete_solve_adjoint( + prob, alg, sensealg, u0, p, + SciMLBase.ChainRulesOriginator(), args...; kwargs... + ) + # ChainRules branch of `forward_sensitivity_backpass` returns + # `(NoTangent(), NoTangent(), NoTangent(), du0, adj, NoTangent(), rest...)`. + # Replace position 4 (`du0`) with `NoTangent()`. + function mooncake_forward_sensitivity_backpass(Δ) + cr = backpass(Δ) + return (cr[1], cr[2], cr[3], NoTangent(), cr[5:end]...) + end + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) + primal = solve( + remake(prob; u0, p), alg, args...; + sensealg = DiffEqBase.SensitivityADPassThrough(), + kwargs_filtered... + ) + return primal, mooncake_forward_sensitivity_backpass +end + function SciMLBase._concrete_solve_forward( prob::SciMLBase.AbstractODEProblem, alg, sensealg::AbstractForwardSensitivityAlgorithm, @@ -1846,21 +1897,6 @@ function Base.showerror(io::IO, e::EnzymeTrackedRealError) return println(io, ENZYME_TRACKED_REAL_ERROR_MESSAGE) end -const MOONCAKE_TRACKED_REAL_ERROR_MESSAGE = """ -`Mooncake` is not compatible with `TrackerAdjoint`. -Either choose a different adjoint method like `GaussAdjoint` or -`ReverseDiffAdjoint`, or use a different AD system like `ReverseDiff`. -For more details, on these methods see -https://docs.sciml.ai/SciMLSensitivity/stable/. -""" - -struct MooncakeTrackedRealError <: Exception -end - -function Base.showerror(io::IO, e::MooncakeTrackedRealError) - return println(io, MOONCAKE_TRACKED_REAL_ERROR_MESSAGE) -end - function SciMLBase._concrete_solve_adjoint( prob::Union{ SciMLBase.AbstractDiscreteProblem, @@ -1881,10 +1917,6 @@ function SciMLBase._concrete_solve_adjoint( throw(EnzymeTrackedRealError()) end - if originator isa SciMLBase.MooncakeOriginator - throw(MooncakeTrackedRealError()) - end - if !(p === nothing || p isa SciMLBase.NullParameters) if !isscimlstructure(p) throw(SciMLStructuresCompatibilityError()) @@ -2093,6 +2125,41 @@ function SciMLBase._concrete_solve_adjoint( tracker_adjoint_backpass end +# Mooncake-specific `TrackerAdjoint` path. Same reasoning as the +# `ReverseDiffAdjoint` + `MooncakeOriginator` method below: the main method +# returns `sensitivity_solution(tracked_sol, …)` with `Tracker.TrackedReal` / +# `TrackedArray` type parameters embedded in `tracked_sol.interp` / `.prob` +# / `.alg`, and Mooncake's `@from_rrule` plumbing chokes when recursively +# computing `tangent_type` on those fields. Delegate the tape to the +# `ChainRulesOriginator` path and re-solve with `SensitivityADPassThrough` +# for the primal. +function SciMLBase._concrete_solve_adjoint( + prob::Union{ + SciMLBase.AbstractDiscreteProblem, + SciMLBase.AbstractODEProblem, + SciMLBase.AbstractDAEProblem, + SciMLBase.AbstractDDEProblem, + SciMLBase.AbstractSDEProblem, + SciMLBase.AbstractSDDEProblem, + SciMLBase.AbstractRODEProblem, + }, + alg, sensealg::TrackerAdjoint, + u0, p, originator::SciMLBase.MooncakeOriginator, + args...; kwargs... + ) + _, backpass = SciMLBase._concrete_solve_adjoint( + prob, alg, sensealg, u0, p, + SciMLBase.ChainRulesOriginator(), args...; kwargs... + ) + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) + primal = solve( + remake(prob; u0, p), alg, args...; + sensealg = DiffEqBase.SensitivityADPassThrough(), + kwargs_filtered... + ) + return primal, backpass +end + const REVERSEDIFF_ADJOINT_GPU_COMPATIBILITY_MESSAGE = """ ReverseDiffAdjoint is not compatible GPU-based array types. Use a different sensitivity analysis method, like InterpolatingAdjoint or TrackerAdjoint, diff --git a/test/concrete_solve_derivatives.jl b/test/concrete_solve_derivatives.jl index dbd95501a..442ea2a17 100644 --- a/test/concrete_solve_derivatives.jl +++ b/test/concrete_solve_derivatives.jl @@ -443,14 +443,19 @@ Tests callable structs with different AD backends end # Mooncake is not in `REVERSE_BACKENDS` because it doesn't yet compose - # with every sensealg, but it does compose with `ReverseDiffAdjoint` via - # the dedicated `MooncakeOriginator` dispatch added in #1420 (the - # hybrid_diffeq.md pattern from #1419). + # with every sensealg, but it does compose with `ReverseDiffAdjoint` and + # `TrackerAdjoint` via the dedicated `MooncakeOriginator` dispatches + # added in #1420 (the hybrid_diffeq.md pattern from #1419). @testset "Mooncake with ReverseDiffAdjoint" begin result = gradient_mooncake(senseloss(ReverseDiffAdjoint()), u0p) @test result ≈ ref_grad_senseloss end + @testset "Mooncake with TrackerAdjoint" begin + result = gradient_mooncake(senseloss(TrackerAdjoint()), u0p) + @test result ≈ ref_grad_senseloss + end + # Test with p-only differentiation (senseloss3 and senseloss4 from alternative_ad_frontend.jl) struct senseloss_p{T} sense::T @@ -479,6 +484,16 @@ Tests callable structs with different AD backends @test result ≈ ref_grad_p end end + + # Mooncake + `ForwardSensitivity` via the dedicated `MooncakeOriginator` + # dispatch added in #1420. p-only because `ForwardSensitivity` can't + # differentiate `u0`, and the Mooncake dispatch rewrites the `du0` + # slot to `NoTangent()` so Mooncake's cotangent threading doesn't trip + # on the main method's `@not_implemented` stub for `du0`. + @testset "Mooncake with ForwardSensitivity (p-only)" begin + result = gradient_mooncake(senseloss_p(ForwardSensitivity()), p_only) + @test result ≈ ref_grad_p + end end #= From 4c868c17b9746c7df0e5575301729dc6b944aaa1 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sat, 11 Apr 2026 15:42:10 -0400 Subject: [PATCH 4/4] Document why the Mooncake dispatch re-solves instead of stripping in place MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous commit attempted to replace the second `solve()` call in each Mooncake dispatch with a recursive `SciMLBase.value`-based walker (via `ConstructionBase.setproperties`) that would rebuild the returned `ODESolution` with plain type parameters. The walker does successfully strip `ReverseDiff.TrackedReal` / `Tracker.TrackedReal` / `Dual` types from nested fields (`u`, `t`, `k`, `interp.timeseries`, `interp.ts`, `interp.ks`, cache scratch arrays, …) and produces a solution on which `Mooncake.tangent_type` succeeds. It does not, however, satisfy Mooncake's `DerivedRule` type assertion: `solve()` wraps `prob.f` in a `FunctionWrappersWrappers.FunctionWrappersWrapper` during `get_concrete_problem`, and the resulting `FunctionWrapper` has no public positional constructor for `ConstructionBase.setproperties`, so a generic walker cannot rebuild it. Mooncake's inference on `CRC.rrule(solve_up, …)` nonetheless expects the `FunctionWrapper`- wrapped `ODEFunction` (because that's what a plain-arithmetic `solve()` normally returns), so anything short of actually invoking `solve()` — including substituting `sol.prob` with `remake(prob; u0, p)`, which keeps the raw `typeof(f)` instead of wrapping — produces a `TypeError` mismatch on the `ODEFunction{…, FunctionWrapper{…}, …}` slot. This commit therefore keeps the re-solve approach, and adds a comment next to each Mooncake dispatch explaining why the walker alternative was tried and abandoned. A truly general `strip_values(sol)` helper in SciMLBase would need the same `solve()` round-trip internally, so the one-extra-solve cost is unavoidable in this rrule. Co-Authored-By: Chris Rackauckas --- src/concrete_solve.jl | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 425fccc5d..4dfb212c3 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1160,6 +1160,19 @@ end # `TrackerAdjoint` do. Delegate to the `ChainRulesOriginator` path for the # sensitivity tape and re-solve the plain problem for the primal. # +# A tempting alternative is to walk the returned `primal` and strip +# tracked / augmented types via `SciMLBase.value` recursively. That +# approach *almost* works but fails on one specific slot: `solve()` wraps +# `prob.f` in a `FunctionWrappersWrappers.FunctionWrappersWrapper` during +# `get_concrete_problem`, and the resulting `FunctionWrapper` has no +# public positional constructor for `ConstructionBase.setproperties`, so a +# generic walker can't rebuild it. Mooncake's `@from_rrule` type +# inference nonetheless expects the `FunctionWrapper` version (because +# that's what `solve()` returns), so anything less than actually invoking +# `solve()` produces a type mismatch on the `DerivedRule` assertion. +# A truly general `strip_values(sol)` in SciMLBase would need the same +# `solve()` round-trip internally, so the cost is unavoidable here. +# # Additionally, the main `forward_sensitivity_backpass` returns `du0 = # @not_implemented(...)` because `ForwardSensitivity` can't differentiate # w.r.t. `u0`. Mooncake's `@from_rrule` plumbing then tries to convert that @@ -2354,14 +2367,25 @@ end # # This method delegates the tape construction (and hence the whole backward # pass) to the `ChainRulesOriginator` path, then replaces the primal with -# a fresh plain-arithmetic solve of the same problem. This keeps the -# outward-facing return type identical to what the non-sensitivity solve -# would return (i.e. `InterpolationData` and `DEStats`, not the -# `LinearInterpolation` / `Nothing` shape `build_solution` would produce), -# which is important because Mooncake's `DerivedRule` specialises on the -# inferred return type of the underlying `solve_up` call — that inference -# does not narrow through the `originator` kwarg, so the compiled rule -# expects the *main* method's return shape even on the Mooncake dispatch. +# a fresh plain-arithmetic solve of the same problem. The obvious +# alternative — walking the returned primal and stripping tracked scalars +# via `SciMLBase.value` recursively — *almost* works but fails on one +# specific slot: `solve()` wraps `prob.f` in a +# `FunctionWrappersWrappers.FunctionWrappersWrapper` during +# `get_concrete_problem`, and the resulting `FunctionWrapper` has no +# public positional constructor for `ConstructionBase.setproperties`, so a +# generic walker can't rebuild it. Mooncake's `@from_rrule` type +# inference nonetheless expects the `FunctionWrapper` version (because +# that's what `solve()` normally returns from a plain-arithmetic solve), +# so anything short of actually invoking `solve()` produces a type +# mismatch on the `DerivedRule` assertion. A truly general +# `strip_values(sol)` in SciMLBase would need the same `solve()` +# round-trip internally, so the cost is unavoidable here. +# +# Keeping this in a dedicated method dispatched on `MooncakeOriginator` +# stops Julia type inference from joining two different return shapes +# into a `Union{ODESolution{tracked…}, ODESolution{plain…}}`, which would +# otherwise trip Mooncake's `DerivedRule` type assertion. function SciMLBase._concrete_solve_adjoint( prob::Union{ SciMLBase.AbstractDiscreteProblem,