Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,32 @@ jobs:
- uses: julia-actions/julia-uploadcodecov@latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
extra:
name: ${{matrix.test_group.test_type}}-${{ matrix.test_group.label }}-${{ matrix.version }}-${{ matrix.arch }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
test_group: [
{test_type: 'ext', label: 'mooncake'},
]
version:
- '1.10'
- '1'
arch:
- x64
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v3
- uses: julia-actions/julia-buildpkg@v1
- run: julia --code-coverage=user --eval 'include("test/${{ matrix.test_group.test_type }}/${{ matrix.test_group.label }}/${{ matrix.test_group.label }}.jl")'
- uses: julia-actions/julia-uploadcodecov@latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
docs:
name: Documentation
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
/benchmark/*.json
/benchmark/Manifest.toml
/docs/Manifest.toml
/test/ext/**/Manifest.toml
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
StaticArraysChainRulesCoreExt = "ChainRulesCore"
StaticArraysMooncakeExt = "Mooncake"
StaticArraysStatisticsExt = "Statistics"

[compat]
Expand All @@ -25,6 +27,7 @@ ChainRulesTestUtils = "1"
InteractiveUtils = "1"
JLArrays = "0.1"
LinearAlgebra = "1.6"
Mooncake = "0.5.27"
OffsetArrays = "1"
PrecompileTools = "1"
Random = "1.6"
Expand Down
158 changes: 158 additions & 0 deletions ext/StaticArraysMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
module StaticArraysMooncakeExt

using Mooncake
using Random: AbstractRNG
using Base: IEEEFloat
using StaticArrays: StaticArrays, SArray

using Mooncake: @foldable

import Mooncake:
MaybeCache,
IncCache,
SetToZeroCache,
NoFData,
NoRData,
CoDual,
Dual,
MinimalCtx,
primal,
tangent,
zero_fcodual,
tangent_type,
fdata_type,
rdata_type,
zero_tangent_internal,
randn_tangent_internal,
set_to_zero_internal!!,
increment_internal!!,
_add_to_primal_internal,
tangent_to_primal_internal!!,
primal_to_tangent_internal!!,
_dot_internal,
_scale_internal,
zero_rdata,
zero_rdata_from_type,
can_produce_zero_rdata_from_type,
_verify_rdata_value,
__verify_fdata_value,
_new_

# Element types treated as differentiable: real and complex IEEE floats.
const _SElt = Union{IEEEFloat,Complex{<:IEEEFloat}}

# An SArray with a supported element type uses *itself* as its tangent type.
# It is immutable and stores only by-value scalar data, so its fdata is empty
# and its rdata carries the full tangent. This mirrors how Mooncake handles
# `Complex{<:IEEEFloat}` in `src/rules/complex.jl`.

@foldable function tangent_type(::Type{SArray{S,T,N,L}}) where {S,T<:_SElt,N,L}
return SArray{S,T,N,L}
end

@foldable function tangent_type(
::Type{NoFData}, ::Type{SArray{S,T,N,L}}
) where {S,T<:_SElt,N,L}
return SArray{S,T,N,L}
end

# Non-parametric alias used as a constraint, analogous to `CF` in complex.jl.
const _SAFloat = SArray{S,T,N,L} where {S,T<:_SElt,N,L}

@foldable fdata_type(::Type{T}) where {T<:_SAFloat} = NoFData
@foldable rdata_type(::Type{T}) where {T<:_SAFloat} = T

tangent(::NoFData, t::_SAFloat) = t
Comment on lines +49 to +65
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from whether this should live in Mooncake or StaticArrays, was this even reviewed by a human before submission? These definitions look like they should be exactly identical to defaults.

Is this really the kind of pull requests you would welcome in your projects?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is actually needed -- the default tangent for Mooncake is not StaticArrays. Mooncake's automatically derived tangents are usually a NamedTuple of all field members of structs.

The tests for this PR run successfully locally, and I reviewed the code myself before creating it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's all necessary, then it's not much better. Nearly duplicating so much code for each array-like type is going to cause huge maintenance issues. Maybe it could all be generated by a macro that lives in Mooncake.jl?

Anyway, if you don't have the time to address issues raised in a review, making PRs is just a waste of time.

Copy link
Copy Markdown
Author

@yebai yebai Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's all necessary, then it's not much better. Nearly duplicating so much code for each array-like type is going to cause huge maintenance issues. Maybe it could all be generated by a macro that lives in Mooncake.jl?

Will Tebbutt primarily designed Moocanke’s tangent-type system while working with me. Automating this via metaprogramming is difficult, as subtypes of AbstractArray can vary considerably. This reflects lessons learned from Zygote, ChainRules, and other autograd systems. Will has nearly a decade of experience with automatic differentiation in Julia, so I trust his judgment on this.

if you don't have the time to address issues raised in a review, making PRs is just a waste of time.

The level of disagreement has made it difficult to sustain focused technical discussion.

Related: https://chalk-lab.github.io/Mooncake.jl/stable/developer_documentation/custom_tangent_type/#Writing-Custom-Tangent-Types

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Automating this via metaprogramming is difficult, as subtypes of AbstractArray can vary considerably.

I have about 8 years of experience writing and maintaining Julia code, and in my opinion automation via metaprogramming would be the easier approach in the long run. This extension is its current form breaks major good engineering practices, for example not accessing internal functions of other packages and only using public API. Unless "internal" in function name is only decorative and they are not internal at all?


# Core tangent operations.

zero_tangent_internal(p::_SAFloat, ::MaybeCache) = zero(p)

zero_rdata(p::_SAFloat) = zero(p)
zero_rdata_from_type(::Type{P}) where {P<:_SAFloat} = zero(P)
@foldable can_produce_zero_rdata_from_type(::Type{<:_SAFloat}) = true

set_to_zero_internal!!(::SetToZeroCache, p::_SAFloat) = zero(p)

function randn_tangent_internal(
rng::AbstractRNG, ::SArray{S,T,N,L}, ::MaybeCache
) where {S,T<:_SElt,N,L}
return SArray{S,T,N,L}(ntuple(_ -> randn(rng, T), Val(L)))
end

increment_internal!!(::IncCache, t::T, s::T) where {T<:_SAFloat} = t + s

_add_to_primal_internal(::MaybeCache, x::T, t::T, ::Bool) where {T<:_SAFloat} = x + t

tangent_to_primal_internal!!(::T, t::T, ::MaybeCache) where {T<:_SAFloat} = t
primal_to_tangent_internal!!(::T, x::T, ::MaybeCache) where {T<:_SAFloat} = x

# By-value type, so there is no primal address to record.
function Mooncake.TestUtils.populate_address_map_internal(
m::Mooncake.TestUtils.AddressMap, ::P, ::P
) where {P<:_SAFloat}
return m
end

# rdata/fdata are validated structurally for non-primitive aggregates; by-value
# SArrays are leaves like `Complex`, so we short-circuit verification.
_verify_rdata_value(::P, ::P) where {P<:_SAFloat} = nothing
__verify_fdata_value(::IdDict{Any,Nothing}, ::P, ::P) where {P<:_SAFloat} = nothing

# Delegate element-wise reductions to the existing tuple handlers, which in
# turn dispatch to per-element `_dot_internal` / `_scale_internal` (correct for
# both `IEEEFloat` and `Complex{<:IEEEFloat}` element types).
function _dot_internal(c::MaybeCache, t::T, s::T) where {T<:_SAFloat}
return _dot_internal(c, Tuple(t), Tuple(s))
end

function _scale_internal(c::MaybeCache, a::Float64, t::T) where {T<:_SAFloat}
return T(_scale_internal(c, a, Tuple(t)))
end

# Rules. `_new_` is already declared a primitive globally
# (`Tuple{typeof(_new_),Vararg}` in `src/rules/new.jl`), so we only need to
# add more-specific `frule!!` / `rrule!!` methods for SArray construction.
# Mooncake's IR normalisation rewrites `SArray(...)` constructor calls to
# `_new_(SArray{S,T,N,L}, data::NTuple{L,T})`.

function Mooncake.frule!!(
::Dual{typeof(_new_)}, ::Dual{Type{P}}, data::Dual{NTuple{L,T}}
) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}}
y = _new_(P, primal(data))
dy = _new_(P, tangent(data))
return Dual(y, dy)
end

