Skip to content

Commit 91d5cd5

Browse files
Harsh SinghHarsh Singh
authored andcommitted
fix(Rosenbrock): extract DAE fields into typed function for inference stability
The if mass_matrix === I branch in HybridExplicitImplicitRK alg_cache caused type instability. Extract into _build_dae_fields with explicit return type annotation to create a type stability barrier. Fixes @inferred solve(prob, Tsit5DA()) failure.
1 parent 87ea627 commit 91d5cd5

1 file changed

Lines changed: 23 additions & 21 deletions

File tree

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -365,27 +365,11 @@ function alg_cache(
365365
verbose = verbose.linear_verbosity
366366
)
367367

368-
# Detect algebraic variables from mass matrix
369-
mass_matrix = f.mass_matrix
370-
if mass_matrix === I
371-
diff_vars = collect(1:length(u))
372-
alg_vars = Int[]
373-
g_z = zeros(eltype(u), 0, 0)
374-
g_y = zeros(eltype(u), 0, 0)
375-
W_z = zeros(eltype(u), 0, 0)
376-
linsolve_tmp_z = zeros(eltype(u), 0)
377-
linsolve_z = nothing
378-
else
379-
n = length(u)
380-
diff_vars = findall(i -> mass_matrix[i, i] != 0, 1:n)
381-
alg_vars = findall(i -> mass_matrix[i, i] == 0, 1:n)
382-
n_g = length(alg_vars)
383-
n_f = length(diff_vars)
384-
g_z = zeros(eltype(u), n_g, n_g)
385-
g_y = zeros(eltype(u), n_g, n_f)
386-
W_z = zeros(eltype(u), n_g, n_g)
387-
linsolve_tmp_z = zeros(eltype(u), n_g)
388-
end
368+
# Detect algebraic variables from mass matrix and build DAE workspace.
369+
# Wrapped in a function call to create a type stability barrier —
370+
# the concrete cache type must not depend on the mass_matrix branch.
371+
diff_vars, alg_vars, g_z, g_y, W_z, linsolve_tmp_z = _build_dae_fields(
372+
f.mass_matrix, u)
389373

390374
return HybridExplicitImplicitCache(
391375
u, uprev, dense, du, du1, du2, ks,
@@ -395,3 +379,21 @@ function alg_cache(
395379
diff_vars, alg_vars, g_z, g_y, W_z, linsolve_tmp_z
396380
)
397381
end
382+
383+
function _build_dae_fields(
384+
mass_matrix, u
385+
)::Tuple{Vector{Int}, Vector{Int}, Matrix{eltype(u)}, Matrix{eltype(u)}, Matrix{eltype(u)}, Vector{eltype(u)}}
386+
T = eltype(u)
387+
n = length(u)
388+
if mass_matrix === I
389+
diff_vars = collect(1:n)
390+
alg_vars = Int[]
391+
else
392+
diff_vars = findall(i -> mass_matrix[i, i] != 0, 1:n)
393+
alg_vars = findall(i -> mass_matrix[i, i] == 0, 1:n)
394+
end
395+
n_g = length(alg_vars)
396+
n_f = length(diff_vars)
397+
return diff_vars, alg_vars, zeros(T, n_g, n_g), zeros(T, n_g, n_f),
398+
zeros(T, n_g, n_g), zeros(T, n_g)
399+
end

0 commit comments

Comments
 (0)