diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 696f1fcb8..c2c873970 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -1,5 +1,7 @@ using Base: RefValue using Base: ismutabletype +using LinearAlgebra: Hermitian, Symmetric +using SparseArrays: SparseMatrixCSC # Interfaces @@ -15,6 +17,11 @@ accum(x, y, zs...) = accum(accum(x, y), zs...) accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...) accum(x::AbstractArray, ys::AbstractArray...) = Base.broadcast_preserving_zero_d(accum, x, ys...) + +const HermOrSymSparse{T, I} = Union{Hermitian{T, SparseMatrixCSC{T, I}}, Symmetric{T, SparseMatrixCSC{T, I}}} + +accum(x::HermOrSymSparse, y::HermOrSymSparse) = x + y + accum(::Tuple{}, ::NamedTuple{}) = () accum(::NamedTuple{}, ::Tuple{}) = () diff --git a/test/lib/lib.jl b/test/lib/lib.jl index 11e64cba9..66f00b2b4 100644 --- a/test/lib/lib.jl +++ b/test/lib/lib.jl @@ -5,5 +5,17 @@ @test Zygote.accum(t1, t2) == (a = 2, b = 4, c = 3) @test_throws ArgumentError Zygote.accum(t2, t1) @test Zygote.accum(fill(0.0), fill(0.0)) == fill(0.0) + + # HermOrSymSparse accumulation + S = sparse([1, 2, 2], [1, 1, 2], [1.0, 2.0, 3.0], 2, 2) + H1 = Hermitian(S + S') + H2 = Hermitian(2S + 2S') + @test Zygote.accum(H1, H2) == H1 + H2 + @test Zygote.accum(H1, H2) isa Hermitian{Float64, <:SparseMatrixCSC} + + Sym1 = Symmetric(S + S') + Sym2 = Symmetric(2S + 2S') + @test Zygote.accum(Sym1, Sym2) == Sym1 + Sym2 + @test Zygote.accum(Sym1, Sym2) isa Symmetric{Float64, <:SparseMatrixCSC} end end diff --git a/test/lib_tests.jl b/test/lib_tests.jl index 8c788b3d0..0af554b64 100644 --- a/test/lib_tests.jl +++ b/test/lib_tests.jl @@ -3,6 +3,7 @@ using ChainRulesTestUtils using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular, Symmetric using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular +using SparseArrays: sparse, SparseMatrixCSC using Zygote: ZygoteRuleConfig, _pullback, _reverse include("lib/number.jl")