Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,36 @@ jobs:
- uses: julia-actions/julia-uploadcodecov@latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
test-downstream:
name: Julia ${{ matrix.version }} downstream - ${{ matrix.package }} [${{ matrix.env }}]
runs-on: ubuntu-latest
strategy:
matrix:
version:
- '1.10'
- '1'
package:
- 'Mooncake'
env:
- 'old'
- 'new'
env:
STATICARRAYS_DOWNSTREAM_TEST_PACKAGE: ${{ matrix.package }}
STATICARRAYS_DOWNSTREAM_TEST_ENV: ${{ matrix.env }}
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v3
with:
version: ${{ matrix.version }}
- uses: julia-actions/cache@v3
- uses: julia-actions/julia-buildpkg@v1
- run: julia -O1 --code-coverage=user --color=yes test/downstream/harness.jl
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v6
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
docs:
name: Documentation
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
/benchmark/*.json
/benchmark/Manifest.toml
/docs/Manifest.toml
/test/downstream/**/Manifest.toml
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.18"
version = "1.9.19"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -11,10 +11,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
StaticArraysChainRulesCoreExt = "ChainRulesCore"
StaticArraysMooncakeExt = "Mooncake"
StaticArraysStatisticsExt = "Statistics"

[compat]
Expand All @@ -25,6 +27,7 @@ ChainRulesTestUtils = "1"
InteractiveUtils = "1"
JLArrays = "0.1"
LinearAlgebra = "1.6"
Mooncake = "0.5"
OffsetArrays = "1"
PrecompileTools = "1"
Random = "1.6"
Expand All @@ -41,6 +44,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
13 changes: 13 additions & 0 deletions ext/StaticArraysMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module StaticArraysMooncakeExt

using StaticArrays: SArray,MArray
using Mooncake: Mooncake

@static if isdefined(Mooncake, :FriendlyTangentCache) # checks Mooncake >= v0.5.25
# see https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/998
function Mooncake.friendly_tangent_cache(x::Union{SArray,MArray})
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work with SArray where elements are Symmetric SArrays?

What exactly is Mooncake doing to the returned value to get a mutable type? copy on SArray returns an immutable again.

Copy link
Copy Markdown
Author

@gdalle gdalle Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work with SArray where elements are Symmetric SArrays?

Probably not, one might need some kind of deep copy. Or we could restrict it to Union{SArray{<:IEEEFloat}, MArray{<:IEEEFloat}} to be safe, albeit incomplete.

What exactly is Mooncake doing to the returned value to get a mutable type? copy on SArray returns an immutable again.

Contrary to what the name "cache" suggests, I don't think the output of friendly_tangent_cache has to be mutable. From what I understand reading Claude's comments in chalk-lab/Mooncake.jl#1103, it seems that this function is meant to output a kind of template for reconstructing the gradient type we want.


To clarify, I had nothing to do with the breaking changes in Mooncake v0.5.25, I actually disapprove of them and I don't fully understand them (especially since they were LLM-generated). I'm just trying to keep the bare minimum working (Mooncake's gradient returning a SArray in simple cases)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know enough about Mooncake to have an opinion yet. I just think a few more complex examples regarding how this is all supposed to work for nested types would be really helpful.

I did take a look at the Julia AD ecosystem a couple of years ago so I know a bit about how it works, and I still use it sometimes, though I don't see good areas to contribute to there sadly. Too many conflicting goals and points of view on how things should work.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @yebai is the expert on this new Mooncake API. Perhaps it would be better as a Mooncake extension, since it is a very modest amount of code and StaticArrays is a very old and stable package?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StaticArrays.jl already has a ChainRulesCore extension so adding Mooncake wouldn't be unprecedented. My main worry is that this friendly tangent thing doesn't seem particularly stable and well-tested, even in comparison with ChainRules. So it's unclear many iterations on the idea are still needed. I'd suggest figuring out more complex examples outside of StaticArrays.jl first, and make an extension when it's clear that the API works well across multiple nested array types from different packages.

return Mooncake.FriendlyTangentCache{Mooncake.AsPrimal}(copy(x))
end
end

end
10 changes: 10 additions & 0 deletions test/downstream/Mooncake/new/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
StaticArrays = {path = "../../../.."}

[compat]
Mooncake = ">=0.5.25"
10 changes: 10 additions & 0 deletions test/downstream/Mooncake/old/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
StaticArrays = {path = "../../../.."}

[compat]
Mooncake = "<0.5.25"
20 changes: 20 additions & 0 deletions test/downstream/Mooncake/test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using StaticArrays
using Mooncake
using Test

@testset verbose=true "Mooncake integration" begin
f(x) = sum(abs2, x)
config = Mooncake.Config(; friendly_tangents=true)
@testset "$(typeof(x))" for x in [
SVector(1.0, 2.0),
MVector(1.0, 2.0) ,
SMatrix{2,2}(1.0, 2.0, 3.0, 4.0),
MMatrix{2,2}(1.0, 2.0, 3.0, 4.0)
]
cache = prepare_gradient_cache(f, zero(x); config)
val, grads = value_and_gradient!!(cache, f, x)
g = grads[2]
@test g isa typeof(x)
@test g == 2x
end
end
6 changes: 6 additions & 0 deletions test/downstream/harness.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using Pkg
package = ENV["STATICARRAYS_DOWNSTREAM_TEST_PACKAGE"];
env = ENV["STATICARRAYS_DOWNSTREAM_TEST_ENV"];
Pkg.activate(joinpath(@__DIR__, package, env))
Pkg.develop(PackageSpec(path=joinpath(@__DIR__, "..", "..")))
include(joinpath(package, "test.jl"))
Loading