Skip to content

Commit 6f57229

Browse files
Merge pull request #3070 from ChrisRackauckas-Claude/add-tsit5da-method
Add Tsit5DA: hybrid explicit/linear-implicit solver for DAEs
2 parents 2f95bda + 7c77b45 commit 6f57229

9 files changed

Lines changed: 730 additions & 8 deletions

lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import LinearSolve
2323
import LinearSolve: UniformScaling
2424
import ForwardDiff
2525
using FiniteDiff
26-
using LinearAlgebra: mul!, diag, diagm, I, Diagonal, norm, lu!
26+
using LinearAlgebra: mul!, diag, diagm, I, Diagonal, norm, lu, lu!
2727
using ADTypes
2828
import OrdinaryDiffEqCore, OrdinaryDiffEqDifferentiation
2929

@@ -293,7 +293,7 @@ end
293293

294294
export Rosenbrock23, Rosenbrock32, RosShamp4, Veldd4, Velds4, GRK4T, GRK4A,
295295
Ros4LStab, ROS3P, Rodas3, Rodas23W, Rodas3P, Rodas4, Rodas42, Rodas4P, Rodas4P2,
296-
Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr, Rodas6P,
296+
Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr, Rodas6P, HybridExplicitImplicitRK, Tsit5DA,
297297
RosenbrockW6S4OS, ROS34PW1a, ROS34PW1b, ROS34PW2, ROS34PW3, ROS34PRw, ROS3PRL,
298298
ROS3PRL2, ROK4a,
299299
ROS2, ROS2PR, ROS2S, ROS3, ROS3PR, Scholz4_7

lib/OrdinaryDiffEqRosenbrock/src/alg_utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ alg_order(alg::Rodas5P) = 5
3434
alg_order(alg::Rodas5Pr) = 5
3535
alg_order(alg::Rodas5Pe) = 5
3636
alg_order(alg::Rodas6P) = 6
37+
alg_order(alg::HybridExplicitImplicitRK) = alg.order
3738

3839
alg_adaptive_order(alg::Rosenbrock32) = 2
3940
alg_adaptive_order(alg::Rosenbrock23) = 3
@@ -61,12 +62,14 @@ isfsal(alg::Rodas42) = false
6162
isfsal(alg::Rodas4P) = false
6263
isfsal(alg::Rodas4P2) = false
6364
isfsal(alg::Rodas6P) = false
65+
isfsal(alg::HybridExplicitImplicitRK) = false
6466

6567
function has_stiff_interpolation(
6668
::Union{
6769
Rosenbrock23, Rosenbrock32, Rodas23W,
6870
Rodas3P, Rodas4, Rodas4P, Rodas4P2, Rodas5,
6971
Rodas5P, Rodas5Pe, Rodas5Pr, Rodas6P,
72+
HybridExplicitImplicitRK,
7073
}
7174
)
7275
return true

lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,68 @@ for (Alg, desc, refs, is_W) in [
139139
end
140140
end
141141
end
142+
143+
################################################################################
144+
# HybridExplicitImplicitRK — generic tableau-based hybrid explicit/linear-implicit method
145+
################################################################################
146+
147+
struct HybridExplicitImplicitRK{TabType, CS, AD, F, P, FDT, ST, CJ, StepLimiter, StageLimiter} <:
148+
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
149+
tab::TabType
150+
order::Int
151+
linsolve::F
152+
precs::P
153+
step_limiter!::StepLimiter
154+
stage_limiter!::StageLimiter
155+
autodiff::AD
156+
end
157+
158+
function HybridExplicitImplicitRK(tab;
159+
order,
160+
chunk_size = Val{0}(), autodiff = AutoForwardDiff(),
161+
standardtag = Val{true}(), concrete_jac = nothing,
162+
diff_type = Val{:forward}(), linsolve = nothing,
163+
precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!,
164+
stage_limiter! = trivial_limiter!
165+
)
166+
AD_choice, chunk_size, diff_type = _process_AD_choice(
167+
autodiff, chunk_size, diff_type
168+
)
169+
return HybridExplicitImplicitRK{
170+
typeof(tab), _unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve),
171+
typeof(precs), diff_type, _unwrap_val(standardtag),
172+
_unwrap_val(concrete_jac), typeof(step_limiter!),
173+
typeof(stage_limiter!),
174+
}(
175+
tab, order, linsolve, precs, step_limiter!,
176+
stage_limiter!, AD_choice
177+
)
178+
end
179+
180+
# Keyword-only constructor for remake support
181+
function HybridExplicitImplicitRK(;
182+
tab,
183+
order,
184+
chunk_size = Val{0}(), autodiff = AutoForwardDiff(),
185+
standardtag = Val{true}(), concrete_jac = nothing,
186+
diff_type = Val{:forward}(), linsolve = nothing,
187+
precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!,
188+
stage_limiter! = trivial_limiter!
189+
)
190+
return HybridExplicitImplicitRK(tab;
191+
order, chunk_size, autodiff, standardtag, concrete_jac,
192+
diff_type, linsolve, precs, step_limiter!, stage_limiter!
193+
)
194+
end
195+
196+
"""
197+
A 12-stage order 5(4) hybrid explicit/linear-implicit method for semi-explicit index-1 DAEs.
198+
Differential variables are treated explicitly (like Tsit5), algebraic variables use Rosenbrock-type
199+
linear-implicit stages. Only the small algebraic Jacobian block needs factorization.
200+
For pure ODEs (no algebraic constraints), reduces to an explicit Runge-Kutta method.
201+
202+
References:
203+
- Steinebach G., Rodas6P and Tsit5DA - two new Rosenbrock-type methods for DAEs.
204+
arXiv:2511.21252, 2025.
205+
"""
206+
Tsit5DA(; kwargs...) = HybridExplicitImplicitRK(Tsit5DATableau; order = 5, kwargs...)

lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ function SciMLBase.interp_summary(
2121
Union{
2222
RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache,
2323
RosenbrockCache, Rodas23WCache, Rodas3PCache,
24+
HybridExplicitImplicitConstantCache, HybridExplicitImplicitCache,
2425
},
2526
}
2627
return dense ? "specialized 3rd order \"free\" stiffness-aware interpolation" :

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,3 +959,159 @@ end
959959
### RosenbrockW6S4O
960960

