-
Notifications
You must be signed in to change notification settings - Fork 39
Expand file tree
/
Copy pathFillArraysSparseArraysExt.jl
More file actions
78 lines (61 loc) · 3.44 KB
/
FillArraysSparseArraysExt.jl
File metadata and controls
78 lines (61 loc) · 3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
module FillArraysSparseArraysExt
using SparseArrays
using SparseArrays: SparseVectorUnion
import Base: convert, kron
using FillArrays
using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix, getindex_value, AbstractFillVector, _fill_dot, OneElementVector, OneElementMatrix
# Specifying the full namespace is necessary because of https://github.com/JuliaLang/julia/issues/48533
# See https://github.com/JuliaStats/LogExpFunctions.jl/pull/63
using FillArrays.LinearAlgebra
import LinearAlgebra: dot, kron, I
##################
## Sparse arrays
##################
SparseVector{T}(Z::ZerosVector) where T = spzeros(T, length(Z))
SparseVector{Tv,Ti}(Z::ZerosVector) where {Tv,Ti} = spzeros(Tv, Ti, length(Z))
convert(::Type{AbstractSparseVector}, Z::ZerosVector{T}) where T = spzeros(T, length(Z))
convert(::Type{AbstractSparseVector{T}}, Z::ZerosVector) where T= spzeros(T, length(Z))
SparseMatrixCSC{T}(Z::ZerosMatrix) where T = spzeros(T, size(Z)...)
SparseMatrixCSC{Tv,Ti}(Z::Zeros{T,2,Axes}) where {Tv,Ti<:Integer,T,Axes} = spzeros(Tv, Ti, size(Z)...)
convert(::Type{AbstractSparseMatrix}, Z::ZerosMatrix{T}) where T = spzeros(T, size(Z)...)
convert(::Type{AbstractSparseMatrix{T}}, Z::ZerosMatrix) where T = spzeros(T, size(Z)...)
convert(::Type{AbstractSparseArray}, Z::Zeros{T}) where T = spzeros(T, size(Z)...)
convert(::Type{AbstractSparseArray{Tv}}, Z::Zeros{T}) where {T,Tv} = spzeros(Tv, size(Z)...)
convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Zeros{T}) where {T,Tv,Ti} = spzeros(Tv, Ti, size(Z)...)
convert(::Type{AbstractSparseArray{Tv,Ti,N}}, Z::Zeros{T,N}) where {T,Tv,Ti,N} = spzeros(Tv, Ti, size(Z)...)
SparseMatrixCSC{Tv}(Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)
# works around missing `speye`:
SparseMatrixCSC{Tv,Ti}(Z::Eye{T}) where {T,Tv,Ti<:Integer} =
convert(SparseMatrixCSC{Tv,Ti}, SparseMatrixCSC{Tv}(I, size(Z)...))
convert(::Type{AbstractSparseMatrix}, Z::Eye{T}) where {T} = SparseMatrixCSC{T}(I, size(Z)...)
convert(::Type{AbstractSparseMatrix{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)
convert(::Type{AbstractSparseArray}, Z::Eye{T}) where T = SparseMatrixCSC{T}(I, size(Z)...)
convert(::Type{AbstractSparseArray{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)
convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti} =
convert(SparseMatrixCSC{Tv,Ti}, Z)
convert(::Type{AbstractSparseArray{Tv,Ti,2}}, Z::Eye{T}) where {T,Tv,Ti} =
convert(SparseMatrixCSC{Tv,Ti}, Z)
function SparseMatrixCSC{Tv}(R::RectOrDiagonalFill) where {Tv}
SparseMatrixCSC{Tv,eltype(axes(R,1))}(R)
end
function SparseMatrixCSC{Tv,Ti}(R::RectOrDiagonalFill) where {Tv,Ti}
Base.require_one_based_indexing(R)
v = parent(R)
J = getindex_value(v)*I
SparseMatrixCSC{Tv,Ti}(J, size(R))
end
# TODO: remove in v2.0
@deprecate kron(E1::RectDiagonalFill, E2::RectDiagonalFill) kron(sparse(E1), sparse(E2))
# Ambiguity. see #178
dot(x::AbstractFillVector, y::SparseVectorUnion) = _fill_dot(x, y)
# OneElements with different indices should return
# SparseArrays under addition
function FillArrays.oneelement_addsub(a::OneElementVector, b::OneElementVector, aval, bval)
return sparsevec([a.ind[1], b.ind[1]], [aval, bval], length(a))
end
function FillArrays.oneelement_addsub(a::OneElementMatrix, b::OneElementMatrix, aval, bval)
nzval = [aval, bval]
nzind = ntuple(i -> [a.ind[i], b.ind[i]], ndims(a))
return sparse(nzind..., nzval, size(a)...)
end
end # module