From 75d3c056ce09314359a4ab5eea8b10c3572559fa Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 4 Apr 2026 15:37:41 +0100 Subject: [PATCH 1/8] fix Mooncake friendly_tangents compatibility --- .../utils.jl | 27 ++++++++++++++-- .../test/Back/Mooncake/test.jl | 32 +++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index b22d8d49b..ddf1d281d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -11,9 +11,32 @@ end function zero_tangent_or_primal(x, backend::AnyAutoMooncake) if get_config(backend).friendly_tangents - # zero(x) but safer - return tangent_to_primal!!(_copy_output(x), zero_tangent(x)) + # Mooncake 0.5.25+ replaced `tangent_to_primal!!` with the + # `tangent_to_friendly!!` framework. For this internal backup we still + # need a primal-shaped value, so use the `AsPrimal` path when + # available and fall back for older Mooncake releases. + return tangent_to_user_primal(zero_tangent(x), x) else return zero_tangent(x) end end + +@inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing + +const mooncake_tangent_to_friendly = maybe_getfield(Mooncake, Symbol("tangent_to_friendly!!")) +const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache) +const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal) +const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) + +function tangent_to_user_primal(tx, x) + if !isnothing(mooncake_tangent_to_friendly) && + !isnothing(mooncake_friendly_tangent_cache) && + !isnothing(mooncake_as_primal) && + !isnothing(mooncake_no_cache) + dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) + cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any,Any}() + return mooncake_tangent_to_friendly(dest, x, tx, cache) + else + return tangent_to_primal!!(_copy_output(x), tx) + end +end diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index d531e542a..b481b759c 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -1,6 +1,7 @@ include("../../testutils.jl") using DifferentiationInterface, DifferentiationInterfaceTest +using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric using Mooncake: Mooncake using Test @@ -80,3 +81,34 @@ test_differentiation( logging = LOGGING, excluded = SECOND_ORDER ) + +@testset "Friendly tangents structured matrices" begin + backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) + inputs = ( + Symmetric([2.0 1.0; 1.0 3.0]), + Hermitian(ComplexF64[2 1 + im; 1 - im 3]), + SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]), + ) + f(x) = real(sum(abs2, x)) + + @testset "$(typeof(x))" for x in inputs + grad = gradient(f, backend, x) + y, grad2 = value_and_gradient(f, backend, x) + pb = only(pullback(identity, backend, x, (x,))) + + @test grad isa Matrix + @test grad2 isa Matrix + @test pb isa Matrix + @test grad == grad2 + @test y == f(x) + @test pb == Matrix(x) + + grad_dense = zero(Matrix(x)) + @test gradient!(f, grad_dense, backend, x) === grad_dense + @test grad_dense == grad + + tx_dense = (zero(Matrix(x)),) + @test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1] + @test tx_dense[1] == pb + end +end From f42023428e7e4a055c2b6f3a2d27d99fe63ed2be Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 4 Apr 2026 15:41:39 +0100 Subject: [PATCH 2/8] format Mooncake fix --- .../utils.jl | 13 +++--- .../test/Back/Mooncake/test.jl | 41 ++++++++----------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index ddf1d281d..927a4cfbe 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -21,18 +21,21 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake) end end -@inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing +@inline maybe_getfield(mod, name::Symbol) = + isdefined(mod, name) ? getfield(mod, name) : nothing -const mooncake_tangent_to_friendly = maybe_getfield(Mooncake, Symbol("tangent_to_friendly!!")) +const mooncake_tangent_to_friendly = maybe_getfield( + Mooncake, Symbol("tangent_to_friendly!!") +) const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache) const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal) const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) function tangent_to_user_primal(tx, x) if !isnothing(mooncake_tangent_to_friendly) && - !isnothing(mooncake_friendly_tangent_cache) && - !isnothing(mooncake_as_primal) && - !isnothing(mooncake_no_cache) + !isnothing(mooncake_friendly_tangent_cache) && + !isnothing(mooncake_as_primal) && + !isnothing(mooncake_no_cache) dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any,Any}() return mooncake_tangent_to_friendly(dest, x, tx, cache) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index b481b759c..eb296a567 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -13,8 +13,8 @@ nomatrix(scens) = filter(s -> !(s.x isa AbstractMatrix) && !(s.y isa AbstractMat backends = [ AutoMooncake(), AutoMooncakeForward(), - AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)), - AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)), + AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)), + AutoMooncakeForward(; config=Mooncake.Config(; friendly_tangents=true)), ] for backend in backends @@ -23,31 +23,25 @@ for backend in backends end test_differentiation( - backends[3:4], - default_scenarios(); - excluded = SECOND_ORDER, - logging = LOGGING, + backends[3:4], default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING ); test_differentiation( backends[3:4], nomatrix( default_scenarios(; - include_normal = false, - include_constantified = true, - include_cachified = true, - use_tuples = true - ) + include_normal=false, + include_constantified=true, + include_cachified=true, + use_tuples=true, + ), ); - excluded = SECOND_ORDER, - logging = LOGGING, + excluded=SECOND_ORDER, + logging=LOGGING, ); test_differentiation( - backends[1:2], - nomatrix(default_scenarios()); - excluded = SECOND_ORDER, - logging = LOGGING, + backends[1:2], nomatrix(default_scenarios()); excluded=SECOND_ORDER, logging=LOGGING ); EXCLUDED = @static if VERSION ≥ v"1.11-" && VERSION ≤ v"1.12-" @@ -63,12 +57,12 @@ end test_differentiation( [SecondOrder(AutoMooncakeForward(), AutoMooncake())], nomatrix(default_scenarios()); - excluded = EXCLUDED, - logging = LOGGING, + excluded=EXCLUDED, + logging=LOGGING, ) @testset "NamedTuples" begin - ps = (; A = rand(5), B = rand(5)) + ps = (; A=rand(5), B=rand(5)) myfun(ps) = sum(ps.A .* ps.B) grad = gradient(myfun, backends[1], ps) @test grad.A == ps.B @@ -76,14 +70,11 @@ test_differentiation( end test_differentiation( - backends[3:4], - nomatrix(static_scenarios()); - logging = LOGGING, - excluded = SECOND_ORDER + backends[3:4], nomatrix(static_scenarios()); logging=LOGGING, excluded=SECOND_ORDER ) @testset "Friendly tangents structured matrices" begin - backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) + backend = AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)) inputs = ( Symmetric([2.0 1.0; 1.0 3.0]), Hermitian(ComplexF64[2 1 + im; 1 - im 3]), From 8235e6c60398e7e9afe676ab5afb1792ce614c34 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sat, 4 Apr 2026 15:46:04 +0100 Subject: [PATCH 3/8] remove Mooncake test formatting churn --- .../test/Back/Mooncake/test.jl | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index eb296a567..b481b759c 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -13,8 +13,8 @@ nomatrix(scens) = filter(s -> !(s.x isa AbstractMatrix) && !(s.y isa AbstractMat backends = [ AutoMooncake(), AutoMooncakeForward(), - AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)), - AutoMooncakeForward(; config=Mooncake.Config(; friendly_tangents=true)), + AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)), + AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)), ] for backend in backends @@ -23,25 +23,31 @@ for backend in backends end test_differentiation( - backends[3:4], default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING + backends[3:4], + default_scenarios(); + excluded = SECOND_ORDER, + logging = LOGGING, ); test_differentiation( backends[3:4], nomatrix( default_scenarios(; - include_normal=false, - include_constantified=true, - include_cachified=true, - use_tuples=true, - ), + include_normal = false, + include_constantified = true, + include_cachified = true, + use_tuples = true + ) ); - excluded=SECOND_ORDER, - logging=LOGGING, + excluded = SECOND_ORDER, + logging = LOGGING, ); test_differentiation( - backends[1:2], nomatrix(default_scenarios()); excluded=SECOND_ORDER, logging=LOGGING + backends[1:2], + nomatrix(default_scenarios()); + excluded = SECOND_ORDER, + logging = LOGGING, ); EXCLUDED = @static if VERSION ≥ v"1.11-" && VERSION ≤ v"1.12-" @@ -57,12 +63,12 @@ end test_differentiation( [SecondOrder(AutoMooncakeForward(), AutoMooncake())], nomatrix(default_scenarios()); - excluded=EXCLUDED, - logging=LOGGING, + excluded = EXCLUDED, + logging = LOGGING, ) @testset "NamedTuples" begin - ps = (; A=rand(5), B=rand(5)) + ps = (; A = rand(5), B = rand(5)) myfun(ps) = sum(ps.A .* ps.B) grad = gradient(myfun, backends[1], ps) @test grad.A == ps.B @@ -70,11 +76,14 @@ test_differentiation( end test_differentiation( - backends[3:4], nomatrix(static_scenarios()); logging=LOGGING, excluded=SECOND_ORDER + backends[3:4], + nomatrix(static_scenarios()); + logging = LOGGING, + excluded = SECOND_ORDER ) @testset "Friendly tangents structured matrices" begin - backend = AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)) + backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) inputs = ( Symmetric([2.0 1.0; 1.0 3.0]), Hermitian(ComplexF64[2 1 + im; 1 - im 3]), From a478578e57a8a0c958c103376b4dea0156cd837f Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 08:12:26 +0100 Subject: [PATCH 4/8] ci: retrigger after Mooncake v0.5.26 release From 65997c462b8b493756bb37c1d7019fa7f6f92c81 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 08:25:15 +0100 Subject: [PATCH 5/8] style: apply Runic formatting to utils.jl --- .../ext/DifferentiationInterfaceMooncakeExt/utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 927a4cfbe..2184e9ecb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -33,11 +33,11 @@ const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) function tangent_to_user_primal(tx, x) if !isnothing(mooncake_tangent_to_friendly) && - !isnothing(mooncake_friendly_tangent_cache) && - !isnothing(mooncake_as_primal) && - !isnothing(mooncake_no_cache) + !isnothing(mooncake_friendly_tangent_cache) && + !isnothing(mooncake_as_primal) && + !isnothing(mooncake_no_cache) dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) - cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any,Any}() + cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}() return mooncake_tangent_to_friendly(dest, x, tx, cache) else return tangent_to_primal!!(_copy_output(x), tx) From 106f50fc7b1b77e9f350a45ad4fc2b867864dc8c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 10:41:25 +0100 Subject: [PATCH 6/8] test: skip friendly_tangents static_scenarios on Julia 1.11 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mooncake returns raw Tangent objects instead of friendly arrays for StaticArrays on Julia 1.11. This is an upstream bug — skip the test until it is fixed. --- .../test/Back/Mooncake/test.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index b481b759c..7e5a291a7 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -75,12 +75,15 @@ test_differentiation( @test grad.B == ps.A end -test_differentiation( - backends[3:4], - nomatrix(static_scenarios()); - logging = LOGGING, - excluded = SECOND_ORDER -) +# friendly_tangents + StaticArrays broken on Julia 1.11 (upstream Mooncake bug) +@static if !(VERSION ≥ v"1.11-" && VERSION < v"1.12-") + test_differentiation( + backends[3:4], + nomatrix(static_scenarios()); + logging = LOGGING, + excluded = SECOND_ORDER, + ) +end @testset "Friendly tangents structured matrices" begin backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) From 7043da2edaf2fb58870902246784eb175d66aa5b Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 10:44:47 +0100 Subject: [PATCH 7/8] fix: convert raw Mooncake.Tangent in pullback/gradient results On Julia 1.11, Mooncake may return raw Tangent objects instead of friendly arrays for StaticArrays even with friendly_tangents=true. Add _maybe_to_primal dispatch as a safety net that converts leaked Tangent objects to primal-shaped values, no-op otherwise. --- .../DifferentiationInterfaceMooncakeExt/onearg.jl | 8 ++++---- .../DifferentiationInterfaceMooncakeExt/twoarg.jl | 4 ++-- .../DifferentiationInterfaceMooncakeExt/utils.jl | 5 +++++ .../test/Back/Mooncake/test.jl | 15 ++++++--------- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 2514cdc40..2bbf49f1c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -35,7 +35,7 @@ function DI.value_and_pullback( new_y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return new_y, (_copy_output(new_dx),) + return new_y, (_maybe_to_primal(new_dx, x),) end function DI.value_and_pullback( @@ -51,7 +51,7 @@ function DI.value_and_pullback( y, (_, new_dx) = value_and_pullback!!( prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - y, _copy_output(new_dx) + y, _maybe_to_primal(new_dx, x) end y = first(ys_and_tx[1]) tx = map(last, ys_and_tx) @@ -134,7 +134,7 @@ function DI.value_and_gradient( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - return y, _copy_output(new_grad) + return y, _maybe_to_primal(new_grad, x) end function DI.value_and_gradient!( @@ -150,7 +150,7 @@ function DI.value_and_gradient!( prep.cache, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) - copyto!(grad, new_grad) + copyto!(grad, _maybe_to_primal(new_grad, x)) return y, grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 2b55131b9..ed7f4ca9c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -64,7 +64,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - return y, (_copy_output(dx),) + return y, (_maybe_to_primal(dx, x),) end function DI.value_and_pullback( @@ -90,7 +90,7 @@ function DI.value_and_pullback( prep.args_to_zero ) copyto!(y, y_after) - _copy_output(dx) + _maybe_to_primal(dx, x) end return y, tx end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 2184e9ecb..5be9097b3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -21,6 +21,11 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake) end end +# Safety net: if Mooncake returns a raw Tangent (e.g. Julia 1.11 + StaticArrays), +# convert it to a primal-shaped value. No-op for already-converted results. +_maybe_to_primal(tx, x) = _copy_output(tx) +_maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x) + @inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 7e5a291a7..313c0d2bf 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -75,15 +75,12 @@ test_differentiation( @test grad.B == ps.A end -# friendly_tangents + StaticArrays broken on Julia 1.11 (upstream Mooncake bug) -@static if !(VERSION ≥ v"1.11-" && VERSION < v"1.12-") - test_differentiation( - backends[3:4], - nomatrix(static_scenarios()); - logging = LOGGING, - excluded = SECOND_ORDER, - ) -end +test_differentiation( + backends[3:4], + nomatrix(static_scenarios()); + logging = LOGGING, + excluded = SECOND_ORDER, +) @testset "Friendly tangents structured matrices" begin backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) From ce72bafde06c0427f022cdd93703a7d95c4ed6b7 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 10 Apr 2026 11:57:34 +0100 Subject: [PATCH 8/8] fix: handle MutableTangent and forward mode tangent leaks Also convert leaked Mooncake.MutableTangent (e.g. MVector tangents) and apply _maybe_to_primal in forward mode (pushforward) paths. --- .../ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl | 2 +- .../ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl | 4 ++-- .../ext/DifferentiationInterfaceMooncakeExt/utils.jl | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index c470b6473..cab7f84d6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -41,7 +41,7 @@ function DI.value_and_pushforward( map(first_unwrap, contexts, prep.context_tangents)..., ) y = first(y_and_dy) - dy = _copy_output(last(y_and_dy)) + dy = _maybe_to_primal(last(y_and_dy), y) return y, dy end y = _copy_output(first(ys_and_ty[1])) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 3c75f530b..8ebb6ef99 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -55,7 +55,7 @@ function DI.value_and_pushforward( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - return _copy_output(new_dy) + return _maybe_to_primal(new_dy, y) end return y, ty end @@ -93,7 +93,7 @@ function DI.value_and_pushforward!( (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - copyto!(dy, new_dy) + copyto!(dy, _maybe_to_primal(new_dy, y)) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 5be9097b3..beeb6f611 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -25,6 +25,7 @@ end # convert it to a primal-shaped value. No-op for already-converted results. _maybe_to_primal(tx, x) = _copy_output(tx) _maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x) +_maybe_to_primal(tx::Mooncake.MutableTangent, x) = tangent_to_user_primal(tx, x) @inline maybe_getfield(mod, name::Symbol) = isdefined(mod, name) ? getfield(mod, name) : nothing