Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #988 +/- ##
==========================================
- Coverage 98.21% 97.28% -0.94%
==========================================
Files 135 131 -4
Lines 8000 7984 -16
==========================================
- Hits 7857 7767 -90
- Misses 143 217 +74
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
I think I understand the CI error, there is something I need to patch on the Mooncake side, will come back to this |
|
chalk-lab/Mooncake.jl#1129 should unblock this PR, we'll release it as soon as it's available |
|
Thank you for taking a crack at this! I'll wait until the tests pass before reviewing if that's okay |
|
totally fine! |
Mooncake returns raw Tangent objects instead of friendly arrays for StaticArrays on Julia 1.11. This is an upstream bug — skip the test until it is fixed.
On Julia 1.11, Mooncake may return raw Tangent objects instead of friendly arrays for StaticArrays even with friendly_tangents=true. Add _maybe_to_primal dispatch as a safety net that converts leaked Tangent objects to primal-shaped values, no-op otherwise.
Also convert leaked Mooncake.MutableTangent (e.g. MVector tangents) and apply _maybe_to_primal in forward mode (pushforward) paths.
|
@gdalle the Mooncake CIs are passing (it probably requires Mooncake v0.5.26). Could you take over? I also won't be offended if you want to start a new PR. |
|
I'll take a look when I can, thanks a bunch! DI's tests are failing on main too because of Mooncake's breaking release so this is a priority for me. Do you know why coverage is not complete? |
|
Thanks a lot.
I am not certain. A bad guess is that some code changes are more defensive than necessary. Sorry! |
gdalle
left a comment
There was a problem hiding this comment.
Thank you for trying to fix what others broke! I added a few remarks to understand the task a bit better, I'll wait for your answers
| @inline maybe_getfield(mod, name::Symbol) = | ||
| isdefined(mod, name) ? getfield(mod, name) : nothing | ||
|
|
||
| const mooncake_tangent_to_friendly = maybe_getfield( | ||
| Mooncake, Symbol("tangent_to_friendly!!") | ||
| ) | ||
| const mooncake_friendly_tangent_cache = maybe_getfield(Mooncake, :FriendlyTangentCache) | ||
| const mooncake_as_primal = maybe_getfield(Mooncake, :AsPrimal) | ||
| const mooncake_no_cache = maybe_getfield(Mooncake, :NoCache) |
There was a problem hiding this comment.
This doesn't seem to be very robust? I'd rather impose a lower bound for Mooncake at v0.5.25 in Project.toml (that way we're sure we can use all of these symbols)
| ) | ||
| y = first(y_and_dy) | ||
| dy = _copy_output(last(y_and_dy)) | ||
| dy = _maybe_to_primal(last(y_and_dy), y) |
There was a problem hiding this comment.
Why do we need to ensure that primal conversion happens here? If friendly_tangents is set to true, won't Mooncake's pushforward and pullback already return a primal-like object?
| backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)) | ||
| inputs = ( | ||
| Symmetric([2.0 1.0; 1.0 3.0]), | ||
| Hermitian(ComplexF64[2 1 + im; 1 - im 3]), |
There was a problem hiding this comment.
Do you know which convention Mooncake uses for gradients of functions with complex inputs and real outputs? There are two possible choices, see e.g. https://arxiv.org/abs/2409.06752
| @testset "$(typeof(x))" for x in inputs | ||
| grad = gradient(f, backend, x) | ||
| y, grad2 = value_and_gradient(f, backend, x) | ||
| pb = only(pullback(identity, backend, x, (x,))) |
There was a problem hiding this comment.
This is not a strong enough test, the function is too simple
| @test grad isa Matrix | ||
| @test grad2 isa Matrix | ||
| @test pb isa Matrix | ||
| @test grad == grad2 |
There was a problem hiding this comment.
grad and grad2 are never compared against the ground truth
| !isnothing(mooncake_as_primal) && | ||
| !isnothing(mooncake_no_cache) | ||
| dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) | ||
| cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}() |
There was a problem hiding this comment.
Why is a type-unstable dictionary needed here?
Does this make every tangent-to-primal conversion outside of bitstypes slow?
There was a problem hiding this comment.
Regardless of why, we may want to allocate this dictionary in the preparation phase
There was a problem hiding this comment.
yup this will regress performance on hot paths. This should be allocated once during the prepare phase and stored in the extras cache.
|
@AstitvaAggarwal @Technici4n could you maybe take a look too? |
AstitvaAggarwal
left a comment
There was a problem hiding this comment.
also we might want to keep track of future possible tangent_types: _maybe_to_primal(x, _) = x will silently pass through any tangent type not yet accounted for (e.g. a future Mooncake.SparseTangent), making failures invisible.
| !isnothing(mooncake_as_primal) && | ||
| !isnothing(mooncake_no_cache) | ||
| dest = mooncake_friendly_tangent_cache{mooncake_as_primal}(_copy_output(x)) | ||
| cache = isbitstype(typeof(x)) ? mooncake_no_cache() : IdDict{Any, Any}() |
There was a problem hiding this comment.
yup this will regress performance on hot paths. This should be allocated once during the prepare phase and stored in the extras cache.
An attempt at addressing #986.
Feel free to make any edits or take over!