Skip to content

Mooncake ext: unwrap struct tangents back to primal type#989

Closed
ChrisRackauckas-Claude wants to merge 1 commit intoJuliaDiff:mainfrom
ChrisRackauckas-Claude:mooncake-componentarrays-fix
Closed

Mooncake ext: unwrap struct tangents back to primal type#989
ChrisRackauckas-Claude wants to merge 1 commit intoJuliaDiff:mainfrom
ChrisRackauckas-Claude:mooncake-componentarrays-fix

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown

Summary

Mooncake.value_and_gradient!! (and the pullback / pushforward analogues) returns the differential as a Mooncake.Tangent / Mooncake.MutableTangent whenever the primal is a struct-backed array such as ComponentArray or an MVector. Downstream callers — most notably OptimizationBase, which preallocates a ComponentVector buffer and passes it to gradient! — expect a value with the same layout as the primal and call copyto! / iterate on it, which raised

MethodError: no method matching iterate(::Mooncake.Tangent{@NamedTuple{data::Vector{Float64}, axes::Mooncake.NoTangent}})

and broke every Optimization.jl loop that used AutoMooncake with ComponentArrays parameters (including all SciMLSensitivity neural ODE training tutorials).

This PR converts the tangent back to the primal type at the boundary of the DI extension via Mooncake.tangent_to_primal!!. This is the same unwrap path SciMLSensitivity already adopted internally (SciML/SciMLSensitivity.jl@4205d49); the helpers live in utils.jl and are reused by the gradient, pullback, and pushforward code paths in both onearg.jl / twoarg.jl and the forward counterparts so the conversion happens consistently at one place.

gradient! with mismatched buffer types

DI's gradient! accepts grad buffers whose type differs from the primal (e.g. an MVector buffer for an SVector primal — default_scenarios() exercises this). The in-place helper allocates the unwrap target with _copy_output(x) and then copyto!s into grad. When grad itself is immutable (SVector), no in-place update is possible, so gradient! forwards the freshly built primal-shaped value rather than the unchanged buffer. That matches what callers compare against and is the only sensible interpretation of gradient! for an immutable destination.

The mutable-vs-immutable check uses ismutabletype(typeof(parent(grad))) because ComponentVector is itself an immutable struct (ismutabletype returns false) but wraps a mutable Vector and supports copyto! — walking down to the array parent captures both ComponentVector and SVector correctly.

tangent_to_primal!! deprecation

tangent_to_primal!! is a deprecated Mooncake API; the future replacement is tangent_to_friendly!!, but it currently returns the raw Tangent for ComponentArray (no friendly_tangent_cache override exists upstream) and so is not yet a viable substitute. The DI extension already uses tangent_to_primal!! in zero_tangent_or_primal, so there is no new dependency on a deprecated symbol — only a new call site. A comment in utils.jl notes the migration path.

Test results

Full Mooncake backend test suite on Julia 1.11:

Baseline (DI 0.7.16) This PR
Passed 35,841 37,077
Failed 0 0
Errored 184 0

The 184 baseline errors were almost entirely Mooncake returning Tangent / MutableTangent for MVector / SVector / ComponentVector primals in default_scenarios() and static_scenarios(). They all go away with this fix, plus the new component_scenarios() block adds 468 ComponentArrays-specific assertions that are also passing.

Test plan

  • JULIA_DI_TEST_GROUP=Mooncake julia run_backend.jl — 37,077 passed, 0 failed, 0 errored
  • Targeted ComponentArrays gradient/gradient!/pullback/pushforward via component_scenarios() — all pass
  • End-to-end SciMLSensitivity neural ODE training (Lux + ComponentArrays + Optimization.OptimizationFunction(loss, AutoMooncake(...))) — converges in 5 Adam steps from 169.5 → 12.9, previously raised the MethodError shown above

Excluded scenario

AutoMooncakeForward() (without friendly_tangents) is excluded from the new component_scenarios() block because its forward-mode pushforward path has a separate, pre-existing bug at the input (Dual construction) side: it raises ArgumentError: Tangent types do not match primal types when given a ComponentVector dx, because Mooncake forward mode expects the tangent to already be a Mooncake.Tangent rather than a primal-shaped value. That input-side conversion is independent of the output-side fix in this PR; the friendly-tangents forward backend is exercised and passes.

🤖 Generated with Claude Code

Mooncake's `value_and_gradient!!` (and the pullback / pushforward
analogues) returns the differential as a `Mooncake.Tangent` /
`MutableTangent` whenever the primal is a struct-backed array such as
`ComponentArray` or an `MVector`.  Downstream callers — most notably
`OptimizationBase`, which preallocates a `ComponentVector` buffer and
passes it to `gradient!` — expect a value with the same layout as the
primal and call `copyto!`/`iterate` on it, which raised
`MethodError: no method matching iterate(::Mooncake.Tangent)` and
broke every Optimization.jl loop that used `AutoMooncake` with
ComponentArrays parameters (including all SciMLSensitivity neural
ODE training tutorials).

