Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ using Mooncake:
rdata,
tangent_type,
NoTangent,
Tangent,
MutableTangent,
@is_primitive,
zero_fcodual,
MinimalCtx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ function DI.value_and_pushforward(
map(first_unwrap, contexts, prep.context_tangents)...,
)
y = first(y_and_dy)
dy = _copy_output(last(y_and_dy))
dy_raw = last(y_and_dy)
dy = _to_primal_alloc(y, dy_raw)
return y, dy
end
y = _copy_output(first(ys_and_ty[1]))
Expand Down Expand Up @@ -72,7 +73,7 @@ function DI.value_and_pushforward!(
) where {F, C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
foreach(_to_primal!, ty, new_ty)
return y, ty
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function DI.value_and_pushforward(
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
return _copy_output(new_dy)
return _to_primal_alloc(y, new_dy)
end
return y, ty
end
Expand Down Expand Up @@ -93,7 +93,7 @@ function DI.value_and_pushforward!(
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
copyto!(dy, new_dy)
_to_primal!(dy, new_dy)
end
return y, ty
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function DI.value_and_pullback(
new_y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
return new_y, (_copy_output(new_dx),)
return new_y, (_to_primal_alloc(x, new_dx),)
end

function DI.value_and_pullback(
Expand All @@ -51,7 +51,7 @@ function DI.value_and_pullback(
y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
y, _copy_output(new_dx)
y, _to_primal_alloc(x, new_dx)
end
y = first(ys_and_tx[1])
tx = map(last, ys_and_tx)
Expand All @@ -69,7 +69,7 @@ function DI.value_and_pullback!(
) where {F, C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
y, new_tx = DI.value_and_pullback(f, prep, backend, x, ty, contexts...)
foreach(copyto!, tx, new_tx)
foreach(_to_primal!, tx, new_tx)
return y, tx
end

Expand Down Expand Up @@ -134,7 +134,7 @@ function DI.value_and_gradient(
prep.cache, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
return y, _copy_output(new_grad)
return y, _to_primal_alloc(x, new_grad)
end

function DI.value_and_gradient!(
Expand All @@ -150,7 +150,7 @@ function DI.value_and_gradient!(
prep.cache, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
copyto!(grad, new_grad)
grad = _to_primal_into!(grad, x, new_grad)
return y, grad
end

Expand All @@ -175,6 +175,10 @@ function DI.gradient!(
contexts::Vararg{DI.Context, C},
) where {F, C}
DI.check_prep(f, prep, backend, x, contexts...)
DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
return grad
# Note: when `grad` is immutable (e.g. an `SVector`), `value_and_gradient!`
# returns a freshly built primal-shaped value rather than the original
# buffer (no in-place update is possible). Forward that value to the
# caller instead of returning the unchanged `grad`.
_, new_grad = DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
return new_grad
end
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function DI.value_and_pullback(
prep.args_to_zero
)
copyto!(y, y_after)
return y, (_copy_output(dx),)
return y, (_to_primal_alloc(x, dx),)
end

function DI.value_and_pullback(
Expand All @@ -90,7 +90,7 @@ function DI.value_and_pullback(
prep.args_to_zero
)
copyto!(y, y_after)
_copy_output(dx)
_to_primal_alloc(x, dx)
end
return y, tx
end
Expand All @@ -107,7 +107,7 @@ function DI.value_and_pullback!(
) where {F, C}
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
_, new_tx = DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...)
foreach(copyto!, tx, new_tx)
foreach(_to_primal!, tx, new_tx)
return y, tx
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,71 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
return zero_tangent(x)
end
end

# When the primal is a struct-backed array (e.g. `ComponentArray`, `MVector`)
# or a struct whose `tangent_type` is `Tangent` / `MutableTangent`,
# `value_and_gradient!!` and friends return the differential as the tangent
# wrapper rather than something whose layout matches the primal. Downstream
# code (`copyto!`, iteration, OptimizationBase, `≈` against the expected
# primal-shaped result) expects a value with the same shape as the primal,
# so we unwrap here.
#
# `tangent_to_primal!!` is a deprecated Mooncake API but is the only stable
# entry point that converts a `Tangent` / `MutableTangent` back to its primal
# type. `tangent_to_friendly!!` is the future replacement, but it does not
# yet perform the conversion for `ComponentArray` (it falls through to
# `AsRaw` and returns the raw `Tangent`). Once `friendly_tangent_cache` is
# defined for the relevant types upstream and Mooncake removes
# `tangent_to_primal!!`, this helper should switch over.
const _MooncakeStructTangent = Union{Tangent, MutableTangent}

@inline _to_primal_alloc(primal, dx) = _copy_output(dx)
@inline function _to_primal_alloc(primal::P, dx::_MooncakeStructTangent) where {P}
return tangent_to_primal!!(_copy_output(primal), dx)::P
end

@inline function _to_primal_into!(grad, primal, new_grad)
copyto!(grad, new_grad)
return grad
end
@inline function _to_primal_into!(
grad, primal::P, new_grad::_MooncakeStructTangent
) where {P}
# Build the unwrapped gradient at the *primal* type — DI allows the caller
# to pass a `grad` buffer whose type differs from the primal (e.g. a
# mutable `MVector` buffer for an immutable `SVector` primal), and
# `tangent_to_primal!!` requires the destination type to match the
# tangent's primal type. We allocate a fresh primal-shaped buffer with
# `_copy_output(primal)`, fill it via `tangent_to_primal!!`, then copy
# the result into `grad`. When `grad` itself is immutable (e.g. an
# `SVector` buffer), no in-place update is possible — DI's `gradient!`
# API contract cannot be honored for an immutable buffer anyway, so we
# return the freshly built primal-shaped value, which higher-level
# callers compare by value rather than identity.
result = tangent_to_primal!!(_copy_output(primal), new_grad)::P
if _can_setindex(grad)
copyto!(grad, result)
return grad
else
return result
end
end

# Convenience used in the pullback / pushforward `foreach(_to_primal!, …)`
# call sites where there is no separate primal buffer to pass through — the
# buffer `grad` *is* the primal-shaped destination.
@inline function _to_primal!(grad, new_grad)
copyto!(grad, new_grad)
return grad
end
@inline function _to_primal!(grad::P, new_grad::_MooncakeStructTangent) where {P}
return _to_primal_into!(grad, grad, new_grad)
end

# Whether `copyto!(grad, ...)` can update `grad`'s elements in place.
# `ComponentVector` is itself an immutable struct (`ismutabletype` returns
# false) but wraps a mutable `Vector`, so `copyto!` works on it; conversely,
# `SVector` wraps a `Tuple` and `copyto!` errors. Walking down to the array
# parent and checking *its* type captures both cases correctly.
@inline _can_setindex(grad::AbstractArray) = ismutabletype(typeof(parent(grad)))
@inline _can_setindex(grad) = ismutabletype(typeof(grad))
1 change: 1 addition & 0 deletions DifferentiationInterface/test/Back/Mooncake/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Expand Down
50 changes: 50 additions & 0 deletions DifferentiationInterface/test/Back/Mooncake/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,56 @@ test_differentiation(
@test grad.B == ps.A
end

# Regression test for AutoMooncake gradient/gradient!/pullback/pushforward on
# a struct-backed AbstractArray (ComponentArray). Before, the Mooncake
# extension returned the differential as a `Mooncake.Tangent` and DI tried to
# `copyto!` it into the preallocated `ComponentVector` buffer downstream
# callers (e.g. OptimizationBase) pass in, raising a `MethodError` on
# `iterate(::Mooncake.Tangent)`. This blocked any Optimization.jl loop that
# used ComponentArrays parameters with `AutoMooncake`.
#
# The high-level scenario suite from DifferentiationInterfaceTest exercises
# the out-of-place and in-place versions of `gradient`, `pullback`, and
# `pushforward` for both `f(x)` and the `dy * f(x)` accumulation pattern,
# which together cover every code path the fix touches.
#
# `AutoMooncakeForward()` (without `friendly_tangents`) is excluded from this
# scenario because its forward-mode pushforward path has a separate,
# pre-existing bug at the *input* (Dual construction) side: it raises
# `ArgumentError: Tangent types do not match primal types` when given a
# `ComponentVector` `dx`, because Mooncake forward mode expects the tangent
# to already be a `Mooncake.Tangent` rather than a primal-shaped value.
# That input-side conversion is independent of the output-side fix in this
# PR; the friendly-tangents forward backend below covers the fixed code paths.
using ComponentArrays: ComponentArrays, ComponentVector
component_backends = [
backends[1], # AutoMooncake() — reverse, the path OptimizationBase uses
backends[3], # AutoMooncake(friendly_tangents=true) — reverse + friendly
backends[4], # AutoMooncakeForward(friendly_tangents=true) — forward + friendly
]
test_differentiation(
component_backends,
component_scenarios();
excluded = SECOND_ORDER,
logging = LOGGING,
)

# Direct gradient! sanity check on a small ComponentVector — this is the
# specific call shape OptimizationBase uses, kept as an explicit assertion in
# case `component_scenarios()` is ever pared down.
@testset "ComponentArrays gradient! into preallocated buffer" begin
ps = ComponentVector(a = 1.0, b = [2.0, 3.0])
myfun(p) = p.a^2 + sum(p.b .^ 2)
for backend in component_backends
gbuf = similar(ps)
fill!(gbuf, 0)
gradient!(myfun, gbuf, backend, ps)
@test gbuf isa ComponentVector
@test gbuf.a ≈ 2 * ps.a
@test gbuf.b ≈ 2 .* ps.b
end
end

test_differentiation(
backends[3:4],
nomatrix(static_scenarios());
Expand Down
Loading