Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function DI.value_and_pushforward(
map(first_unwrap, contexts, prep.context_tangents)...,
)
y = first(y_and_dy)
dy = _copy_output(last(y_and_dy))
dy = _maybe_to_primal(last(y_and_dy), y)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to ensure that primal conversion happens here? If friendly_tangents is set to true, won't Mooncake's pushforward and pullback already return a primal-like object?

return y, dy
end
y = _copy_output(first(ys_and_ty[1]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function DI.value_and_pushforward(
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
return _copy_output(new_dy)
return _maybe_to_primal(new_dy, y)
end
return y, ty
end
Expand Down Expand Up @@ -93,7 +93,7 @@ function DI.value_and_pushforward!(
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
copyto!(dy, new_dy)
copyto!(dy, _maybe_to_primal(new_dy, y))
end
return y, ty
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function DI.value_and_pullback(
new_y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
return new_y, (_copy_output(new_dx),)
return new_y, (_maybe_to_primal(new_dx, x),)
end

function DI.value_and_pullback(
Expand All @@ -51,7 +51,7 @@ function DI.value_and_pullback(
y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
y, _copy_output(new_dx)
y, _maybe_to_primal(new_dx, x)
end
y = first(ys_and_tx[1])
tx = map(last, ys_and_tx)
Expand Down Expand Up @@ -134,7 +134,7 @@ function DI.value_and_gradient(
prep.cache, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
return y, _copy_output(new_grad)
return y, _maybe_to_primal(new_grad, x)
end

function DI.value_and_gradient!(
Expand All @@ -150,7 +150,7 @@ function DI.value_and_gradient!(
prep.cache, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
copyto!(grad, new_grad)
copyto!(grad, _maybe_to_primal(new_grad, x))
return y, grad
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function DI.value_and_pullback(
prep.args_to_zero
)
copyto!(y, y_after)
return y, (_copy_output(dx),)
return y, (_maybe_to_primal(dx, x),)
end

function DI.value_and_pullback(
Expand All @@ -90,7 +90,7 @@ function DI.value_and_pullback(
prep.args_to_zero
)
copyto!(y, y_after)
_copy_output(dx)
_maybe_to_primal(dx, x)
end
return y, tx
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,41 @@ end

function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
if get_config(backend).friendly_tangents
# zero(x) but safer
return tangent_to_primal!!(_copy_output(x), zero_tangent(x))
# Mooncake 0.5.25+ replaced `tangent_to_primal!!` with the
# `tangent_to_friendly!!` framework. For this internal backup we still
# need a primal-shaped value, so use the `AsPrimal` path when
# available and fall back for older Mooncake releases.
return tangent_to_user_primal(zero_tangent(x), x)
else
return zero_tangent(x)
end
end

# Safety net: if Mooncake returns a raw Tangent (e.g. Julia 1.11 + StaticArrays),
# convert it to a primal-shaped value. No-op for already-converted results.
_maybe_to_primal(tx, x) = _copy_output(tx)
_maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x)
_maybe_to_primal(tx::Mooncake.MutableTangent, x) = tangent_to_user_primal(tx, x)

@inline maybe_getfield(mod, name::Symbol) =
isdefined(mod, name) ? getfield(mod, name) : nothing

const mooncake_tangent_to_friendly = maybe_getfield(
Mooncake, Symbol("tangent_to_friendly!!")
)
const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache)
const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal)
const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache)
Comment on lines +30 to +38
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to be very robust? I'd rather impose a lower bound for Mooncake at v0.5.25 in Project.toml (that way we're sure we can use all of these symbols)


function tangent_to_user_primal(tx, x)
if !isnothing(mooncake_tangent_to_friendly) &&
!isnothing(mooncake_friendly_tangent_cache) &&
!isnothing(mooncake_as_primal) &&
!isnothing(mooncake_no_cache)
dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x))
cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a type-unstable dictionary needed here?
Does this make every tangent-to-primal conversion outside of bitstypes slow?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless of why, we may want to allocate this dictionary in the preparation phase

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup this will regress performance on hot paths. This should be allocated once during the prepare phase and stored in the extras cache.

return mooncake_tangent_to_friendly(dest, x, tx, cache)
else
return tangent_to_primal!!(_copy_output(x), tx)
end
end
34 changes: 33 additions & 1 deletion DifferentiationInterface/test/Back/Mooncake/test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
include("../../testutils.jl")

using DifferentiationInterface, DifferentiationInterfaceTest
using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric
using Mooncake: Mooncake
using Test

Expand Down Expand Up @@ -78,5 +79,36 @@ test_differentiation(
backends[3:4],
nomatrix(static_scenarios());
logging = LOGGING,
excluded = SECOND_ORDER
excluded = SECOND_ORDER,
)

@testset "Friendly tangents structured matrices" begin
backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))
inputs = (
Symmetric([2.0 1.0; 1.0 3.0]),
Hermitian(ComplexF64[2 1 + im; 1 - im 3]),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know which convention Mooncake uses for gradients of functions with complex inputs and real outputs? There are two possible choices, see e.g. https://arxiv.org/abs/2409.06752

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]),
)
f(x) = real(sum(abs2, x))

@testset "$(typeof(x))" for x in inputs
grad = gradient(f, backend, x)
y, grad2 = value_and_gradient(f, backend, x)
pb = only(pullback(identity, backend, x, (x,)))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a strong enough test, the function is too simple


@test grad isa Matrix
@test grad2 isa Matrix
@test pb isa Matrix
@test grad == grad2
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad and grad2 are never compared against the ground truth

@test y == f(x)
@test pb == Matrix(x)

grad_dense = zero(Matrix(x))
@test gradient!(f, grad_dense, backend, x) === grad_dense
@test grad_dense == grad

tx_dense = (zero(Matrix(x)),)
@test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1]
@test tx_dense[1] == pb
end
end
Loading