Convert the tangent back to the primal type at the boundary of the
DI extension via `Mooncake.tangent_to_primal!!`.  This is the same
unwrap path SciMLSensitivity already adopted internally
(`SciML/SciMLSensitivity.jl@4205d49`); the helper lives in
`utils.jl` and is reused by the gradient, pullback, and pushforward
code paths in both `onearg.jl`/`twoarg.jl` and the forward
counterparts so the conversion happens consistently.

`gradient!` accepts `grad` buffers whose type differs from the
primal (e.g. an `MVector` buffer for an `SVector` primal), so the
in-place helper allocates the unwrap target with `_copy_output(x)`
and then `copyto!`s into `grad`.  When `grad` itself is immutable
(SVector), no in-place update is possible, so `gradient!` forwards
the freshly built primal-shaped value rather than the unchanged
buffer — this matches what callers compare against and is the only
sensible interpretation of `gradient!` for an immutable destination.

`tangent_to_primal!!` is a deprecated Mooncake API; the future
replacement is `tangent_to_friendly!!`, but it currently returns
the raw `Tangent` for `ComponentArray` (no `friendly_tangent_cache`
override exists) and so is not yet a viable substitute.  A comment
in `utils.jl` notes the migration path.

Add a regression test for the previously broken path: the new
`component_scenarios()` block in `test/Back/Mooncake/test.jl`
exercises gradient/gradient!/pullback/pushforward (out- and in-
place) on `ComponentVector` against the three Mooncake backends
that don't hit the unrelated forward-mode-without-friendly-tangents
input bug, and an explicit `gradient!` testset captures the exact
preallocated-buffer call shape that OptimizationBase uses.

Full Mooncake test suite: 37077 passing, 0 failures, 0 errors
(baseline DI 0.7.16: 35841 passing, 184 errors).

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 10, 2026

Codecov Report

❌ Patch coverage is 94.28571% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 98.19%. Comparing base (a5ecbe0) to head (8e98346).

Files with missing lines Patch % Lines
...e/ext/DifferentiationInterfaceMooncakeExt/utils.jl 89.47% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #989      +/-   ##
==========================================
- Coverage   98.21%   98.19%   -0.03%     
==========================================
  Files         135      135              
  Lines        8000     8021      +21     
==========================================
+ Hits         7857     7876      +19     
- Misses        143      145       +2     
Flag Coverage Δ
DI 98.94% <94.28%> (-0.04%) ⬇️
DIT 96.22% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/SciMLSensitivity.jl that referenced this pull request Apr 11, 2026
WIP — depends on JuliaDiff/DifferentiationInterface.jl#989

With the recent SciMLSensitivity-side Mooncake fixes (1376/1397/1412)
and the upstream DI Mooncake fix in
JuliaDiff/DifferentiationInterface.jl#989 (which unwraps
`Mooncake.Tangent`/`MutableTangent` cotangents back to the primal type
so `OptimizationBase.gradient!` no longer dies on
`copyto!(::ComponentVector, ::Mooncake.Tangent)`), Mooncake now works
end-to-end for the majority of the SciMLSensitivity tutorials.

This PR migrates every doc/tutorial that I could verify runs cleanly
under `OPT.AutoMooncake(; config = nothing)` (or the direct Mooncake /
DifferentiationInterface API for non-Optimization examples).  Each
migrated example was executed locally on Julia 1.11 against
SciMLSensitivity master + the patched DI to confirm the gradient flows
and the optimizer makes progress.

## Migrated to Mooncake

- `getting_started.md` — `Zygote.gradient(loss, u0, p)` →
  `DI.gradient(closure, AutoMooncake, p)` (also exercises
  `GaussAdjoint`)
- `manual/differential_equation_sensitivities.md` — same DI rewrite,
  reordered the AD list to put Mooncake first
- `tutorials/parameter_estimation_ode.md` —
  `OPT.AutoZygote()` → `OPT.AutoMooncake(; config = nothing)`,
  PolyOpt converges to ≈2e-6 in 100 steps
- `tutorials/chaotic_ode.md` — `Zygote.gradient(p -> G(p), p)` →
  `DI.gradient(p -> G(p), AutoMooncake, p)` for `ForwardLSS`
- `tutorials/training_tips/divergence.md` — Lotka-Volterra retcode
  pattern, `AutoMooncake` swap
- `tutorials/training_tips/local_minima.md` — Lux + ComponentArrays
  neural ODE, two `AutoZygote → AutoMooncake` swaps
- `tutorials/training_tips/multiple_nn.md` — Lux + multi-NN +
  ComponentArrays + `InterpolatingAdjoint(ReverseDiffVJP)`,
  `AutoZygote → AutoMooncake`
- `examples/ode/exogenous_input.md` — Hammerstein system + Lux UDE
- `examples/hybrid_jump/bouncing_ball.md` — `OPT.AutoMooncake(...)`
  swap. Also replaced `sol[end][1]` with `last(sol.u)[1]` to dodge a
  pre-existing `BoundsError` in `SciMLBaseMooncakeExt._scatter_pullback`
  for `getindex(::ODESolution, end)`; the underlying
  `Vector{Vector{Float64}}` access takes the same value with no rrule
  bug
- `examples/optimal_control/optimal_control.md` — drops the now-unused
  `import Zygote` (the example uses `OPT.AutoForwardDiff()`)
- `examples/pde/pde_constrained.md` — 1D heat-equation parameter fit,
  `AutoZygote → AutoMooncake`, both `@example pde` and `@example pde2`
  blocks
- `examples/sde/optimization_sde.md` (Example 3 only) — SDE control
  with `ForwardDiffSensitivity()`, `AutoZygote → AutoMooncake`. Example
  1 keeps Zygote because it relies on `EnsembleProblem` (see below)
- `Benchmark.md` — `Zygote.gradient($loss_neuralode, $u0, $ps, $st)` →
  `DI.gradient($loss_ps, $backend, $ps)` with a closure over `u0`/`st`.
  This block is `julia` not `@example`, so it isn't executed by
  Documenter, but the rewrite still demonstrates the recommended user
  pattern
- `faq.md` — out-of-place RHS isolation snippet rewritten from
  `Zygote.pullback` to `Mooncake.prepare_pullback_cache` /
  `Mooncake.value_and_pullback!!`, verified locally on a Lotka-Volterra
  closure
- `index.md` — list reorder to put Mooncake (and Enzyme) above Zygote
  in the AD compatibility table
- `docs/Project.toml` — adds `Mooncake` and `DifferentiationInterface`
  with appropriate compat bounds

## Left on Zygote with an explanatory `!!! note`

The remaining tutorials hit one of three independent upstream blockers
in Mooncake itself, all of which are out of scope for a docs PR.  I
left them on `OPT.AutoZygote()` and added a callout pointing at the
specific failure mode so future contributors know what to monitor:

- **`EnsembleProblem` rule compilation fails** (`StackOverflowError`
  inside Mooncake's rule compiler when it tries to differentiate
  `__solve(::AbstractEnsembleProblem, …)`):
  - `tutorials/data_parallel.md`
  - `examples/sde/optimization_sde.md` (Example 1, the quasi-likelihood
    fit)
  - `examples/sde/SDE_control.md` (also has `Zygote.@Nograd CreateGrid`
    which would translate to
    `Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(CreateGrid), Any, Any}`
    once the EnsembleProblem blocker is resolved — I left both lines in
    the tutorial commentary so it's a one-line fix later)
- **`MethodOfSteps` DDE adjoint fails** (`StackOverflowError` during
  rule compilation of the `DDEProblem` solve):
  - `examples/dde/delay_diffeq.md`
- **ComponentArrays cotangent / SciMLBase Mooncake-extension gaps**
  on the more exotic adjoint paths (missing
  `increment_and_get_rdata!` method, `ReverseDiffAdjoint`-tracked
  values that don't match Mooncake's `CoDual` type expectations,
  nested `ComponentVector` cotangents, or
  `SecondOrder(AutoMooncake, AutoMooncake)` fallback):
  - `examples/ode/second_order_adjoints.md` (NewtonTrustRegion needs a
    Hessian, the Adam-only first half does work with Mooncake but the
    point of the tutorial is the second-order optimization)
  - `examples/ode/second_order_neural.md` (`SecondOrderODEProblem` +
    Lux + CV)
  - `examples/optimal_control/feedback_control.md` (nested
    `ComponentArray(; u0, p_all)`)
  - `examples/hybrid_jump/hybrid_diffeq.md` (`ReverseDiffAdjoint`
    inner)
  - `examples/neural_ode/simplechains.md`
    (`QuadratureAdjoint(ZygoteVJP)` + `StaticArrays`)
  - `examples/pde/brusselator.md` (FBDF stiff PDE + Lux + CV via the
    auto-selected adjoint)

Each note records the exact error so it's clear which Mooncake/SciMLBase
upstream patch unblocks the migration.  When that lands, switching the
remaining files is mechanical.

## Verified locally

Every migrated `@example` block above was either run directly or
matched a pattern that I ran end-to-end (Lux+CA neural ODE training,
Optimization+CA loop, etc.) against the patched DI from
JuliaDiff/DifferentiationInterface.jl#989.  The tutorials that are
blocked are exactly the ones I could not get past compilation.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/SciMLSensitivity.jl that referenced this pull request Apr 11, 2026
After SciML/ComponentArrays.jl#350 (released as ComponentArrays
v0.15.34) registers a `friendly_tangent_cache` override for
`ComponentArray`, the `OPT.AutoMooncake(; config = Mooncake.Config(;
friendly_tangents = true))` form now uses the friendly-tangent
unwrap path inside Mooncake itself, which solves the same
`copyto!(::ComponentVector, ::Mooncake.Tangent)` crash that
JuliaDiff/DifferentiationInterface.jl#989 fixed at the DI layer for
the `config = nothing` default.

