From d8cc2731388cc9b23cf9fc9f03194d8934ca428d Mon Sep 17 00:00:00 2001 From: Harsh Singh Date: Sat, 4 Apr 2026 22:56:58 +0530 Subject: [PATCH 1/3] Add initial IMEX Runge-Kutta solvers (SSP and BHR schemes) --- .../src/OrdinaryDiffEqSDIRK.jl | 3 +- lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl | 9 + lib/OrdinaryDiffEqSDIRK/src/algorithms.jl | 180 +++++ .../src/kencarp_kvaerno_caches.jl | 258 +++++++ .../src/kencarp_kvaerno_perform_step.jl | 639 ++++++++++++++++++ lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl | 181 +++++ 6 files changed, 1269 insertions(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl index 37eb64be2c2..f4226a5ea5d 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl @@ -43,7 +43,8 @@ export ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22, Kvaerno5, KenCarp4, KenCarp47, KenCarp5, KenCarp58, ESDIRK54I8L2SA, SFSDIRK4, SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, Kvaerno5, KenCarp4, KenCarp5, SFSDIRK4, SFSDIRK5, CFNLIRK3, SFSDIRK6, - SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA + SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA, + IMEXSSP222, IMEXSSP2322, IMEXSSP3332, IMEXSSP3433 import PrecompileTools import Preferences diff --git a/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl b/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl index c6903a76dee..f917c45add8 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl @@ -57,3 +57,12 @@ issplit(alg::KenCarp47) = true issplit(alg::KenCarp5) = true issplit(alg::KenCarp58) = true issplit(alg::CFNLIRK3) = true +issplit(alg::IMEXSSP222) = true +issplit(alg::IMEXSSP2322) = true +issplit(alg::IMEXSSP3332) = true +issplit(alg::IMEXSSP3433) = true + +alg_order(alg::IMEXSSP222) = 2 +alg_order(alg::IMEXSSP2322) = 2 +alg_order(alg::IMEXSSP3332) = 2 +alg_order(alg::IMEXSSP3433) = 3 diff --git a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl index 56d80a25fb2..b6179675ca9 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl @@ -1593,3 +1593,183 @@ function ESDIRK659L2SA(; controller, AD_choice ) end + +@doc SDIRK_docstring( + "2-stage 2nd-order L-stable SSP IMEX-SDIRK method for split (implicit+explicit) ODEs. From Pareschi & Russo (2005), Table 2.", + "IMEXSSP222"; + references = "@article{pareschi2005implicit, + title={Implicit-explicit Runge-Kutta schemes and applications to hyperbolic systems with relaxation}, + author={Pareschi, Lorenzo and Russo, Giovanni}, + journal={Journal of Scientific Computing}, + volume={25}, + number={1}, + pages={129--155}, + year={2005}, + publisher={Springer}}", + extra_keyword_description = """ + - `extrapolant`: TBD + - `step_limiter!`: function of the form `limiter!(u, integrator, p, t)` + """, + extra_keyword_default = """ + extrapolant = :linear, + step_limiter! = trivial_limiter!, + """ +) +struct IMEXSSP222{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + linsolve::F + nlsolve::F2 + precs::P + extrapolant::Symbol + step_limiter!::StepLimiter + autodiff::AD +end +function IMEXSSP222(; + chunk_size = Val{0}(), autodiff = AutoForwardDiff(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}(), + linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), + extrapolant = :linear, step_limiter! = trivial_limiter! + ) + AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) + return IMEXSSP222{ + _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), + _unwrap_val(concrete_jac), typeof(step_limiter!), + }(linsolve, nlsolve, precs, extrapolant, step_limiter!, AD_choice) +end + +@doc SDIRK_docstring( + "3-stage 2nd-order stiffly-accurate SSP IMEX-SDIRK method for split ODEs. From Pareschi & Russo (2005), Table 3.", + "IMEXSSP2322"; + references = "@article{pareschi2005implicit, + title={Implicit-explicit Runge-Kutta schemes and applications to hyperbolic systems with relaxation}, + author={Pareschi, Lorenzo and Russo, Giovanni}, + journal={Journal of Scientific Computing}, + volume={25}, + number={1}, + pages={129--155}, + year={2005}, + publisher={Springer}}", + extra_keyword_description = """ + - `extrapolant`: TBD + - `step_limiter!`: function of the form `limiter!(u, integrator, p, t)` + """, + extra_keyword_default = """ + extrapolant = :linear, + step_limiter! = trivial_limiter!, + """ +) +struct IMEXSSP2322{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + linsolve::F + nlsolve::F2 + precs::P + extrapolant::Symbol + step_limiter!::StepLimiter + autodiff::AD +end +function IMEXSSP2322(; + chunk_size = Val{0}(), autodiff = AutoForwardDiff(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}(), + linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), + extrapolant = :linear, step_limiter! = trivial_limiter! + ) + AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) + return IMEXSSP2322{ + _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), + _unwrap_val(concrete_jac), typeof(step_limiter!), + }(linsolve, nlsolve, precs, extrapolant, step_limiter!, AD_choice) +end + +@doc SDIRK_docstring( + "3-stage 2nd-order L-stable SSP IMEX-SDIRK method (3rd order SSP explicit part) for split ODEs. From Pareschi & Russo (2005), Table 6.", + "IMEXSSP3332"; + references = "@article{pareschi2005implicit, + title={Implicit-explicit Runge-Kutta schemes and applications to hyperbolic systems with relaxation}, + author={Pareschi, Lorenzo and Russo, Giovanni}, + journal={Journal of Scientific Computing}, + volume={25}, + number={1}, + pages={129--155}, + year={2005}, + publisher={Springer}}", + extra_keyword_description = """ + - `extrapolant`: TBD + - `step_limiter!`: function of the form `limiter!(u, integrator, p, t)` + """, + extra_keyword_default = """ + extrapolant = :linear, + step_limiter! = trivial_limiter!, + """ +) +struct IMEXSSP3332{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + linsolve::F + nlsolve::F2 + precs::P + extrapolant::Symbol + step_limiter!::StepLimiter + autodiff::AD +end +function IMEXSSP3332(; + chunk_size = Val{0}(), autodiff = AutoForwardDiff(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}(), + linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), + extrapolant = :linear, step_limiter! = trivial_limiter! + ) + AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) + return IMEXSSP3332{ + _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), + _unwrap_val(concrete_jac), typeof(step_limiter!), + }(linsolve, nlsolve, precs, extrapolant, step_limiter!, AD_choice) +end + +@doc SDIRK_docstring( + "4-stage 3rd-order L-stable SSP IMEX-SDIRK method for split ODEs. From Pareschi & Russo (2005), Table 7.", + "IMEXSSP3433"; + references = "@article{pareschi2005implicit, + title={Implicit-explicit Runge-Kutta schemes and applications to hyperbolic systems with relaxation}, + author={Pareschi, Lorenzo and Russo, Giovanni}, + journal={Journal of Scientific Computing}, + volume={25}, + number={1}, + pages={129--155}, + year={2005}, + publisher={Springer}}", + extra_keyword_description = """ + - `extrapolant`: TBD + - `step_limiter!`: function of the form `limiter!(u, integrator, p, t)` + """, + extra_keyword_default = """ + extrapolant = :linear, + step_limiter! = trivial_limiter!, + """ +) +struct IMEXSSP3433{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + linsolve::F + nlsolve::F2 + precs::P + extrapolant::Symbol + step_limiter!::StepLimiter + autodiff::AD +end +function IMEXSSP3433(; + chunk_size = Val{0}(), autodiff = AutoForwardDiff(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}(), + linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), + extrapolant = :linear, step_limiter! = trivial_limiter! + ) + AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) + return IMEXSSP3433{ + _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), + _unwrap_val(concrete_jac), typeof(step_limiter!), + }(linsolve, nlsolve, precs, extrapolant, step_limiter!, AD_choice) +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl index 5ac57d8ac10..5d9059cc597 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl @@ -728,3 +728,261 @@ function alg_cache( k1, k2, k3, k4, k5, k6, k7, k8, atmp, nlsolver, tab ) end + +# ---- IMEXSSP222 ---- + +@cache mutable struct IMEXSSP222ConstantCache{N, Tab} <: SDIRKConstantCache + nlsolver::N + tab::Tab +end + +function alg_cache( + alg::IMEXSSP222, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = IMEXSSP222Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c2 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return IMEXSSP222ConstantCache(nlsolver, tab) +end + +@cache mutable struct IMEXSSP222Cache{uType, rateType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + z₁::uType + z₂::uType + k1::kType + k2::kType + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function alg_cache( + alg::IMEXSSP222, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = IMEXSSP222Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c2 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + z₁ = zero(u) + z₂ = nlsolver.z + if f isa SplitFunction + k1 = zero(u) + k2 = zero(u) + else + k1 = nothing + k2 = nothing + end + return IMEXSSP222Cache(u, uprev, fsalfirst, z₁, z₂, k1, k2, nlsolver, tab, alg.step_limiter!) +end + +# ---- IMEXSSP2322 ---- + +@cache mutable struct IMEXSSP2322ConstantCache{N, Tab} <: SDIRKConstantCache + nlsolver::N + tab::Tab +end + +function alg_cache( + alg::IMEXSSP2322, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = IMEXSSP2322Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c3 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return IMEXSSP2322ConstantCache(nlsolver, tab) +end + +@cache mutable struct IMEXSSP2322Cache{uType, rateType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + z₁::uType + z₂::uType + z₃::uType + k2::kType + k3::kType + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function alg_cache( + alg::IMEXSSP2322, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = IMEXSSP2322Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c3 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + z₁ = zero(u) + z₂ = zero(u) + z₃ = nlsolver.z + if f isa SplitFunction + k2 = zero(u) + k3 = zero(u) + else + k2 = nothing + k3 = nothing + end + return IMEXSSP2322Cache(u, uprev, fsalfirst, z₁, z₂, z₃, k2, k3, nlsolver, tab, alg.step_limiter!) +end + +# ---- IMEXSSP3332 ---- + +@cache mutable struct IMEXSSP3332ConstantCache{N, Tab} <: SDIRKConstantCache + nlsolver::N + tab::Tab +end + +function alg_cache( + alg::IMEXSSP3332, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = IMEXSSP3332Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c3 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return IMEXSSP3332ConstantCache(nlsolver, tab) +end + +@cache mutable struct IMEXSSP3332Cache{uType, rateType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + z₁::uType + z₂::uType + z₃::uType + k1::kType + k2::kType + k3::kType + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function alg_cache( + alg::IMEXSSP3332, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = IMEXSSP3332Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c3 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + z₁ = zero(u) + z₂ = zero(u) + z₃ = nlsolver.z + if f isa SplitFunction + k1 = zero(u) + k2 = zero(u) + k3 = zero(u) + else + k1 = nothing + k2 = nothing + k3 = nothing + end + return IMEXSSP3332Cache(u, uprev, fsalfirst, z₁, z₂, z₃, k1, k2, k3, nlsolver, tab, alg.step_limiter!) +end + +# ---- IMEXSSP3433 ---- + +@cache mutable struct IMEXSSP3433ConstantCache{N, Tab} <: SDIRKConstantCache + nlsolver::N + tab::Tab +end + +function alg_cache( + alg::IMEXSSP3433, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = IMEXSSP3433Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c4 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return IMEXSSP3433ConstantCache(nlsolver, tab) +end + +@cache mutable struct IMEXSSP3433Cache{uType, rateType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + z₁::uType + z₂::uType + z₃::uType + z₄::uType + k2::kType + k3::kType + k4::kType + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function alg_cache( + alg::IMEXSSP3433, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = IMEXSSP3433Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c4 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + z₁ = zero(u) + z₂ = zero(u) + z₃ = zero(u) + z₄ = nlsolver.z + if f isa SplitFunction + k2 = zero(u) + k3 = zero(u) + k4 = zero(u) + else + k2 = nothing + k3 = nothing + k4 = nothing + end + return IMEXSSP3433Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k2, k3, k4, nlsolver, tab, alg.step_limiter!) +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl index c4a8110d3e7..1332287abb0 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl @@ -2734,3 +2734,642 @@ end @.. broadcast = false integrator.fsallast = z₈ / dt end end + +# =========================================================================== +# IMEX-SSP methods (Pareschi & Russo 2005) +# =========================================================================== + +# --------------------------------------------------------------------------- +# IMEXSSP222 — 2-stage, 2nd order, L-stable (Table 2) +# --------------------------------------------------------------------------- + +@muladd function perform_step!( + integrator, cache::IMEXSSP222ConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + (; γ, a21, c2, ea21, eb1, eb2) = cache.tab + + f2 = nothing + k1 = nothing + k2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + ##### Stage 1 + nlsolver.z = zero(u) + nlsolver.tmp = uprev + nlsolver.c = γ + z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k1 = dt * f2(nlsolver.tmp + γ * z₁, p, t) + integrator.stats.nf2 += 1 + end + + ##### Stage 2 + nlsolver.z = z₁ + if integrator.f isa SplitFunction + nlsolver.tmp = uprev + a21 * z₁ + ea21 * k1 + else + nlsolver.tmp = uprev + a21 * z₁ + end + nlsolver.c = c2 + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k2 = dt * f2(nlsolver.tmp + γ * z₂, p, t + dt) + integrator.stats.nf2 += 1 + u = uprev + eb1 * (z₁ + k1) + eb2 * (z₂ + k2) + else + u = uprev + eb1 * z₁ + eb2 * z₂ + end + + ################################### Finalize + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z₂ / dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!( + integrator, cache::IMEXSSP222Cache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + (; z₁, z₂, k1, k2, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + (; γ, a21, c2, ea21, eb1, eb2) = cache.tab + + f2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + ##### Stage 1 + z₁ .= zero(eltype(u)) + nlsolver.z = z₁ + @.. broadcast=false tmp = uprev + nlsolver.c = γ + z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + isnewton(nlsolver) && set_new_W!(nlsolver, false) + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₁ + f2(k1, u, p, t) + k1 .*= dt + integrator.stats.nf2 += 1 + end + + ##### Stage 2 + copyto!(z₂, z₁) + nlsolver.z = z₂ + if integrator.f isa SplitFunction + @.. broadcast=false tmp = uprev + a21 * z₁ + ea21 * k1 + else + @.. broadcast=false tmp = uprev + a21 * z₁ + end + nlsolver.c = c2 + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₂ + f2(k2, u, p, t + dt) + k2 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast=false u = uprev + eb1 * (z₁ + k1) + eb2 * (z₂ + k2) + else + @.. broadcast=false u = uprev + eb1 * z₁ + eb2 * z₂ + end + + step_limiter!(u, integrator, p, t + dt) + + ################################### Finalize + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast=false integrator.fsallast = z₂ / dt + end +end + +# --------------------------------------------------------------------------- +# IMEXSSP2322 — 3-stage, 2nd order, stiffly accurate (Table 3) +# --------------------------------------------------------------------------- + +@muladd function perform_step!( + integrator, cache::IMEXSSP2322ConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + (; γ, a21, c2, c3, ea32, eb2, eb3) = cache.tab + + f2 = nothing + k2 = nothing + k3 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + ##### Stage 1 (k₁ never needed: b̃₁=0 and ã_{i1}=0 for all i) + nlsolver.z = zero(u) + nlsolver.tmp = uprev + nlsolver.c = γ + z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + ##### Stage 2 + nlsolver.z = z₁ + nlsolver.tmp = uprev + a21 * z₁ + nlsolver.c = c2 + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k2 = dt * f2(nlsolver.tmp + γ * z₂, p, t) + integrator.stats.nf2 += 1 + end + + ##### Stage 3 + nlsolver.z = z₂ + if integrator.f isa SplitFunction + nlsolver.tmp = uprev + γ * z₂ + ea32 * k2 + else + nlsolver.tmp = uprev + γ * z₂ + end + nlsolver.c = c3 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k3 = dt * f2(nlsolver.tmp + γ * z₃, p, t + dt) + integrator.stats.nf2 += 1 + u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + else + u = nlsolver.tmp + γ * z₃ + end + + ################################### Finalize + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z₃ / dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!( + integrator, cache::IMEXSSP2322Cache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + (; z₁, z₂, z₃, k2, k3, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + (; γ, a21, c2, c3, ea32, eb2, eb3) = cache.tab + + f2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + ##### Stage 1 (k₁ never needed) + z₁ .= zero(eltype(u)) + nlsolver.z = z₁ + @.. broadcast=false tmp = uprev + nlsolver.c = γ + z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + isnewton(nlsolver) && set_new_W!(nlsolver, false) + + ##### Stage 2 + copyto!(z₂, z₁) + nlsolver.z = z₂ + @.. broadcast=false tmp = uprev + a21 * z₁ + nlsolver.c = c2 + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₂ + f2(k2, u, p, t) + k2 .*= dt + integrator.stats.nf2 += 1 + end + + ##### Stage 3 + copyto!(z₃, z₂) + nlsolver.z = z₃ + if integrator.f isa SplitFunction + @.. broadcast=false tmp = uprev + γ * z₂ + ea32 * k2 + else + @.. broadcast=false tmp = uprev + γ * z₂ + end + nlsolver.c = c3 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₃ + f2(k3, u, p, t + dt) + k3 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast=false u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + else + @.. broadcast=false u = tmp + γ * z₃ + end + + step_limiter!(u, integrator, p, t + dt) + + ################################### Finalize + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast=false integrator.fsallast = z₃ / dt + end +end + +# --------------------------------------------------------------------------- +# IMEXSSP3332 — 3-stage, 2nd order IMEX, L-stable (Table 6) +# --------------------------------------------------------------------------- + +@muladd function perform_step!( + integrator, cache::IMEXSSP3332ConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + (; γ, a21, a31, c2, c3, ea21, ea31, ea32, eb1, eb2, eb3) = cache.tab + + f2 = nothing + k1 = nothing + k2 = nothing + k3 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + ##### Stage 1 + nlsolver.z = zero(u) + nlsolver.tmp = uprev + nlsolver.c = γ + z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k1 = dt * f2(nlsolver.tmp + γ * z₁, p, t) + integrator.stats.nf2 += 1 + end + + ##### Stage 2 + nlsolver.z = z₁ + if integrator.f isa SplitFunction + nlsolver.tmp = uprev + a21 * z₁ + ea21 * k1 + else + nlsolver.tmp = uprev + a21 * z₁ + end + nlsolver.c = c2 + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k2 = dt * f2(nlsolver.tmp + γ * z₂, p, t + dt) + integrator.stats.nf2 += 1 + end + + ##### Stage 3 + nlsolver.z = z₂ + if integrator.f isa SplitFunction + nlsolver.tmp = uprev + a31 * z₁ + ea31 * k1 + ea32 * k2 + else + nlsolver.tmp = uprev + a31 * z₁ + end + nlsolver.c = c3 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k3 = dt * f2(nlsolver.tmp + γ * z₃, p, t + c3 * dt) + integrator.stats.nf2 += 1 + u = uprev + eb1 * (z₁ + k1) + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + else + u = uprev + eb1 * z₁ + eb2 * z₂ + eb3 * z₃ + end + + ################################### Finalize + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z₃ / dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!( + integrator, cache::IMEXSSP3332Cache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + (; z₁, z₂, z₃, k1, k2, k3, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + (; γ, a21, a31, c2, c3, ea21, ea31, ea32, eb1, eb2, eb3) = cache.tab + + f2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + ##### Stage 1 + z₁ .= zero(eltype(u)) + nlsolver.z = z₁ + @.. broadcast=false tmp = uprev + nlsolver.c = γ + z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + isnewton(nlsolver) && set_new_W!(nlsolver, false) + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₁ + f2(k1, u, p, t) + k1 .*= dt + integrator.stats.nf2 += 1 + end + + ##### Stage 2 + copyto!(z₂, z₁) + nlsolver.z = z₂ + if integrator.f isa SplitFunction + @.. broadcast=false tmp = uprev + a21 * z₁ + ea21 * k1 + else + @.. broadcast=false tmp = uprev + a21 * z₁ + end + nlsolver.c = c2 + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₂ + f2(k2, u, p, t + dt) + k2 .*= dt + integrator.stats.nf2 += 1 + end + + ##### Stage 3 + copyto!(z₃, z₂) + nlsolver.z = z₃ + if integrator.f isa SplitFunction + @.. broadcast=false tmp = uprev + a31 * z₁ + ea31 * k1 + ea32 * k2 + else + @.. broadcast=false tmp = uprev + a31 * z₁ + end + nlsolver.c = c3 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₃ + f2(k3, u, p, t + c3 * dt) + k3 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast=false u = uprev + eb1 * (z₁ + k1) + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + else + @.. broadcast=false u = uprev + eb1 * z₁ + eb2 * z₂ + eb3 * z₃ + end + + step_limiter!(u, integrator, p, t + dt) + + ################################### Finalize + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast=false integrator.fsallast = z₃ / dt + end +end + +# --------------------------------------------------------------------------- +# IMEXSSP3433 — 4-stage, 3rd order, L-stable (Table 7) +# --------------------------------------------------------------------------- + +@muladd function perform_step!( + integrator, cache::IMEXSSP3433ConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + (; γ, a21, a32, a41, a42, a43, c3, c4, ea32, ea42, ea43, eb2, eb3, eb4) = cache.tab + + f2 = nothing + k2 = nothing + k3 = nothing + k4 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + ##### Stage 1 (k₁ never needed: b̃₁=0 and ã_{i1}=0 for all i) + nlsolver.z = zero(u) + nlsolver.tmp = uprev + nlsolver.c = γ + z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + ##### Stage 2 + nlsolver.z = z₁ + nlsolver.tmp = uprev + a21 * z₁ + nlsolver.c = zero(γ) + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k2 = dt * f2(nlsolver.tmp + γ * z₂, p, t) + integrator.stats.nf2 += 1 + end + + ##### Stage 3 + nlsolver.z = z₂ + if integrator.f isa SplitFunction + nlsolver.tmp = uprev + a32 * z₂ + ea32 * k2 + else + nlsolver.tmp = uprev + a32 * z₂ + end + nlsolver.c = c3 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k3 = dt * f2(nlsolver.tmp + γ * z₃, p, t + dt) + integrator.stats.nf2 += 1 + end + + ##### Stage 4 + nlsolver.z = z₃ + if integrator.f isa SplitFunction + nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + ea42 * k2 + ea43 * k3 + else + nlsolver.tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + end + nlsolver.c = c4 + z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + k4 = dt * f2(nlsolver.tmp + γ * z₄, p, t + c4 * dt) + integrator.stats.nf2 += 1 + u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + eb4 * (z₄ + k4) + else + u = uprev + eb2 * z₂ + eb3 * z₃ + eb4 * z₄ + end + + ################################### Finalize + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z₄ / dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!( + integrator, cache::IMEXSSP3433Cache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + (; z₁, z₂, z₃, z₄, k2, k3, k4, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + (; γ, a21, a32, a41, a42, a43, c3, c4, ea32, ea42, ea43, eb2, eb3, eb4) = cache.tab + + f2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + ##### Stage 1 (k₁ never needed) + z₁ .= zero(eltype(u)) + nlsolver.z = z₁ + @.. broadcast=false tmp = uprev + nlsolver.c = γ + z₁ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + isnewton(nlsolver) && set_new_W!(nlsolver, false) + + ##### Stage 2 + copyto!(z₂, z₁) + nlsolver.z = z₂ + @.. broadcast=false tmp = uprev + a21 * z₁ + nlsolver.c = zero(γ) + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₂ + f2(k2, u, p, t) + k2 .*= dt + integrator.stats.nf2 += 1 + end + + ##### Stage 3 + copyto!(z₃, z₂) + nlsolver.z = z₃ + if integrator.f isa SplitFunction + @.. broadcast=false tmp = uprev + a32 * z₂ + ea32 * k2 + else + @.. broadcast=false tmp = uprev + a32 * z₂ + end + nlsolver.c = c3 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₃ + f2(k3, u, p, t + dt) + k3 .*= dt + integrator.stats.nf2 += 1 + end + + ##### Stage 4 + copyto!(z₄, z₃) + nlsolver.z = z₄ + if integrator.f isa SplitFunction + @.. broadcast=false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + + ea42 * k2 + ea43 * k3 + else + @.. broadcast=false tmp = uprev + a41 * z₁ + a42 * z₂ + a43 * z₃ + end + nlsolver.c = c4 + z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction + @.. broadcast=false u = tmp + γ * z₄ + f2(k4, u, p, t + c4 * dt) + k4 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast=false u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + eb4 * (z₄ + k4) + else + @.. broadcast=false u = uprev + eb2 * z₂ + eb3 * z₃ + eb4 * z₄ + end + + step_limiter!(u, integrator, p, t + dt) + + ################################### Finalize + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast=false integrator.fsallast = z₄ / dt + end +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl index ff7391848fa..31dac2087e8 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl @@ -2828,3 +2828,184 @@ function KenCarp58Tableau(T, T2) ebtilde4, ebtilde5, ebtilde6, ebtilde7, ebtilde8 ) end + +#= +IMEX-SSP2(2,2,2) — Pareschi & Russo (2005), Table 2 +2-stage, 2nd order, L-stable, SSP + +Implicit tableau (γ = 1 - 1/√2): + c = [γ, 1-γ] + A = [γ 0 ] + [1-2γ γ ] + b = [1/2, 1/2] + +Explicit tableau: + c̃ = [0, 1] + à = [0 0] + [1 0] + b̃ = [1/2, 1/2] +=# +struct IMEXSSP222Tableau{T, T2} + γ::T2 + a21::T + c2::T2 + ea21::T + eb1::T + eb2::T +end + +function IMEXSSP222Tableau(T, T2) + γ = convert(T2, 1 - 1 / sqrt(2)) + a21 = convert(T, sqrt(2) - 1) # = 1 - 2γ + c2 = convert(T2, 1 / sqrt(2)) # = 1 - γ + ea21 = convert(T, 1) + eb1 = convert(T, 1 // 2) + eb2 = convert(T, 1 // 2) + return IMEXSSP222Tableau(γ, a21, c2, ea21, eb1, eb2) +end + +#= +IMEX-SSP2(3,2,2) — Pareschi & Russo (2005), Table 3 +3-stage, 2nd order, stiffly accurate, SSP + +Implicit tableau (γ = 1/2): + c = [1/2, 0, 1] + A = [ γ 0 0] + [-γ γ 0] + [ 0 γ γ] + b = [0, 1/2, 1/2] (= last row, stiffly accurate) + +Explicit tableau: + c̃ = [0, 0, 1] + à = [0 0 0] + [0 0 0] + [0 1 0] + b̃ = [0, 1/2, 1/2] +=# +struct IMEXSSP2322Tableau{T, T2} + γ::T2 + a21::T + c2::T2 + c3::T2 + ea32::T + eb2::T + eb3::T +end + +function IMEXSSP2322Tableau(T, T2) + γ = convert(T2, 1 // 2) + a21 = convert(T, -1 // 2) + c2 = convert(T2, 0) + c3 = convert(T2, 1) + ea32 = convert(T, 1) + eb2 = convert(T, 1 // 2) + eb3 = convert(T, 1 // 2) + return IMEXSSP2322Tableau(γ, a21, c2, c3, ea32, eb2, eb3) +end + +#= +IMEX-SSP3(3,3,2) — Pareschi & Russo (2005), Table 6 +3-stage, 2nd order IMEX (3rd order SSP explicit), L-stable + +Implicit tableau (γ = 1 - 1/√2): + c = [γ, 1-γ, 1/2] + A = [γ 0 0 ] + [1-2γ γ 0 ] + [1/2-γ 0 γ ] + b = [1/6, 1/6, 2/3] + +Explicit tableau: + c̃ = [0, 1, 1/2] + à = [0 0 0] + [1 0 0] + [1/4 1/4 0] + b̃ = [1/6, 1/6, 2/3] +=# +struct IMEXSSP3332Tableau{T, T2} + γ::T2 + a21::T + a31::T + c2::T2 + c3::T2 + ea21::T + ea31::T + ea32::T + eb1::T + eb2::T + eb3::T +end + +function IMEXSSP3332Tableau(T, T2) + γ = convert(T2, 1 - 1 / sqrt(2)) + a21 = convert(T, sqrt(2) - 1) # = 1 - 2γ + a31 = convert(T, 1 / sqrt(2) - 1 // 2) # = 1/2 - γ + c2 = convert(T2, 1 / sqrt(2)) # = 1 - γ + c3 = convert(T2, 1 // 2) + ea21 = convert(T, 1) + ea31 = convert(T, 1 // 4) + ea32 = convert(T, 1 // 4) + eb1 = convert(T, 1 // 6) + eb2 = convert(T, 1 // 6) + eb3 = convert(T, 2 // 3) + return IMEXSSP3332Tableau(γ, a21, a31, c2, c3, ea21, ea31, ea32, eb1, eb2, eb3) +end + +#= +IMEX-SSP3(4,3,3) — Pareschi & Russo (2005), Table 7 +4-stage, 3rd order, L-stable, SSP +α=0.24169426078821, β=0.06042356519705, η=0.12915286960590, γ=α + +Implicit tableau: + c = [α, 0, 1, 1/2] + A = [ α 0 0 0 ] + [-α α 0 0 ] + [ 0 1-α α 0 ] + [ β η 1/2-β-η-α α ] + b = [0, 1/6, 1/6, 2/3] + +Explicit tableau: + c̃ = [0, 0, 1, 1/2] + à = [0 0 0 0] + [0 0 0 0] + [0 1 0 0] + [0 1/4 1/4 0] + b̃ = [0, 1/6, 1/6, 2/3] + (k₁ never used: b̃₁=0 and ã_{i1}=0 for all i) +=# +struct IMEXSSP3433Tableau{T, T2} + γ::T2 + a21::T + a32::T + a41::T + a42::T + a43::T + c3::T2 + c4::T2 + ea32::T + ea42::T + ea43::T + eb2::T + eb3::T + eb4::T +end + +function IMEXSSP3433Tableau(T, T2) + α = 0.24169426078821 + β = 0.06042356519705 + η = 0.12915286960590 + γ = convert(T2, α) + a21 = convert(T, -α) + a32 = convert(T, 1 - α) + a41 = convert(T, β) + a42 = convert(T, η) + a43 = convert(T, 1 // 2 - β - η - α) + c3 = convert(T2, 1) + c4 = convert(T2, 1 // 2) + ea32 = convert(T, 1) + ea42 = convert(T, 1 // 4) + ea43 = convert(T, 1 // 4) + eb2 = convert(T, 1 // 6) + eb3 = convert(T, 1 // 6) + eb4 = convert(T, 2 // 3) + return IMEXSSP3433Tableau(γ, a21, a32, a41, a42, a43, c3, c4, ea32, ea42, ea43, eb2, eb3, eb4) +end From bebc56315b0b8627177be4709529dc0a8844269a Mon Sep 17 00:00:00 2001 From: Harsh Singh Date: Sun, 5 Apr 2026 10:18:56 +0530 Subject: [PATCH 2/3] Add IMEX Runge-Kutta solvers (ARS, SSP, BHR) with ESDIRK implicit structure and proper cache handling --- .../src/OrdinaryDiffEqSDIRK.jl | 3 +- lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl | 15 + lib/OrdinaryDiffEqSDIRK/src/algorithms.jl | 166 ++++ .../src/kencarp_kvaerno_caches.jl | 281 +++++++ .../src/kencarp_kvaerno_perform_step.jl | 738 ++++++++++++++++++ lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl | 250 ++++++ 6 files changed, 1452 insertions(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl index f4226a5ea5d..5d021c996b6 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl @@ -44,7 +44,8 @@ export ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22, SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, Kvaerno5, KenCarp4, KenCarp5, SFSDIRK4, SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA, - IMEXSSP222, IMEXSSP2322, IMEXSSP3332, IMEXSSP3433 + IMEXSSP222, IMEXSSP2322, IMEXSSP3332, IMEXSSP3433, + ARS222, ARS232, ARS443, BHR553 import PrecompileTools import Preferences diff --git a/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl b/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl index f917c45add8..4b7559cea80 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl @@ -51,6 +51,11 @@ ssp_coefficient(alg::SSPSDIRK2) = 4 isesdirk(alg::TRBDF2) = true +isesdirk(alg::ARS222) = true +isesdirk(alg::ARS232) = true +isesdirk(alg::ARS443) = true +isesdirk(alg::BHR553) = true + issplit(alg::KenCarp3) = true issplit(alg::KenCarp4) = true issplit(alg::KenCarp47) = true @@ -62,7 +67,17 @@ issplit(alg::IMEXSSP2322) = true issplit(alg::IMEXSSP3332) = true issplit(alg::IMEXSSP3433) = true +issplit(alg::ARS222) = true +issplit(alg::ARS232) = true +issplit(alg::ARS443) = true +issplit(alg::BHR553) = true + alg_order(alg::IMEXSSP222) = 2 alg_order(alg::IMEXSSP2322) = 2 alg_order(alg::IMEXSSP3332) = 2 alg_order(alg::IMEXSSP3433) = 3 + +alg_order(alg::ARS222) = 2 +alg_order(alg::ARS232) = 2 +alg_order(alg::ARS443) = 3 +alg_order(alg::BHR553) = 3 diff --git a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl index b6179675ca9..ef62010473a 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl @@ -1773,3 +1773,169 @@ function IMEXSSP3433(; _unwrap_val(concrete_jac), typeof(step_limiter!), }(linsolve, nlsolve, precs, extrapolant, step_limiter!, AD_choice) end + +const _ARS_BHR_REFERENCE = "@article{ascher1997implicit, + title={Implicit-explicit Runge-Kutta methods for time-dependent partial differential equations}, + author={Ascher, Uri M and Ruuth, Steven J and Spiteri, Raymond J}, + journal={Applied Numerical Mathematics}, + volume={25}, + number={2-3}, + pages={151--167}, + year={1997}, + publisher={Elsevier}}" + +@doc SDIRK_docstring( + "3-stage 2nd-order IMEX ESDIRK method (ARS(2,2,2)) for split ODEs. From Ascher, Ruuth & Spiteri (1997), Table II.", + "ARS222"; + references = _ARS_BHR_REFERENCE, + extra_keyword_description = """ + - `extrapolant`: TBD + - `step_limiter!`: function of the form `limiter!(u, integrator, p, t)` + """, + extra_keyword_default = """ + extrapolant = :linear, + step_limiter! = trivial_limiter!, + """ +) +struct ARS222{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + linsolve::F + nlsolve::F2 + precs::P + extrapolant::Symbol + step_limiter!::StepLimiter + autodiff::AD +end +function ARS222(; + chunk_size = Val{0}(), autodiff = AutoForwardDiff(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}(), + linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), + extrapolant = :linear, step_limiter! = trivial_limiter! + ) + AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) + return ARS222{ + _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), + _unwrap_val(concrete_jac), typeof(step_limiter!), + }(linsolve, nlsolve, precs, extrapolant, step_limiter!, AD_choice) +end + +@doc SDIRK_docstring( + "3-stage 2nd-order IMEX ESDIRK method (ARS(2,3,2)) for split ODEs. From Ascher, Ruuth & Spiteri (1997).", + "ARS232"; + references = _ARS_BHR_REFERENCE, + extra_keyword_description = """ + - `extrapolant`: TBD + - `step_limiter!`: function of the form `limiter!(u, integrator, p, t)` + """, + extra_keyword_default = """ + extrapolant = :linear, + step_limiter! = trivial_limiter!, + """ +) +struct ARS232{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + linsolve::F + nlsolve::F2 + precs::P + extrapolant::Symbol + step_limiter!::StepLimiter + autodiff::AD +end +function ARS232(; + chunk_size = Val{0}(), autodiff = AutoForwardDiff(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}(), + linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), + extrapolant = :linear, step_limiter! = trivial_limiter! + ) + AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) + return ARS232{ + _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), + _unwrap_val(concrete_jac), typeof(step_limiter!), + }(linsolve, nlsolve, precs, extrapolant, step_limiter!, AD_choice) +end + +@doc SDIRK_docstring( + "5-stage 3rd-order IMEX ESDIRK method (ARS(4,4,3)) for split ODEs. From Ascher, Ruuth & Spiteri (1997), Table IV.", + "ARS443"; + references = _ARS_BHR_REFERENCE, + extra_keyword_description = """ + - `extrapolant`: TBD + - `step_limiter!`: function of the form `limiter!(u, integrator, p, t)` + """, + extra_keyword_default = """ + extrapolant = :linear, + step_limiter! = trivial_limiter!, + """ +) +struct ARS443{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + linsolve::F + nlsolve::F2 + precs::P + extrapolant::Symbol + step_limiter!::StepLimiter + autodiff::AD +end +function ARS443(; + chunk_size = Val{0}(), autodiff = AutoForwardDiff(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}(), + linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), + extrapolant = :linear, step_limiter! = trivial_limiter! + ) + AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) + return ARS443{ + _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), + _unwrap_val(concrete_jac), typeof(step_limiter!), + }(linsolve, nlsolve, precs, extrapolant, step_limiter!, AD_choice) +end + +@doc SDIRK_docstring( + "5-stage 3rd-order IMEX ESDIRK method (BHR(5,5,3)*) for split ODEs. From Boscarino & Russo (2009).", + "BHR553"; + references = "@article{boscarino2009error, + title={Error analysis of IMEX Runge-Kutta methods derived from differential-algebraic systems}, + author={Boscarino, Sebastiano and Russo, Giovanni}, + journal={SIAM Journal on Numerical Analysis}, + volume={49}, + number={4}, + pages={1600--1624}, + year={2009}, + publisher={SIAM}}", + extra_keyword_description = """ + - `extrapolant`: TBD + - `step_limiter!`: function of the form `limiter!(u, integrator, p, t)` + """, + extra_keyword_default = """ + extrapolant = :linear, + step_limiter! = trivial_limiter!, + """ +) +struct BHR553{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: + OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + linsolve::F + nlsolve::F2 + precs::P + extrapolant::Symbol + step_limiter!::StepLimiter + autodiff::AD +end +function BHR553(; + chunk_size = Val{0}(), autodiff = AutoForwardDiff(), + standardtag = Val{true}(), concrete_jac = nothing, + diff_type = Val{:forward}(), + linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), + extrapolant = :linear, step_limiter! = trivial_limiter! + ) + AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) + return BHR553{ + _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), + _unwrap_val(concrete_jac), typeof(step_limiter!), + }(linsolve, nlsolve, precs, extrapolant, step_limiter!, AD_choice) +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl index 5d9059cc597..b6f0a39efe2 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl @@ -986,3 +986,284 @@ function alg_cache( end return IMEXSSP3433Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k2, k3, k4, nlsolver, tab, alg.step_limiter!) end + +# ---- ARS222 ---- + +@cache mutable struct ARS222ConstantCache{N, Tab} <: SDIRKConstantCache + nlsolver::N + tab::Tab +end + +function alg_cache( + alg::ARS222, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ARS222Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.γ + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return ARS222ConstantCache(nlsolver, tab) +end + +@cache mutable struct ARS222Cache{uType, rateType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + z₁::uType + z₂::uType + z₃::uType + k1::kType + k2::kType + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function alg_cache( + alg::ARS222, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ARS222Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.γ + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + z₁ = zero(u) + z₂ = zero(u) + z₃ = nlsolver.z + if f isa SplitFunction + k1 = zero(u) + k2 = zero(u) + else + k1 = nothing + k2 = nothing + end + return ARS222Cache(u, uprev, fsalfirst, z₁, z₂, z₃, k1, k2, nlsolver, tab, alg.step_limiter!) +end + +# ---- ARS232 ---- + +@cache mutable struct ARS232ConstantCache{N, Tab} <: SDIRKConstantCache + nlsolver::N + tab::Tab +end + +function alg_cache( + alg::ARS232, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ARS232Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.γ + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return ARS232ConstantCache(nlsolver, tab) +end + +@cache mutable struct ARS232Cache{uType, rateType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + z₁::uType + z₂::uType + z₃::uType + k1::kType + k2::kType + k3::kType + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function alg_cache( + alg::ARS232, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ARS232Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.γ + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + z₁ = zero(u) + z₂ = zero(u) + z₃ = nlsolver.z + if f isa SplitFunction + k1 = zero(u) + k2 = zero(u) + k3 = zero(u) + else + k1 = nothing + k2 = nothing + k3 = nothing + end + return ARS232Cache(u, uprev, fsalfirst, z₁, z₂, z₃, k1, k2, k3, nlsolver, tab, alg.step_limiter!) +end + +# ---- ARS443 ---- + +@cache mutable struct ARS443ConstantCache{N, Tab} <: SDIRKConstantCache + nlsolver::N + tab::Tab +end + +function alg_cache( + alg::ARS443, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ARS443Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.γ + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return ARS443ConstantCache(nlsolver, tab) +end + +@cache mutable struct ARS443Cache{uType, rateType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + z₁::uType + z₂::uType + z₃::uType + z₄::uType + z₅::uType + k1::kType + k2::kType + k3::kType + k4::kType + k5::kType + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function alg_cache( + alg::ARS443, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ARS443Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.γ + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + z₁ = zero(u) + z₂ = zero(u) + z₃ = zero(u) + z₄ = zero(u) + z₅ = nlsolver.z + if f isa SplitFunction + k1 = zero(u) + k2 = zero(u) + k3 = zero(u) + k4 = zero(u) + k5 = zero(u) + else + k1 = nothing + k2 = nothing + k3 = nothing + k4 = nothing + k5 = nothing + end + return ARS443Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, k1, k2, k3, k4, k5, nlsolver, tab, alg.step_limiter!) +end + +# ---- BHR553 ---- + +@cache mutable struct BHR553ConstantCache{N, Tab} <: SDIRKConstantCache + nlsolver::N + tab::Tab +end + +function alg_cache( + alg::BHR553, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = BHR553Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c2 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return BHR553ConstantCache(nlsolver, tab) +end + +@cache mutable struct BHR553Cache{uType, rateType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + z₁::uType + z₂::uType + z₃::uType + z₄::uType + z₅::uType + k1::kType + k2::kType + k3::kType + k4::kType + k5::kType + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function alg_cache( + alg::BHR553, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = BHR553Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ, c = tab.γ, tab.c2 + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + z₁ = zero(u) + z₂ = zero(u) + z₃ = zero(u) + z₄ = zero(u) + z₅ = nlsolver.z + if f isa SplitFunction + k1 = zero(u) + k2 = zero(u) + k3 = zero(u) + k4 = zero(u) + k5 = zero(u) + else + k1 = nothing + k2 = nothing + k3 = nothing + k4 = nothing + k5 = nothing + end + return BHR553Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, k1, k2, k3, k4, k5, nlsolver, tab, alg.step_limiter!) +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl index 1332287abb0..c78d5e29b18 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl @@ -3373,3 +3373,741 @@ end @.. broadcast=false integrator.fsallast = z₄ / dt end end + +# --------------------------------------------------------------------------- +# ARS(2,2,2) — 3-stage, 2nd order, ESDIRK, ISA (Ascher et al. 1997 Table II) +# --------------------------------------------------------------------------- + +@muladd function perform_step!( + integrator, cache::ARS222ConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + (; γ, a32, ea21, ea31, ea32) = cache.tab + + f2 = nothing + k1 = nothing + k2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + γdt = γ * dt + markfirststage!(nlsolver) + + # Stage 1 (ESDIRK: trivial, FSAL) + if integrator.f isa SplitFunction + z₁ = dt .* f(uprev, p, t) + k1 = dt * integrator.fsalfirst - z₁ + else + z₁ = dt * integrator.fsalfirst + end + + # Stage 2 + nlsolver.z = z₂ = z₁ + nlsolver.tmp = uprev # a21=0 + if integrator.f isa SplitFunction + nlsolver.tmp = nlsolver.tmp + ea21 * k1 + end + nlsolver.c = γ + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 3 + if integrator.f isa SplitFunction + u = nlsolver.tmp + γ * z₂ # Y₂ + k2 = dt * f2(u, p, t + γdt) + integrator.stats.nf2 += 1 + tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 + else + tmp = uprev + a32 * z₂ + end + nlsolver.z = z₃ = z₂ + nlsolver.tmp = tmp + nlsolver.c = 1 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Final update: u = tmp + γ*z₃ + # For split: tmp already contains ea31*k1+ea32*k2 = eb1*k1+eb2*k2 (GSA: b̃ = last row of Ã) + u = nlsolver.tmp + γ * z₃ + + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z₃ ./ dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!(integrator, cache::ARS222Cache, repeat_step = false) + (; t, dt, uprev, u, p) = integrator + (; z₁, z₂, z₃, k1, k2, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + (; γ, a32, ea21, ea31, ea32) = cache.tab + + f2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + γdt = γ * dt + markfirststage!(nlsolver) + + # Stage 1 (ESDIRK: trivial, FSAL) + if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail + f(z₁, uprev, p, t) + z₁ .*= dt + else + @.. broadcast = false z₁ = dt * integrator.fsalfirst + end + + # Stage 2 + copyto!(z₂, z₁) + nlsolver.z = z₂ + @.. broadcast = false tmp = uprev # a21=0 + if integrator.f isa SplitFunction + @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ + @.. broadcast = false tmp += ea21 * k1 + end + nlsolver.c = γ + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + isnewton(nlsolver) && set_new_W!(nlsolver, false) + + # Stage 3 + if integrator.f isa SplitFunction + @.. broadcast = false u = tmp + γ * z₂ # Y₂ + f2(k2, u, p, t + γdt) + k2 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast = false tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 + else + @.. broadcast = false tmp = uprev + a32 * z₂ + end + copyto!(z₃, z₂) + nlsolver.z = z₃ + nlsolver.c = 1 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Final update: u = tmp + γ*z₃ (works for both split and non-split: GSA for split) + @.. broadcast = false u = tmp + γ * z₃ + + step_limiter!(u, integrator, p, t + dt) + + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast = false integrator.fsallast = z₃ / dt + end +end + +# --------------------------------------------------------------------------- +# ARS(2,3,2) — 3-stage, 2nd order, ESDIRK, ISA (Ascher et al. 1997) +# --------------------------------------------------------------------------- + +@muladd function perform_step!( + integrator, cache::ARS232ConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + (; γ, a32, ea21, ea31, ea32, eb2, eb3) = cache.tab + + f2 = nothing + k1 = nothing + k2 = nothing + k3 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + γdt = γ * dt + markfirststage!(nlsolver) + + # Stage 1 (ESDIRK: trivial, FSAL) + if integrator.f isa SplitFunction + z₁ = dt .* f(uprev, p, t) + k1 = dt * integrator.fsalfirst - z₁ + else + z₁ = dt * integrator.fsalfirst + end + + # Stage 2 + nlsolver.z = z₂ = z₁ + nlsolver.tmp = uprev # a21=0 + if integrator.f isa SplitFunction + nlsolver.tmp = nlsolver.tmp + ea21 * k1 + end + nlsolver.c = γ + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 3 + if integrator.f isa SplitFunction + u = nlsolver.tmp + γ * z₂ # Y₂ + k2 = dt * f2(u, p, t + γdt) + integrator.stats.nf2 += 1 + tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 + else + tmp = uprev + a32 * z₂ + end + nlsolver.z = z₃ = z₂ + nlsolver.tmp = tmp + nlsolver.c = 1 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Final update + if integrator.f isa SplitFunction + u = nlsolver.tmp + γ * z₃ # Y₃ (ISA implicit part) + k3 = dt * f2(u, p, t + dt) + integrator.stats.nf2 += 1 + # b̃ = [0, 1-γ, γ] ≠ last row of à → must reconstruct explicitly + u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + else + u = nlsolver.tmp + γ * z₃ # ISA + end + + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z₃ ./ dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!(integrator, cache::ARS232Cache, repeat_step = false) + (; t, dt, uprev, u, p) = integrator + (; z₁, z₂, z₃, k1, k2, k3, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + (; γ, a32, ea21, ea31, ea32, eb2, eb3) = cache.tab + + f2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + γdt = γ * dt + markfirststage!(nlsolver) + + # Stage 1 (ESDIRK: trivial, FSAL) + if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail + f(z₁, uprev, p, t) + z₁ .*= dt + else + @.. broadcast = false z₁ = dt * integrator.fsalfirst + end + + # Stage 2 + copyto!(z₂, z₁) + nlsolver.z = z₂ + @.. broadcast = false tmp = uprev # a21=0 + if integrator.f isa SplitFunction + @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ + @.. broadcast = false tmp += ea21 * k1 + end + nlsolver.c = γ + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + isnewton(nlsolver) && set_new_W!(nlsolver, false) + + # Stage 3 + if integrator.f isa SplitFunction + @.. broadcast = false u = tmp + γ * z₂ # Y₂ + f2(k2, u, p, t + γdt) + k2 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast = false tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 + else + @.. broadcast = false tmp = uprev + a32 * z₂ + end + copyto!(z₃, z₂) + nlsolver.z = z₃ + nlsolver.c = 1 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Final update + @.. broadcast = false u = tmp + γ * z₃ # ISA (or Y₃ for split) + if integrator.f isa SplitFunction + f2(k3, u, p, t + dt) + k3 .*= dt + integrator.stats.nf2 += 1 + # b̃ = [0, 1-γ, γ]; explicit reconstruction needed + @.. broadcast = false u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + end + + step_limiter!(u, integrator, p, t + dt) + + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast = false integrator.fsallast = z₃ / dt + end +end + +# --------------------------------------------------------------------------- +# ARS(4,4,3) — 5-stage, 3rd order, ESDIRK, ISA (Ascher et al. 1997 Table IV) +# --------------------------------------------------------------------------- + +@muladd function perform_step!( + integrator, cache::ARS443ConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + (; γ, a32, a42, a43, a52, a53, a54, + ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, + eb2, eb3, eb4, eb5, c3, c4) = cache.tab + + f2 = nothing + k1 = nothing + k2 = nothing + k3 = nothing + k4 = nothing + k5 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + γdt = γ * dt + markfirststage!(nlsolver) + + # Stage 1 (ESDIRK: trivial, FSAL) + if integrator.f isa SplitFunction + z₁ = dt .* f(uprev, p, t) + k1 = dt * integrator.fsalfirst - z₁ + else + z₁ = dt * integrator.fsalfirst + end + + # Stage 2 (a21=0, c₂=γ) + nlsolver.z = z₂ = z₁ + nlsolver.tmp = uprev # a21=0 + if integrator.f isa SplitFunction + nlsolver.tmp = nlsolver.tmp + ea21 * k1 + end + nlsolver.c = γ + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 3 (c₃=2/3) + if integrator.f isa SplitFunction + u = nlsolver.tmp + γ * z₂ # Y₂ + k2 = dt * f2(u, p, t + γdt) + integrator.stats.nf2 += 1 + tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 + else + tmp = uprev + a32 * z₂ + end + nlsolver.z = z₃ = z₂ + nlsolver.tmp = tmp + nlsolver.c = c3 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 4 (c₄=1/2) + if integrator.f isa SplitFunction + u = nlsolver.tmp + γ * z₃ # Y₃ + k3 = dt * f2(u, p, t + c3 * dt) + integrator.stats.nf2 += 1 + tmp = uprev + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 + else + tmp = uprev + a42 * z₂ + a43 * z₃ + end + nlsolver.z = z₄ = z₃ + nlsolver.tmp = tmp + nlsolver.c = c4 + z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 5 (c₅=1) + if integrator.f isa SplitFunction + u = nlsolver.tmp + γ * z₄ # Y₄ + k4 = dt * f2(u, p, t + c4 * dt) + integrator.stats.nf2 += 1 + tmp = uprev + a52 * z₂ + a53 * z₃ + a54 * z₄ + + ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 + else + tmp = uprev + a52 * z₂ + a53 * z₃ + a54 * z₄ + end + nlsolver.z = z₅ = z₄ + nlsolver.tmp = tmp + nlsolver.c = 1 + z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Final update: ISA → u = tmp + γ*z₅; for split add explicit contributions + u = nlsolver.tmp + γ * z₅ + if integrator.f isa SplitFunction + k5 = dt * f2(u, p, t + dt) + integrator.stats.nf2 += 1 + # b̃ = b_i = [0, eb2, eb3, eb4, eb5] + u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + eb4 * (z₄ + k4) + eb5 * (z₅ + k5) + end + + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z₅ ./ dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!(integrator, cache::ARS443Cache, repeat_step = false) + (; t, dt, uprev, u, p) = integrator + (; z₁, z₂, z₃, z₄, z₅, k1, k2, k3, k4, k5, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + (; γ, a32, a42, a43, a52, a53, a54, + ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, + eb2, eb3, eb4, eb5, c3, c4) = cache.tab + + f2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + γdt = γ * dt + markfirststage!(nlsolver) + + # Stage 1 (ESDIRK: trivial, FSAL) + if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail + f(z₁, uprev, p, t) + z₁ .*= dt + else + @.. broadcast = false z₁ = dt * integrator.fsalfirst + end + + # Stage 2 (a21=0, c₂=γ) + copyto!(z₂, z₁) + nlsolver.z = z₂ + @.. broadcast = false tmp = uprev # a21=0 + if integrator.f isa SplitFunction + @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ + @.. broadcast = false tmp += ea21 * k1 + end + nlsolver.c = γ + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + isnewton(nlsolver) && set_new_W!(nlsolver, false) + + # Stage 3 (c₃=2/3) + if integrator.f isa SplitFunction + @.. broadcast = false u = tmp + γ * z₂ # Y₂ + f2(k2, u, p, t + γdt) + k2 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast = false tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 + else + @.. broadcast = false tmp = uprev + a32 * z₂ + end + copyto!(z₃, z₂) + nlsolver.z = z₃ + nlsolver.c = c3 + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 4 (c₄=1/2) + if integrator.f isa SplitFunction + @.. broadcast = false u = tmp + γ * z₃ # Y₃ + f2(k3, u, p, t + c3 * dt) + k3 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast = false tmp = uprev + a42 * z₂ + a43 * z₃ + ea41 * k1 + + ea42 * k2 + ea43 * k3 + else + @.. broadcast = false tmp = uprev + a42 * z₂ + a43 * z₃ + end + copyto!(z₄, z₃) + nlsolver.z = z₄ + nlsolver.c = c4 + z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 5 (c₅=1) + if integrator.f isa SplitFunction + @.. broadcast = false u = tmp + γ * z₄ # Y₄ + f2(k4, u, p, t + c4 * dt) + k4 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast = false tmp = uprev + a52 * z₂ + a53 * z₃ + a54 * z₄ + + ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 + else + @.. broadcast = false tmp = uprev + a52 * z₂ + a53 * z₃ + a54 * z₄ + end + copyto!(z₅, z₄) + nlsolver.z = z₅ + nlsolver.c = 1 + z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Final update + @.. broadcast = false u = tmp + γ * z₅ # ISA (or Y₅ for split) + if integrator.f isa SplitFunction + f2(k5, u, p, t + dt) + k5 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast = false u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + + eb4 * (z₄ + k4) + eb5 * (z₅ + k5) + end + + step_limiter!(u, integrator, p, t + dt) + + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast = false integrator.fsallast = z₅ / dt + end +end + +# --------------------------------------------------------------------------- +# BHR(5,5,3)* — 5-stage, 3rd order, ESDIRK, ISA (Boscarino & Russo 2009) +# --------------------------------------------------------------------------- + +@muladd function perform_step!( + integrator, cache::BHR553ConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + (; γ, a21, a31, a41, a43, a51, a53, a54, + ea21, ea31, ea32, ea41, ea43, ea51, ea52, ea53, ea54, + eb1, eb3, eb4, eb5, c2, c4) = cache.tab + + f2 = nothing + k1 = nothing + k2 = nothing + k3 = nothing + k4 = nothing + k5 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + # Stage 1 (ESDIRK: trivial, FSAL) + if integrator.f isa SplitFunction + z₁ = dt .* f(uprev, p, t) + k1 = dt * integrator.fsalfirst - z₁ + else + z₁ = dt * integrator.fsalfirst + end + + # Stage 2 (a21=γ, c₂=2γ) + nlsolver.z = z₂ = z₁ + nlsolver.tmp = uprev + a21 * z₁ + if integrator.f isa SplitFunction + nlsolver.tmp = nlsolver.tmp + ea21 * k1 + end + nlsolver.c = c2 + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 3 (a31=γ, a32=0, c₃=2γ; same c as stage 2) + if integrator.f isa SplitFunction + u = nlsolver.tmp + γ * z₂ # Y₂ + k2 = dt * f2(u, p, t + c2 * dt) + integrator.stats.nf2 += 1 + tmp = uprev + a31 * z₁ + ea31 * k1 + ea32 * k2 # a32=0 implicit + else + tmp = uprev + a31 * z₁ + end + nlsolver.z = z₃ = z₂ + nlsolver.tmp = tmp + nlsolver.c = c2 # c₃ = c₂ = 2γ + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 4 (a41, a42=0, a43, c₄=1.5) + if integrator.f isa SplitFunction + u = nlsolver.tmp + γ * z₃ # Y₃ + k3 = dt * f2(u, p, t + c2 * dt) # c₃ = 2γ + integrator.stats.nf2 += 1 + tmp = uprev + a41 * z₁ + a43 * z₃ + ea41 * k1 + ea43 * k3 # a42=0, ae42=0 + else + tmp = uprev + a41 * z₁ + a43 * z₃ + end + nlsolver.z = z₄ = z₃ + nlsolver.tmp = tmp + nlsolver.c = c4 + z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 5 (a51, a52=0, a53, a54, c₅=1) + if integrator.f isa SplitFunction + u = nlsolver.tmp + γ * z₄ # Y₄ + k4 = dt * f2(u, p, t + c4 * dt) + integrator.stats.nf2 += 1 + tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + + ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 # a52=0 + else + tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + end + nlsolver.z = z₅ = z₄ + nlsolver.tmp = tmp + nlsolver.c = 1 + z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Final update: ISA → u = tmp + γ*z₅; for split use b̃ = b_i + u = nlsolver.tmp + γ * z₅ + if integrator.f isa SplitFunction + k5 = dt * f2(u, p, t + dt) + integrator.stats.nf2 += 1 + # b̃ = [eb1, 0, eb3, eb4, eb5]; b̃₂=0 so no z₂+k2 term + u = uprev + eb1 * (z₁ + k1) + eb3 * (z₃ + k3) + eb4 * (z₄ + k4) + eb5 * (z₅ + k5) + end + + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z₅ ./ dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!(integrator, cache::BHR553Cache, repeat_step = false) + (; t, dt, uprev, u, p) = integrator + (; z₁, z₂, z₃, z₄, z₅, k1, k2, k3, k4, k5, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + (; γ, a21, a31, a41, a43, a51, a53, a54, + ea21, ea31, ea32, ea41, ea43, ea51, ea52, ea53, ea54, + eb1, eb3, eb4, eb5, c2, c4) = cache.tab + + f2 = nothing + if integrator.f isa SplitFunction + f = integrator.f.f1 + f2 = integrator.f.f2 + else + f = integrator.f + end + + markfirststage!(nlsolver) + + # Stage 1 (ESDIRK: trivial, FSAL) + if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail + f(z₁, uprev, p, t) + z₁ .*= dt + else + @.. broadcast = false z₁ = dt * integrator.fsalfirst + end + + # Stage 2 (a21=γ, c₂=2γ) + copyto!(z₂, z₁) + nlsolver.z = z₂ + @.. broadcast = false tmp = uprev + a21 * z₁ + if integrator.f isa SplitFunction + @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ + @.. broadcast = false tmp += ea21 * k1 + end + nlsolver.c = c2 + z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + isnewton(nlsolver) && set_new_W!(nlsolver, false) + + # Stage 3 (a31=γ, a32=0, c₃=2γ) + if integrator.f isa SplitFunction + @.. broadcast = false u = tmp + γ * z₂ # Y₂ + f2(k2, u, p, t + c2 * dt) + k2 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast = false tmp = uprev + a31 * z₁ + ea31 * k1 + ea32 * k2 # a32=0 implicit + else + @.. broadcast = false tmp = uprev + a31 * z₁ + end + copyto!(z₃, z₂) + nlsolver.z = z₃ + nlsolver.c = c2 # c₃ = 2γ + z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 4 (a41, a42=0, a43, c₄=1.5) + if integrator.f isa SplitFunction + @.. broadcast = false u = tmp + γ * z₃ # Y₃ + f2(k3, u, p, t + c2 * dt) # c₃ = 2γ + k3 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast = false tmp = uprev + a41 * z₁ + a43 * z₃ + + ea41 * k1 + ea43 * k3 # a42=0, ae42=0 + else + @.. broadcast = false tmp = uprev + a41 * z₁ + a43 * z₃ + end + copyto!(z₄, z₃) + nlsolver.z = z₄ + nlsolver.c = c4 + z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Stage 5 (a51, a52=0, a53, a54, c₅=1) + if integrator.f isa SplitFunction + @.. broadcast = false u = tmp + γ * z₄ # Y₄ + f2(k4, u, p, t + c4 * dt) + k4 .*= dt + integrator.stats.nf2 += 1 + @.. broadcast = false tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + + ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 # a52=0 + else + @.. broadcast = false tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + end + copyto!(z₅, z₄) + nlsolver.z = z₅ + nlsolver.c = 1 + z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + # Final update + @.. broadcast = false u = tmp + γ * z₅ # ISA (or Y₅ for split) + if integrator.f isa SplitFunction + f2(k5, u, p, t + dt) + k5 .*= dt + integrator.stats.nf2 += 1 + # b̃ = [eb1, 0, eb3, eb4, eb5]; b̃₂=0 + @.. broadcast = false u = uprev + eb1 * (z₁ + k1) + eb3 * (z₃ + k3) + + eb4 * (z₄ + k4) + eb5 * (z₅ + k5) + end + + step_limiter!(u, integrator, p, t + dt) + + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast = false integrator.fsallast = z₅ / dt + end +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl index 31dac2087e8..c0c75f30f62 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl @@ -3009,3 +3009,253 @@ function IMEXSSP3433Tableau(T, T2) eb4 = convert(T, 2 // 3) return IMEXSSP3433Tableau(γ, a21, a32, a41, a42, a43, c3, c4, ea32, ea42, ea43, eb2, eb3, eb4) end + +#= +ARS(2,2,2) — Ascher, Ruuth & Spiteri (1997), Table II +3-stage, 2nd order, L-stable, ESDIRK, ISA + +γ = 1 - √2/2 + +Implicit: c = [0, γ, 1] + A = [0 0 0] + [0 γ 0] + [0 1-γ γ] + b = [0, 1-γ, γ] (ISA = last row) + +Explicit: + à = [0 0 0] + [γ 0 0] + [δ 1-δ 0] + b̃ = [δ, 1-δ, 0] where δ = 1 - 1/(2γ) +=# +struct ARS222Tableau{T, T2} + γ::T2 + a32::T # = 1-γ + ea21::T # = γ + ea31::T # = δ = 1 - 1/(2γ) + ea32::T # = 1-δ + eb1::T # = δ + eb2::T # = 1-δ +end + +function ARS222Tableau(T, T2) + γ = convert(T2, 1 - sqrt(T2(2)) / 2) + δ = 1 - 1 / (2γ) + return ARS222Tableau{T, T2}( + γ, + convert(T, 1 - γ), # a32 + convert(T, γ), # ea21 + convert(T, δ), # ea31 + convert(T, 1 - δ), # ea32 + convert(T, δ), # eb1 + convert(T, 1 - δ), # eb2 + ) +end + +#= +ARS(2,3,2) — Ascher, Ruuth & Spiteri (1997) +3-stage, 2nd order, ESDIRK, ISA + +Same implicit tableau as ARS(2,2,2), different explicit part. + +Explicit: + à = [0 0 0] + [γ 0 0] + [δ' 1-δ' 0] + b̃ = [0, 1-γ, γ] where δ' = -2√2/3 +=# +struct ARS232Tableau{T, T2} + γ::T2 + a32::T # = 1-γ + ea21::T # = γ + ea31::T # = δ' = -2√2/3 + ea32::T # = 1-δ' + eb2::T # = 1-γ (b̃₁=0) + eb3::T # = γ +end + +function ARS232Tableau(T, T2) + γ = convert(T2, 1 - sqrt(T2(2)) / 2) + δ = convert(T, -2 * sqrt(T(2)) / 3) + return ARS232Tableau{T, T2}( + γ, + convert(T, 1 - γ), # a32 + convert(T, γ), # ea21 + δ, # ea31 = δ' + 1 - δ, # ea32 = 1-δ' + convert(T, 1 - γ), # eb2 + convert(T, γ), # eb3 + ) +end + +#= +ARS(4,4,3) — Ascher, Ruuth & Spiteri (1997), Table IV +5-stage, 3rd order, ESDIRK, ISA, γ = 1/2 + +Implicit: + c = [0, 1/2, 2/3, 1/2, 1] + A = [0 0 0 0 0 ] + [0 1/2 0 0 0 ] + [0 1/6 1/2 0 0 ] + [0 -1/2 1/2 1/2 0 ] + [0 3/2 -3/2 1/2 1/2 ] (= b, ISA) + b = [0, 3/2, -3/2, 1/2, 1/2] + +Explicit: + à = [0 0 0 0 0] + [1/2 0 0 0 0] + [11/18 1/18 0 0 0] + [5/6 -5/6 1/2 0 0] + [1/4 7/4 3/4 -7/4 0] + b̃ = [0, 3/2, -3/2, 1/2, 1/2] (= b_i) +=# +struct ARS443Tableau{T, T2} + γ::T2 # = 1/2 + a32::T # = 1/6 + a42::T # = -1/2 + a43::T # = 1/2 + a52::T # = 3/2 + a53::T # = -3/2 + a54::T # = 1/2 + ea21::T # = 1/2 + ea31::T # = 11/18 + ea32::T # = 1/18 + ea41::T # = 5/6 + ea42::T # = -5/6 + ea43::T # = 1/2 + ea51::T # = 1/4 + ea52::T # = 7/4 + ea53::T # = 3/4 + ea54::T # = -7/4 + eb2::T # = 3/2 (b̃₁=0) + eb3::T # = -3/2 + eb4::T # = 1/2 + eb5::T # = 1/2 + c3::T2 # = 2/3 + c4::T2 # = 1/2 +end + +function ARS443Tableau(T, T2) + γ = convert(T2, 1 // 2) + return ARS443Tableau{T, T2}( + γ, + convert(T, 1 // 6), # a32 + convert(T, -1 // 2), # a42 + convert(T, 1 // 2), # a43 + convert(T, 3 // 2), # a52 + convert(T, -3 // 2), # a53 + convert(T, 1 // 2), # a54 + convert(T, 1 // 2), # ea21 + convert(T, 11 // 18), # ea31 + convert(T, 1 // 18), # ea32 + convert(T, 5 // 6), # ea41 + convert(T, -5 // 6), # ea42 + convert(T, 1 // 2), # ea43 + convert(T, 1 // 4), # ea51 + convert(T, 7 // 4), # ea52 + convert(T, 3 // 4), # ea53 + convert(T, -7 // 4), # ea54 + convert(T, 3 // 2), # eb2 + convert(T, -3 // 2), # eb3 + convert(T, 1 // 2), # eb4 + convert(T, 1 // 2), # eb5 + convert(T2, 2 // 3), # c3 + convert(T2, 1 // 2), # c4 + ) +end + +#= +BHR(5,5,3)* — Boscarino & Russo (2009) +5-stage, 3rd order, ESDIRK, ISA, L-stable + +γ = 0.435866521508460 + +Implicit: + c = [0, 2γ, 2γ, c₄=1.5, 1] + A = [0 0 0 0 0] + [γ γ 0 0 0] + [γ 0 γ 0 0] + [a41 0 a43 γ 0] + [a51 0 a53 a54 γ] (= b, ISA) + b = [a51, 0, a53, a54, γ] + where: + a41 = 3c₄/2 - c₄²/(4γ) - γ, a43 = c₄²/(4γ) - c₄/2 + b3 = 0.362863385578740, b4 = -0.168124349878957 + +Explicit: + à = [0 0 0 0 0] + [2γ 0 0 0 0] + [γ γ 0 0 0] + [ea41 0 ea43 0 0] + [ea51 ea52 ã53 ã54 0] + b̃ = b_i + where: + ea41 = c₄ - c₄²/(4γ), ea43 = c₄²/(4γ) + ã53 = 1.195970114894582, ã54 = -0.150831109536248 + ea51 = 1 + b3 - ã53 - ã54, ea52 = -b3 +=# +struct BHR553Tableau{T, T2} + γ::T2 + a21::T # = γ + a31::T # = γ + a41::T # = 3c₄/2 - c₄²/(4γ) - γ + a43::T # = c₄²/(4γ) - c₄/2 + a51::T # = 1 - b3 - b4 - γ + a53::T # = b3 + a54::T # = b4 + ea21::T # = 2γ + ea31::T # = γ + ea32::T # = γ + ea41::T # = c₄ - c₄²/(4γ) + ea43::T # = c₄²/(4γ) + ea51::T # = 1 + b3 - ã53 - ã54 + ea52::T # = -b3 + ea53::T # = ã53 + ea54::T # = ã54 + eb1::T # = a51 (= 1 - b3 - b4 - γ) + eb3::T # = b3 + eb4::T # = b4 + eb5::T # = γ + c2::T2 # = 2γ + c4::T2 # = 1.5 +end + +function BHR553Tableau(T, T2) + γ = convert(T2, 0.435866521508460) + b3 = 0.362863385578740 + b4 = -0.168124349878957 + c4val = 1.5 + ã53 = 1.195970114894582 + ã54 = -0.150831109536248 + a41val = 3 * c4val / 2 - c4val^2 / (4 * γ) - γ + a43val = c4val^2 / (4 * γ) - c4val / 2 + a51val = 1 - b3 - b4 - γ + ea41val = c4val - c4val^2 / (4 * γ) + ea43val = c4val^2 / (4 * γ) + ea51val = 1 + b3 - ã53 - ã54 + return BHR553Tableau{T, T2}( + γ, + convert(T, γ), # a21 + convert(T, γ), # a31 + convert(T, a41val), # a41 + convert(T, a43val), # a43 + convert(T, a51val), # a51 + convert(T, b3), # a53 + convert(T, b4), # a54 + convert(T, 2 * γ), # ea21 + convert(T, γ), # ea31 + convert(T, γ), # ea32 + convert(T, ea41val), # ea41 + convert(T, ea43val), # ea43 + convert(T, ea51val), # ea51 + convert(T, -b3), # ea52 + convert(T, ã53), # ea53 + convert(T, ã54), # ea54 + convert(T, a51val), # eb1 + convert(T, b3), # eb3 + convert(T, b4), # eb4 + convert(T, γ), # eb5 + convert(T2, 2 * γ), # c2 + convert(T2, c4val), # c4 + ) +end From 0a53fae3471eef30ef8fcb20a3f4493d8b93aed4 Mon Sep 17 00:00:00 2001 From: Harsh Singh Date: Wed, 8 Apr 2026 12:31:06 +0530 Subject: [PATCH 3/3] Add IMEX RK solvers implementation (tableaus, perform_step, caches) --- lib/OrdinaryDiffEqCore/src/algorithms.jl | 2 + .../src/OrdinaryDiffEqSDIRK.jl | 6 +- lib/OrdinaryDiffEqSDIRK/src/algorithms.jl | 8 +- .../src/generic_imex_perform_step.jl | 269 +++++++ lib/OrdinaryDiffEqSDIRK/src/imex_tableaus.jl | 221 ++++++ .../src/kencarp_kvaerno_caches.jl | 281 ------- .../src/kencarp_kvaerno_perform_step.jl | 738 +----------------- lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl | 250 ------ 8 files changed, 501 insertions(+), 1274 deletions(-) create mode 100644 lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl create mode 100644 lib/OrdinaryDiffEqSDIRK/src/imex_tableaus.jl diff --git a/lib/OrdinaryDiffEqCore/src/algorithms.jl b/lib/OrdinaryDiffEqCore/src/algorithms.jl index 087bb4814df..8127b2776b7 100644 --- a/lib/OrdinaryDiffEqCore/src/algorithms.jl +++ b/lib/OrdinaryDiffEqCore/src/algorithms.jl @@ -45,6 +45,8 @@ abstract type OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT, ST, CJ} <: OrdinaryDiffEqAlgorithm end abstract type OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} <: OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT, ST, CJ} end +abstract type OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ} <: +OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} end abstract type OrdinaryDiffEqRosenbrockAlgorithm{CS, AD, FDT, ST, CJ} <: OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT, ST, CJ} end const NewtonAlgorithm = Union{ diff --git a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl index 5d021c996b6..24790ee0d8e 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl @@ -7,6 +7,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache, OrdinaryDiffEqNewtonAdaptiveAlgorithm, OrdinaryDiffEqNewtonAlgorithm, + OrdinaryDiffEqNewtonESDIRKAlgorithm, DEFAULT_PRECS, OrdinaryDiffEqAdaptiveAlgorithm, CompiledFloats, uses_uprev, alg_cache, _vec, _reshape, @cache, isfsal, full_cache, @@ -37,6 +38,8 @@ include("kencarp_kvaerno_caches.jl") include("sdirk_perform_step.jl") include("kencarp_kvaerno_perform_step.jl") include("sdirk_tableaus.jl") +include("imex_tableaus.jl") +include("generic_imex_perform_step.jl") export ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22, Kvaerno3, KenCarp3, Cash4, Hairer4, Hairer42, SSPSDIRK2, Kvaerno4, @@ -44,8 +47,7 @@ export ImplicitEuler, ImplicitMidpoint, Trapezoid, TRBDF2, SDIRK2, SDIRK22, SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, Kvaerno5, KenCarp4, KenCarp5, SFSDIRK4, SFSDIRK5, CFNLIRK3, SFSDIRK6, SFSDIRK7, SFSDIRK8, ESDIRK436L2SA2, ESDIRK437L2SA, ESDIRK547L2SA2, ESDIRK659L2SA, - IMEXSSP222, IMEXSSP2322, IMEXSSP3332, IMEXSSP3433, - ARS222, ARS232, ARS443, BHR553 + IMEXSSP222, IMEXSSP2322, IMEXSSP3332, IMEXSSP3433 import PrecompileTools import Preferences diff --git a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl index ef62010473a..50cff4a2b25 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl @@ -1798,7 +1798,7 @@ const _ARS_BHR_REFERENCE = "@article{ascher1997implicit, """ ) struct ARS222{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P @@ -1835,7 +1835,7 @@ end """ ) struct ARS232{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P @@ -1872,7 +1872,7 @@ end """ ) struct ARS443{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P @@ -1917,7 +1917,7 @@ end """ ) struct BHR553{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: - OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ} + OrdinaryDiffEqNewtonESDIRKAlgorithm{CS, AD, FDT, ST, CJ} linsolve::F nlsolve::F2 precs::P diff --git a/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl new file mode 100644 index 00000000000..efb6c952b13 --- /dev/null +++ b/lib/OrdinaryDiffEqSDIRK/src/generic_imex_perform_step.jl @@ -0,0 +1,269 @@ +mutable struct ESDIRKIMEXConstantCache{Tab, N} <: SDIRKConstantCache + nlsolver::N + tab::Tab +end + +mutable struct ESDIRKIMEXCache{uType, rateType, N, Tab, kType, StepLimiter} <: + SDIRKMutableCache + u::uType + uprev::uType + fsalfirst::rateType + zs::Vector{uType} + ks::Vector{kType} + nlsolver::N + tab::Tab + step_limiter!::StepLimiter +end + +function full_cache(c::ESDIRKIMEXCache) + base = (c.u, c.uprev, c.fsalfirst, c.zs...) + if eltype(c.ks) !== Nothing + return tuple(base..., c.ks...) + end + return base +end + +const ESDIRKIMEXAlgorithm = Union{ARS222, ARS232, ARS443, BHR553} + +function alg_cache( + alg::ESDIRKIMEXAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, + uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ESDIRKIMEXTableau(alg, constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ = tab.Ai[2, 2] + c = tab.c[2] + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose + ) + return ESDIRKIMEXConstantCache(nlsolver, tab) +end + +function alg_cache( + alg::ESDIRKIMEXAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, + ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + tab = ESDIRKIMEXTableau(alg, constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + γ = tab.Ai[2, 2] + c = tab.c[2] + nlsolver = build_nlsolver( + alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, + uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose + ) + fsalfirst = zero(rate_prototype) + + s = tab.s + if f isa SplitFunction + ks = [zero(u) for _ in 1:s] + else + ks = Vector{Nothing}(nothing, s) + end + + zs = [zero(u) for _ in 1:(s - 1)] + push!(zs, nlsolver.z) + + return ESDIRKIMEXCache( + u, uprev, fsalfirst, zs, ks, nlsolver, tab, alg.step_limiter! + ) +end + +function initialize!(integrator, cache::ESDIRKIMEXConstantCache) + integrator.kshortsize = 2 + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) + integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + integrator.fsallast = zero(integrator.fsalfirst) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + return nothing +end + +function initialize!(integrator, cache::ESDIRKIMEXCache) + integrator.kshortsize = 2 + resize!(integrator.k, integrator.kshortsize) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + return nothing +end + +@muladd function perform_step!( + integrator, cache::ESDIRKIMEXConstantCache, repeat_step = false + ) + (; t, dt, uprev, u, p) = integrator + nlsolver = cache.nlsolver + tab = cache.tab + (; Ai, bi, Ae, be, c, s) = tab + γ = Ai[2, 2] + + f2 = nothing + k = Vector{typeof(u)}(undef, s) + if integrator.f isa SplitFunction + f_impl = integrator.f.f1 + f2 = integrator.f.f2 + else + f_impl = integrator.f + end + + markfirststage!(nlsolver) + + z = Vector{typeof(u)}(undef, s) + + # Stage 1: explicit (ESDIRK: a₁₁ = 0) + if integrator.f isa SplitFunction + z[1] = dt * f_impl(uprev, p, t) + else + z[1] = dt * integrator.fsalfirst + end + + if integrator.f isa SplitFunction + k[1] = dt * integrator.fsalfirst - z[1] + end + + # Stages 2..s + for i in 2:s + tmp = uprev + for j in 1:(i - 1) + tmp = tmp + Ai[i, j] * z[j] + end + + if integrator.f isa SplitFunction + for j in 1:(i - 1) + tmp = tmp + Ae[i, j] * k[j] + end + end + + if integrator.f isa SplitFunction + z_guess = z[1] + else + z_guess = zero(u) + end + + nlsolver.z = z_guess + nlsolver.tmp = tmp + nlsolver.c = c[i] + nlsolver.γ = γ + z[i] = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + + if integrator.f isa SplitFunction && i < s + u_stage = tmp + γ * z[i] + k[i] = dt * f2(u_stage, p, t + c[i] * dt) + integrator.stats.nf2 += 1 + end + end + + # Compute solution + u = nlsolver.tmp + γ * z[s] + if integrator.f isa SplitFunction + k[s] = dt * f2(u, p, t + dt) + integrator.stats.nf2 += 1 + u = uprev + for i in 1:s + u = u + bi[i] * z[i] + be[i] * k[i] + end + end + + if integrator.f isa SplitFunction + integrator.k[1] = integrator.fsalfirst + integrator.fsallast = integrator.f(u, p, t + dt) + integrator.k[2] = integrator.fsallast + else + integrator.fsallast = z[s] ./ dt + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + end + integrator.u = u +end + +@muladd function perform_step!(integrator, cache::ESDIRKIMEXCache, repeat_step = false) + (; t, dt, uprev, u, p) = integrator + (; zs, ks, nlsolver, step_limiter!) = cache + (; tmp) = nlsolver + tab = cache.tab + (; Ai, bi, Ae, be, c, s) = tab + γ = Ai[2, 2] + + f2 = nothing + if integrator.f isa SplitFunction + f_impl = integrator.f.f1 + f2 = integrator.f.f2 + else + f_impl = integrator.f + end + + markfirststage!(nlsolver) + + # Stage 1: explicit (ESDIRK: a₁₁ = 0) + if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail + f_impl(zs[1], integrator.uprev, p, integrator.t) + zs[1] .*= dt + else + @.. broadcast=false zs[1] = dt * integrator.fsalfirst + end + + if integrator.f isa SplitFunction + @.. broadcast=false ks[1] = dt * integrator.fsalfirst - zs[1] + end + + # Stages 2..s + for i in 2:s + @.. broadcast=false tmp = uprev + for j in 1:(i - 1) + @.. broadcast=false tmp += Ai[i, j] * zs[j] + end + + if integrator.f isa SplitFunction + for j in 1:(i - 1) + @.. broadcast=false tmp += Ae[i, j] * ks[j] + end + end + + if integrator.f isa SplitFunction + copyto!(zs[i], zs[1]) + else + fill!(zs[i], zero(eltype(u))) + end + + nlsolver.z = zs[i] + nlsolver.c = c[i] + nlsolver.γ = γ + zs[i] = nlsolve!(nlsolver, integrator, cache, repeat_step) + nlsolvefail(nlsolver) && return + if i > 2 + isnewton(nlsolver) && set_new_W!(nlsolver, false) + end + + if integrator.f isa SplitFunction && i < s + @.. broadcast=false u = tmp + γ * zs[i] + f2(ks[i], u, p, t + c[i] * dt) + ks[i] .*= dt + integrator.stats.nf2 += 1 + end + end + + # Compute solution + @.. broadcast=false u = tmp + γ * zs[s] + if integrator.f isa SplitFunction + f2(ks[s], u, p, t + dt) + ks[s] .*= dt + integrator.stats.nf2 += 1 + @.. broadcast=false u = uprev + for i in 1:s + @.. broadcast=false u += bi[i] * zs[i] + be[i] * ks[i] + end + end + + step_limiter!(u, integrator, p, t + dt) + + if integrator.f isa SplitFunction + integrator.f(integrator.fsallast, u, p, t + dt) + else + @.. broadcast=false integrator.fsallast = zs[s] / dt + end +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/imex_tableaus.jl b/lib/OrdinaryDiffEqSDIRK/src/imex_tableaus.jl new file mode 100644 index 00000000000..5f8b51aacdd --- /dev/null +++ b/lib/OrdinaryDiffEqSDIRK/src/imex_tableaus.jl @@ -0,0 +1,221 @@ +struct ESDIRKIMEXTableau{T, T2} + Ai::Matrix{T} + bi::Vector{T} + Ae::Matrix{T} + be::Vector{T} + c::Vector{T2} + btilde::Union{Vector{T}, Nothing} + ebtilde::Union{Vector{T}, Nothing} + α::Union{Matrix{T2}, Nothing} + order::Int + s::Int +end + +# Dispatch: each algorithm type maps to its tableau constructor +ESDIRKIMEXTableau(::ARS222, T, T2) = ARS222ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::ARS232, T, T2) = ARS232ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::ARS443, T, T2) = ARS443ESDIRKIMEXTableau(T, T2) +ESDIRKIMEXTableau(::BHR553, T, T2) = BHR553ESDIRKIMEXTableau(T, T2) + +# +# ARS(2,2,2) Tableau — Ascher, Ruuth & Spiteri (1997) +# 3-stage, 2nd order, ESDIRK, ISA+GSA +# +function ARS222ESDIRKIMEXTableau(T, T2) + γ = convert(T2, 1 - sqrt(T2(2)) / 2) + δ = 1 - 1 / (2γ) + + s = 3 + + # Implicit tableau + Ai = zeros(T, s, s) + Ai[2, 2] = convert(T, γ) + Ai[3, 2] = convert(T, 1 - γ) + Ai[3, 3] = convert(T, γ) + + bi = zeros(T, s) + bi[2] = convert(T, 1 - γ) + bi[3] = convert(T, γ) + + # Explicit tableau + Ae = zeros(T, s, s) + Ae[2, 1] = convert(T, γ) + Ae[3, 1] = convert(T, δ) + Ae[3, 2] = convert(T, 1 - δ) + + be = zeros(T, s) + be[1] = convert(T, δ) + be[2] = convert(T, 1 - δ) + + c = zeros(T2, s) + c[1] = zero(T2) + c[2] = convert(T2, γ) + c[3] = one(T2) + + return ESDIRKIMEXTableau(Ai, bi, Ae, be, c, nothing, nothing, nothing, 2, s) +end + +# +# ARS(2,3,2) Tableau — Ascher, Ruuth & Spiteri (1997) +# 3-stage, 2nd order, ESDIRK, ISA (same implicit as ARS222, different explicit) +# +function ARS232ESDIRKIMEXTableau(T, T2) + γ = convert(T2, 1 - sqrt(T2(2)) / 2) + δ = convert(T, -2 * sqrt(T(2)) / 3) + + s = 3 + + # Implicit tableau (same as ARS222) + Ai = zeros(T, s, s) + Ai[2, 2] = convert(T, γ) + Ai[3, 2] = convert(T, 1 - γ) + Ai[3, 3] = convert(T, γ) + + bi = zeros(T, s) + bi[2] = convert(T, 1 - γ) + bi[3] = convert(T, γ) + + # Explicit tableau + Ae = zeros(T, s, s) + Ae[2, 1] = convert(T, γ) + Ae[3, 1] = δ + Ae[3, 2] = 1 - δ + + be = zeros(T, s) + be[2] = convert(T, 1 - γ) + be[3] = convert(T, γ) + + c = zeros(T2, s) + c[1] = zero(T2) + c[2] = convert(T2, γ) + c[3] = one(T2) + + return ESDIRKIMEXTableau(Ai, bi, Ae, be, c, nothing, nothing, nothing, 2, s) +end + +# +# ARS(4,4,3) Tableau — Ascher, Ruuth & Spiteri (1997), Table IV +# 5-stage, 3rd order, ESDIRK, ISA, γ = 1/2 +# +function ARS443ESDIRKIMEXTableau(T, T2) + γ = convert(T, 1 // 2) + + s = 5 + + # Implicit tableau + Ai = zeros(T, s, s) + Ai[2, 2] = γ + Ai[3, 2] = convert(T, 1 // 6) + Ai[3, 3] = γ + Ai[4, 2] = convert(T, -1 // 2) + Ai[4, 3] = convert(T, 1 // 2) + Ai[4, 4] = γ + Ai[5, 2] = convert(T, 3 // 2) + Ai[5, 3] = convert(T, -3 // 2) + Ai[5, 4] = convert(T, 1 // 2) + Ai[5, 5] = γ + + bi = zeros(T, s) + bi[2] = convert(T, 3 // 2) + bi[3] = convert(T, -3 // 2) + bi[4] = convert(T, 1 // 2) + bi[5] = convert(T, 1 // 2) + + # Explicit tableau + Ae = zeros(T, s, s) + Ae[2, 1] = convert(T, 1 // 2) + Ae[3, 1] = convert(T, 11 // 18) + Ae[3, 2] = convert(T, 1 // 18) + Ae[4, 1] = convert(T, 5 // 6) + Ae[4, 2] = convert(T, -5 // 6) + Ae[4, 3] = convert(T, 1 // 2) + Ae[5, 1] = convert(T, 1 // 4) + Ae[5, 2] = convert(T, 7 // 4) + Ae[5, 3] = convert(T, 3 // 4) + Ae[5, 4] = convert(T, -7 // 4) + + # be = bi for this method + be = zeros(T, s) + be[2] = convert(T, 3 // 2) + be[3] = convert(T, -3 // 2) + be[4] = convert(T, 1 // 2) + be[5] = convert(T, 1 // 2) + + c = zeros(T2, s) + c[1] = zero(T2) + c[2] = convert(T2, 1 // 2) + c[3] = convert(T2, 2 // 3) + c[4] = convert(T2, 1 // 2) + c[5] = one(T2) + + return ESDIRKIMEXTableau(Ai, bi, Ae, be, c, nothing, nothing, nothing, 3, s) +end + +# +# BHR(5,5,3)* Tableau — Boscarino & Russo (2009) +# 5-stage, 3rd order, ESDIRK, ISA, L-stable +# +function BHR553ESDIRKIMEXTableau(T, T2) + γ = convert(T2, 0.435866521508460) + b3 = 0.362863385578740 + b4 = -0.168124349878957 + c4val = 1.5 + ã53 = 1.195970114894582 + ã54 = -0.150831109536248 + a41val = 3 * c4val / 2 - c4val^2 / (4 * γ) - γ + a43val = c4val^2 / (4 * γ) - c4val / 2 + a51val = 1 - b3 - b4 - γ + ea41val = c4val - c4val^2 / (4 * γ) + ea43val = c4val^2 / (4 * γ) + ea51val = 1 + b3 - ã53 - ã54 + + s = 5 + + # Implicit tableau + Ai = zeros(T, s, s) + Ai[2, 1] = convert(T, γ) + Ai[2, 2] = convert(T, γ) + Ai[3, 1] = convert(T, γ) + Ai[3, 3] = convert(T, γ) + Ai[4, 1] = convert(T, a41val) + Ai[4, 3] = convert(T, a43val) + Ai[4, 4] = convert(T, γ) + Ai[5, 1] = convert(T, a51val) + Ai[5, 3] = convert(T, b3) + Ai[5, 4] = convert(T, b4) + Ai[5, 5] = convert(T, γ) + + bi = zeros(T, s) + bi[1] = convert(T, a51val) + bi[3] = convert(T, b3) + bi[4] = convert(T, b4) + bi[5] = convert(T, γ) + + # Explicit tableau + Ae = zeros(T, s, s) + Ae[2, 1] = convert(T, 2 * γ) + Ae[3, 1] = convert(T, γ) + Ae[3, 2] = convert(T, γ) + Ae[4, 1] = convert(T, ea41val) + Ae[4, 3] = convert(T, ea43val) + Ae[5, 1] = convert(T, ea51val) + Ae[5, 2] = convert(T, -b3) + Ae[5, 3] = convert(T, ã53) + Ae[5, 4] = convert(T, ã54) + + # be = bi for this method + be = zeros(T, s) + be[1] = convert(T, a51val) + be[3] = convert(T, b3) + be[4] = convert(T, b4) + be[5] = convert(T, γ) + + c = zeros(T2, s) + c[1] = zero(T2) + c[2] = convert(T2, 2 * γ) + c[3] = convert(T2, 2 * γ) + c[4] = convert(T2, c4val) + c[5] = one(T2) + + return ESDIRKIMEXTableau(Ai, bi, Ae, be, c, nothing, nothing, nothing, 3, s) +end diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl index b6f0a39efe2..5d9059cc597 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_caches.jl @@ -986,284 +986,3 @@ function alg_cache( end return IMEXSSP3433Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, k2, k3, k4, nlsolver, tab, alg.step_limiter!) end - -# ---- ARS222 ---- - -@cache mutable struct ARS222ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ARS222, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ARS222Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return ARS222ConstantCache(nlsolver, tab) -end - -@cache mutable struct ARS222Cache{uType, rateType, N, Tab, kType, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - k1::kType - k2::kType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache( - alg::ARS222, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ARS222Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - z₁ = zero(u) - z₂ = zero(u) - z₃ = nlsolver.z - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - else - k1 = nothing - k2 = nothing - end - return ARS222Cache(u, uprev, fsalfirst, z₁, z₂, z₃, k1, k2, nlsolver, tab, alg.step_limiter!) -end - -# ---- ARS232 ---- - -@cache mutable struct ARS232ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ARS232, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ARS232Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return ARS232ConstantCache(nlsolver, tab) -end - -@cache mutable struct ARS232Cache{uType, rateType, N, Tab, kType, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - k1::kType - k2::kType - k3::kType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache( - alg::ARS232, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ARS232Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - z₁ = zero(u) - z₂ = zero(u) - z₃ = nlsolver.z - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - end - return ARS232Cache(u, uprev, fsalfirst, z₁, z₂, z₃, k1, k2, k3, nlsolver, tab, alg.step_limiter!) -end - -# ---- ARS443 ---- - -@cache mutable struct ARS443ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::ARS443, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ARS443Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return ARS443ConstantCache(nlsolver, tab) -end - -@cache mutable struct ARS443Cache{uType, rateType, N, Tab, kType, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache( - alg::ARS443, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = ARS443Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.γ - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = nlsolver.z - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - end - return ARS443Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, k1, k2, k3, k4, k5, nlsolver, tab, alg.step_limiter!) -end - -# ---- BHR553 ---- - -@cache mutable struct BHR553ConstantCache{N, Tab} <: SDIRKConstantCache - nlsolver::N - tab::Tab -end - -function alg_cache( - alg::BHR553, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{false}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = BHR553Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c2 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false), verbose - ) - return BHR553ConstantCache(nlsolver, tab) -end - -@cache mutable struct BHR553Cache{uType, rateType, N, Tab, kType, StepLimiter} <: - SDIRKMutableCache - u::uType - uprev::uType - fsalfirst::rateType - z₁::uType - z₂::uType - z₃::uType - z₄::uType - z₅::uType - k1::kType - k2::kType - k3::kType - k4::kType - k5::kType - nlsolver::N - tab::Tab - step_limiter!::StepLimiter -end - -function alg_cache( - alg::BHR553, u, rate_prototype, ::Type{uEltypeNoUnits}, - ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, - uprev, uprev2, f, t, dt, reltol, p, calck, - ::Val{true}, verbose - ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - tab = BHR553Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) - γ, c = tab.γ, tab.c2 - nlsolver = build_nlsolver( - alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, - uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true), verbose - ) - fsalfirst = zero(rate_prototype) - z₁ = zero(u) - z₂ = zero(u) - z₃ = zero(u) - z₄ = zero(u) - z₅ = nlsolver.z - if f isa SplitFunction - k1 = zero(u) - k2 = zero(u) - k3 = zero(u) - k4 = zero(u) - k5 = zero(u) - else - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - end - return BHR553Cache(u, uprev, fsalfirst, z₁, z₂, z₃, z₄, z₅, k1, k2, k3, k4, k5, nlsolver, tab, alg.step_limiter!) -end diff --git a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl index c78d5e29b18..dfc1cb89261 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl @@ -3374,740 +3374,4 @@ end end end -# --------------------------------------------------------------------------- -# ARS(2,2,2) — 3-stage, 2nd order, ESDIRK, ISA (Ascher et al. 1997 Table II) -# --------------------------------------------------------------------------- - -@muladd function perform_step!( - integrator, cache::ARS222ConstantCache, repeat_step = false - ) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a32, ea21, ea31, ea32) = cache.tab - - f2 = nothing - k1 = nothing - k2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - γdt = γ * dt - markfirststage!(nlsolver) - - # Stage 1 (ESDIRK: trivial, FSAL) - if integrator.f isa SplitFunction - z₁ = dt .* f(uprev, p, t) - k1 = dt * integrator.fsalfirst - z₁ - else - z₁ = dt * integrator.fsalfirst - end - - # Stage 2 - nlsolver.z = z₂ = z₁ - nlsolver.tmp = uprev # a21=0 - if integrator.f isa SplitFunction - nlsolver.tmp = nlsolver.tmp + ea21 * k1 - end - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 3 - if integrator.f isa SplitFunction - u = nlsolver.tmp + γ * z₂ # Y₂ - k2 = dt * f2(u, p, t + γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - tmp = uprev + a32 * z₂ - end - nlsolver.z = z₃ = z₂ - nlsolver.tmp = tmp - nlsolver.c = 1 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Final update: u = tmp + γ*z₃ - # For split: tmp already contains ea31*k1+ea32*k2 = eb1*k1+eb2*k2 (GSA: b̃ = last row of Ã) - u = nlsolver.tmp + γ * z₃ - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₃ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::ARS222Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, k1, k2, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a32, ea21, ea31, ea32) = cache.tab - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - γdt = γ * dt - markfirststage!(nlsolver) - - # Stage 1 (ESDIRK: trivial, FSAL) - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - f(z₁, uprev, p, t) - z₁ .*= dt - else - @.. broadcast = false z₁ = dt * integrator.fsalfirst - end - - # Stage 2 - copyto!(z₂, z₁) - nlsolver.z = z₂ - @.. broadcast = false tmp = uprev # a21=0 - if integrator.f isa SplitFunction - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 - end - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - # Stage 3 - if integrator.f isa SplitFunction - @.. broadcast = false u = tmp + γ * z₂ # Y₂ - f2(k2, u, p, t + γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - @.. broadcast = false tmp = uprev + a32 * z₂ - end - copyto!(z₃, z₂) - nlsolver.z = z₃ - nlsolver.c = 1 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Final update: u = tmp + γ*z₃ (works for both split and non-split: GSA for split) - @.. broadcast = false u = tmp + γ * z₃ - - step_limiter!(u, integrator, p, t + dt) - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast = false integrator.fsallast = z₃ / dt - end -end - -# --------------------------------------------------------------------------- -# ARS(2,3,2) — 3-stage, 2nd order, ESDIRK, ISA (Ascher et al. 1997) -# --------------------------------------------------------------------------- - -@muladd function perform_step!( - integrator, cache::ARS232ConstantCache, repeat_step = false - ) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a32, ea21, ea31, ea32, eb2, eb3) = cache.tab - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - γdt = γ * dt - markfirststage!(nlsolver) - - # Stage 1 (ESDIRK: trivial, FSAL) - if integrator.f isa SplitFunction - z₁ = dt .* f(uprev, p, t) - k1 = dt * integrator.fsalfirst - z₁ - else - z₁ = dt * integrator.fsalfirst - end - - # Stage 2 - nlsolver.z = z₂ = z₁ - nlsolver.tmp = uprev # a21=0 - if integrator.f isa SplitFunction - nlsolver.tmp = nlsolver.tmp + ea21 * k1 - end - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 3 - if integrator.f isa SplitFunction - u = nlsolver.tmp + γ * z₂ # Y₂ - k2 = dt * f2(u, p, t + γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - tmp = uprev + a32 * z₂ - end - nlsolver.z = z₃ = z₂ - nlsolver.tmp = tmp - nlsolver.c = 1 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Final update - if integrator.f isa SplitFunction - u = nlsolver.tmp + γ * z₃ # Y₃ (ISA implicit part) - k3 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - # b̃ = [0, 1-γ, γ] ≠ last row of à → must reconstruct explicitly - u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) - else - u = nlsolver.tmp + γ * z₃ # ISA - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₃ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::ARS232Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, k1, k2, k3, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a32, ea21, ea31, ea32, eb2, eb3) = cache.tab - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - γdt = γ * dt - markfirststage!(nlsolver) - - # Stage 1 (ESDIRK: trivial, FSAL) - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - f(z₁, uprev, p, t) - z₁ .*= dt - else - @.. broadcast = false z₁ = dt * integrator.fsalfirst - end - - # Stage 2 - copyto!(z₂, z₁) - nlsolver.z = z₂ - @.. broadcast = false tmp = uprev # a21=0 - if integrator.f isa SplitFunction - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 - end - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - # Stage 3 - if integrator.f isa SplitFunction - @.. broadcast = false u = tmp + γ * z₂ # Y₂ - f2(k2, u, p, t + γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - @.. broadcast = false tmp = uprev + a32 * z₂ - end - copyto!(z₃, z₂) - nlsolver.z = z₃ - nlsolver.c = 1 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Final update - @.. broadcast = false u = tmp + γ * z₃ # ISA (or Y₃ for split) - if integrator.f isa SplitFunction - f2(k3, u, p, t + dt) - k3 .*= dt - integrator.stats.nf2 += 1 - # b̃ = [0, 1-γ, γ]; explicit reconstruction needed - @.. broadcast = false u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) - end - - step_limiter!(u, integrator, p, t + dt) - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast = false integrator.fsallast = z₃ / dt - end -end - -# --------------------------------------------------------------------------- -# ARS(4,4,3) — 5-stage, 3rd order, ESDIRK, ISA (Ascher et al. 1997 Table IV) -# --------------------------------------------------------------------------- - -@muladd function perform_step!( - integrator, cache::ARS443ConstantCache, repeat_step = false - ) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a32, a42, a43, a52, a53, a54, - ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, - eb2, eb3, eb4, eb5, c3, c4) = cache.tab - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - γdt = γ * dt - markfirststage!(nlsolver) - - # Stage 1 (ESDIRK: trivial, FSAL) - if integrator.f isa SplitFunction - z₁ = dt .* f(uprev, p, t) - k1 = dt * integrator.fsalfirst - z₁ - else - z₁ = dt * integrator.fsalfirst - end - - # Stage 2 (a21=0, c₂=γ) - nlsolver.z = z₂ = z₁ - nlsolver.tmp = uprev # a21=0 - if integrator.f isa SplitFunction - nlsolver.tmp = nlsolver.tmp + ea21 * k1 - end - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 3 (c₃=2/3) - if integrator.f isa SplitFunction - u = nlsolver.tmp + γ * z₂ # Y₂ - k2 = dt * f2(u, p, t + γdt) - integrator.stats.nf2 += 1 - tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - tmp = uprev + a32 * z₂ - end - nlsolver.z = z₃ = z₂ - nlsolver.tmp = tmp - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 4 (c₄=1/2) - if integrator.f isa SplitFunction - u = nlsolver.tmp + γ * z₃ # Y₃ - k3 = dt * f2(u, p, t + c3 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a42 * z₂ + a43 * z₃ + ea41 * k1 + ea42 * k2 + ea43 * k3 - else - tmp = uprev + a42 * z₂ + a43 * z₃ - end - nlsolver.z = z₄ = z₃ - nlsolver.tmp = tmp - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 5 (c₅=1) - if integrator.f isa SplitFunction - u = nlsolver.tmp + γ * z₄ # Y₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a52 * z₂ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 - else - tmp = uprev + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ = z₄ - nlsolver.tmp = tmp - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Final update: ISA → u = tmp + γ*z₅; for split add explicit contributions - u = nlsolver.tmp + γ * z₅ - if integrator.f isa SplitFunction - k5 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - # b̃ = b_i = [0, eb2, eb3, eb4, eb5] - u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + eb4 * (z₄ + k4) + eb5 * (z₅ + k5) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₅ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::ARS443Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, k1, k2, k3, k4, k5, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a32, a42, a43, a52, a53, a54, - ea21, ea31, ea32, ea41, ea42, ea43, ea51, ea52, ea53, ea54, - eb2, eb3, eb4, eb5, c3, c4) = cache.tab - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - γdt = γ * dt - markfirststage!(nlsolver) - - # Stage 1 (ESDIRK: trivial, FSAL) - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - f(z₁, uprev, p, t) - z₁ .*= dt - else - @.. broadcast = false z₁ = dt * integrator.fsalfirst - end - - # Stage 2 (a21=0, c₂=γ) - copyto!(z₂, z₁) - nlsolver.z = z₂ - @.. broadcast = false tmp = uprev # a21=0 - if integrator.f isa SplitFunction - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 - end - nlsolver.c = γ - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - # Stage 3 (c₃=2/3) - if integrator.f isa SplitFunction - @.. broadcast = false u = tmp + γ * z₂ # Y₂ - f2(k2, u, p, t + γdt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a32 * z₂ + ea31 * k1 + ea32 * k2 - else - @.. broadcast = false tmp = uprev + a32 * z₂ - end - copyto!(z₃, z₂) - nlsolver.z = z₃ - nlsolver.c = c3 - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 4 (c₄=1/2) - if integrator.f isa SplitFunction - @.. broadcast = false u = tmp + γ * z₃ # Y₃ - f2(k3, u, p, t + c3 * dt) - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a42 * z₂ + a43 * z₃ + ea41 * k1 + - ea42 * k2 + ea43 * k3 - else - @.. broadcast = false tmp = uprev + a42 * z₂ + a43 * z₃ - end - copyto!(z₄, z₃) - nlsolver.z = z₄ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 5 (c₅=1) - if integrator.f isa SplitFunction - @.. broadcast = false u = tmp + γ * z₄ # Y₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a52 * z₂ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 - else - @.. broadcast = false tmp = uprev + a52 * z₂ + a53 * z₃ + a54 * z₄ - end - copyto!(z₅, z₄) - nlsolver.z = z₅ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Final update - @.. broadcast = false u = tmp + γ * z₅ # ISA (or Y₅ for split) - if integrator.f isa SplitFunction - f2(k5, u, p, t + dt) - k5 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false u = uprev + eb2 * (z₂ + k2) + eb3 * (z₃ + k3) + - eb4 * (z₄ + k4) + eb5 * (z₅ + k5) - end - - step_limiter!(u, integrator, p, t + dt) - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast = false integrator.fsallast = z₅ / dt - end -end - -# --------------------------------------------------------------------------- -# BHR(5,5,3)* — 5-stage, 3rd order, ESDIRK, ISA (Boscarino & Russo 2009) -# --------------------------------------------------------------------------- - -@muladd function perform_step!( - integrator, cache::BHR553ConstantCache, repeat_step = false - ) - (; t, dt, uprev, u, p) = integrator - nlsolver = cache.nlsolver - (; γ, a21, a31, a41, a43, a51, a53, a54, - ea21, ea31, ea32, ea41, ea43, ea51, ea52, ea53, ea54, - eb1, eb3, eb4, eb5, c2, c4) = cache.tab - - f2 = nothing - k1 = nothing - k2 = nothing - k3 = nothing - k4 = nothing - k5 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - markfirststage!(nlsolver) - - # Stage 1 (ESDIRK: trivial, FSAL) - if integrator.f isa SplitFunction - z₁ = dt .* f(uprev, p, t) - k1 = dt * integrator.fsalfirst - z₁ - else - z₁ = dt * integrator.fsalfirst - end - - # Stage 2 (a21=γ, c₂=2γ) - nlsolver.z = z₂ = z₁ - nlsolver.tmp = uprev + a21 * z₁ - if integrator.f isa SplitFunction - nlsolver.tmp = nlsolver.tmp + ea21 * k1 - end - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 3 (a31=γ, a32=0, c₃=2γ; same c as stage 2) - if integrator.f isa SplitFunction - u = nlsolver.tmp + γ * z₂ # Y₂ - k2 = dt * f2(u, p, t + c2 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a31 * z₁ + ea31 * k1 + ea32 * k2 # a32=0 implicit - else - tmp = uprev + a31 * z₁ - end - nlsolver.z = z₃ = z₂ - nlsolver.tmp = tmp - nlsolver.c = c2 # c₃ = c₂ = 2γ - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 4 (a41, a42=0, a43, c₄=1.5) - if integrator.f isa SplitFunction - u = nlsolver.tmp + γ * z₃ # Y₃ - k3 = dt * f2(u, p, t + c2 * dt) # c₃ = 2γ - integrator.stats.nf2 += 1 - tmp = uprev + a41 * z₁ + a43 * z₃ + ea41 * k1 + ea43 * k3 # a42=0, ae42=0 - else - tmp = uprev + a41 * z₁ + a43 * z₃ - end - nlsolver.z = z₄ = z₃ - nlsolver.tmp = tmp - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 5 (a51, a52=0, a53, a54, c₅=1) - if integrator.f isa SplitFunction - u = nlsolver.tmp + γ * z₄ # Y₄ - k4 = dt * f2(u, p, t + c4 * dt) - integrator.stats.nf2 += 1 - tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 # a52=0 - else - tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ - end - nlsolver.z = z₅ = z₄ - nlsolver.tmp = tmp - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Final update: ISA → u = tmp + γ*z₅; for split use b̃ = b_i - u = nlsolver.tmp + γ * z₅ - if integrator.f isa SplitFunction - k5 = dt * f2(u, p, t + dt) - integrator.stats.nf2 += 1 - # b̃ = [eb1, 0, eb3, eb4, eb5]; b̃₂=0 so no z₂+k2 term - u = uprev + eb1 * (z₁ + k1) + eb3 * (z₃ + k3) + eb4 * (z₄ + k4) + eb5 * (z₅ + k5) - end - - if integrator.f isa SplitFunction - integrator.k[1] = integrator.fsalfirst - integrator.fsallast = integrator.f(u, p, t + dt) - integrator.k[2] = integrator.fsallast - else - integrator.fsallast = z₅ ./ dt - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - end - integrator.u = u -end - -@muladd function perform_step!(integrator, cache::BHR553Cache, repeat_step = false) - (; t, dt, uprev, u, p) = integrator - (; z₁, z₂, z₃, z₄, z₅, k1, k2, k3, k4, k5, nlsolver, step_limiter!) = cache - (; tmp) = nlsolver - (; γ, a21, a31, a41, a43, a51, a53, a54, - ea21, ea31, ea32, ea41, ea43, ea51, ea52, ea53, ea54, - eb1, eb3, eb4, eb5, c2, c4) = cache.tab - - f2 = nothing - if integrator.f isa SplitFunction - f = integrator.f.f1 - f2 = integrator.f.f2 - else - f = integrator.f - end - - markfirststage!(nlsolver) - - # Stage 1 (ESDIRK: trivial, FSAL) - if integrator.f isa SplitFunction && !repeat_step && !integrator.last_stepfail - f(z₁, uprev, p, t) - z₁ .*= dt - else - @.. broadcast = false z₁ = dt * integrator.fsalfirst - end - - # Stage 2 (a21=γ, c₂=2γ) - copyto!(z₂, z₁) - nlsolver.z = z₂ - @.. broadcast = false tmp = uprev + a21 * z₁ - if integrator.f isa SplitFunction - @.. broadcast = false k1 = dt * integrator.fsalfirst - z₁ - @.. broadcast = false tmp += ea21 * k1 - end - nlsolver.c = c2 - z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - isnewton(nlsolver) && set_new_W!(nlsolver, false) - - # Stage 3 (a31=γ, a32=0, c₃=2γ) - if integrator.f isa SplitFunction - @.. broadcast = false u = tmp + γ * z₂ # Y₂ - f2(k2, u, p, t + c2 * dt) - k2 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a31 * z₁ + ea31 * k1 + ea32 * k2 # a32=0 implicit - else - @.. broadcast = false tmp = uprev + a31 * z₁ - end - copyto!(z₃, z₂) - nlsolver.z = z₃ - nlsolver.c = c2 # c₃ = 2γ - z₃ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 4 (a41, a42=0, a43, c₄=1.5) - if integrator.f isa SplitFunction - @.. broadcast = false u = tmp + γ * z₃ # Y₃ - f2(k3, u, p, t + c2 * dt) # c₃ = 2γ - k3 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a41 * z₁ + a43 * z₃ + - ea41 * k1 + ea43 * k3 # a42=0, ae42=0 - else - @.. broadcast = false tmp = uprev + a41 * z₁ + a43 * z₃ - end - copyto!(z₄, z₃) - nlsolver.z = z₄ - nlsolver.c = c4 - z₄ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Stage 5 (a51, a52=0, a53, a54, c₅=1) - if integrator.f isa SplitFunction - @.. broadcast = false u = tmp + γ * z₄ # Y₄ - f2(k4, u, p, t + c4 * dt) - k4 .*= dt - integrator.stats.nf2 += 1 - @.. broadcast = false tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ + - ea51 * k1 + ea52 * k2 + ea53 * k3 + ea54 * k4 # a52=0 - else - @.. broadcast = false tmp = uprev + a51 * z₁ + a53 * z₃ + a54 * z₄ - end - copyto!(z₅, z₄) - nlsolver.z = z₅ - nlsolver.c = 1 - z₅ = nlsolve!(nlsolver, integrator, cache, repeat_step) - nlsolvefail(nlsolver) && return - - # Final update - @.. broadcast = false u = tmp + γ * z₅ # ISA (or Y₅ for split) - if integrator.f isa SplitFunction - f2(k5, u, p, t + dt) - k5 .*= dt - integrator.stats.nf2 += 1 - # b̃ = [eb1, 0, eb3, eb4, eb5]; b̃₂=0 - @.. broadcast = false u = uprev + eb1 * (z₁ + k1) + eb3 * (z₃ + k3) + - eb4 * (z₄ + k4) + eb5 * (z₅ + k5) - end - - step_limiter!(u, integrator, p, t + dt) - - if integrator.f isa SplitFunction - integrator.f(integrator.fsallast, u, p, t + dt) - else - @.. broadcast = false integrator.fsallast = z₅ / dt - end -end +# ARS and BHR IMEX methods are handled by generic_imex_perform_step.jl diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl index c0c75f30f62..31dac2087e8 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl @@ -3009,253 +3009,3 @@ function IMEXSSP3433Tableau(T, T2) eb4 = convert(T, 2 // 3) return IMEXSSP3433Tableau(γ, a21, a32, a41, a42, a43, c3, c4, ea32, ea42, ea43, eb2, eb3, eb4) end - -#= -ARS(2,2,2) — Ascher, Ruuth & Spiteri (1997), Table II -3-stage, 2nd order, L-stable, ESDIRK, ISA - -γ = 1 - √2/2 - -Implicit: c = [0, γ, 1] - A = [0 0 0] - [0 γ 0] - [0 1-γ γ] - b = [0, 1-γ, γ] (ISA = last row) - -Explicit: - à = [0 0 0] - [γ 0 0] - [δ 1-δ 0] - b̃ = [δ, 1-δ, 0] where δ = 1 - 1/(2γ) -=# -struct ARS222Tableau{T, T2} - γ::T2 - a32::T # = 1-γ - ea21::T # = γ - ea31::T # = δ = 1 - 1/(2γ) - ea32::T # = 1-δ - eb1::T # = δ - eb2::T # = 1-δ -end - -function ARS222Tableau(T, T2) - γ = convert(T2, 1 - sqrt(T2(2)) / 2) - δ = 1 - 1 / (2γ) - return ARS222Tableau{T, T2}( - γ, - convert(T, 1 - γ), # a32 - convert(T, γ), # ea21 - convert(T, δ), # ea31 - convert(T, 1 - δ), # ea32 - convert(T, δ), # eb1 - convert(T, 1 - δ), # eb2 - ) -end - -#= -ARS(2,3,2) — Ascher, Ruuth & Spiteri (1997) -3-stage, 2nd order, ESDIRK, ISA - -Same implicit tableau as ARS(2,2,2), different explicit part. - -Explicit: - à = [0 0 0] - [γ 0 0] - [δ' 1-δ' 0] - b̃ = [0, 1-γ, γ] where δ' = -2√2/3 -=# -struct ARS232Tableau{T, T2} - γ::T2 - a32::T # = 1-γ - ea21::T # = γ - ea31::T # = δ' = -2√2/3 - ea32::T # = 1-δ' - eb2::T # = 1-γ (b̃₁=0) - eb3::T # = γ -end - -function ARS232Tableau(T, T2) - γ = convert(T2, 1 - sqrt(T2(2)) / 2) - δ = convert(T, -2 * sqrt(T(2)) / 3) - return ARS232Tableau{T, T2}( - γ, - convert(T, 1 - γ), # a32 - convert(T, γ), # ea21 - δ, # ea31 = δ' - 1 - δ, # ea32 = 1-δ' - convert(T, 1 - γ), # eb2 - convert(T, γ), # eb3 - ) -end - -#= -ARS(4,4,3) — Ascher, Ruuth & Spiteri (1997), Table IV -5-stage, 3rd order, ESDIRK, ISA, γ = 1/2 - -Implicit: - c = [0, 1/2, 2/3, 1/2, 1] - A = [0 0 0 0 0 ] - [0 1/2 0 0 0 ] - [0 1/6 1/2 0 0 ] - [0 -1/2 1/2 1/2 0 ] - [0 3/2 -3/2 1/2 1/2 ] (= b, ISA) - b = [0, 3/2, -3/2, 1/2, 1/2] - -Explicit: - à = [0 0 0 0 0] - [1/2 0 0 0 0] - [11/18 1/18 0 0 0] - [5/6 -5/6 1/2 0 0] - [1/4 7/4 3/4 -7/4 0] - b̃ = [0, 3/2, -3/2, 1/2, 1/2] (= b_i) -=# -struct ARS443Tableau{T, T2} - γ::T2 # = 1/2 - a32::T # = 1/6 - a42::T # = -1/2 - a43::T # = 1/2 - a52::T # = 3/2 - a53::T # = -3/2 - a54::T # = 1/2 - ea21::T # = 1/2 - ea31::T # = 11/18 - ea32::T # = 1/18 - ea41::T # = 5/6 - ea42::T # = -5/6 - ea43::T # = 1/2 - ea51::T # = 1/4 - ea52::T # = 7/4 - ea53::T # = 3/4 - ea54::T # = -7/4 - eb2::T # = 3/2 (b̃₁=0) - eb3::T # = -3/2 - eb4::T # = 1/2 - eb5::T # = 1/2 - c3::T2 # = 2/3 - c4::T2 # = 1/2 -end - -function ARS443Tableau(T, T2) - γ = convert(T2, 1 // 2) - return ARS443Tableau{T, T2}( - γ, - convert(T, 1 // 6), # a32 - convert(T, -1 // 2), # a42 - convert(T, 1 // 2), # a43 - convert(T, 3 // 2), # a52 - convert(T, -3 // 2), # a53 - convert(T, 1 // 2), # a54 - convert(T, 1 // 2), # ea21 - convert(T, 11 // 18), # ea31 - convert(T, 1 // 18), # ea32 - convert(T, 5 // 6), # ea41 - convert(T, -5 // 6), # ea42 - convert(T, 1 // 2), # ea43 - convert(T, 1 // 4), # ea51 - convert(T, 7 // 4), # ea52 - convert(T, 3 // 4), # ea53 - convert(T, -7 // 4), # ea54 - convert(T, 3 // 2), # eb2 - convert(T, -3 // 2), # eb3 - convert(T, 1 // 2), # eb4 - convert(T, 1 // 2), # eb5 - convert(T2, 2 // 3), # c3 - convert(T2, 1 // 2), # c4 - ) -end - -#= -BHR(5,5,3)* — Boscarino & Russo (2009) -5-stage, 3rd order, ESDIRK, ISA, L-stable - -γ = 0.435866521508460 - -Implicit: - c = [0, 2γ, 2γ, c₄=1.5, 1] - A = [0 0 0 0 0] - [γ γ 0 0 0] - [γ 0 γ 0 0] - [a41 0 a43 γ 0] - [a51 0 a53 a54 γ] (= b, ISA) - b = [a51, 0, a53, a54, γ] - where: - a41 = 3c₄/2 - c₄²/(4γ) - γ, a43 = c₄²/(4γ) - c₄/2 - b3 = 0.362863385578740, b4 = -0.168124349878957 - -Explicit: - à = [0 0 0 0 0] - [2γ 0 0 0 0] - [γ γ 0 0 0] - [ea41 0 ea43 0 0] - [ea51 ea52 ã53 ã54 0] - b̃ = b_i - where: - ea41 = c₄ - c₄²/(4γ), ea43 = c₄²/(4γ) - ã53 = 1.195970114894582, ã54 = -0.150831109536248 - ea51 = 1 + b3 - ã53 - ã54, ea52 = -b3 -=# -struct BHR553Tableau{T, T2} - γ::T2 - a21::T # = γ - a31::T # = γ - a41::T # = 3c₄/2 - c₄²/(4γ) - γ - a43::T # = c₄²/(4γ) - c₄/2 - a51::T # = 1 - b3 - b4 - γ - a53::T # = b3 - a54::T # = b4 - ea21::T # = 2γ - ea31::T # = γ - ea32::T # = γ - ea41::T # = c₄ - c₄²/(4γ) - ea43::T # = c₄²/(4γ) - ea51::T # = 1 + b3 - ã53 - ã54 - ea52::T # = -b3 - ea53::T # = ã53 - ea54::T # = ã54 - eb1::T # = a51 (= 1 - b3 - b4 - γ) - eb3::T # = b3 - eb4::T # = b4 - eb5::T # = γ - c2::T2 # = 2γ - c4::T2 # = 1.5 -end - -function BHR553Tableau(T, T2) - γ = convert(T2, 0.435866521508460) - b3 = 0.362863385578740 - b4 = -0.168124349878957 - c4val = 1.5 - ã53 = 1.195970114894582 - ã54 = -0.150831109536248 - a41val = 3 * c4val / 2 - c4val^2 / (4 * γ) - γ - a43val = c4val^2 / (4 * γ) - c4val / 2 - a51val = 1 - b3 - b4 - γ - ea41val = c4val - c4val^2 / (4 * γ) - ea43val = c4val^2 / (4 * γ) - ea51val = 1 + b3 - ã53 - ã54 - return BHR553Tableau{T, T2}( - γ, - convert(T, γ), # a21 - convert(T, γ), # a31 - convert(T, a41val), # a41 - convert(T, a43val), # a43 - convert(T, a51val), # a51 - convert(T, b3), # a53 - convert(T, b4), # a54 - convert(T, 2 * γ), # ea21 - convert(T, γ), # ea31 - convert(T, γ), # ea32 - convert(T, ea41val), # ea41 - convert(T, ea43val), # ea43 - convert(T, ea51val), # ea51 - convert(T, -b3), # ea52 - convert(T, ã53), # ea53 - convert(T, ã54), # ea54 - convert(T, a51val), # eb1 - convert(T, b3), # eb3 - convert(T, b4), # eb4 - convert(T, γ), # eb5 - convert(T2, 2 * γ), # c2 - convert(T2, c4val), # c4 - ) -end