-
Notifications
You must be signed in to change notification settings - Fork 155
Add Mooncake.jl extension #1343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,3 +7,4 @@ | |
| /benchmark/*.json | ||
| /benchmark/Manifest.toml | ||
| /docs/Manifest.toml | ||
| /test/ext/**/Manifest.toml | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| module StaticArraysMooncakeExt | ||
|
|
||
| using Mooncake | ||
| using Random: AbstractRNG | ||
| using Base: IEEEFloat | ||
| using StaticArrays: StaticArrays, SArray | ||
|
|
||
| using Mooncake: @foldable | ||
|
|
||
| import Mooncake: | ||
| MaybeCache, | ||
| IncCache, | ||
| SetToZeroCache, | ||
| NoFData, | ||
| NoRData, | ||
| CoDual, | ||
| Dual, | ||
| MinimalCtx, | ||
| primal, | ||
| tangent, | ||
| zero_fcodual, | ||
| tangent_type, | ||
| fdata_type, | ||
| rdata_type, | ||
| zero_tangent_internal, | ||
| randn_tangent_internal, | ||
| set_to_zero_internal!!, | ||
| increment_internal!!, | ||
| _add_to_primal_internal, | ||
| tangent_to_primal_internal!!, | ||
| primal_to_tangent_internal!!, | ||
| _dot_internal, | ||
| _scale_internal, | ||
| zero_rdata, | ||
| zero_rdata_from_type, | ||
| can_produce_zero_rdata_from_type, | ||
| _verify_rdata_value, | ||
| __verify_fdata_value, | ||
| _new_ | ||
|
|
||
| # Element types treated as differentiable: real and complex IEEE floats. | ||
| const _SElt = Union{IEEEFloat,Complex{<:IEEEFloat}} | ||
|
|
||
| # An SArray with a supported element type uses *itself* as its tangent type. | ||
| # It is immutable and stores only by-value scalar data, so its fdata is empty | ||
| # and its rdata carries the full tangent. This mirrors how Mooncake handles | ||
| # `Complex{<:IEEEFloat}` in `src/rules/complex.jl`. | ||
|
|
||
| @foldable function tangent_type(::Type{SArray{S,T,N,L}}) where {S,T<:_SElt,N,L} | ||
| return SArray{S,T,N,L} | ||
| end | ||
|
|
||
| @foldable function tangent_type( | ||
| ::Type{NoFData}, ::Type{SArray{S,T,N,L}} | ||
| ) where {S,T<:_SElt,N,L} | ||
| return SArray{S,T,N,L} | ||
| end | ||
|
|
||
| # Non-parametric alias used as a constraint, analogous to `CF` in complex.jl. | ||
| const _SAFloat = SArray{S,T,N,L} where {S,T<:_SElt,N,L} | ||
|
|
||
| @foldable fdata_type(::Type{T}) where {T<:_SAFloat} = NoFData | ||
| @foldable rdata_type(::Type{T}) where {T<:_SAFloat} = T | ||
|
|
||
| tangent(::NoFData, t::_SAFloat) = t | ||
|
|
||
| # Core tangent operations. | ||
|
|
||
| zero_tangent_internal(p::_SAFloat, ::MaybeCache) = zero(p) | ||
|
|
||
| zero_rdata(p::_SAFloat) = zero(p) | ||
| zero_rdata_from_type(::Type{P}) where {P<:_SAFloat} = zero(P) | ||
| @foldable can_produce_zero_rdata_from_type(::Type{<:_SAFloat}) = true | ||
|
|
||
| set_to_zero_internal!!(::SetToZeroCache, p::_SAFloat) = zero(p) | ||
|
|
||
| function randn_tangent_internal( | ||
| rng::AbstractRNG, ::SArray{S,T,N,L}, ::MaybeCache | ||
| ) where {S,T<:_SElt,N,L} | ||
| return SArray{S,T,N,L}(ntuple(_ -> randn(rng, T), Val(L))) | ||
| end | ||
|
|
||
| increment_internal!!(::IncCache, t::T, s::T) where {T<:_SAFloat} = t + s | ||
|
|
||
| _add_to_primal_internal(::MaybeCache, x::T, t::T, ::Bool) where {T<:_SAFloat} = x + t | ||
|
|
||
| tangent_to_primal_internal!!(::T, t::T, ::MaybeCache) where {T<:_SAFloat} = t | ||
| primal_to_tangent_internal!!(::T, x::T, ::MaybeCache) where {T<:_SAFloat} = x | ||
|
|
||
| # By-value type, so there is no primal address to record. | ||
| function Mooncake.TestUtils.populate_address_map_internal( | ||
| m::Mooncake.TestUtils.AddressMap, ::P, ::P | ||
| ) where {P<:_SAFloat} | ||
| return m | ||
| end | ||
|
|
||
| # rdata/fdata are validated structurally for non-primitive aggregates; by-value | ||
| # SArrays are leaves like `Complex`, so we short-circuit verification. | ||
| _verify_rdata_value(::P, ::P) where {P<:_SAFloat} = nothing | ||
| __verify_fdata_value(::IdDict{Any,Nothing}, ::P, ::P) where {P<:_SAFloat} = nothing | ||
|
|
||
| # Delegate element-wise reductions to the existing tuple handlers, which in | ||
| # turn dispatch to per-element `_dot_internal` / `_scale_internal` (correct for | ||
| # both `IEEEFloat` and `Complex{<:IEEEFloat}` element types). | ||
| function _dot_internal(c::MaybeCache, t::T, s::T) where {T<:_SAFloat} | ||
| return _dot_internal(c, Tuple(t), Tuple(s)) | ||
| end | ||
|
|
||
| function _scale_internal(c::MaybeCache, a::Float64, t::T) where {T<:_SAFloat} | ||
| return T(_scale_internal(c, a, Tuple(t))) | ||
| end | ||
|
|
||
| # Rules. `_new_` is already declared a primitive globally | ||
| # (`Tuple{typeof(_new_),Vararg}` in `src/rules/new.jl`), so we only need to | ||
| # add more-specific `frule!!` / `rrule!!` methods for SArray construction. | ||
| # Mooncake's IR normalisation rewrites `SArray(...)` constructor calls to | ||
| # `_new_(SArray{S,T,N,L}, data::NTuple{L,T})`. | ||
|
|
||
| function Mooncake.frule!!( | ||
| ::Dual{typeof(_new_)}, ::Dual{Type{P}}, data::Dual{NTuple{L,T}} | ||
| ) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}} | ||
| y = _new_(P, primal(data)) | ||
| dy = _new_(P, tangent(data)) | ||
| return Dual(y, dy) | ||
| end | ||
|
|
||
| function Mooncake.rrule!!( | ||
| ::CoDual{typeof(_new_)}, ::CoDual{Type{P}}, data::CoDual{NTuple{L,T}} | ||
| ) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}} | ||
| y = _new_(P, primal(data)) | ||
| _new_SArray_pb(dy::P) = NoRData(), NoRData(), Tuple(dy) | ||
| return zero_fcodual(y), _new_SArray_pb | ||
| end | ||
|
|
||
| Mooncake.@is_primitive MinimalCtx Tuple{ | ||
| typeof(getindex),SArray{S,T,N,L},Int | ||
| } where {S,T<:_SElt,N,L} | ||
|
|
||
| function Mooncake.frule!!( | ||
| ::Dual{typeof(getindex)}, x::Dual{P}, i::Dual{Int} | ||
| ) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}} | ||
| idx = primal(i) | ||
| return Dual(primal(x)[idx], tangent(x)[idx]) | ||
| end | ||
|
|
||
| function Mooncake.rrule!!( | ||
| ::CoDual{typeof(getindex)}, x::CoDual{P,NoFData}, i::CoDual{Int} | ||
| ) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}} | ||
| idx = primal(i) | ||
| y = primal(x)[idx] | ||
| function getindex_SArray_pb(dy::T) | ||
| dx = P(ntuple(j -> j == idx ? dy : zero(T), Val(L))) | ||
| return NoRData(), dx, NoRData() | ||
| end | ||
| return zero_fcodual(y), getindex_SArray_pb | ||
| end | ||
|
|
||
| end # module StaticArraysMooncakeExt | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| [deps] | ||
| AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" | ||
| JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" | ||
| Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||
| StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | ||
| StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" | ||
| Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
|
||
| [compat] | ||
| Mooncake = ">=0.5.27" | ||
| StaticArrays = "1" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| # Run from the repo root with: | ||
| # julia --project=test/ext/mooncake test/ext/mooncake/mooncake.jl | ||
|
|
||
| using Pkg | ||
| Pkg.activate(@__DIR__) | ||
| Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) | ||
|
|
||
| using AllocCheck, JET, Mooncake, StableRNGs, StaticArrays, Test | ||
| using Mooncake.TestUtils: test_rule, test_tangent_interface, test_tangent_splitting | ||
|
|
||
| # The extension covers `SArray` only by design (it's the by-value, leaf-tangent | ||
| # case); `MArray` falls through to Mooncake's generic mutable-array handling. | ||
| @testset verbose=true "Mooncake integration" begin | ||
| cases = Any[ | ||
| SVector{3,Float64}(1.0, 2.0, 3.0), | ||
| SVector{2,Float32}(1.0f0, -2.0f0), | ||
| SMatrix{2,2,Float64}(1.0, 2.0, 3.0, 4.0), | ||
| SVector{2,ComplexF64}(1.0 + 2.0im, -3.0 + 1.0im), | ||
| SVector{1,ComplexF32}(0.5f0 + 0.25f0im), | ||
| ] | ||
|
|
||
| @testset "tangent interface for $(typeof(p))" for p in cases | ||
| rng = StableRNG(123) | ||
| test_tangent_interface(rng, p) | ||
| test_tangent_splitting(rng, p) | ||
| end | ||
|
|
||
| @testset "rrule!! getindex $(typeof(p))" for p in cases | ||
| for i in eachindex(p) | ||
| test_rule(StableRNG(123), getindex, p, i; is_primitive=true) | ||
| end | ||
| end | ||
|
|
||
| @testset "rrule!! _new_ construction" begin | ||
| # `_new_` is the primitive that IR normalisation lowers `SArray(...)` | ||
| # construction calls into; test the SArray-specific method directly. | ||
| new_cases = Any[ | ||
| (SVector{3,Float64}, (1.0, 2.0, 3.0)), | ||
| (SMatrix{2,2,Float64,4}, (1.0, 2.0, 3.0, 4.0)), | ||
| (SVector{2,ComplexF64}, (1.0 + 2.0im, -3.0 + 1.0im)), | ||
| ] | ||
| for (P, data) in new_cases | ||
| test_rule(StableRNG(123), Mooncake._new_, P, data; is_primitive=true) | ||
| end | ||
| end | ||
|
|
||
| @testset "end-to-end gradient" begin | ||
| f(x) = x[1]^2 + 2 * x[2] * x[3] | ||
| x = SVector{3,Float64}(1.5, -2.0, 0.5) | ||
| cache = Mooncake.prepare_gradient_cache(f, x) | ||
| val, (_, dx) = Mooncake.value_and_gradient!!(cache, f, x) | ||
| @test val ≈ f(x) | ||
| @test dx ≈ SVector{3,Float64}(2 * x[1], 2 * x[3], 2 * x[2]) | ||
| end | ||
| end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside from whether this should live in Mooncake or StaticArrays, was this even reviewed by a human before submission? These definitions look like they should be exactly identical to defaults.
Is this really the kind of pull requests you would welcome in your projects?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is actually needed -- the default tangent for Mooncake is not StaticArrays. Mooncake's automatically derived tangents are usually a NamedTuple of all field members of structs.
The tests for this PR run successfully locally, and I reviewed the code myself before creating it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's all necessary, then it's not much better. Nearly duplicating so much code for each array-like type is going to cause huge maintenance issues. Maybe it could all be generated by a macro that lives in Mooncake.jl?
Anyway, if you don't have the time to address issues raised in a review, making PRs is just a waste of time.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will Tebbutt primarily designed Moocanke’s tangent-type system while working with me. Automating this via metaprogramming is difficult, as subtypes of AbstractArray can vary considerably. This reflects lessons learned from Zygote, ChainRules, and other autograd systems. Will has nearly a decade of experience with automatic differentiation in Julia, so I trust his judgment on this.
The level of disagreement has made it difficult to sustain focused technical discussion.
Related: https://chalk-lab.github.io/Mooncake.jl/stable/developer_documentation/custom_tangent_type/#Writing-Custom-Tangent-Types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have about 8 years of experience writing and maintaining Julia code, and in my opinion automation via metaprogramming would be the easier approach in the long run. This extension is its current form breaks major good engineering practices, for example not accessing internal functions of other packages and only using public API. Unless "internal" in function name is only decorative and they are not internal at all?