Skip to content

Mooncake ext: handle ComponentVector cotangents in increment_and_get_rdata!#351

Closed
ChrisRackauckas-Claude wants to merge 1 commit intoSciML:mainfrom
ChrisRackauckas-Claude:mooncake-componentvector-rdata
Closed

Mooncake ext: handle ComponentVector cotangents in increment_and_get_rdata!#351
ChrisRackauckas-Claude wants to merge 1 commit intoSciML:mainfrom
ChrisRackauckas-Claude:mooncake-componentvector-rdata

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Summary

Follow-up to #350. Adds the missing increment_and_get_rdata! dispatch for the case where a chain rule (or @from_chainrules/@from_rrule-generated rule) emits a cotangent for a ComponentVector as a ComponentVector (not as the underlying backing Vector).

The existing dispatch only matches t::Array{<:IEEEFloat}. When t is a ComponentVector (which is what SciMLSensitivity's adjoint backpass produces for the parameter cotangent), dispatch falls through to the generic @from_rrule path that errors:

ArgumentError: The fdata type Mooncake.FData{@NamedTuple{data::Vector{Float32}, axes::Mooncake.NoFData}}, rdata type Mooncake.NoRData, and tangent type ComponentArrays.ComponentVector{...} combination is not supported with @from_chainrules or @from_rrule. This is because Mooncake.jl does not currently have a method of \`increment_and_get_rdata!\` to handle this type combination.

This blocks every Lux-based UDE / neural-ODE training loop that uses SciMLSensitivity's GaussAdjoint(autojacvec=ZygoteVJP) (or any of the other ChainRules-based adjoints) under AutoMooncake. The SciMLSensitivity adjoint backpass returns the parameter gradient as a ComponentVector, but the upstream FData is the raw underlying Vector, so there's a type mismatch the existing method can't bridge.

Fix

Add the missing dispatch: when t is a ComponentVector whose underlying data type matches the FData's payload, strip the wrapper with getdata(t) and forward to the existing accumulator.

Test plan

  • End-to-end on the SciMLSensitivity tutorials that were previously blocked:
  • Used OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) in all cases (the path enabled by Add friendly_tangent_cache function to Mooncake #350)
  • Happy to add a unit test if you'd like one — let me know what shape

🤖 Generated with Claude Code

…rdata!

The existing `increment_and_get_rdata!` method only matched
`t::Array{<:IEEEFloat}`. When a chain rule (or
`@from_chainrules`/`@from_rrule`-generated rule) emits a cotangent for
a `ComponentVector` *as a `ComponentVector`* (not as the underlying
backing `Vector`), the dispatch falls through to the generic
`@from_rrule` path that errors:

    ArgumentError: The fdata type Mooncake.FData{@NamedTuple{data::Vector{Float32},
    axes::Mooncake.NoFData}}, rdata type Mooncake.NoRData, and tangent type
    ComponentArrays.ComponentVector{...} combination is not supported with
    @from_chainrules or @from_rrule. This is because Mooncake.jl does not currently
    have a method of `increment_and_get_rdata!` to handle this type combination.

This blocks every Lux-based UDE / neural-ODE training loop that uses
SciMLSensitivity's `GaussAdjoint(autojacvec=ZygoteVJP)` (or similar
ChainRules-based adjoint) under `AutoMooncake`, because the SciMLSensitivity
adjoint backpass returns the parameter gradient as a `ComponentVector`
but the upstream FData is the raw underlying `Vector`.

Add the missing dispatch: when `t` is a `ComponentVector` whose
underlying data type matches the FData's payload, strip the wrapper
with `getdata(t)` and forward to the existing accumulator.

I verified the fix end-to-end on the SciMLSensitivity tutorials that
were previously blocked: hybrid pharmacometric ODE with default
sensealg, Lux + ComponentArrays neural ODE training, brusselator UDE
with FBDF stiff PDE solve all now train with
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@AstitvaAggarwal
Copy link
Copy Markdown
Member

I was going to come back to this as this was the only remaining to-do for ComponentArrays.jl from the original SciMLsensitivity issue 😅, fyi DiffEqGPU.jl works fine as is.

ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/SciMLSensitivity.jl that referenced this pull request Apr 11, 2026
…er notes

After investigating the four blocked tutorials more carefully and adding
the missing `increment_and_get_rdata!` dispatch for `ComponentVector`
cotangents in ComponentArrays' Mooncake extension
(SciML/ComponentArrays.jl#351), three more tutorials are now Mooncake-
compatible end-to-end:

## hybrid_diffeq.md (un-reverted)

The original file pinned `sensealg = SMS.ReverseDiffAdjoint()` explicitly.
The continuous adjoints (`BacksolveAdjoint`, `InterpolatingAdjoint`,
`GaussAdjoint`, `QuadratureAdjoint`) are now compatible with callbacks
for ODEs, so the explicit `ReverseDiffAdjoint` choice is no longer
necessary. Drop it and let the default sensealg auto-pick. Combined
with CA SciML#351 (which fixes the `increment_and_get_rdata!` mismatch on
the parameter cotangent path), the example now trains under
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.

## delay_diffeq.md (un-reverted)

Same story — the original file pinned `sensealg = SMS.ReverseDiffAdjoint()`
explicitly, which Mooncake hits with a `StackOverflowError` during rule
compilation.  The default sensealg for DDEs is `ForwardDiffSensitivity()`
for problems with fewer than 100 parameters (SciMLSensitivity.jl
`concrete_solve.jl:434-454`), and that path uses ForwardDiff dual
numbers inside the rrule — which Mooncake handles fine.  Drop the
explicit `ReverseDiffAdjoint` and let the default pick. Replace the
narrative line that explained the explicit choice with a note about
the automatic ForwardDiff/ReverseDiff fallback for DDEs (continuous
adjoints are not yet defined for DDEs, so the discretize-then-optimize
methods are the only option here).

## brusselator.md (un-reverted)

CA SciML#351 also unblocks this — the FBDF stiff-PDE adjoint with Lux+CV
parameters was the same `increment_and_get_rdata!` mismatch, and once
that dispatch lands the default `GaussAdjoint(ZygoteVJP)` flows
through Mooncake without further changes.

## simplechains.md (note expanded)

I tested the full matrix (default, `QuadratureAdjoint(ZygoteVJP)`,
`QuadratureAdjoint(MooncakeVJP)`, `InterpolatingAdjoint(ReverseDiffVJP)`,
`GaussAdjoint(MooncakeVJP)`) and **none of them work** with the
SimpleChains+`StaticArrays` out-of-place flow.  Each fails for a
different reason — the new note enumerates all four with the exact
upstream symptom so future contributors know which layer needs to grow
the missing dispatch.  Notable findings:

  - The default sensealg picks `GaussAdjoint`, which trips an
    `@assert sensealg isa QuadratureAdjoint` in
    `adjoint_common.jl:747` because `u::SVector` is immutable and
    only `QuadratureAdjoint` is wired up for the immutable-state path.
  - `QuadratureAdjoint(autojacvec=ZygoteVJP())` (the explicit choice
    in the file) emits a `ChainRulesCore.Tangent` cotangent that
    SciMLSensitivity's `df_iip`/`df_oop` adjoint backpass can't unwrap
    — Mooncake's pullback fails with a `BoundsError` accessing the
    nested `Tangent` fields.  Zygote produces a different cotangent
    shape that flows through cleanly, which is why the tutorial works
    on `AutoZygote` but not `AutoMooncake`.
  - `QuadratureAdjoint(autojacvec=MooncakeVJP())` and
    `GaussAdjoint(autojacvec=MooncakeVJP())` both fail with
    `setindex!(::SVector, …)` — `MooncakeVJP` mutates the cotangent
    buffer in place, which has no method for static arrays.
  - `InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))` fails with
    `conversion to pointer not defined for ReverseDiff.TrackedArray` —
    SimpleChains reaches into raw pointer storage that's incompatible
    with ReverseDiff-tracked types.

## second_order_neural.md (note refined)

This is **not a missing rule**. `SecondOrderODEProblem` constructs an
`ODEProblem{…, SciMLBase.SecondOrderODEProblem{false}}` wrapping a
`DynamicalODEFunction`, so the existing
`_concrete_solve_adjoint(::AbstractODEProblem, …)` methods dispatch
fine.  The actual blocker is a `df_iip`/`df_oop` bug in
`SciMLSensitivity/src/concrete_solve.jl`: when the state is an
`ArrayPartition{Tuple{Vector,Vector}}` (which is what
`SecondOrderODEProblem` uses internally), the Mooncake-originated
cotangent comes back shaped as
`ChainRulesCore.Tangent{NamedTuple{x::Tangent{Tuple{Vector,Vector}}}}`,
and the adjoint backpass calls `vec(x)` on this nested `Tangent` and
raises `MethodError: no method matching vec(::ChainRulesCore.Tangent)`.
Zygote happens to produce a different (recursively-array-shaped)
cotangent that flows through, which is why the tutorial works on
Zygote but not Mooncake.  The note now records this precisely so the
fix path is clear: add a `Tangent` → `ArrayPartition` unwrap inside
`df_iip`/`df_oop`.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

Superseded by #352, which adds the same dispatch for the more-general ComponentArray cotangent case (plus the SubArray-backed ComponentVector case my PR didn't touch). Verified end-to-end with the SciMLSensitivity tutorial migration in SciML/SciMLSensitivity.jl#1419 against CA main. Closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants