-
Notifications
You must be signed in to change notification settings - Fork 32
Mooncake 0.5.25 compat #988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
75d3c05
f420234
8235e6c
a478578
65997c4
106f50f
7043da2
ce72baf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,9 +11,41 @@ end | |
|
|
||
| function zero_tangent_or_primal(x, backend::AnyAutoMooncake) | ||
| if get_config(backend).friendly_tangents | ||
| # zero(x) but safer | ||
| return tangent_to_primal!!(_copy_output(x), zero_tangent(x)) | ||
| # Mooncake 0.5.25+ replaced `tangent_to_primal!!` with the | ||
| # `tangent_to_friendly!!` framework. For this internal backup we still | ||
| # need a primal-shaped value, so use the `AsPrimal` path when | ||
| # available and fall back for older Mooncake releases. | ||
| return tangent_to_user_primal(zero_tangent(x), x) | ||
| else | ||
| return zero_tangent(x) | ||
| end | ||
| end | ||
|
|
||
| # Safety net: if Mooncake returns a raw Tangent (e.g. Julia 1.11 + StaticArrays), | ||
| # convert it to a primal-shaped value. No-op for already-converted results. | ||
| _maybe_to_primal(tx, x) = _copy_output(tx) | ||
| _maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x) | ||
| _maybe_to_primal(tx::Mooncake.MutableTangent, x) = tangent_to_user_primal(tx, x) | ||
|
|
||
| @inline maybe_getfield(mod, name::Symbol) = | ||
| isdefined(mod, name) ? getfield(mod, name) : nothing | ||
|
|
||
| const mooncake_tangent_to_friendly = maybe_getfield( | ||
| Mooncake, Symbol("tangent_to_friendly!!") | ||
| ) | ||
| const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache) | ||
| const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal) | ||
| const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) | ||
|
Comment on lines
+30
to
+38
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't seem to be very robust? I'd rather impose a lower bound for Mooncake at v0.5.25 in |
||
|
|
||
| function tangent_to_user_primal(tx, x) | ||
| if !isnothing(mooncake_tangent_to_friendly) && | ||
| !isnothing(mooncake_friendly_tangent_cache) && | ||
| !isnothing(mooncake_as_primal) && | ||
| !isnothing(mooncake_no_cache) | ||
| dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) | ||
| cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is a type-unstable dictionary needed here?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regardless of why, we may want to allocate this dictionary in the preparation phase
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yup this will regress performance on hot paths. This should be allocated once during the prepare phase and stored in the extras cache. |
||
| return mooncake_tangent_to_friendly(dest, x, tx, cache) | ||
| else | ||
| return tangent_to_primal!!(_copy_output(x), tx) | ||
| end | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| include("../../testutils.jl") | ||
|
|
||
| using DifferentiationInterface, DifferentiationInterfaceTest | ||
| using LinearAlgebra: Hermitian, SymTridiagonal, Symmetric | ||
| using Mooncake: Mooncake | ||
| using Test | ||
|
|
||
|
|
@@ -78,5 +79,36 @@ test_differentiation( | |
| backends[3:4], | ||
| nomatrix(static_scenarios()); | ||
| logging = LOGGING, | ||
| excluded = SECOND_ORDER | ||
| excluded = SECOND_ORDER, | ||
| ) | ||
|
|
||
| @testset "Friendly tangents structured matrices" begin | ||
| backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) | ||
| inputs = ( | ||
| Symmetric([2.0 1.0; 1.0 3.0]), | ||
| Hermitian(ComplexF64[2 1 + im; 1 - im 3]), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know which convention Mooncake uses for gradients of functions with complex inputs and real outputs? There are two possible choices, see e.g. https://arxiv.org/abs/2409.06752 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| SymTridiagonal([2.0, 3.0, 4.0], [5.0, 6.0]), | ||
| ) | ||
| f(x) = real(sum(abs2, x)) | ||
|
|
||
| @testset "$(typeof(x))" for x in inputs | ||
| grad = gradient(f, backend, x) | ||
| y, grad2 = value_and_gradient(f, backend, x) | ||
| pb = only(pullback(identity, backend, x, (x,))) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a strong enough test, the function is too simple |
||
|
|
||
| @test grad isa Matrix | ||
| @test grad2 isa Matrix | ||
| @test pb isa Matrix | ||
| @test grad == grad2 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| @test y == f(x) | ||
| @test pb == Matrix(x) | ||
|
|
||
| grad_dense = zero(Matrix(x)) | ||
| @test gradient!(f, grad_dense, backend, x) === grad_dense | ||
| @test grad_dense == grad | ||
|
|
||
| tx_dense = (zero(Matrix(x)),) | ||
| @test only(pullback!(identity, tx_dense, backend, x, (x,))) === tx_dense[1] | ||
| @test tx_dense[1] == pb | ||
| end | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to ensure that primal conversion happens here? If
friendly_tangentsis set totrue, won't Mooncake's pushforward and pullback already return a primal-like object?