Skip to content

Commit 0a53fae

Browse files
Harsh SinghHarsh Singh
authored andcommitted
Add IMEX RK solvers implementation (tableaus, perform_step, caches)
1 parent bebc563 commit 0a53fae

File tree

8 files changed

+501
-1274
lines changed

8 files changed

+501
-1274
lines changed

lib/OrdinaryDiffEqCore/src/algorithms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ abstract type OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT, ST, CJ} <:
4545
OrdinaryDiffEqAlgorithm end
4646
abstract type OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
4747
OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT, ST, CJ} end
48+
abstract type OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ} <:
49+
OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} end
4850
abstract type OrdinaryDiffEqRosenbrockAlgorithm{CS, AD, FDT, ST, CJ} <:
4951
OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT, ST, CJ} end
5052
const NewtonAlgorithm = Union{

lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
77
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
88
OrdinaryDiffEqNewtonAdaptiveAlgorithm,
99
OrdinaryDiffEqNewtonAlgorithm,
10+
OrdinaryDiffEqNewtonESDIRKAlgorithm,
1011
DEFAULT_PRECS,
1112
OrdinaryDiffEqAdaptiveAlgorithm, CompiledFloats, uses_uprev,
1213
alg_cache, _vec, _reshape, @cache, isfsal, full_cache,
@@ -37,15 +38,16 @@ include("kencarp_kvaerno_caches.jl")
3738
include("sdirk_perform_step.jl")
3839
include("kencarp_kvaerno_perform_step.jl")
3940
include("sdirk_tableaus.jl")
41+
include("imex_tableaus.jl")
42+
include("generic_imex_perform_step.jl")
4043

4144
export ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22,
4245
Kvaerno3, KenCarp3, Cash4, Hairer4, Hairer42, SSPSDIRK2, Kvaerno4,
4346
Kvaerno5, KenCarp4, KenCarp47, KenCarp5, KenCarp58, ESDIRK54I8L2SA, SFSDIRK4,
4447
SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, Kvaerno5, KenCarp4, KenCarp5,
4548
SFSDIRK4, SFSDIRK5, CFNLIRK3, SFSDIRK6,
4649
SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA,
47-
IMEXSSP222, IMEXSSP2322, IMEXSSP3332, IMEXSSP3433,
48-
ARS222, ARS232, ARS443, BHR553
50+
IMEXSSP222, IMEXSSP2322, IMEXSSP3332, IMEXSSP3433
4951

5052
import PrecompileTools
5153
import Preferences

lib/OrdinaryDiffEqSDIRK/src/algorithms.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,7 +1798,7 @@ const _ARS_BHR_REFERENCE = "@article{ascher1997implicit,
17981798
"""
17991799
)
18001800
struct ARS222{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <:
1801-
OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ}
1801+
OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ}
18021802
linsolve::F
18031803
nlsolve::F2
18041804
precs::P
@@ -1835,7 +1835,7 @@ end
18351835
"""
18361836
)
18371837
struct ARS232{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <:
1838-
OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ}
1838+
OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ}
18391839
linsolve::F
18401840
nlsolve::F2
18411841
precs::P
@@ -1872,7 +1872,7 @@ end
18721872
"""
18731873
)
18741874
struct ARS443{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <:
1875-
OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ}
1875+
OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ}
18761876
linsolve::F
18771877
nlsolve::F2
18781878
precs::P
@@ -1917,7 +1917,7 @@ end
19171917
"""
19181918
)
19191919
struct BHR553{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <:
1920-
OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ}
1920+
OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ}
19211921
linsolve::F
19221922
nlsolve::F2
19231923
precs::P
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
mutable struct ESDIRKIMEXConstantCache{Tab, N} <: SDIRKConstantCache
2+
nlsolver::N
3+
tab::Tab
4+
end
5+
6+
mutable struct ESDIRKIMEXCache{uType, rateType, N, Tab, kType, StepLimiter} <:
7+
SDIRKMutableCache
8+
u::uType
9+
uprev::uType
10+
fsalfirst::rateType
11+
zs::Vector{uType}
12+
ks::Vector{kType}
13+
nlsolver::N
14+
tab::Tab
15+
step_limiter!::StepLimiter
16+
end
17+
18+
function full_cache(c::ESDIRKIMEXCache)
19+
base = (c.u, c.uprev, c.fsalfirst, c.zs...)
20+
if eltype(c.ks) !== Nothing
21+
return tuple(base..., c.ks...)
22+
end
23+
return base
24+
end
25+
26+
const ESDIRKIMEXAlgorithm = Union{ARS222, ARS232, ARS443, BHR553}
27+
28+
function alg_cache(
29+
alg::ESDIRKIMEXAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
30+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits},
31+
uprev, uprev2, f, t, dt, reltol, p, calck,
32+
::Val{false}, verbose
33+
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
34+
tab = ESDIRKIMEXTableau(alg, constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
35+
γ = tab.Ai[2, 2]
36+
c = tab.c[2]
37+
nlsolver = build_nlsolver(
38+
alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
39+
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose
40+
)
41+
return ESDIRKIMEXConstantCache(nlsolver, tab)
42+
end
43+
44+
function alg_cache(
45+
alg::ESDIRKIMEXAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
46+
::Type{uBottomEltypeNoUnits},
47+
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
48+
::Val{true}, verbose
49+
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
50+
tab = ESDIRKIMEXTableau(alg, constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
51+
γ = tab.Ai[2, 2]
52+
c = tab.c[2]
53+
nlsolver = build_nlsolver(
54+
alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
55+
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose
56+
)
57+
fsalfirst = zero(rate_prototype)
58+
59+
s = tab.s
60+
if f isa SplitFunction
61+
ks = [zero(u) for _ in 1:s]
62+
else
63+
ks = Vector{Nothing}(nothing, s)
64+
end
65+
66+
zs = [zero(u) for _ in 1:(s - 1)]
67+
push!(zs, nlsolver.z)
68+
69+
return ESDIRKIMEXCache(
70+
u, uprev, fsalfirst, zs, ks, nlsolver, tab, alg.step_limiter!
71+
)
72+
end
73+
74+
function initialize!(integrator, cache::ESDIRKIMEXConstantCache)
75+
integrator.kshortsize = 2
76+
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
77+
integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t)
78+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
79+
integrator.fsallast = zero(integrator.fsalfirst)
80+
integrator.k[1] = integrator.fsalfirst
81+
integrator.k[2] = integrator.fsallast
82+
return nothing
83+
end
84+
85+
function initialize!(integrator, cache::ESDIRKIMEXCache)
86+
integrator.kshortsize = 2
87+
resize!(integrator.k, integrator.kshortsize)
88+
integrator.k[1] = integrator.fsalfirst
89+
integrator.k[2] = integrator.fsallast
90+
integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t)
91+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
92+
return nothing
93+
end
94+
95+
@muladd function perform_step!(
96+
integrator, cache::ESDIRKIMEXConstantCache, repeat_step = false
97+
)
98+
(; t, dt, uprev, u, p) = integrator
99+
nlsolver = cache.nlsolver
100+
tab = cache.tab
101+
(; Ai, bi, Ae, be, c, s) = tab
102+
γ = Ai[2, 2]
103+
104+
f2 = nothing
105+
k = Vector{typeof(u)}(undef, s)
106+
if integrator.f isa SplitFunction
107+
f_impl = integrator.f.f1
108+
f2 = integrator.f.f2
109+
else
110+
f_impl = integrator.f
111+
end
112+
113+
markfirststage!(nlsolver)
114+
115+
z = Vector{typeof(u)}(undef, s)
116+
117+
# Stage 1: explicit (ESDIRK: a₁₁ = 0)
118+
if integrator.f isa SplitFunction
119+
z[1] = dt * f_impl(uprev, p, t)
120+
else
121+
z[1] = dt * integrator.fsalfirst
122+
end
123+
124+
if integrator.f isa SplitFunction
125+
k[1] = dt * integrator.fsalfirst - z[1]
126+
end
127+
128+
# Stages 2..s
129+
for i in 2:s
130+
tmp = uprev
131+
for j in 1:(i - 1)
132+
tmp = tmp + Ai[i, j] * z[j]
133+
end
134+
135+
if integrator.f isa SplitFunction
136+
for j in 1:(i - 1)
137+
tmp = tmp + Ae[i, j] * k[j]
138+
end
139+
end
140+
141+
if integrator.f isa SplitFunction
142+
z_guess = z[1]
143+
else
144+
z_guess = zero(u)
145+
end
146+
147+
nlsolver.z = z_guess
148+
nlsolver.tmp = tmp
149+
nlsolver.c = c[i]
150+
nlsolver.γ = γ
151+
z[i] = nlsolve!(nlsolver, integrator, cache, repeat_step)
152+
nlsolvefail(nlsolver) && return
153+
154+
if integrator.f isa SplitFunction && i < s
155+
u_stage = tmp + γ * z[i]
156+
k[i] = dt * f2(u_stage, p, t + c[i] * dt)
157+
integrator.stats.nf2 += 1
158+
end
159+
end
160+
161+
# Compute solution
162+
u = nlsolver.tmp + γ * z[s]
163+
if integrator.f isa SplitFunction
164+
k[s] = dt * f2(u, p, t + dt)
165+
integrator.stats.nf2 += 1
166+
u = uprev
167+
for i in 1:s
168+
u = u + bi[i] * z[i] + be[i] * k[i]
169+
end
170+
end
171+
172+
if integrator.f isa SplitFunction
173+
integrator.k[1] = integrator.fsalfirst
174+
integrator.fsallast = integrator.f(u, p, t + dt)
175+
integrator.k[2] = integrator.fsallast
176+
else
177+
integrator.fsallast = z[s] ./ dt
178+
integrator.k[1] = integrator.fsalfirst
179+
integrator.k[2] = integrator.fsallast
180+
end
181+
integrator.u = u
182+
end
183+
184+
@muladd function perform_step!(integrator, cache::ESDIRKIMEXCache, repeat_step = false)
185+
(; t, dt, uprev, u, p) = integrator
186+
(; zs, ks, nlsolver, step_limiter!) = cache
187+
(; tmp) = nlsolver
188+
tab = cache.tab
189+
(; Ai, bi, Ae, be, c, s) = tab
190+
γ = Ai[2, 2]
191+
192+
f2 = nothing
193+
if integrator.f isa SplitFunction
194+
f_impl = integrator.f.f1
195+
f2 = integrator.f.f2
196+
else
197+
f_impl = integrator.f
198+
end
199+
200+
markfirststage!(nlsolver)
201+
202+
# Stage 1: explicit (ESDIRK: a₁₁ = 0)
203+
if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail
204+
f_impl(zs[1], integrator.uprev, p, integrator.t)
205+
zs[1] .*= dt
206+
else
207+
@.. broadcast=false zs[1] = dt * integrator.fsalfirst
208+
end
209+
210+
if integrator.f isa SplitFunction
211+
@.. broadcast=false ks[1] = dt * integrator.fsalfirst - zs[1]
212+
end
213+
214+
# Stages 2..s
215+
for i in 2:s
216+
@.. broadcast=false tmp = uprev
217+
for j in 1:(i - 1)
218+
@.. broadcast=false tmp += Ai[i, j] * zs[j]
219+
end
220+
221+
if integrator.f isa SplitFunction
222+
for j in 1:(i - 1)
223+
@.. broadcast=false tmp += Ae[i, j] * ks[j]
224+
end
225+
end
226+
227+
if integrator.f isa SplitFunction
228+
copyto!(zs[i], zs[1])
229+
else
230+
fill!(zs[i], zero(eltype(u)))
231+
end
232+
233+
nlsolver.z = zs[i]
234+
nlsolver.c = c[i]
235+
nlsolver.γ = γ
236+
zs[i] = nlsolve!(nlsolver, integrator, cache, repeat_step)
237+
nlsolvefail(nlsolver) && return
238+
if i > 2
239+
isnewton(nlsolver) && set_new_W!(nlsolver, false)
240+
end
241+
242+
if integrator.f isa SplitFunction && i < s
243+
@.. broadcast=false u = tmp + γ * zs[i]
244+
f2(ks[i], u, p, t + c[i] * dt)
245+
ks[i] .*= dt
246+
integrator.stats.nf2 += 1
247+
end
248+
end
249+
250+
# Compute solution
251+
@.. broadcast=false u = tmp + γ * zs[s]
252+
if integrator.f isa SplitFunction
253+
f2(ks[s], u, p, t + dt)
254+
ks[s] .*= dt
255+
integrator.stats.nf2 += 1
256+
@.. broadcast=false u = uprev
257+
for i in 1:s
258+
@.. broadcast=false u += bi[i] * zs[i] + be[i] * ks[i]
259+
end
260+
end
261+
262+
step_limiter!(u, integrator, p, t + dt)
263+
264+
if integrator.f isa SplitFunction
265+
integrator.f(integrator.fsallast, u, p, t + dt)
266+
else
267+
@.. broadcast=false integrator.fsallast = zs[s] / dt
268+
end
269+
end

0 commit comments

Comments
 (0)