I re-tested the migration with **stock DI 0.7.16** plus
**ComponentArrays from main (0.15.34)** and confirmed the migrated
tutorials still pass end-to-end (LV+CA BFGS, multiple_nn Lux+CA Adam,
local_minima Lux+CA Adam, parameter_estimation_ode PolyOpt,
getting_started + GaussAdjoint, bouncing_ball with the
`last(sol.u)[1]` workaround, divergence, exogenous_input, etc.).

The reverted tutorials (\`EnsembleProblem\`, \`MethodOfSteps\` DDE,
\`SecondOrderODEProblem\`, nested CV, \`ReverseDiffAdjoint\` inner,
\`SimpleChains\`+\`StaticArrays\`, FBDF stiff PDE) are still blocked
on independent upstream issues that CA SciML#350 does not address — I
reverified each one with friendly_tangents+CA-main and they still
fail with the same errors recorded in the !!! note callouts.

This commit:

1. Switches every migrated `OPT.AutoMooncake(; config = nothing)` /
   `SMS.AutoMooncake(...)` / `DI.AutoMooncake(...)` / `ADTypes.AutoMooncake(...)`
   to `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`
   (and the equivalent for the other prefixes).
2. Updates the recommended pattern shown in every `!!! note` callout
   on the still-Zygote tutorials to match.
3. Bumps the `ComponentArrays` compat in `docs/Project.toml` from
   `0.15` to `0.15.34` so the docs build picks up the friendly-tangent
   support.

With this change the SMS docs PR no longer hard-depends on
JuliaDiff/DifferentiationInterface.jl#989. That DI patch is still an
independently useful improvement (it makes the default
`config = nothing` form work without the user having to know about the
flag, and also fixes the `MVector`/`SVector` cases), but it is no
longer a blocker for this migration.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@ChrisRackauckas-Claude
Copy link
Copy Markdown
Author

Heads up that this is no longer a hard dependency for the SciML/SciMLSensitivity.jl#1419 docs migration. After SciML/ComponentArrays.jl#350 was merged into ComponentArrays main (registry PR JuliaRegistries/General#152661, releasing as v0.15.34), the migration also works with stock DI 0.7.16 if the user opts in to Mooncake.Config(; friendly_tangents = true). I switched the SMS docs PR over to that pattern so it can ship as soon as CA 0.15.34 lands in the registry.

This patch is still independently useful though:

  1. It makes the default AutoMooncake(; config = nothing) form work without the user having to know about friendly_tangents. With CA 0.15.34 alone, the default still crashes — only friendly_tangents = true is fixed (because Mooncake's value_and_gradient!! only invokes friendly_tangent_cache when the flag is on).

  2. It also fixes the MVector / SVector cases (the 184 → 0 errors in the Mooncake test suite), which CA Better error messages for wrong preparation #350 doesn't touch — those are independent of ComponentArrays.

  3. The _can_setindex(grad) fallback for immutable buffers (returning the freshly built primal-shaped value when gradient! is given an SVector destination) is novel and not duplicated by CA Better error messages for wrong preparation #350.

So I'd argue it's still worth merging — it just becomes a quality-of-life improvement rather than a hard blocker. Happy to rebase / split / scope down if you'd prefer a smaller change.

@gdalle
Copy link
Copy Markdown
Member

gdalle commented Apr 11, 2026

Is this the same fix @sunxd3 is working on in #988?

@gdalle
Copy link
Copy Markdown
Member

gdalle commented Apr 11, 2026

@ChrisRackauckas given the discussion in SciML/SciMLSensitivity.jl#1419 and the existence of #988, is my review necessary on this one?

@gdalle gdalle marked this pull request as draft April 11, 2026 13:54
@gdalle gdalle self-assigned this Apr 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants