diff --git a/src/WoodburyMatrices.jl b/src/WoodburyMatrices.jl index 3e665b8..b251888 100644 --- a/src/WoodburyMatrices.jl +++ b/src/WoodburyMatrices.jl @@ -88,7 +88,8 @@ function _ldiv!(dest, W::AbstractWoodbury, A::Union{Factorization,Diagonal}, B) mul!(W.tmpk1, W.V, W.tmpN1) mul!(W.tmpk2, W.Cp, W.tmpk1) mul!(W.tmpN2, W.U, W.tmpk2) - ldiv!(A, W.tmpN2) + W.tmpN3 .= W.tmpN2 + ldiv!(W.tmpN2, A, W.tmpN3) for i in eachindex(W.tmpN2) dest[i] = W.tmpN1[i] - W.tmpN2[i] end diff --git a/src/symwoodbury.jl b/src/symwoodbury.jl index 2d832f8..47cbeb6 100644 --- a/src/symwoodbury.jl +++ b/src/symwoodbury.jl @@ -5,11 +5,12 @@ struct SymWoodbury{T,AType,BType,DType,DpType} <: AbstractWoodbury{T} Dp::DpType tmpN1::Union{Vector{T}, Nothing} tmpN2::Union{Vector{T}, Nothing} + tmpN3::Union{Vector{T}, Nothing} tmpk1::Union{Vector{T}, Nothing} tmpk2::Union{Vector{T}, Nothing} - SymWoodbury{T}(A, B, D, Dp, tmpN1, tmpN2, tmpk1, tmpk2) where {T} = - new{T,typeof(A),typeof(B),typeof(D),typeof(Dp)}(A, B, D, Dp, tmpN1, tmpN2, tmpk1, tmpk2) + SymWoodbury{T}(A, B, D, Dp, tmpN1, tmpN2, tmpN3, tmpk1, tmpk2) where {T} = + new{T,typeof(A),typeof(B),typeof(D),typeof(Dp)}(A, B, D, Dp, tmpN1, tmpN2, tmpN3, tmpk1, tmpk2) end """ @@ -28,7 +29,11 @@ or factorization. See also [Woodbury](@ref), where `allocatetmp` and `use_pinv` are explained. """ -function SymWoodbury(A, B::AbstractVecOrMat, D; allocatetmp::Bool=false, use_pinv::Bool=false) +function SymWoodbury(A, B::AbstractVecOrMat, D; + allocatetmp::Bool=false, + use_pinv::Bool=false, + allocs=nothing, +) @noinline throwdmm(B, D, A) = throw(DimensionMismatch("Sizes of B ($(size(B))) and/or D ($(size(D))) are inconsistent with A ($(size(A)))")) n = size(A, 1) @@ -44,16 +49,9 @@ function SymWoodbury(A, B::AbstractVecOrMat, D; allocatetmp::Bool=false, use_pin Dp = use_pinv ? safepinv(Dpinv) : safeinv(Dpinv) # temporary space for allocation-free solver (vector RHS only) T = typeof(float(zero(eltype(A)) * zero(eltype(B)) * zero(eltype(D)))) - if allocatetmp - tmpN1 = Vector{T}(undef, n) - tmpN2 = Vector{T}(undef, n) - tmpk1 = Vector{T}(undef, k) - tmpk2 = Vector{T}(undef, k) - else - tmpN1 = tmpN2 = tmpk1 = tmpk2 = nothing - end + tmpN1, tmpN2, tmpN3, tmpk1, tmpk2 = _allocate_tmp(T, allocs, allocatetmp, n, k) - SymWoodbury{T}(A, B, D, Dp, tmpN1, tmpN2, tmpk1, tmpk2) + SymWoodbury{T}(A, B, D, Dp, tmpN1, tmpN2, tmpN3, tmpk1, tmpk2) end convert(::Type{W}, O::SymWoodbury) where {W<:Woodbury} = Woodbury(O.A, O.B, O.D, O.B') diff --git a/src/woodbury.jl b/src/woodbury.jl index ecd152a..605d57e 100644 --- a/src/woodbury.jl +++ b/src/woodbury.jl @@ -6,15 +6,16 @@ struct Woodbury{T,AType,UType,VType,CType,CpType} <: AbstractWoodbury{T} V::VType tmpN1::Union{Vector{T}, Nothing} tmpN2::Union{Vector{T}, Nothing} + tmpN3::Union{Vector{T}, Nothing} tmpk1::Union{Vector{T}, Nothing} tmpk2::Union{Vector{T}, Nothing} - Woodbury{T}(A, U, C, Cp, V, tmpN1, tmpN2, tmpk1, tmpk2) where {T} = - new{T,typeof(A),typeof(U),typeof(V),typeof(C),typeof(Cp)}(A, U, C, Cp, V, tmpN1, tmpN2, tmpk1, tmpk2) + Woodbury{T}(A, U, C, Cp, V, tmpN1, tmpN2, tmpN3, tmpk1, tmpk2) where {T} = + new{T,typeof(A),typeof(U),typeof(V),typeof(C),typeof(Cp)}(A, U, C, Cp, V, tmpN1, tmpN2, tmpN3, tmpk1, tmpk2) end """ - W = Woodbury(A, U, C, V; allocatetmp::Bool=false, use_pinv::Bool=false) + W = Woodbury(A, U, C, V; allocatetmp::Bool=false, use_pinv::Bool=false, allocs=nothing) Represent a matrix `W = A + UCV`. Equations `Wx = b` will be solved using the @@ -27,7 +28,9 @@ If `W` is rank-deficient or nearly so, setting `use_pinv` to true will use the pseudoinverse in the Woodbury formula to improve numerical stability. If `allocatetmp` is true, temporary storage used for intermediate steps in -multiplication and division will be allocated. +multiplication and division will be allocated. `allocs` can also be supplied +as a keyword. These must be some iterator with three vectors of length `N`, +and two of length `k` !!! warning If you'll use the same `W` in multiple threads, you should use `allocatetmp=false` @@ -36,7 +39,11 @@ multiplication and division will be allocated. See also [SymWoodbury](@ref). """ -function Woodbury(A, U::AbstractMatrix, C, V::AbstractMatrix; allocatetmp::Bool=false, use_pinv::Bool=false) +function Woodbury(A, U::AbstractMatrix, C, V::AbstractMatrix; + allocatetmp::Bool=false, + use_pinv::Bool=false, + allocs=nothing, +) @noinline throwdmm1(U, V, A) = throw(DimensionMismatch("Sizes of U ($(size(U))) and/or V ($(size(V))) are inconsistent with A ($(size(A)))")) @noinline throwdmm2(k) = throw(DimensionMismatch("C should be $(k)x$(k)")) @@ -54,16 +61,9 @@ function Woodbury(A, U::AbstractMatrix, C, V::AbstractMatrix; allocatetmp::Bool= Cp = use_pinv ? safepinv(Cpinv) : safeinv(Cpinv) # temporary space for allocation-free solver (vector RHS only) T = typeof(float(zero(eltype(A)) * zero(eltype(U)) * zero(eltype(C)) * zero(eltype(V)))) - if allocatetmp - tmpN1 = Vector{T}(undef, N) - tmpN2 = Vector{T}(undef, N) - tmpk1 = Vector{T}(undef, k) - tmpk2 = Vector{T}(undef, k) - else - tmpN1 = tmpN2 = tmpk1 = tmpk2 = nothing - end + tmpN1, tmpN2, tmpN3, tmpk1, tmpk2 = _allocate_tmp(T, allocs, allocatetmp, N, k) - Woodbury{T}(A, U, C, Cp, V, tmpN1, tmpN2, tmpk1, tmpk2) + Woodbury{T}(A, U, C, Cp, V, tmpN1, tmpN2, tmpN3, tmpk1, tmpk2) end Woodbury(A, U::AbstractVector{T}, C, V::AbstractMatrix{T}) where {T} = Woodbury(A, reshape(U, length(U), 1), C, V) @@ -130,3 +130,21 @@ function issymmetric(W::Woodbury) issymmetric(W.A) && issymmetric(W.C) && W.U == W.V' && return true return issymmetric(Matrix(W)) end + +@inline function _allocate_tmp(::Type{T}, allocs, allocatetmp, N, k) where T + if !isnothing(allocs) + # Check there are five allocs and they match N and k + length(allocs) == 5 || throw(ArgumentError("Must have 5 allocs, got $(length(allocs))")) + length(allocs[1]) == length(allocs[2]) == length(allocs[3]) == N || throw(ArgumentError("First three allocs must have length $N")) + length(allocs[4]) == length(allocs[5]) == k || throw(ArgumentError("Last two allocs must have length $k")) + foreach(allocs) do a + typeof(a) <: Vector{T} || throw(ArgumentError("All allocs must have type Vector{$T}, got $(typeof(a))")) + end + allocs + elseif allocatetmp + V = Vector{T} + V(undef, N), V(undef, N), V(undef, N), V(undef, k), V(undef, k) + else + nothing, nothing, nothing, nothing, nothing + end +end \ No newline at end of file diff --git a/test/symwoodbury.jl b/test/symwoodbury.jl index 4483aeb..70c7857 100644 --- a/test/symwoodbury.jl +++ b/test/symwoodbury.jl @@ -26,8 +26,18 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64, Int) ε = eps(abs2(float(one(elty)))) A = Diagonal(a) - for W in (SymWoodbury(A, B, D), SymWoodbury(A, B, D; allocatetmp=true), SymWoodbury(A, B[:,1][:], 2.), SymWoodbury(A, B, D; use_pinv=true)) - + n = size(A, 1) + k = size(B, 1) + tmp_elty = typeof(float(zero(eltype(A)) * zero(eltype(B)) * zero(eltype(D)))) + allocs = [(Vector{tmp_elty}(undef, n) for i in 1:3)..., (Vector{tmp_elty}(undef, k) for i in 1:2)...] + + for W in ( + SymWoodbury(A, B, D), + SymWoodbury(A, B, D; allocatetmp=true), + SymWoodbury(A, B, D; allocs), + SymWoodbury(A, B[:,1][:], 2.), + SymWoodbury(A, B, D; use_pinv=true), + ) @test issymmetric(W) F = Matrix(W) @test (2*W)*v ≈ 2*(W*v) diff --git a/test/woodbury.jl b/test/woodbury.jl index 82a6aa0..2a2b082 100644 --- a/test/woodbury.jl +++ b/test/woodbury.jl @@ -38,8 +38,16 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64, Int) ε = eps(abs2(float(one(elty)))) T = Tridiagonal(dl, d, du) + n = size(T, 1) + k = size(U, 2) + tmp_elty = typeof(float(zero(eltype(T)) * zero(eltype(U)) * zero(eltype(C)) * zero(eltype(V)))) + allocs = [(Vector{tmp_elty}(undef, n) for i in 1:3)..., (Vector{tmp_elty}(undef, k) for i in 1:2)...] # Matrix for A - for W in (Woodbury(T, U, C, V), Woodbury(T, U, C, V; allocatetmp=true)) + for W in ( + Woodbury(T, U, C, V), + Woodbury(T, U, C, V; allocatetmp=true), + Woodbury(T, U, C, V; allocs), + ) @test size(W, 1) == n @test size(W) == (n, n) @test axes(W) === (Base.OneTo(n), Base.OneTo(n))