diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 3513d548c..f5ce24b3a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -23,6 +23,8 @@ using Mooncake: rdata, tangent_type, NoTangent, + Tangent, + MutableTangent, @is_primitive, zero_fcodual, MinimalCtx, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index c470b6473..a66195395 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -41,7 +41,8 @@ function DI.value_and_pushforward( map(first_unwrap, contexts, prep.context_tangents)..., ) y = first(y_and_dy) - dy = _copy_output(last(y_and_dy)) + dy_raw = last(y_and_dy) + dy = _to_primal_alloc(y, dy_raw) return y, dy end y = _copy_output(first(ys_and_ty[1])) @@ -72,7 +73,7 @@ function DI.value_and_pushforward!( ) where {F, C} DI.check_prep(f, prep, backend, x, tx, contexts...) y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...) - foreach(copyto!, ty, new_ty) + foreach(_to_primal!, ty, new_ty) return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 3c75f530b..b39d20a7f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -55,7 +55,7 @@ function DI.value_and_pushforward( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - return _copy_output(new_dy) + return _to_primal_alloc(y, new_dy) end return y, ty end @@ -93,7 +93,7 @@ function DI.value_and_pushforward!( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - copyto!(dy, new_dy) + _to_primal!(dy, new_dy) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 2514cdc40..5f57c4251 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -35,7 +35,7 @@ function DI.value_and_pullback( new_y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return new_y, (_copy_output(new_dx),) + return new_y, (_to_primal_alloc(x, new_dx),) end function DI.value_and_pullback( @@ -51,7 +51,7 @@ function DI.value_and_pullback( y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - y, _copy_output(new_dx) + y, _to_primal_alloc(x, new_dx) end y = first(ys_and_tx[1]) tx = map(last, ys_and_tx) @@ -69,7 +69,7 @@ function DI.value_and_pullback!( ) where {F, C} DI.check_prep(f, prep, backend, x, ty, contexts...) y, new_tx = DI.value_and_pullback(f, prep, backend, x, ty, contexts...) - foreach(copyto!, tx, new_tx) + foreach(_to_primal!, tx, new_tx) return y, tx end @@ -134,7 +134,7 @@ function DI.value_and_gradient( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return y, _copy_output(new_grad) + return y, _to_primal_alloc(x, new_grad) end function DI.value_and_gradient!( @@ -150,7 +150,7 @@ function DI.value_and_gradient!( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - copyto!(grad, new_grad) + grad = _to_primal_into!(grad, x, new_grad) return y, grad end @@ -175,6 +175,10 @@ function DI.gradient!( contexts::Vararg{DI.Context, C}, ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) - DI.value_and_gradient!(f, grad, prep, backend, x, contexts...) - return grad + # Note: when `grad` is immutable (e.g. an `SVector`), `value_and_gradient!` + # returns a freshly built primal-shaped value rather than the original + # buffer (no in-place update is possible). Forward that value to the + # caller instead of returning the unchanged `grad`. + _, new_grad = DI.value_and_gradient!(f, grad, prep, backend, x, contexts...) + return new_grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 2b55131b9..923831c06 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -64,7 +64,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - return y, (_copy_output(dx),) + return y, (_to_primal_alloc(x, dx),) end function DI.value_and_pullback( @@ -90,7 +90,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - _copy_output(dx) + _to_primal_alloc(x, dx) end return y, tx end @@ -107,7 +107,7 @@ function DI.value_and_pullback!( ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) _, new_tx = DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) - foreach(copyto!, tx, new_tx) + foreach(_to_primal!, tx, new_tx) return y, tx end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index b22d8d49b..c754c1536 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -17,3 +17,71 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake) return zero_tangent(x) end end + +# When the primal is a struct-backed array (e.g. `ComponentArray`, `MVector`) +# or a struct whose `tangent_type` is `Tangent` / `MutableTangent`, +# `value_and_gradient!!` and friends return the differential as the tangent +# wrapper rather than something whose layout matches the primal. Downstream +# code (`copyto!`, iteration, OptimizationBase, `≈` against the expected +# primal-shaped result) expects a value with the same shape as the primal, +# so we unwrap here. +# +# `tangent_to_primal!!` is a deprecated Mooncake API but is the only stable +# entry point that converts a `Tangent` / `MutableTangent` back to its primal +# type. `tangent_to_friendly!!` is the future replacement, but it does not +# yet perform the conversion for `ComponentArray` (it falls through to +# `AsRaw` and returns the raw `Tangent`). Once `friendly_tangent_cache` is +# defined for the relevant types upstream and Mooncake removes +# `tangent_to_primal!!`, this helper should switch over. +const _MooncakeStructTangent = Union{Tangent, MutableTangent} + +@inline _to_primal_alloc(primal, dx) = _copy_output(dx) +@inline function _to_primal_alloc(primal::P, dx::_MooncakeStructTangent) where {P} + return tangent_to_primal!!(_copy_output(primal), dx)::P +end + +@inline function _to_primal_into!(grad, primal, new_grad) + copyto!(grad, new_grad) + return grad +end +@inline function _to_primal_into!( + grad, primal::P, new_grad::_MooncakeStructTangent + ) where {P} + # Build the unwrapped gradient at the *primal* type — DI allows the caller + # to pass a `grad` buffer whose type differs from the primal (e.g. a + # mutable `MVector` buffer for an immutable `SVector` primal), and + # `tangent_to_primal!!` requires the destination type to match the + # tangent's primal type. We allocate a fresh primal-shaped buffer with + # `_copy_output(primal)`, fill it via `tangent_to_primal!!`, then copy + # the result into `grad`. When `grad` itself is immutable (e.g. an + # `SVector` buffer), no in-place update is possible — DI's `gradient!` + # API contract cannot be honored for an immutable buffer anyway, so we + # return the freshly built primal-shaped value, which higher-level + # callers compare by value rather than identity. + result = tangent_to_primal!!(_copy_output(primal), new_grad)::P + if _can_setindex(grad) + copyto!(grad, result) + return grad + else + return result + end +end + +# Convenience used in the pullback / pushforward `foreach(_to_primal!, …)` +# call sites where there is no separate primal buffer to pass through — the +# buffer `grad` *is* the primal-shaped destination. +@inline function _to_primal!(grad, new_grad) + copyto!(grad, new_grad) + return grad +end +@inline function _to_primal!(grad::P, new_grad::_MooncakeStructTangent) where {P} + return _to_primal_into!(grad, grad, new_grad) +end + +# Whether `copyto!(grad, ...)` can update `grad`'s elements in place. +# `ComponentVector` is itself an immutable struct (`ismutabletype` returns +# false) but wraps a mutable `Vector`, so `copyto!` works on it; conversely, +# `SVector` wraps a `Tuple` and `copyto!` errors. Walking down to the array +# parent and checking *its* type captures both cases correctly. +@inline _can_setindex(grad::AbstractArray) = ismutabletype(typeof(parent(grad))) +@inline _can_setindex(grad) = ismutabletype(typeof(grad)) diff --git a/DifferentiationInterface/test/Back/Mooncake/Project.toml b/DifferentiationInterface/test/Back/Mooncake/Project.toml index f13f37c5b..215d4b38d 100644 --- a/DifferentiationInterface/test/Back/Mooncake/Project.toml +++ b/DifferentiationInterface/test/Back/Mooncake/Project.toml @@ -1,5 +1,6 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index d531e542a..7d94e77da 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -74,6 +74,56 @@ test_differentiation( @test grad.B == ps.A end +# Regression test for AutoMooncake gradient/gradient!/pullback/pushforward on +# a struct-backed AbstractArray (ComponentArray). Before, the Mooncake +# extension returned the differential as a `Mooncake.Tangent` and DI tried to +# `copyto!` it into the preallocated `ComponentVector` buffer downstream +# callers (e.g. OptimizationBase) pass in, raising a `MethodError` on +# `iterate(::Mooncake.Tangent)`. This blocked any Optimization.jl loop that +# used ComponentArrays parameters with `AutoMooncake`. +# +# The high-level scenario suite from DifferentiationInterfaceTest exercises +# the out-of-place and in-place versions of `gradient`, `pullback`, and +# `pushforward` for both `f(x)` and the `dy * f(x)` accumulation pattern, +# which together cover every code path the fix touches. +# +# `AutoMooncakeForward()` (without `friendly_tangents`) is excluded from this +# scenario because its forward-mode pushforward path has a separate, +# pre-existing bug at the *input* (Dual construction) side: it raises +# `ArgumentError: Tangent types do not match primal types` when given a +# `ComponentVector` `dx`, because Mooncake forward mode expects the tangent +# to already be a `Mooncake.Tangent` rather than a primal-shaped value. +# That input-side conversion is independent of the output-side fix in this +# PR; the friendly-tangents forward backend below covers the fixed code paths. +using ComponentArrays: ComponentArrays, ComponentVector +component_backends = [ + backends[1], # AutoMooncake() — reverse, the path OptimizationBase uses + backends[3], # AutoMooncake(friendly_tangents=true) — reverse + friendly + backends[4], # AutoMooncakeForward(friendly_tangents=true) — forward + friendly +] +test_differentiation( + component_backends, + component_scenarios(); + excluded = SECOND_ORDER, + logging = LOGGING, +) + +# Direct gradient! sanity check on a small ComponentVector — this is the +# specific call shape OptimizationBase uses, kept as an explicit assertion in +# case `component_scenarios()` is ever pared down. +@testset "ComponentArrays gradient! into preallocated buffer" begin + ps = ComponentVector(a = 1.0, b = [2.0, 3.0]) + myfun(p) = p.a^2 + sum(p.b .^ 2) + for backend in component_backends + gbuf = similar(ps) + fill!(gbuf, 0) + gradient!(myfun, gbuf, backend, ps) + @test gbuf isa ComponentVector + @test gbuf.a ≈ 2 * ps.a + @test gbuf.b ≈ 2 .* ps.b + end +end + test_differentiation( backends[3:4], nomatrix(static_scenarios());