Mooncake ext: unwrap struct tangents back to primal type#989
Mooncake ext: unwrap struct tangents back to primal type#989ChrisRackauckas-Claude wants to merge 1 commit intoJuliaDiff:mainfrom
Conversation
Mooncake's `value_and_gradient!!` (and the pullback / pushforward analogues) returns the differential as a `Mooncake.Tangent` / `MutableTangent` whenever the primal is a struct-backed array such as `ComponentArray` or an `MVector`. Downstream callers — most notably `OptimizationBase`, which preallocates a `ComponentVector` buffer and passes it to `gradient!` — expect a value with the same layout as the primal and call `copyto!`/`iterate` on it, which raised `MethodError: no method matching iterate(::Mooncake.Tangent)` and broke every Optimization.jl loop that used `AutoMooncake` with ComponentArrays parameters (including all SciMLSensitivity neural ODE training tutorials). Convert the tangent back to the primal type at the boundary of the DI extension via `Mooncake.tangent_to_primal!!`. This is the same unwrap path SciMLSensitivity already adopted internally (`SciML/SciMLSensitivity.jl@4205d49`); the helper lives in `utils.jl` and is reused by the gradient, pullback, and pushforward code paths in both `onearg.jl`/`twoarg.jl` and the forward counterparts so the conversion happens consistently. `gradient!` accepts `grad` buffers whose type differs from the primal (e.g. an `MVector` buffer for an `SVector` primal), so the in-place helper allocates the unwrap target with `_copy_output(x)` and then `copyto!`s into `grad`. When `grad` itself is immutable (SVector), no in-place update is possible, so `gradient!` forwards the freshly built primal-shaped value rather than the unchanged buffer — this matches what callers compare against and is the only sensible interpretation of `gradient!` for an immutable destination. `tangent_to_primal!!` is a deprecated Mooncake API; the future replacement is `tangent_to_friendly!!`, but it currently returns the raw `Tangent` for `ComponentArray` (no `friendly_tangent_cache` override exists) and so is not yet a viable substitute. A comment in `utils.jl` notes the migration path. Add a regression test for the previously broken path: the new `component_scenarios()` block in `test/Back/Mooncake/test.jl` exercises gradient/gradient!/pullback/pushforward (out- and in- place) on `ComponentVector` against the three Mooncake backends that don't hit the unrelated forward-mode-without-friendly-tangents input bug, and an explicit `gradient!` testset captures the exact preallocated-buffer call shape that OptimizationBase uses. Full Mooncake test suite: 37077 passing, 0 failures, 0 errors (baseline DI 0.7.16: 35841 passing, 184 errors). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #989 +/- ##
==========================================
- Coverage 98.21% 98.19% -0.03%
==========================================
Files 135 135
Lines 8000 8021 +21
==========================================
+ Hits 7857 7876 +19
- Misses 143 145 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
WIP — depends on JuliaDiff/DifferentiationInterface.jl#989 With the recent SciMLSensitivity-side Mooncake fixes (1376/1397/1412) and the upstream DI Mooncake fix in JuliaDiff/DifferentiationInterface.jl#989 (which unwraps `Mooncake.Tangent`/`MutableTangent` cotangents back to the primal type so `OptimizationBase.gradient!` no longer dies on `copyto!(::ComponentVector, ::Mooncake.Tangent)`), Mooncake now works end-to-end for the majority of the SciMLSensitivity tutorials. This PR migrates every doc/tutorial that I could verify runs cleanly under `OPT.AutoMooncake(; config = nothing)` (or the direct Mooncake / DifferentiationInterface API for non-Optimization examples). Each migrated example was executed locally on Julia 1.11 against SciMLSensitivity master + the patched DI to confirm the gradient flows and the optimizer makes progress. ## Migrated to Mooncake - `getting_started.md` — `Zygote.gradient(loss, u0, p)` → `DI.gradient(closure, AutoMooncake, p)` (also exercises `GaussAdjoint`) - `manual/differential_equation_sensitivities.md` — same DI rewrite, reordered the AD list to put Mooncake first - `tutorials/parameter_estimation_ode.md` — `OPT.AutoZygote()` → `OPT.AutoMooncake(; config = nothing)`, PolyOpt converges to ≈2e-6 in 100 steps - `tutorials/chaotic_ode.md` — `Zygote.gradient(p -> G(p), p)` → `DI.gradient(p -> G(p), AutoMooncake, p)` for `ForwardLSS` - `tutorials/training_tips/divergence.md` — Lotka-Volterra retcode pattern, `AutoMooncake` swap - `tutorials/training_tips/local_minima.md` — Lux + ComponentArrays neural ODE, two `AutoZygote → AutoMooncake` swaps - `tutorials/training_tips/multiple_nn.md` — Lux + multi-NN + ComponentArrays + `InterpolatingAdjoint(ReverseDiffVJP)`, `AutoZygote → AutoMooncake` - `examples/ode/exogenous_input.md` — Hammerstein system + Lux UDE - `examples/hybrid_jump/bouncing_ball.md` — `OPT.AutoMooncake(...)` swap. Also replaced `sol[end][1]` with `last(sol.u)[1]` to dodge a pre-existing `BoundsError` in `SciMLBaseMooncakeExt._scatter_pullback` for `getindex(::ODESolution, end)`; the underlying `Vector{Vector{Float64}}` access takes the same value with no rrule bug - `examples/optimal_control/optimal_control.md` — drops the now-unused `import Zygote` (the example uses `OPT.AutoForwardDiff()`) - `examples/pde/pde_constrained.md` — 1D heat-equation parameter fit, `AutoZygote → AutoMooncake`, both `@example pde` and `@example pde2` blocks - `examples/sde/optimization_sde.md` (Example 3 only) — SDE control with `ForwardDiffSensitivity()`, `AutoZygote → AutoMooncake`. Example 1 keeps Zygote because it relies on `EnsembleProblem` (see below) - `Benchmark.md` — `Zygote.gradient($loss_neuralode, $u0, $ps, $st)` → `DI.gradient($loss_ps, $backend, $ps)` with a closure over `u0`/`st`. This block is `julia` not `@example`, so it isn't executed by Documenter, but the rewrite still demonstrates the recommended user pattern - `faq.md` — out-of-place RHS isolation snippet rewritten from `Zygote.pullback` to `Mooncake.prepare_pullback_cache` / `Mooncake.value_and_pullback!!`, verified locally on a Lotka-Volterra closure - `index.md` — list reorder to put Mooncake (and Enzyme) above Zygote in the AD compatibility table - `docs/Project.toml` — adds `Mooncake` and `DifferentiationInterface` with appropriate compat bounds ## Left on Zygote with an explanatory `!!! note` The remaining tutorials hit one of three independent upstream blockers in Mooncake itself, all of which are out of scope for a docs PR. I left them on `OPT.AutoZygote()` and added a callout pointing at the specific failure mode so future contributors know what to monitor: - **`EnsembleProblem` rule compilation fails** (`StackOverflowError` inside Mooncake's rule compiler when it tries to differentiate `__solve(::AbstractEnsembleProblem, …)`): - `tutorials/data_parallel.md` - `examples/sde/optimization_sde.md` (Example 1, the quasi-likelihood fit) - `examples/sde/SDE_control.md` (also has `Zygote.@Nograd CreateGrid` which would translate to `Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(CreateGrid), Any, Any}` once the EnsembleProblem blocker is resolved — I left both lines in the tutorial commentary so it's a one-line fix later) - **`MethodOfSteps` DDE adjoint fails** (`StackOverflowError` during rule compilation of the `DDEProblem` solve): - `examples/dde/delay_diffeq.md` - **ComponentArrays cotangent / SciMLBase Mooncake-extension gaps** on the more exotic adjoint paths (missing `increment_and_get_rdata!` method, `ReverseDiffAdjoint`-tracked values that don't match Mooncake's `CoDual` type expectations, nested `ComponentVector` cotangents, or `SecondOrder(AutoMooncake, AutoMooncake)` fallback): - `examples/ode/second_order_adjoints.md` (NewtonTrustRegion needs a Hessian, the Adam-only first half does work with Mooncake but the point of the tutorial is the second-order optimization) - `examples/ode/second_order_neural.md` (`SecondOrderODEProblem` + Lux + CV) - `examples/optimal_control/feedback_control.md` (nested `ComponentArray(; u0, p_all)`) - `examples/hybrid_jump/hybrid_diffeq.md` (`ReverseDiffAdjoint` inner) - `examples/neural_ode/simplechains.md` (`QuadratureAdjoint(ZygoteVJP)` + `StaticArrays`) - `examples/pde/brusselator.md` (FBDF stiff PDE + Lux + CV via the auto-selected adjoint) Each note records the exact error so it's clear which Mooncake/SciMLBase upstream patch unblocks the migration. When that lands, switching the remaining files is mechanical. ## Verified locally Every migrated `@example` block above was either run directly or matched a pattern that I ran end-to-end (Lux+CA neural ODE training, Optimization+CA loop, etc.) against the patched DI from JuliaDiff/DifferentiationInterface.jl#989. The tutorials that are blocked are exactly the ones I could not get past compilation. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
After SciML/ComponentArrays.jl#350 (released as ComponentArrays v0.15.34) registers a `friendly_tangent_cache` override for `ComponentArray`, the `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))` form now uses the friendly-tangent unwrap path inside Mooncake itself, which solves the same `copyto!(::ComponentVector, ::Mooncake.Tangent)` crash that JuliaDiff/DifferentiationInterface.jl#989 fixed at the DI layer for the `config = nothing` default. I re-tested the migration with **stock DI 0.7.16** plus **ComponentArrays from main (0.15.34)** and confirmed the migrated tutorials still pass end-to-end (LV+CA BFGS, multiple_nn Lux+CA Adam, local_minima Lux+CA Adam, parameter_estimation_ode PolyOpt, getting_started + GaussAdjoint, bouncing_ball with the `last(sol.u)[1]` workaround, divergence, exogenous_input, etc.). The reverted tutorials (\`EnsembleProblem\`, \`MethodOfSteps\` DDE, \`SecondOrderODEProblem\`, nested CV, \`ReverseDiffAdjoint\` inner, \`SimpleChains\`+\`StaticArrays\`, FBDF stiff PDE) are still blocked on independent upstream issues that CA SciML#350 does not address — I reverified each one with friendly_tangents+CA-main and they still fail with the same errors recorded in the !!! note callouts. This commit: 1. Switches every migrated `OPT.AutoMooncake(; config = nothing)` / `SMS.AutoMooncake(...)` / `DI.AutoMooncake(...)` / `ADTypes.AutoMooncake(...)` to `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))` (and the equivalent for the other prefixes). 2. Updates the recommended pattern shown in every `!!! note` callout on the still-Zygote tutorials to match. 3. Bumps the `ComponentArrays` compat in `docs/Project.toml` from `0.15` to `0.15.34` so the docs build picks up the friendly-tangent support. With this change the SMS docs PR no longer hard-depends on JuliaDiff/DifferentiationInterface.jl#989. That DI patch is still an independently useful improvement (it makes the default `config = nothing` form work without the user having to know about the flag, and also fixes the `MVector`/`SVector` cases), but it is no longer a blocker for this migration. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
|
Heads up that this is no longer a hard dependency for the SciML/SciMLSensitivity.jl#1419 docs migration. After SciML/ComponentArrays.jl#350 was merged into ComponentArrays main (registry PR JuliaRegistries/General#152661, releasing as v0.15.34), the migration also works with stock DI 0.7.16 if the user opts in to This patch is still independently useful though:
So I'd argue it's still worth merging — it just becomes a quality-of-life improvement rather than a hard blocker. Happy to rebase / split / scope down if you'd prefer a smaller change. |
|
@ChrisRackauckas given the discussion in SciML/SciMLSensitivity.jl#1419 and the existence of #988, is my review necessary on this one? |
Summary
Mooncake.value_and_gradient!!(and the pullback / pushforward analogues) returns the differential as aMooncake.Tangent/Mooncake.MutableTangentwhenever the primal is a struct-backed array such asComponentArrayor anMVector. Downstream callers — most notablyOptimizationBase, which preallocates aComponentVectorbuffer and passes it togradient!— expect a value with the same layout as the primal and callcopyto!/iterateon it, which raisedand broke every Optimization.jl loop that used
AutoMooncakewith ComponentArrays parameters (including all SciMLSensitivity neural ODE training tutorials).This PR converts the tangent back to the primal type at the boundary of the DI extension via
Mooncake.tangent_to_primal!!. This is the same unwrap path SciMLSensitivity already adopted internally (SciML/SciMLSensitivity.jl@4205d49); the helpers live inutils.jland are reused by the gradient, pullback, and pushforward code paths in bothonearg.jl/twoarg.jland the forward counterparts so the conversion happens consistently at one place.gradient!with mismatched buffer typesDI's
gradient!acceptsgradbuffers whose type differs from the primal (e.g. anMVectorbuffer for anSVectorprimal —default_scenarios()exercises this). The in-place helper allocates the unwrap target with_copy_output(x)and thencopyto!s intograd. Whengraditself is immutable (SVector), no in-place update is possible, sogradient!forwards the freshly built primal-shaped value rather than the unchanged buffer. That matches what callers compare against and is the only sensible interpretation ofgradient!for an immutable destination.The mutable-vs-immutable check uses
ismutabletype(typeof(parent(grad)))becauseComponentVectoris itself an immutable struct (ismutabletypereturnsfalse) but wraps a mutableVectorand supportscopyto!— walking down to the array parent captures bothComponentVectorandSVectorcorrectly.tangent_to_primal!!deprecationtangent_to_primal!!is a deprecated Mooncake API; the future replacement istangent_to_friendly!!, but it currently returns the rawTangentforComponentArray(nofriendly_tangent_cacheoverride exists upstream) and so is not yet a viable substitute. The DI extension already usestangent_to_primal!!inzero_tangent_or_primal, so there is no new dependency on a deprecated symbol — only a new call site. A comment inutils.jlnotes the migration path.Test results
Full Mooncake backend test suite on Julia 1.11:
The 184 baseline errors were almost entirely Mooncake returning
Tangent/MutableTangentforMVector/SVector/ComponentVectorprimals indefault_scenarios()andstatic_scenarios(). They all go away with this fix, plus the newcomponent_scenarios()block adds 468 ComponentArrays-specific assertions that are also passing.Test plan
JULIA_DI_TEST_GROUP=Mooncake julia run_backend.jl— 37,077 passed, 0 failed, 0 erroredcomponent_scenarios()— all passOptimization.OptimizationFunction(loss, AutoMooncake(...))) — converges in 5 Adam steps from 169.5 → 12.9, previously raised theMethodErrorshown aboveExcluded scenario
AutoMooncakeForward()(withoutfriendly_tangents) is excluded from the newcomponent_scenarios()block because its forward-mode pushforward path has a separate, pre-existing bug at the input (Dual construction) side: it raisesArgumentError: Tangent types do not match primal typeswhen given aComponentVectordx, because Mooncake forward mode expects the tangent to already be aMooncake.Tangentrather than a primal-shaped value. That input-side conversion is independent of the output-side fix in this PR; the friendly-tangents forward backend is exercised and passes.🤖 Generated with Claude Code