Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ jobs:
version: 'nightly'
os: windows-latest
group: 'group-B'
- arch: x86
version: '1'
os: ubuntu-latest
group: 'group-M'
exclude:
# Remove some configurations from the build matrix to reduce CI time.
# See https://github.com/marketplace/actions/setup-julia-environment
Expand Down
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.18"
version = "1.9.19"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -11,10 +11,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
StaticArraysChainRulesCoreExt = "ChainRulesCore"
StaticArraysMooncakeExt = "Mooncake"
StaticArraysStatisticsExt = "Statistics"

[compat]
Expand All @@ -25,6 +27,7 @@ ChainRulesTestUtils = "1"
InteractiveUtils = "1"
JLArrays = "0.1"
LinearAlgebra = "1.6"
Mooncake = "0.5"
OffsetArrays = "1"
PrecompileTools = "1"
Random = "1.6"
Expand All @@ -42,9 +45,10 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "JLArrays"]
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "JLArrays", "Pkg"]
12 changes: 12 additions & 0 deletions ext/StaticArraysMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module StaticArraysMooncakeExt

using StaticArrays: StaticArray
using Mooncake: Mooncake

@static if isdefined(Mooncake, :FriendlyTangentCache) # checks Mooncake >= v0.5.25
function Mooncake.friendly_tangent_cache(x::StaticArray)
return Mooncake.FriendlyTangentCache{Mooncake.AsPrimal}(copy(x))
end
Comment thread
gdalle marked this conversation as resolved.
Outdated
end

end
23 changes: 23 additions & 0 deletions test/mooncake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Pkg
Pkg.add("Mooncake")

using StaticArrays
using Mooncake
using Test

@testset verbose=true "Mooncake integration" begin
f(x) = sum(abs2, x)
config = Mooncake.Config(; friendly_tangents=true)
@testset "$(typeof(x))" for x in [
SVector(1.0, 2.0),
MVector(1.0, 2.0) ,
SMatrix{2,2}(1.0, 2.0, 3.0, 4.0),
MMatrix{2,2}(1.0, 2.0, 3.0, 4.0)
]
cache = prepare_gradient_cache(f, zero(x); config)
val, grads = value_and_gradient!!(cache, f, x)
g = grads[2]
@test g isa typeof(x)
@test g == 2x
end
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,8 @@ if TEST_GROUP ∈ ["", "all", "group-B"]
addtests("chainrules.jl")
end
end

if TEST_GROUP ∈ ["", "all", "group-M"] && VERSION >= v"1.10-"
# warning: changes the test environment by adding Mooncake
addtests("mooncake.jl")
end
Loading