Skip to content

Commit f76a5b6

Browse files
Harsh SinghHarsh Singh
authored andcommitted
WIP: Initial IMEX tableau abstraction with generic perform_step! (ARS343 prototype)
1 parent 2147e11 commit f76a5b6

5 files changed

Lines changed: 556 additions & 1 deletion

File tree

lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@ include("kencarp_kvaerno_caches.jl")
3737
include("sdirk_perform_step.jl")
3838
include("kencarp_kvaerno_perform_step.jl")
3939
include("sdirk_tableaus.jl")
40+
include("imex_tableaus.jl")
41+
include("generic_imex_perform_step.jl")
4042

4143
export ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22,
4244
Kvaerno3, KenCarp3, Cash4, Hairer4, Hairer42, SSPSDIRK2, Kvaerno4,
4345
Kvaerno5, KenCarp4, KenCarp47, KenCarp5, KenCarp58, ESDIRK54I8L2SA, SFSDIRK4,
4446
SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, Kvaerno5, KenCarp4, KenCarp5,
4547
SFSDIRK4, SFSDIRK5, CFNLIRK3, SFSDIRK6,
46-
SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA
48+
SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA,
49+
ARS343
4750

4851
import PrecompileTools
4952
import Preferences

lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,6 @@ issplit(alg::KenCarp47) = true
5757
issplit(alg::KenCarp5) = true
5858
issplit(alg::KenCarp58) = true
5959
issplit(alg::CFNLIRK3) = true
60+
issplit(alg::ARS343) = true
61+
alg_order(alg::ARS343) = 3
62+
isesdirk(alg::ARS343) = true

