Mooncake ext: handle ComponentVector cotangents in increment_and_get_rdata!#351
Closed
ChrisRackauckas-Claude wants to merge 1 commit intoSciML:mainfrom
Closed
Conversation
…rdata!
The existing `increment_and_get_rdata!` method only matched
`t::Array{<:IEEEFloat}`. When a chain rule (or
`@from_chainrules`/`@from_rrule`-generated rule) emits a cotangent for
a `ComponentVector` *as a `ComponentVector`* (not as the underlying
backing `Vector`), the dispatch falls through to the generic
`@from_rrule` path that errors:
ArgumentError: The fdata type Mooncake.FData{@NamedTuple{data::Vector{Float32},
axes::Mooncake.NoFData}}, rdata type Mooncake.NoRData, and tangent type
ComponentArrays.ComponentVector{...} combination is not supported with
@from_chainrules or @from_rrule. This is because Mooncake.jl does not currently
have a method of `increment_and_get_rdata!` to handle this type combination.
This blocks every Lux-based UDE / neural-ODE training loop that uses
SciMLSensitivity's `GaussAdjoint(autojacvec=ZygoteVJP)` (or similar
ChainRules-based adjoint) under `AutoMooncake`, because the SciMLSensitivity
adjoint backpass returns the parameter gradient as a `ComponentVector`
but the upstream FData is the raw underlying `Vector`.
Add the missing dispatch: when `t` is a `ComponentVector` whose
underlying data type matches the FData's payload, strip the wrapper
with `getdata(t)` and forward to the existing accumulator.
I verified the fix end-to-end on the SciMLSensitivity tutorials that
were previously blocked: hybrid pharmacometric ODE with default
sensealg, Lux + ComponentArrays neural ODE training, brusselator UDE
with FBDF stiff PDE solve all now train with
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Member
|
I was going to come back to this as this was the only remaining to-do for ComponentArrays.jl from the original SciMLsensitivity issue 😅, fyi |
ChrisRackauckas-Claude
pushed a commit
to ChrisRackauckas-Claude/SciMLSensitivity.jl
that referenced
this pull request
Apr 11, 2026
…er notes After investigating the four blocked tutorials more carefully and adding the missing `increment_and_get_rdata!` dispatch for `ComponentVector` cotangents in ComponentArrays' Mooncake extension (SciML/ComponentArrays.jl#351), three more tutorials are now Mooncake- compatible end-to-end: ## hybrid_diffeq.md (un-reverted) The original file pinned `sensealg = SMS.ReverseDiffAdjoint()` explicitly. The continuous adjoints (`BacksolveAdjoint`, `InterpolatingAdjoint`, `GaussAdjoint`, `QuadratureAdjoint`) are now compatible with callbacks for ODEs, so the explicit `ReverseDiffAdjoint` choice is no longer necessary. Drop it and let the default sensealg auto-pick. Combined with CA SciML#351 (which fixes the `increment_and_get_rdata!` mismatch on the parameter cotangent path), the example now trains under `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`. ## delay_diffeq.md (un-reverted) Same story — the original file pinned `sensealg = SMS.ReverseDiffAdjoint()` explicitly, which Mooncake hits with a `StackOverflowError` during rule compilation. The default sensealg for DDEs is `ForwardDiffSensitivity()` for problems with fewer than 100 parameters (SciMLSensitivity.jl `concrete_solve.jl:434-454`), and that path uses ForwardDiff dual numbers inside the rrule — which Mooncake handles fine. Drop the explicit `ReverseDiffAdjoint` and let the default pick. Replace the narrative line that explained the explicit choice with a note about the automatic ForwardDiff/ReverseDiff fallback for DDEs (continuous adjoints are not yet defined for DDEs, so the discretize-then-optimize methods are the only option here). ## brusselator.md (un-reverted) CA SciML#351 also unblocks this — the FBDF stiff-PDE adjoint with Lux+CV parameters was the same `increment_and_get_rdata!` mismatch, and once that dispatch lands the default `GaussAdjoint(ZygoteVJP)` flows through Mooncake without further changes. ## simplechains.md (note expanded) I tested the full matrix (default, `QuadratureAdjoint(ZygoteVJP)`, `QuadratureAdjoint(MooncakeVJP)`, `InterpolatingAdjoint(ReverseDiffVJP)`, `GaussAdjoint(MooncakeVJP)`) and **none of them work** with the SimpleChains+`StaticArrays` out-of-place flow. Each fails for a different reason — the new note enumerates all four with the exact upstream symptom so future contributors know which layer needs to grow the missing dispatch. Notable findings: - The default sensealg picks `GaussAdjoint`, which trips an `@assert sensealg isa QuadratureAdjoint` in `adjoint_common.jl:747` because `u::SVector` is immutable and only `QuadratureAdjoint` is wired up for the immutable-state path. - `QuadratureAdjoint(autojacvec=ZygoteVJP())` (the explicit choice in the file) emits a `ChainRulesCore.Tangent` cotangent that SciMLSensitivity's `df_iip`/`df_oop` adjoint backpass can't unwrap — Mooncake's pullback fails with a `BoundsError` accessing the nested `Tangent` fields. Zygote produces a different cotangent shape that flows through cleanly, which is why the tutorial works on `AutoZygote` but not `AutoMooncake`. - `QuadratureAdjoint(autojacvec=MooncakeVJP())` and `GaussAdjoint(autojacvec=MooncakeVJP())` both fail with `setindex!(::SVector, …)` — `MooncakeVJP` mutates the cotangent buffer in place, which has no method for static arrays. - `InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))` fails with `conversion to pointer not defined for ReverseDiff.TrackedArray` — SimpleChains reaches into raw pointer storage that's incompatible with ReverseDiff-tracked types. ## second_order_neural.md (note refined) This is **not a missing rule**. `SecondOrderODEProblem` constructs an `ODEProblem{…, SciMLBase.SecondOrderODEProblem{false}}` wrapping a `DynamicalODEFunction`, so the existing `_concrete_solve_adjoint(::AbstractODEProblem, …)` methods dispatch fine. The actual blocker is a `df_iip`/`df_oop` bug in `SciMLSensitivity/src/concrete_solve.jl`: when the state is an `ArrayPartition{Tuple{Vector,Vector}}` (which is what `SecondOrderODEProblem` uses internally), the Mooncake-originated cotangent comes back shaped as `ChainRulesCore.Tangent{NamedTuple{x::Tangent{Tuple{Vector,Vector}}}}`, and the adjoint backpass calls `vec(x)` on this nested `Tangent` and raises `MethodError: no method matching vec(::ChainRulesCore.Tangent)`. Zygote happens to produce a different (recursively-array-shaped) cotangent that flows through, which is why the tutorial works on Zygote but not Mooncake. The note now records this precisely so the fix path is clear: add a `Tangent` → `ArrayPartition` unwrap inside `df_iip`/`df_oop`. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Contributor
Author
|
Superseded by #352, which adds the same dispatch for the more-general |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up to #350. Adds the missing
increment_and_get_rdata!dispatch for the case where a chain rule (or@from_chainrules/@from_rrule-generated rule) emits a cotangent for aComponentVectoras aComponentVector(not as the underlying backingVector).The existing dispatch only matches
t::Array{<:IEEEFloat}. Whentis aComponentVector(which is what SciMLSensitivity's adjoint backpass produces for the parameter cotangent), dispatch falls through to the generic@from_rrulepath that errors:This blocks every Lux-based UDE / neural-ODE training loop that uses SciMLSensitivity's
GaussAdjoint(autojacvec=ZygoteVJP)(or any of the other ChainRules-based adjoints) underAutoMooncake. The SciMLSensitivity adjoint backpass returns the parameter gradient as aComponentVector, but the upstream FData is the raw underlyingVector, so there's a type mismatch the existing method can't bridge.Fix
Add the missing dispatch: when
tis aComponentVectorwhose underlying data type matches the FData's payload, strip the wrapper withgetdata(t)and forward to the existing accumulator.Test plan
GaussAdjoint(ZygoteVJP)— now trainssecond_order_adjoints.mdAdam phase — now trains!!! notecallouts from several of those tutorialsOPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))in all cases (the path enabled by Add friendly_tangent_cache function to Mooncake #350)🤖 Generated with Claude Code