function Mooncake.rrule!!(
::CoDual{typeof(_new_)}, ::CoDual{Type{P}}, data::CoDual{NTuple{L,T}}
) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}}
y = _new_(P, primal(data))
_new_SArray_pb(dy::P) = NoRData(), NoRData(), Tuple(dy)
return zero_fcodual(y), _new_SArray_pb
end

Mooncake.@is_primitive MinimalCtx Tuple{
typeof(getindex),SArray{S,T,N,L},Int
} where {S,T<:_SElt,N,L}

function Mooncake.frule!!(
::Dual{typeof(getindex)}, x::Dual{P}, i::Dual{Int}
) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}}
idx = primal(i)
return Dual(primal(x)[idx], tangent(x)[idx])
end

function Mooncake.rrule!!(
::CoDual{typeof(getindex)}, x::CoDual{P,NoFData}, i::CoDual{Int}
) where {S,T<:_SElt,N,L,P<:SArray{S,T,N,L}}
idx = primal(i)
y = primal(x)[idx]
function getindex_SArray_pb(dy::T)
dx = P(ntuple(j -> j == idx ? dy : zero(T), Val(L)))
return NoRData(), dx, NoRData()
end
return zero_fcodual(y), getindex_SArray_pb
end

end # module StaticArraysMooncakeExt
11 changes: 11 additions & 0 deletions test/ext/mooncake/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[deps]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Mooncake = ">=0.5.27"
StaticArrays = "1"
55 changes: 55 additions & 0 deletions test/ext/mooncake/mooncake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Run from the repo root with:
# julia --project=test/ext/mooncake test/ext/mooncake/mooncake.jl

