Skip to content

Commit 8e98346

Browse files
Mooncake ext: unwrap struct tangents back to primal type
Mooncake's `value_and_gradient!!` (and the pullback / pushforward analogues) returns the differential as a `Mooncake.Tangent` / `MutableTangent` whenever the primal is a struct-backed array such as `ComponentArray` or an `MVector`. Downstream callers — most notably `OptimizationBase`, which preallocates a `ComponentVector` buffer and passes it to `gradient!` — expect a value with the same layout as the primal and call `copyto!`/`iterate` on it, which raised `MethodError: no method matching iterate(::Mooncake.Tangent)` and broke every Optimization.jl loop that used `AutoMooncake` with ComponentArrays parameters (including all SciMLSensitivity neural ODE training tutorials). Convert the tangent back to the primal type at the boundary of the DI extension via `Mooncake.tangent_to_primal!!`. This is the same unwrap path SciMLSensitivity already adopted internally (`SciML/SciMLSensitivity.jl@4205d49`); the helper lives in `utils.jl` and is reused by the gradient, pullback, and pushforward code paths in both `onearg.jl`/`twoarg.jl` and the forward counterparts so the conversion happens consistently. `gradient!` accepts `grad` buffers whose type differs from the primal (e.g. an `MVector` buffer for an `SVector` primal), so the in-place helper allocates the unwrap target with `_copy_output(x)` and then `copyto!`s into `grad`. When `grad` itself is immutable (SVector), no in-place update is possible, so `gradient!` forwards the freshly built primal-shaped value rather than the unchanged buffer — this matches what callers compare against and is the only sensible interpretation of `gradient!` for an immutable destination. `tangent_to_primal!!` is a deprecated Mooncake API; the future replacement is `tangent_to_friendly!!`, but it currently returns the raw `Tangent` for `ComponentArray` (no `friendly_tangent_cache` override exists) and so is not yet a viable substitute. A comment in `utils.jl` notes the migration path. Add a regression test for the previously broken path: the new `component_scenarios()` block in `test/Back/Mooncake/test.jl` exercises gradient/gradient!/pullback/pushforward (out- and in- place) on `ComponentVector` against the three Mooncake backends that don't hit the unrelated forward-mode-without-friendly-tangents input bug, and an explicit `gradient!` testset captures the exact preallocated-buffer call shape that OptimizationBase uses. Full Mooncake test suite: 37077 passing, 0 failures, 0 errors (baseline DI 0.7.16: 35841 passing, 184 errors). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent a5ecbe0 commit 8e98346

8 files changed