961961
@RosenbrockW6S4OS(:cache)
962+
963+
################################################################################
964+
965+
### Tsit5DA - hybrid explicit/linear-implicit method for DAEs
966+
967+
struct HybridExplicitImplicitConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
968+
tf::TF
969+
uf::UF
970+
tab::Tab
971+
J::JType
972+
W::WType
973+
linsolve::F
974+
autodiff::AD
975+
interp_order::Int
976+
end
977+
978+
mutable struct HybridExplicitImplicitCache{
979+
uType, rateType, uNoUnitsType, JType, WType, TabType,
980+
TFType, UFType, F, JCType, GCType, RTolType, A,
981+
StepLimiter, StageLimiter, DVType, AVType,
982+
GZType, GYType, WZType, FZ,
983+
} <: RosenbrockMutableCache
984+
u::uType
985+
uprev::uType
986+
dense::Vector{rateType}
987+
du::rateType
988+
du1::rateType
989+
du2::rateType
990+
ks::Vector{rateType}
991+
fsalfirst::rateType
992+
fsallast::rateType
993+
dT::rateType
994+
J::JType
995+
W::WType
996+
tmp::rateType
997+
atmp::uNoUnitsType
998+
weight::uNoUnitsType
999+
tab::TabType
1000+
tf::TFType
1001+
uf::UFType
1002+
linsolve_tmp::rateType
1003+
linsolve::F
1004+
jac_config::JCType
1005+
grad_config::GCType
1006+
reltol::RTolType
1007+
alg::A
1008+
step_limiter!::StepLimiter
1009+
stage_limiter!::StageLimiter
1010+
interp_order::Int
1011+
# DAE-specific fields
1012+
diff_vars::DVType
1013+
alg_vars::AVType
1014+
g_z::GZType # n_g x n_g algebraic Jacobian block
1015+
g_y::GYType # n_g x n_f coupling block
1016+
W_z::WZType # -gamma * g_z (used for linear solve)
1017+
linsolve_tmp_z::FZ # n_g-sized RHS for algebraic solve
1018+
end
1019+
function full_cache(c::HybridExplicitImplicitCache)
1020+
return [
1021+
c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
1022+
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp,
1023+
]
1024+
end
1025+
1026+
function alg_cache(
1027+
alg::HybridExplicitImplicitRK, u, rate_prototype, ::Type{uEltypeNoUnits},
1028+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
1029+
dt, reltol, p, calck,
1030+
::Val{false}, verbose
1031+
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
1032+
tf = TimeDerivativeWrapper(f, u, p)
1033+
uf = UDerivativeWrapper(f, t, p)
1034+
J, W = build_J_W(alg, u, uprev, p, t, dt, f, nothing, uEltypeNoUnits, Val(false))
1035+
tab = alg.tab(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
1036+
return HybridExplicitImplicitConstantCache(
1037+
tf, uf, tab, J, W, nothing, alg_autodiff(alg), size(tab.H, 1)
1038+
)
1039+
end
1040+
1041+
function alg_cache(
1042+
alg::HybridExplicitImplicitRK, u, rate_prototype, ::Type{uEltypeNoUnits},
1043+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
1044+
dt, reltol, p, calck,
1045+
::Val{true}, verbose
1046+
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
1047+
tab = alg.tab(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
1048+
num_stages = size(tab.A, 1)
1049+
interp_order = size(tab.H, 1)
1050+
1051+
# Initialize vectors
1052+
dense = [zero(rate_prototype) for _ in 1:interp_order]
1053+
ks = [zero(rate_prototype) for _ in 1:num_stages]
1054+
du = zero(rate_prototype)
1055+
du1 = zero(rate_prototype)
1056+
du2 = zero(rate_prototype)
1057+
fsalfirst = zero(rate_prototype)
1058+
fsallast = zero(rate_prototype)
1059+
dT = zero(rate_prototype)
1060+
tmp = zero(rate_prototype)
1061+
atmp = similar(u, uEltypeNoUnits)
1062+
recursivefill!(atmp, false)
1063+
weight = similar(u, uEltypeNoUnits)
1064+
recursivefill!(weight, false)
1065+
linsolve_tmp = zero(rate_prototype)
1066+
1067+
tf = TimeGradientWrapper(f, uprev, p)
1068+
uf = UJacobianWrapper(f, t, p)
1069+
1070+
grad_config = build_grad_config(alg, f, tf, du1, t)
1071+
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
1072+
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
1073+
1074+
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
1075+
Pl, Pr = wrapprecs(
1076+
alg.precs(
1077+
W, nothing, u, p, t, nothing, nothing, nothing,
1078+
nothing
1079+
)..., weight, tmp
1080+
)
1081+
linsolve = init(
1082+
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
1083+
Pl = Pl, Pr = Pr,
1084+
assumptions = LinearSolve.OperatorAssumptions(true),
1085+
verbose = verbose.linear_verbosity
1086+
)
1087+
1088+
# Detect algebraic variables from mass matrix
1089+
mass_matrix = f.mass_matrix
1090+
if mass_matrix === I
1091+
diff_vars = collect(1:length(u))
1092+
alg_vars = Int[]
1093+
g_z = zeros(eltype(u), 0, 0)
1094+
g_y = zeros(eltype(u), 0, 0)
1095+
W_z = zeros(eltype(u), 0, 0)
1096+
linsolve_tmp_z = zeros(eltype(u), 0)
1097+
linsolve_z = nothing
1098+
else
1099+
n = length(u)
1100+
diff_vars = findall(i -> mass_matrix[i, i] != 0, 1:n)
1101+
alg_vars = findall(i -> mass_matrix[i, i] == 0, 1:n)
1102+
n_g = length(alg_vars)
1103+
n_f = length(diff_vars)
1104+
g_z = zeros(eltype(u), n_g, n_g)
1105+
g_y = zeros(eltype(u), n_g, n_f)
1106+
W_z = zeros(eltype(u), n_g, n_g)
1107+
linsolve_tmp_z = zeros(eltype(u), n_g)
1108+
end
1109+
1110+
return HybridExplicitImplicitCache(
1111+
u, uprev, dense, du, du1, du2, ks,
1112+
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
1113+
linsolve_tmp, linsolve, jac_config, grad_config, reltol, alg,
1114+
alg.step_limiter!, alg.stage_limiter!, interp_order,
1115+
diff_vars, alg_vars, g_z, g_y, W_z, linsolve_tmp_z
1116+
)
1117+
end

0 commit comments

Comments
 (0)