using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", ".."))

using AllocCheck, JET, Mooncake, StableRNGs, StaticArrays, Test
using Mooncake.TestUtils: test_rule, test_tangent_interface, test_tangent_splitting

# The extension covers `SArray` only by design (it's the by-value, leaf-tangent
# case); `MArray` falls through to Mooncake's generic mutable-array handling.
@testset verbose=true "Mooncake integration" begin
cases = Any[
SVector{3,Float64}(1.0, 2.0, 3.0),
SVector{2,Float32}(1.0f0, -2.0f0),
SMatrix{2,2,Float64}(1.0, 2.0, 3.0, 4.0),
SVector{2,ComplexF64}(1.0 + 2.0im, -3.0 + 1.0im),
SVector{1,ComplexF32}(0.5f0 + 0.25f0im),
]

@testset "tangent interface for $(typeof(p))" for p in cases
rng = StableRNG(123)
test_tangent_interface(rng, p)
test_tangent_splitting(rng, p)
end

@testset "rrule!! getindex $(typeof(p))" for p in cases
for i in eachindex(p)
test_rule(StableRNG(123), getindex, p, i; is_primitive=true)
end
end

@testset "rrule!! _new_ construction" begin
# `_new_` is the primitive that IR normalisation lowers `SArray(...)`
# construction calls into; test the SArray-specific method directly.
new_cases = Any[
(SVector{3,Float64}, (1.0, 2.0, 3.0)),
(SMatrix{2,2,Float64,4}, (1.0, 2.0, 3.0, 4.0)),
(SVector{2,ComplexF64}, (1.0 + 2.0im, -3.0 + 1.0im)),
]
for (P, data) in new_cases
test_rule(StableRNG(123), Mooncake._new_, P, data; is_primitive=true)
end
end

@testset "end-to-end gradient" begin
f(x) = x[1]^2 + 2 * x[2] * x[3]
x = SVector{3,Float64}(1.5, -2.0, 0.5)
cache = Mooncake.prepare_gradient_cache(f, x)
val, (_, dx) = Mooncake.value_and_gradient!!(cache, f, x)
@test val ≈ f(x)
@test dx ≈ SVector{3,Float64}(2 * x[1], 2 * x[3], 2 * x[2])
end
end