diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 92fd1320..699664ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -89,6 +89,32 @@ jobs: - uses: julia-actions/julia-uploadcodecov@latest env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + extra: + name: ${{matrix.test_group.test_type}}-${{ matrix.test_group.label }}-${{ matrix.version }}-${{ matrix.arch }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + test_group: [ + {test_type: 'ext', label: 'mooncake'}, + ] + version: + - '1.10' + - '1' + arch: + - x64 + steps: + - uses: actions/checkout@v6 + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v3 + - uses: julia-actions/julia-buildpkg@v1 + - run: julia --code-coverage=user --eval 'include("test/${{ matrix.test_group.test_type }}/${{ matrix.test_group.label }}/${{ matrix.test_group.label }}.jl")' + - uses: julia-actions/julia-uploadcodecov@latest + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} docs: name: Documentation runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 94cabef0..0982401f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ /benchmark/*.json /benchmark/Manifest.toml /docs/Manifest.toml +/test/ext/**/Manifest.toml diff --git a/Project.toml b/Project.toml index ba8ee58d..0c76a8e7 100644 --- a/Project.toml +++ b/Project.toml @@ -11,10 +11,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [extensions] StaticArraysChainRulesCoreExt = "ChainRulesCore" +StaticArraysMooncakeExt = "Mooncake" StaticArraysStatisticsExt = "Statistics" [compat] @@ -25,6 +27,7 @@ ChainRulesTestUtils = "1" InteractiveUtils = "1" JLArrays = "0.1" LinearAlgebra = "1.6" +Mooncake = "0.5.27" OffsetArrays = "1" PrecompileTools = "1" Random = "1.6" diff --git a/ext/StaticArraysMooncakeExt.jl b/ext/StaticArraysMooncakeExt.jl new file mode 100644 index 00000000..7be613b3 --- /dev/null +++ b/ext/StaticArraysMooncakeExt.jl @@ -0,0 +1,158 @@ +module StaticArraysMooncakeExt + +using Mooncake +using Random: AbstractRNG +using Base: IEEEFloat +using StaticArrays: StaticArrays, SArray + +using Mooncake: @foldable + +import Mooncake: + MaybeCache, + IncCache, + SetToZeroCache, + NoFData, + NoRData, + CoDual, + Dual, + MinimalCtx, + primal, + tangent, + zero_fcodual, + tangent_type, + fdata_type, + rdata_type, + zero_tangent_internal, + randn_tangent_internal, + set_to_zero_internal!!, + increment_internal!!, + _add_to_primal_internal, + tangent_to_primal_internal!!, + primal_to_tangent_internal!!, + _dot_internal, + _scale_internal, + zero_rdata, + zero_rdata_from_type, + can_produce_zero_rdata_from_type, + _verify_rdata_value, + __verify_fdata_value, + _new_ + +# Element types treated as differentiable: real and complex IEEE floats. +const _SElt = Union{IEEEFloat,Complex{<:IEEEFloat}} + +# An SArray with a supported element type uses *itself* as its tangent type. +# It is immutable and stores only by-value scalar data, so its fdata is empty +# and its rdata carries the full tangent. This mirrors how Mooncake handles +# `Complex{<:IEEEFloat}` in `src/rules/complex.jl`. + +@foldable function tangent_type(::Type{SArray{S,T,N,L}}) where {S,T<:_SElt,N,L} + return SArray{S,T,N,L} +end + +@foldable function tangent_type( + ::Type{NoFData}, ::Type{SArray{S,T,N,L}} +) where {S,T<:_SElt,N,L} + return SArray{S,T,N,L} +end + +# Non-parametric alias used as a constraint, analogous to `CF` in complex.jl. +const _SAFloat = SArray{S,T,N,L} where {S,T<:_SElt,N,L} + +@foldable fdata_type(::Type{T}) where {T<:_SAFloat} = NoFData +@foldable rdata_type(::Type{T}) where {T<:_SAFloat} = T + +tangent(::NoFData, t::_SAFloat) = t + +# Core tangent operations. + +zero_tangent_internal(p::_SAFloat, ::MaybeCache) = zero(p) + +zero_rdata(p::_SAFloat) = zero(p) +zero_rdata_from_type(::Type{P}) where {P<:_SAFloat} = zero(P) +@foldable can_produce_zero_rdata_from_type(::Type{<:_SAFloat}) = true + +set_to_zero_internal!!(::SetToZeroCache, p::_SAFloat) = zero(p) + +function randn_tangent_internal( + rng::AbstractRNG, ::SArray{S,T,N,L}, ::MaybeCache +) where {S,T<:_SElt,N,L} + return SArray{S,T,N,L}(ntuple(_ -> randn(rng, T), Val(L))) +end + +increment_internal!!(::IncCache, t::T, s::T) where {T<:_SAFloat} = t + s + +_add_to_primal_internal(::MaybeCache, x::T, t::T, ::Bool) where {T<:_SAFloat} = x + t + +tangent_to_primal_internal!!(::T, t::T, ::MaybeCache) where {T<:_SAFloat} = t +primal_to_tangent_internal!!(::T, x::T, ::MaybeCache) where {T<:_SAFloat} = x + +# By-value type, so there is no primal address to record. +function Mooncake.TestUtils.populate_address_map_internal( + m::Mooncake.TestUtils.AddressMap, ::P, ::P +) where {P<:_SAFloat} + return m +end + +# rdata/fdata are validated structurally for non-primitive aggregates; by-value +# SArrays are leaves like `Complex`, so we short-circuit verification. +_verify_rdata_value(::P, ::P) where {P<:_SAFloat} = nothing +__verify_fdata_value(::IdDict{Any,Nothing}, ::P, ::P) where {P<:_SAFloat} = nothing + +# Delegate element-wise reductions to the existing tuple handlers, which in +# turn dispatch to per-element `_dot_internal` / `_scale_internal` (correct for +# both `IEEEFloat` and `Complex{<:IEEEFloat}` element types). +function _dot_internal(c::MaybeCache, t::T, s::T) where {T<:_SAFloat} + return _dot_internal(c, Tuple(t), Tuple(s)) +end + +function _scale_internal(c::MaybeCache, a::Float64, t::T) where {T<:_SAFloat} + return T(_scale_internal(c, a, Tuple(t))) +end + +# Rules. `_new_` is already declared a primitive globally +# (`Tuple{typeof(_new_),Vararg}` in `src/rules/new.jl`), so we only need to +# add more-specific `frule!!` / `rrule!!` methods for SArray construction. +# Mooncake's IR normalisation rewrites `SArray(...)` constructor calls to +# `_new_(SArray{S,T,N,L}, data::NTuple{L,T})`. + +function Mooncake.frule!!( + ::Dual{typeof(_new_)}, ::Dual{Type{P}}, data::Dual{NTuple{L,T}} +) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}} + y = _new_(P, primal(data)) + dy = _new_(P, tangent(data)) + return Dual(y, dy) +end + +function Mooncake.rrule!!( + ::CoDual{typeof(_new_)}, ::CoDual{Type{P}}, data::CoDual{NTuple{L,T}} +) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}} + y = _new_(P, primal(data)) + _new_SArray_pb(dy::P) = NoRData(), NoRData(), Tuple(dy) + return zero_fcodual(y), _new_SArray_pb +end + +Mooncake.@is_primitive MinimalCtx Tuple{ + typeof(getindex),SArray{S,T,N,L},Int +} where {S,T<:_SElt,N,L} + +function Mooncake.frule!!( + ::Dual{typeof(getindex)}, x::Dual{P}, i::Dual{Int} +) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}} + idx = primal(i) + return Dual(primal(x)[idx], tangent(x)[idx]) +end + +function Mooncake.rrule!!( + ::CoDual{typeof(getindex)}, x::CoDual{P,NoFData}, i::CoDual{Int} +) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}} + idx = primal(i) + y = primal(x)[idx] + function getindex_SArray_pb(dy::T) + dx = P(ntuple(j -> j == idx ? dy : zero(T), Val(L))) + return NoRData(), dx, NoRData() + end + return zero_fcodual(y), getindex_SArray_pb +end + +end # module StaticArraysMooncakeExt diff --git a/test/ext/mooncake/Project.toml b/test/ext/mooncake/Project.toml new file mode 100644 index 00000000..68d11901 --- /dev/null +++ b/test/ext/mooncake/Project.toml @@ -0,0 +1,11 @@ +[deps] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +Mooncake = ">=0.5.27" +StaticArrays = "1" diff --git a/test/ext/mooncake/mooncake.jl b/test/ext/mooncake/mooncake.jl new file mode 100644 index 00000000..131edb00 --- /dev/null +++ b/test/ext/mooncake/mooncake.jl @@ -0,0 +1,55 @@ +# Run from the repo root with: +# julia --project=test/ext/mooncake test/ext/mooncake/mooncake.jl + +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) + +using AllocCheck, JET, Mooncake, StableRNGs, StaticArrays, Test +using Mooncake.TestUtils: test_rule, test_tangent_interface, test_tangent_splitting + +# The extension covers `SArray` only by design (it's the by-value, leaf-tangent +# case); `MArray` falls through to Mooncake's generic mutable-array handling. +@testset verbose=true "Mooncake integration" begin + cases = Any[ + SVector{3,Float64}(1.0, 2.0, 3.0), + SVector{2,Float32}(1.0f0, -2.0f0), + SMatrix{2,2,Float64}(1.0, 2.0, 3.0, 4.0), + SVector{2,ComplexF64}(1.0 + 2.0im, -3.0 + 1.0im), + SVector{1,ComplexF32}(0.5f0 + 0.25f0im), + ] + + @testset "tangent interface for $(typeof(p))" for p in cases + rng = StableRNG(123) + test_tangent_interface(rng, p) + test_tangent_splitting(rng, p) + end + + @testset "rrule!! getindex $(typeof(p))" for p in cases + for i in eachindex(p) + test_rule(StableRNG(123), getindex, p, i; is_primitive=true) + end + end + + @testset "rrule!! _new_ construction" begin + # `_new_` is the primitive that IR normalisation lowers `SArray(...)` + # construction calls into; test the SArray-specific method directly. + new_cases = Any[ + (SVector{3,Float64}, (1.0, 2.0, 3.0)), + (SMatrix{2,2,Float64,4}, (1.0, 2.0, 3.0, 4.0)), + (SVector{2,ComplexF64}, (1.0 + 2.0im, -3.0 + 1.0im)), + ] + for (P, data) in new_cases + test_rule(StableRNG(123), Mooncake._new_, P, data; is_primitive=true) + end + end + + @testset "end-to-end gradient" begin + f(x) = x[1]^2 + 2 * x[2] * x[3] + x = SVector{3,Float64}(1.5, -2.0, 0.5) + cache = Mooncake.prepare_gradient_cache(f, x) + val, (_, dx) = Mooncake.value_and_gradient!!(cache, f, x) + @test val ≈ f(x) + @test dx ≈ SVector{3,Float64}(2 * x[1], 2 * x[3], 2 * x[2]) + end +end