diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index c470b6473..cab7f84d6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -41,7 +41,7 @@ function DI.value_and_pushforward( map(first_unwrap, contexts, prep.context_tangents)..., ) y = first(y_and_dy) - dy = _copy_output(last(y_and_dy)) + dy = _maybe_to_primal(last(y_and_dy), y) return y, dy end y = _copy_output(first(ys_and_ty[1])) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 3c75f530b..8ebb6ef99 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -55,7 +55,7 @@ function DI.value_and_pushforward( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - return _copy_output(new_dy) + return _maybe_to_primal(new_dy, y) end return y, ty end @@ -93,7 +93,7 @@ function DI.value_and_pushforward!( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - copyto!(dy, new_dy) + copyto!(dy, _maybe_to_primal(new_dy, y)) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 2514cdc40..2bbf49f1c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -35,7 +35,7 @@ function DI.value_and_pullback( new_y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return new_y, (_copy_output(new_dx),) + return new_y, (_maybe_to_primal(new_dx, x),) end function DI.value_and_pullback( @@ -51,7 +51,7 @@ function DI.value_and_pullback( y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - y, _copy_output(new_dx) + y, _maybe_to_primal(new_dx, x) end y = first(ys_and_tx[1]) tx = map(last, ys_and_tx) @@ -134,7 +134,7 @@ function DI.value_and_gradient( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return y, _copy_output(new_grad) + return y, _maybe_to_primal(new_grad, x) end function DI.value_and_gradient!( @@ -150,7 +150,7 @@ function DI.value_and_gradient!( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - copyto!(grad, new_grad) + copyto!(grad, _maybe_to_primal(new_grad, x)) return y, grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 2b55131b9..ed7f4ca9c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -64,7 +64,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - return y, (_copy_output(dx),) + return y, (_maybe_to_primal(dx, x),) end function DI.value_and_pullback( @@ -90,7 +90,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - _copy_output(dx) + _maybe_to_primal(dx, x) end return y, tx end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index b22d8d49b..beeb6f611 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -11,9 +11,41 @@ end function zero_tangent_or_primal(x, backend::AnyAutoMooncake) if get_config(backend).friendly_tangents - # zero(x) but safer - return tangent_to_primal!!(_copy_output(x), zero_tangent(x)) + # Mooncake 0.5.25+ replaced `tangent_to_primal!!` with the + # `tangent_to_friendly!!` framework. For this internal backup we still + # need a primal-shaped value, so use the `AsPrimal` path when + # available and fall back for older Mooncake releases. + return tangent_to_user_primal(zero_tangent(x), x) else return zero_tangent(x) end end + +# Safety net: if Mooncake returns a raw Tangent (e.g. Julia 1.11 + StaticArrays), +# convert it to a primal-shaped value. No-op for already-converted results. +_maybe_to_primal(tx, x) = _copy_output(tx) +_maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x) +_maybe_to_primal(tx::Mooncake.MutableTangent, x) = tangent_to_user_primal(tx, x) + +@inline maybe_getfield(mod, name::Symbol) = + isdefined(mod, name) ? getfield(mod, name) : nothing + +const mooncake_tangent_to_friendly = maybe_getfield( + Mooncake, Symbol("tangent_to_friendly!!") +) +const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache) +const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal) +const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) + +function tangent_to_user_primal(tx, x) + if !isnothing(mooncake_tangent_to_friendly) && + !isnothing(mooncake_friendly_tangent_cache) && + !isnothing(mooncake_as_primal) && + !isnothing(mooncake_no_cache) + dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) + cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}() + return mooncake_tangent_to_friendly(dest, x, tx, cache) + else + return tangent_to_primal!!(_copy_output(x), tx) + end +end diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index d531e542a..313c0d2bf 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -1,6 +1,7 @@ include("../../testutils.jl") using DifferentiationInterface, DifferentiationInterfaceTest +using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric using Mooncake: Mooncake using Test @@ -78,5 +79,36 @@ test_differentiation( backends[3:4], nomatrix(static_scenarios()); logging = LOGGING, - excluded = SECOND_ORDER + excluded = SECOND_ORDER, ) + +@testset "Friendly tangents structured matrices" begin + backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) + inputs = ( + Symmetric([2.0 1.0; 1.0 3.0]), + Hermitian(ComplexF64[2 1 + im; 1 - im 3]), + SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]), + ) + f(x) = real(sum(abs2, x)) + + @testset "$(typeof(x))" for x in inputs + grad = gradient(f, backend, x) + y, grad2 = value_and_gradient(f, backend, x) + pb = only(pullback(identity, backend, x, (x,))) + + @test grad isa Matrix + @test grad2 isa Matrix + @test pb isa Matrix + @test grad == grad2 + @test y == f(x) + @test pb == Matrix(x) + + grad_dense = zero(Matrix(x)) + @test gradient!(f, grad_dense, backend, x) === grad_dense + @test grad_dense == grad + + tx_dense = (zero(Matrix(x)),) + @test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1] + @test tx_dense[1] == pb + end +end