forked from SciML/OrdinaryDiffEq.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDiffEqBaseDynamicQuantitiesExt.jl
More file actions
69 lines (59 loc) · 1.85 KB
/
DiffEqBaseDynamicQuantitiesExt.jl
File metadata and controls
69 lines (59 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
module DiffEqBaseDynamicQuantitiesExt
using DiffEqBase
using DynamicQuantities
using LinearAlgebra
import DiffEqBase: default_factorize
@inline DiffEqBase.ODE_DEFAULT_NORM(u::UnionAbstractQuantity, t) = abs(ustrip(u))
@inline function DiffEqBase.UNITLESS_ABS2(x::UnionAbstractQuantity)
return real(abs2(ustrip(x)))
end
DiffEqBase.stripunits(x::UnionAbstractQuantity) = ustrip(x)
DiffEqBase._rate_prototype(u, t::UnionAbstractQuantity, onet) = u / oneunit(t)
DiffEqBase.timedepentdtmin(t::UnionAbstractQuantity, dtmin) =
abs(ustrip(dtmin / oneunit(t)) * oneunit(t))
# Rosenbrock/SDIRK solvers form W/J matrices with Quantity eltype. Factorize/solve in
# value-space (Float64), but return solutions with the RHS units.
struct DQUnitlessLU{F, UT}
F::F
ut::UT
end
@inline function _infer_ut(A::AbstractMatrix{<:UnionAbstractQuantity})
@inbounds for a in A
va = ustrip(a)
if !iszero(va)
return oneunit(inv(a))
end
end
return oneunit(1.0)
end
function default_factorize(A::AbstractMatrix{<:UnionAbstractQuantity})
isempty(A) && return DQUnitlessLU(
lu(Matrix{Float64}(undef, 0, 0); check = false),
oneunit(1.0),
)
ut = _infer_ut(A)
return DQUnitlessLU(lu(ustrip.(A); check = false), ut)
end
function LinearAlgebra.ldiv!(
x::AbstractVector{<:UnionAbstractQuantity},
W::DQUnitlessLU,
b::AbstractVector{<:UnionAbstractQuantity},
)
vb = ustrip.(b)
vx = similar(vb)
LinearAlgebra.ldiv!(vx, W.F, vb)
@inbounds for i in eachindex(x)
x[i] = vx[i] * (oneunit(b[i]) * W.ut)
end
return x
end
function Base.:(\)(W::DQUnitlessLU, b::AbstractVector{<:UnionAbstractQuantity})
vb = ustrip.(b)
vx = W.F \ vb
out = similar(b)
@inbounds for i in eachindex(out)
out[i] = vx[i] * (oneunit(b[i]) * W.ut)
end
return out
end
end