@@ -959,3 +959,159 @@ end
959959# ## RosenbrockW6S4O
960960
961961@RosenbrockW6S4OS (:cache )
962+
963+ # ###############################################################################
964+
965+ # ## Tsit5DA - hybrid explicit/linear-implicit method for DAEs
966+
967+ struct HybridExplicitImplicitConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
968+ tf:: TF
969+ uf:: UF
970+ tab:: Tab
971+ J:: JType
972+ W:: WType
973+ linsolve:: F
974+ autodiff:: AD
975+ interp_order:: Int
976+ end
977+
978+ mutable struct HybridExplicitImplicitCache{
979+ uType, rateType, uNoUnitsType, JType, WType, TabType,
980+ TFType, UFType, F, JCType, GCType, RTolType, A,
981+ StepLimiter, StageLimiter, DVType, AVType,
982+ GZType, GYType, WZType, FZ,
983+ } <: RosenbrockMutableCache
984+ u:: uType
985+ uprev:: uType
986+ dense:: Vector{rateType}
987+ du:: rateType
988+ du1:: rateType
989+ du2:: rateType
990+ ks:: Vector{rateType}
991+ fsalfirst:: rateType
992+ fsallast:: rateType
993+ dT:: rateType
994+ J:: JType
995+ W:: WType
996+ tmp:: rateType
997+ atmp:: uNoUnitsType
998+ weight:: uNoUnitsType
999+ tab:: TabType
1000+ tf:: TFType
1001+ uf:: UFType
1002+ linsolve_tmp:: rateType
1003+ linsolve:: F
1004+ jac_config:: JCType
1005+ grad_config:: GCType
1006+ reltol:: RTolType
1007+ alg:: A
1008+ step_limiter!:: StepLimiter
1009+ stage_limiter!:: StageLimiter
1010+ interp_order:: Int
1011+ # DAE-specific fields
1012+ diff_vars:: DVType
1013+ alg_vars:: AVType
1014+ g_z:: GZType # n_g x n_g algebraic Jacobian block
1015+ g_y:: GYType # n_g x n_f coupling block
1016+ W_z:: WZType # -gamma * g_z (used for linear solve)
1017+ linsolve_tmp_z:: FZ # n_g-sized RHS for algebraic solve
1018+ end
1019+ function full_cache (c:: HybridExplicitImplicitCache )
1020+ return [
1021+ c. u, c. uprev, c. dense... , c. du, c. du1, c. du2,
1022+ c. ks... , c. fsalfirst, c. fsallast, c. dT, c. tmp, c. atmp, c. weight, c. linsolve_tmp,
1023+ ]
1024+ end
1025+
1026+ function alg_cache (
1027+ alg:: HybridExplicitImplicitRK , u, rate_prototype, :: Type{uEltypeNoUnits} ,
1028+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
1029+ dt, reltol, p, calck,
1030+ :: Val{false} , verbose
1031+ ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
1032+ tf = TimeDerivativeWrapper (f, u, p)
1033+ uf = UDerivativeWrapper (f, t, p)
1034+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, nothing , uEltypeNoUnits, Val (false ))
1035+ tab = alg. tab (constvalue (uBottomEltypeNoUnits), constvalue (tTypeNoUnits))
1036+ return HybridExplicitImplicitConstantCache (
1037+ tf, uf, tab, J, W, nothing , alg_autodiff (alg), size (tab. H, 1 )
1038+ )
1039+ end
1040+
1041+ function alg_cache (
1042+ alg:: HybridExplicitImplicitRK , u, rate_prototype, :: Type{uEltypeNoUnits} ,
1043+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
1044+ dt, reltol, p, calck,
1045+ :: Val{true} , verbose
1046+ ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
1047+ tab = alg. tab (constvalue (uBottomEltypeNoUnits), constvalue (tTypeNoUnits))
1048+ num_stages = size (tab. A, 1 )
1049+ interp_order = size (tab. H, 1 )
1050+
1051+ # Initialize vectors
1052+ dense = [zero (rate_prototype) for _ in 1 : interp_order]
1053+ ks = [zero (rate_prototype) for _ in 1 : num_stages]
1054+ du = zero (rate_prototype)
1055+ du1 = zero (rate_prototype)
1056+ du2 = zero (rate_prototype)
1057+ fsalfirst = zero (rate_prototype)
1058+ fsallast = zero (rate_prototype)
1059+ dT = zero (rate_prototype)
1060+ tmp = zero (rate_prototype)
1061+ atmp = similar (u, uEltypeNoUnits)
1062+ recursivefill! (atmp, false )
1063+ weight = similar (u, uEltypeNoUnits)
1064+ recursivefill! (weight, false )
1065+ linsolve_tmp = zero (rate_prototype)
1066+
1067+ tf = TimeGradientWrapper (f, uprev, p)
1068+ uf = UJacobianWrapper (f, t, p)
1069+
1070+ grad_config = build_grad_config (alg, f, tf, du1, t)
1071+ jac_config = build_jac_config (alg, f, uf, du1, uprev, u, tmp, du2)
1072+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val (true ))
1073+
1074+ linprob = LinearProblem (W, _vec (linsolve_tmp); u0 = _vec (tmp))
1075+ Pl, Pr = wrapprecs (
1076+ alg. precs (
1077+ W, nothing , u, p, t, nothing , nothing , nothing ,
1078+ nothing
1079+ )... , weight, tmp
1080+ )
1081+ linsolve = init (
1082+ linprob, alg. linsolve, alias = LinearAliasSpecifier (alias_A = true , alias_b = true ),
1083+ Pl = Pl, Pr = Pr,
1084+ assumptions = LinearSolve. OperatorAssumptions (true ),
1085+ verbose = verbose. linear_verbosity
1086+ )
1087+
1088+ # Detect algebraic variables from mass matrix
1089+ mass_matrix = f. mass_matrix
1090+ if mass_matrix === I
1091+ diff_vars = collect (1 : length (u))
1092+ alg_vars = Int[]
1093+ g_z = zeros (eltype (u), 0 , 0 )
1094+ g_y = zeros (eltype (u), 0 , 0 )
1095+ W_z = zeros (eltype (u), 0 , 0 )
1096+ linsolve_tmp_z = zeros (eltype (u), 0 )
1097+ linsolve_z = nothing
1098+ else
1099+ n = length (u)
1100+ diff_vars = findall (i -> mass_matrix[i, i] != 0 , 1 : n)
1101+ alg_vars = findall (i -> mass_matrix[i, i] == 0 , 1 : n)
1102+ n_g = length (alg_vars)
1103+ n_f = length (diff_vars)
1104+ g_z = zeros (eltype (u), n_g, n_g)
1105+ g_y = zeros (eltype (u), n_g, n_f)
1106+ W_z = zeros (eltype (u), n_g, n_g)
1107+ linsolve_tmp_z = zeros (eltype (u), n_g)
1108+ end
1109+
1110+ return HybridExplicitImplicitCache (
1111+ u, uprev, dense, du, du1, du2, ks,
1112+ fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
1113+ linsolve_tmp, linsolve, jac_config, grad_config, reltol, alg,
1114+ alg. step_limiter!, alg. stage_limiter!, interp_order,
1115+ diff_vars, alg_vars, g_z, g_y, W_z, linsolve_tmp_z
1116+ )
1117+ end
0 commit comments