lib/OrdinaryDiffEqSDIRK/src/algorithms.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,3 +1593,34 @@ function ESDIRK659L2SA(;
15931593
controller, AD_choice
15941594
)
15951595
end
1596+
1597+
struct ARS343{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <:
1598+
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
1599+
linsolve::F
1600+
nlsolve::F2
1601+
precs::P
1602+
smooth_est::Bool
1603+
extrapolant::Symbol
1604+
controller::Symbol
1605+
step_limiter!::StepLimiter
1606+
autodiff::AD
1607+
end
1608+
function ARS343(;
1609+
chunk_size = Val{0}(), autodiff = AutoForwardDiff(),
1610+
standardtag = Val{true}(), concrete_jac = nothing,
1611+
diff_type = Val{:forward}(),
1612+
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
1613+
smooth_est = true, extrapolant = :linear,
1614+
controller = :PI, step_limiter! = trivial_limiter!
1615+
)
1616+
AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type)
1617+
1618+
return ARS343{
1619+
_unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve),
1620+
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
1621+
_unwrap_val(concrete_jac), typeof(step_limiter!),
1622+
}(
1623+
linsolve, nlsolve, precs,
1624+
smooth_est, extrapolant, controller, step_limiter!, AD_choice
1625+
)
1626+
end
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
mutable struct IMEXConstantCache{Tab, N} <: OrdinaryDiffEqConstantCache
2+
nlsolver::N
3+
tab::Tab
4+
end
5+
6+
mutable struct IMEXCache{uType, rateType, uNoUnitsType, N, Tab, kType, StepLimiter} <:
7+
SDIRKMutableCache
8+
u::uType
9+
uprev::uType
10+
fsalfirst::rateType
11+
zs::Vector{uType}
12+
ks::Vector{kType}
13+
atmp::uNoUnitsType
14+
nlsolver::N
15+
tab::Tab
16+
step_limiter!::StepLimiter
17+
end
18+
19+
function full_cache(c::IMEXCache)
20+
base = (c.u, c.uprev, c.fsalfirst, c.zs..., c.atmp)
21+
if eltype(c.ks) !== Nothing
22+
return tuple(base..., c.ks...)
23+
end
24+
return base
25+
end
26+
27+
function alg_cache(
28+
alg::ARS343, u, rate_prototype, ::Type{uEltypeNoUnits},
29+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits},
30+
uprev, uprev2, f, t, dt, reltol, p, calck,
31+
::Val{false}, verbose
32+
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
33+
tab = ARS343Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
34+
γ = tab.Ai[2, 2]
35+
c = tab.c[2]
36+
nlsolver = build_nlsolver(
37+
alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
38+
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose
39+
)
40+
return IMEXConstantCache(nlsolver, tab)
41+
end
42+
43+
function alg_cache(
44+
alg::ARS343, u, rate_prototype, ::Type{uEltypeNoUnits},
45+
::Type{uBottomEltypeNoUnits},
46+
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
47+
::Val{true}, verbose
48+
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
49+
tab = ARS343Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
50+
γ = tab.Ai[2, 2]
51+
c = tab.c[2]
52+
nlsolver = build_nlsolver(
53+
alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
54+
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose
55+
)
56+
fsalfirst = zero(rate_prototype)
57+
58+
s = tab.s
59+
if f isa SplitFunction
60+
ks = [zero(u) for _ in 1:s]
61+
else
62+
ks = Vector{Nothing}(nothing, s)
63+
end
64+
65+
zs = [zero(u) for _ in 1:(s - 1)]
66+
push!(zs, nlsolver.z)
67+
atmp = similar(u, uEltypeNoUnits)
68+
recursivefill!(atmp, false)
69+
70+
return IMEXCache(
71+
u, uprev, fsalfirst, zs, ks, atmp, nlsolver, tab, alg.step_limiter!
72+
)
73+
end
74+
75+
function initialize!(integrator, cache::IMEXConstantCache)
76+
integrator.kshortsize = 2
77+
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
78+
integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t)
79+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
80+
integrator.fsallast = zero(integrator.fsalfirst)
81+
integrator.k[1] = integrator.fsalfirst
82+
integrator.k[2] = integrator.fsallast
83+
end
84+
85+
function initialize!(integrator, cache::IMEXCache)
86+
integrator.kshortsize = 2
87+
resize!(integrator.k, integrator.kshortsize)
88+
integrator.k[1] = integrator.fsalfirst
89+
integrator.k[2] = integrator.fsallast
90+
integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t)
91+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
92+
end
93+
94+
@muladd function perform_step!(
95+
integrator, cache::IMEXConstantCache, repeat_step = false
96+
)
97+
(; t, dt, uprev, u, p) = integrator
98+
nlsolver = cache.nlsolver
99+
tab = cache.tab
100+
(; Ai, bi, Ae, be, c, btilde, ebtilde, α, s) = tab
101+
alg = unwrap_alg(integrator, true)
102+
γ = Ai[2, 2]
103+
104+
f2 = nothing
105+
k = Vector{typeof(u)}(undef, s)
106+
if integrator.f isa SplitFunction
107+
f_impl = integrator.f.f1
108+
f2 = integrator.f.f2
109+
else
110+
f_impl = integrator.f
111+
end
112+
113+
markfirststage!(nlsolver)
114+
115+
z = Vector{typeof(u)}(undef, s)
116+
117+
if integrator.f isa SplitFunction
118+
z[1] = dt * f_impl(uprev, p, t)
119+
else
120+
z[1] = dt * integrator.fsalfirst
121+
end
122+
123+
if integrator.f isa SplitFunction
124+
k[1] = dt * integrator.fsalfirst - z[1]
125+
end
126+
127+
for i in 2:s
128+
tmp = uprev
129+
for j in 1:(i - 1)
130+
tmp = tmp + Ai[i, j] * z[j]
131+
end
132+
133+
if integrator.f isa SplitFunction
134+
for j in 1:(i - 1)
135+
tmp = tmp + Ae[i, j] * k[j]
136+
end
137+
end
138+
139+
if integrator.f isa SplitFunction
140+
z_guess = z[1]
141+
elseif α !== nothing && !iszero(α[i, 1])
142+
z_guess = zero(u)
143+
for j in 1:(i - 1)
144+
z_guess = z_guess + α[i, j] * z[j]
145+
end
146+
else
147+
z_guess = zero(u)
148+
end
149+
150+
nlsolver.z = z_guess
151+
nlsolver.tmp = tmp
152+
nlsolver.c = c[i]
153+
nlsolver.γ = γ
154+
z[i] = nlsolve!(nlsolver, integrator, cache, repeat_step)
155+
nlsolvefail(nlsolver) && return
156+
157+
if integrator.f isa SplitFunction && i < s
158+
u_stage = tmp + γ * z[i]
159+
k[i] = dt * f2(u_stage, p, t + c[i] * dt)
160+
integrator.stats.nf2 += 1
161+
end
162+
end
163+
164+
u = nlsolver.tmp + γ * z[s]
165+
if integrator.f isa SplitFunction
166+
k[s] = dt * f2(u, p, t + dt)
167+
integrator.stats.nf2 += 1
168+
u = uprev
169+
for i in 1:s
170+
u = u + bi[i] * z[i] + be[i] * k[i]
171+
end
172+
end
173+
174+
if integrator.opts.adaptive
175+
tmp = zero(u)
176+
for i in 1:s
177+
tmp = tmp + btilde[i] * z[i]
178+
end
179+
if integrator.f isa SplitFunction && ebtilde !== nothing
180+
for i in 1:s
181+
tmp = tmp + ebtilde[i] * k[i]
182+
end
183+
end
184+
if isnewton(nlsolver) && alg.smooth_est
185+
integrator.stats.nsolve += 1
186+
est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp))
187+
else
188+
est = tmp
189+
end
190+
atmp = calculate_residuals(
191+
est, uprev, u, integrator.opts.abstol,
192+
integrator.opts.reltol, integrator.opts.internalnorm, t
193+
)
194+
integrator.EEst = integrator.opts.internalnorm(atmp, t)
195+
end
196+
197+
if integrator.f isa SplitFunction
198+
integrator.k[1] = integrator.fsalfirst
199+
integrator.fsallast = integrator.f(u, p, t + dt)
200+
integrator.k[2] = integrator.fsallast
201+
else
202+
integrator.fsallast = z[s] ./ dt
203+
integrator.k[1] = integrator.fsalfirst
204+
integrator.k[2] = integrator.fsallast
205+
end
206+
integrator.u = u
207+
end
208+
209+
@muladd function perform_step!(integrator, cache::IMEXCache, repeat_step = false)
210+
(; t, dt, uprev, u, p) = integrator
211+
(; zs, ks, atmp, nlsolver, step_limiter!) = cache
212+
(; tmp) = nlsolver
213+
tab = cache.tab
214+
(; Ai, bi, Ae, be, c, btilde, ebtilde, α, s) = tab
215+
alg = unwrap_alg(integrator, true)
216+
γ = Ai[2, 2]
217+
218+
f2 = nothing
219+
if integrator.f isa SplitFunction
220+
f_impl = integrator.f.f1
221+
f2 = integrator.f.f2
222+
else
223+
f_impl = integrator.f
224+
end
225+
226+
markfirststage!(nlsolver)
227+
228+
if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail
229+
f_impl(zs[1], integrator.uprev, p, integrator.t)
230+
zs[1] .*= dt
231+
else
232+
@.. broadcast = false zs[1] = dt * integrator.fsalfirst
233+
end
234+
235+
if integrator.f isa SplitFunction
236+
@.. broadcast = false ks[1] = dt * integrator.fsalfirst - zs[1]
237+
end
238+
239+
for i in 2:s
240+
@.. broadcast = false tmp = uprev
241+
for j in 1:(i - 1)
242+
@.. broadcast = false tmp += Ai[i, j] * zs[j]
243+
end
244+
245+
if integrator.f isa SplitFunction
246+
for j in 1:(i - 1)
247+
@.. broadcast = false tmp += Ae[i, j] * ks[j]
248+
end
249+
end
250+
251+
if integrator.f isa SplitFunction
252+
copyto!(zs[i], zs[1])
253+
elseif α !== nothing && !iszero(α[i, 1])
254+
fill!(zs[i], zero(eltype(u)))
255+
for j in 1:(i - 1)
256+
@.. broadcast = false zs[i] += α[i, j] * zs[j]
257+
end
258+
else
259+
fill!(zs[i], zero(eltype(u)))
260+
end
261+
262+
nlsolver.z = zs[i]
263+
nlsolver.c = c[i]
264+
nlsolver.γ = γ
265+
zs[i] = nlsolve!(nlsolver, integrator, cache, repeat_step)
266+
nlsolvefail(nlsolver) && return
267+
if i > 2
268+
isnewton(nlsolver) && set_new_W!(nlsolver, false)
269+
end
270+
271+
if integrator.f isa SplitFunction && i < s
272+
@.. broadcast = false u = tmp + γ * zs[i]
273+
f2(ks[i], u, p, t + c[i] * dt)
274+
ks[i] .*= dt
275+
integrator.stats.nf2 += 1
276+
end
277+
end
278+
279+
@.. broadcast = false u = tmp + γ * zs[s]
280+
if integrator.f isa SplitFunction
281+
f2(ks[s], u, p, t + dt)
282+
ks[s] .*= dt
283+
integrator.stats.nf2 += 1
284+
@.. broadcast = false u = uprev
285+
for i in 1:s
286+
@.. broadcast = false u += bi[i] * zs[i] + be[i] * ks[i]
287+
end
288+
end
289+
290+
step_limiter!(u, integrator, p, t + dt)
291+
292+
if integrator.opts.adaptive
293+
@.. broadcast = false tmp = zero(eltype(u))
294+
for i in 1:s
295+
@.. broadcast = false tmp += btilde[i] * zs[i]
296+
end
297+
if integrator.f isa SplitFunction && ebtilde !== nothing
298+
for i in 1:s
299+
@.. broadcast = false tmp += ebtilde[i] * ks[i]
300+
end
301+
end
302+
if isnewton(nlsolver) && alg.smooth_est
303+
est = nlsolver.cache.dz
304+
linres = dolinsolve(
305+
integrator, nlsolver.cache.linsolve; b = _vec(tmp),
306+
linu = _vec(est)
307+
)
308+
integrator.stats.nsolve += 1
309+
else
310+
est = tmp
311+
end
312+
calculate_residuals!(
313+
atmp, est, uprev, u, integrator.opts.abstol,
314+
integrator.opts.reltol, integrator.opts.internalnorm, t
315+
)
316+
integrator.EEst = integrator.opts.internalnorm(atmp, t)
317+
end
318+
319+
if integrator.f isa SplitFunction
320+
integrator.f(integrator.fsallast, u, p, t + dt)
321+
else
322+
@.. broadcast = false integrator.fsallast = zs[s] / dt
323+
end
324+
end

0 commit comments

Comments
 (0)