Lines changed: 140 additions & 14 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ using Mooncake:
2323
rdata,
2424
tangent_type,
2525
NoTangent,
26+
Tangent,
27+
MutableTangent,
2628
@is_primitive,
2729
zero_fcodual,
2830
MinimalCtx,

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ function DI.value_and_pushforward(
4141
map(first_unwrap, contexts, prep.context_tangents)...,
4242
)
4343
y = first(y_and_dy)
44-
dy = _copy_output(last(y_and_dy))
44+
dy_raw = last(y_and_dy)
45+
dy = _to_primal_alloc(y, dy_raw)
4546
return y, dy
4647
end
4748
y = _copy_output(first(ys_and_ty[1]))
@@ -72,7 +73,7 @@ function DI.value_and_pushforward!(
7273
) where {F, C}
7374
DI.check_prep(f, prep, backend, x, tx, contexts...)
7475
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
75-
foreach(copyto!, ty, new_ty)
76+
foreach(_to_primal!, ty, new_ty)
7677
return y, ty
7778
end
7879

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function DI.value_and_pushforward(
5555
(x, dx),
5656
map(first_unwrap, contexts, prep.context_tangents)...,
5757
)
58-
return _copy_output(new_dy)
58+
return _to_primal_alloc(y, new_dy)
5959
end
6060
return y, ty
6161
end
@@ -93,7 +93,7 @@ function DI.value_and_pushforward!(
9393
(x, dx),
9494
map(first_unwrap, contexts, prep.context_tangents)...,
9595
)
96-
copyto!(dy, new_dy)
96+
_to_primal!(dy, new_dy)
9797
end
9898
return y, ty
9999
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function DI.value_and_pullback(
3535
new_y, (_, new_dx) = value_and_pullback!!(
3636
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
3737
)
38-
return new_y, (_copy_output(new_dx),)
38+
return new_y, (_to_primal_alloc(x, new_dx),)
3939
end
4040

4141
function DI.value_and_pullback(
@@ -51,7 +51,7 @@ function DI.value_and_pullback(
5151
y, (_, new_dx) = value_and_pullback!!(
5252
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
5353
)
54-
y, _copy_output(new_dx)
54+
y, _to_primal_alloc(x, new_dx)
5555
end
5656
y = first(ys_and_tx[1])
5757
tx = map(last, ys_and_tx)
@@ -69,7 +69,7 @@ function DI.value_and_pullback!(
6969
) where {F, C}
7070
DI.check_prep(f, prep, backend, x, ty, contexts...)
7171
y, new_tx = DI.value_and_pullback(f, prep, backend, x, ty, contexts...)
72-
foreach(copyto!, tx, new_tx)
72+
foreach(_to_primal!, tx, new_tx)
7373
return y, tx
7474
end
7575

@@ -134,7 +134,7 @@ function DI.value_and_gradient(
134134
prep.cache, f, x, map(DI.unwrap, contexts)...;
135135
prep.args_to_zero
136136
)
137-
return y, _copy_output(new_grad)
137+
return y, _to_primal_alloc(x, new_grad)
138138
end
139139

140140
function DI.value_and_gradient!(
@@ -150,7 +150,7 @@ function DI.value_and_gradient!(
150150
prep.cache, f, x, map(DI.unwrap, contexts)...;
151151
prep.args_to_zero
152152
)
153-
copyto!(grad, new_grad)
153+
grad = _to_primal_into!(grad, x, new_grad)
154154
return y, grad
155155
end
156156

@@ -175,6 +175,10 @@ function DI.gradient!(
175175
contexts::Vararg{DI.Context, C},
176176
) where {F, C}
177177
DI.check_prep(f, prep, backend, x, contexts...)
178-
DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
179-
return grad
178+
# Note: when `grad` is immutable (e.g. an `SVector`), `value_and_gradient!`
179+
# returns a freshly built primal-shaped value rather than the original
180+
# buffer (no in-place update is possible). Forward that value to the
181+
# caller instead of returning the unchanged `grad`.
182+
_, new_grad = DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
183+
return new_grad
180184
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function DI.value_and_pullback(
6464
prep.args_to_zero
6565
)
6666
copyto!(y, y_after)
67-
return y, (_copy_output(dx),)
67+
return y, (_to_primal_alloc(x, dx),)
6868
end
6969

7070
function DI.value_and_pullback(
@@ -90,7 +90,7 @@ function DI.value_and_pullback(
9090
prep.args_to_zero
9191
)
9292
copyto!(y, y_after)
93-
_copy_output(dx)
93+
_to_primal_alloc(x, dx)
9494
end
9595
return y, tx
9696
end
@@ -107,7 +107,7 @@ function DI.value_and_pullback!(
107107
) where {F, C}
108108
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
109109
_, new_tx = DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...)
110-
foreach(copyto!, tx, new_tx)
110+
foreach(_to_primal!, tx, new_tx)
111111
return y, tx
112112
end
113113

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,71 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
1717
return zero_tangent(x)
1818
end
1919
end
20+
21+
# When the primal is a struct-backed array (e.g. `ComponentArray`, `MVector`)
22+
# or a struct whose `tangent_type` is `Tangent` / `MutableTangent`,
23+
# `value_and_gradient!!` and friends return the differential as the tangent
24+
# wrapper rather than something whose layout matches the primal. Downstream
25+
# code (`copyto!`, iteration, OptimizationBase, `≈` against the expected
26+
# primal-shaped result) expects a value with the same shape as the primal,
27+
# so we unwrap here.
28+
#
29+
# `tangent_to_primal!!` is a deprecated Mooncake API but is the only stable
30+
# entry point that converts a `Tangent` / `MutableTangent` back to its primal
31+
# type. `tangent_to_friendly!!` is the future replacement, but it does not
32+
# yet perform the conversion for `ComponentArray` (it falls through to
33+
# `AsRaw` and returns the raw `Tangent`). Once `friendly_tangent_cache` is
34+
# defined for the relevant types upstream and Mooncake removes
35+
# `tangent_to_primal!!`, this helper should switch over.
36+
const _MooncakeStructTangent = Union{Tangent, MutableTangent}
37+
38+
@inline _to_primal_alloc(primal, dx) = _copy_output(dx)
39+
@inline function _to_primal_alloc(primal::P, dx::_MooncakeStructTangent) where {P}
40+
return tangent_to_primal!!(_copy_output(primal), dx)::P
41+
end
42+
43+
@inline function _to_primal_into!(grad, primal, new_grad)
44+
copyto!(grad, new_grad)
45+
return grad
46+
end
47+
@inline function _to_primal_into!(
48+
grad, primal::P, new_grad::_MooncakeStructTangent
49+
) where {P}
50+
# Build the unwrapped gradient at the *primal* type — DI allows the caller
51+
# to pass a `grad` buffer whose type differs from the primal (e.g. a
52+
# mutable `MVector` buffer for an immutable `SVector` primal), and
53+
# `tangent_to_primal!!` requires the destination type to match the
54+
# tangent's primal type. We allocate a fresh primal-shaped buffer with
55+
# `_copy_output(primal)`, fill it via `tangent_to_primal!!`, then copy
56+
# the result into `grad`. When `grad` itself is immutable (e.g. an
57+
# `SVector` buffer), no in-place update is possible — DI's `gradient!`
58+
# API contract cannot be honored for an immutable buffer anyway, so we
59+
# return the freshly built primal-shaped value, which higher-level
60+
# callers compare by value rather than identity.
61+
result = tangent_to_primal!!(_copy_output(primal), new_grad)::P
62+
if _can_setindex(grad)
63+
copyto!(grad, result)
64+
return grad
65+
else
66+
return result
67+
end
68+
end
69+
70+
# Convenience used in the pullback / pushforward `foreach(_to_primal!, …)`
71+
# call sites where there is no separate primal buffer to pass through — the
72+
# buffer `grad` *is* the primal-shaped destination.
73+
@inline function _to_primal!(grad, new_grad)
74+
copyto!(grad, new_grad)
75+
return grad
76+
end
77+
@inline function _to_primal!(grad::P, new_grad::_MooncakeStructTangent) where {P}
78+
return _to_primal_into!(grad, grad, new_grad)
79+
end
80+
81+
# Whether `copyto!(grad, ...)` can update `grad`'s elements in place.
82+
# `ComponentVector` is itself an immutable struct (`ismutabletype` returns
83+
# false) but wraps a mutable `Vector`, so `copyto!` works on it; conversely,
84+
# `SVector` wraps a `Tuple` and `copyto!` errors. Walking down to the array
85+
# parent and checking *its* type captures both cases correctly.
86+
@inline _can_setindex(grad::AbstractArray) = ismutabletype(typeof(parent(grad)))
87+
@inline _can_setindex(grad) = ismutabletype(typeof(grad))

DifferentiationInterface/test/Back/Mooncake/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
34
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
45
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
56
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,56 @@ test_differentiation(
7474
@test grad.B == ps.A
7575
end
7676

77+
# Regression test for AutoMooncake gradient/gradient!/pullback/pushforward on
78+
# a struct-backed AbstractArray (ComponentArray). Before, the Mooncake
79+
# extension returned the differential as a `Mooncake.Tangent` and DI tried to
80+
# `copyto!` it into the preallocated `ComponentVector` buffer downstream
81+
# callers (e.g. OptimizationBase) pass in, raising a `MethodError` on
82+
# `iterate(::Mooncake.Tangent)`. This blocked any Optimization.jl loop that
83+
# used ComponentArrays parameters with `AutoMooncake`.
84+
#
85+
# The high-level scenario suite from DifferentiationInterfaceTest exercises
86+
# the out-of-place and in-place versions of `gradient`, `pullback`, and
87+
# `pushforward` for both `f(x)` and the `dy * f(x)` accumulation pattern,
88+
# which together cover every code path the fix touches.
89+
#
90+
# `AutoMooncakeForward()` (without `friendly_tangents`) is excluded from this
91+
# scenario because its forward-mode pushforward path has a separate,
92+
# pre-existing bug at the *input* (Dual construction) side: it raises
93+
# `ArgumentError: Tangent types do not match primal types` when given a
94+
# `ComponentVector` `dx`, because Mooncake forward mode expects the tangent
95+
# to already be a `Mooncake.Tangent` rather than a primal-shaped value.
96+
# That input-side conversion is independent of the output-side fix in this
97+
# PR; the friendly-tangents forward backend below covers the fixed code paths.
98+
using ComponentArrays: ComponentArrays, ComponentVector
99+
component_backends = [
100+
backends[1], # AutoMooncake() — reverse, the path OptimizationBase uses
101+
backends[3], # AutoMooncake(friendly_tangents=true) — reverse + friendly
102+
backends[4], # AutoMooncakeForward(friendly_tangents=true) — forward + friendly
103+
]
104+
test_differentiation(
105+
component_backends,
106+
component_scenarios();
107+
excluded = SECOND_ORDER,
108+
logging = LOGGING,
109+
)
110+
111+
# Direct gradient! sanity check on a small ComponentVector — this is the
112+
# specific call shape OptimizationBase uses, kept as an explicit assertion in
113+
# case `component_scenarios()` is ever pared down.
114+
@testset "ComponentArrays gradient! into preallocated buffer" begin
115+
ps = ComponentVector(a = 1.0, b = [2.0, 3.0])
116+
myfun(p) = p.a^2 + sum(p.b .^ 2)
117+
for backend in component_backends
118+
gbuf = similar(ps)
119+
fill!(gbuf, 0)
120+
gradient!(myfun, gbuf, backend, ps)
121+
@test gbuf isa ComponentVector
122+
@test gbuf.a 2 * ps.a
123+
@test gbuf.b 2 .* ps.b
124+
end
125+
end
126+
77127
test_differentiation(
78128
backends[3:4],
79129
nomatrix(static_scenarios());

0 commit comments

Comments
 (0)