Skip to content

Commit 934dc0a

Browse files
Fix StiffInitDt DAE guard, type stability with BigFloat, and Julia 1.10 compatibility
- Change StiffInitDt DAE guard from 0.001*tdist to max(smalldt, dtmin). The old value was too large for DAEs with inconsistent initial conditions (e.g., ROBER with tspan (0, 1e5) gave dt=100, causing solver instability). Fixes OrdinaryDiffEqNonlinearSolve and OrdinaryDiffEqRosenbrock failures. - Fix type stability when u0 is BigFloat but tspan is Float64: hub_inv and yddnrm accumulations were promoted to BigFloat by u0 element operations, contaminating the step size h and causing t+h to become BigFloat. This broke FunctionWrappersWrappers which only has Float64 time wrappers. Fix: convert hub_inv, yddnrm, and f call time arguments back to _tType. Fixes InterfaceIV precision_mixing test failures. - Change BDF test to use `import OrdinaryDiffEqCore` with qualified access instead of `using OrdinaryDiffEqCore: initdt_alg` which fails on Julia 1.10 where unexported names cannot be accessed via `using M: name` syntax. Fixes OrdinaryDiffEqBDF LTS test failures. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent d114d04 commit 934dc0a

2 files changed

Lines changed: 28 additions & 24 deletions

File tree

lib/OrdinaryDiffEqBDF/test/stiff_initdt_tests.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
using Test
22
using OrdinaryDiffEqBDF
33
using OrdinaryDiffEqSDIRK
4-
using OrdinaryDiffEqCore: initdt_alg, DefaultInitDt, StiffInitDt
4+
import OrdinaryDiffEqCore
55

66
@testset "StiffInitDt Algorithm" begin
77

88
@testset "Trait dispatch" begin
99
# All implicit methods should use StiffInitDt
10-
@test initdt_alg(FBDF()) isa StiffInitDt
11-
@test initdt_alg(QNDF()) isa StiffInitDt
12-
@test initdt_alg(QNDF1()) isa StiffInitDt
13-
@test initdt_alg(QNDF2()) isa StiffInitDt
14-
@test initdt_alg(ABDF2()) isa StiffInitDt
15-
@test initdt_alg(ImplicitEuler()) isa StiffInitDt
10+
@test OrdinaryDiffEqCore.initdt_alg(FBDF()) isa OrdinaryDiffEqCore.StiffInitDt
11+
@test OrdinaryDiffEqCore.initdt_alg(QNDF()) isa OrdinaryDiffEqCore.StiffInitDt
12+
@test OrdinaryDiffEqCore.initdt_alg(QNDF1()) isa OrdinaryDiffEqCore.StiffInitDt
13+
@test OrdinaryDiffEqCore.initdt_alg(QNDF2()) isa OrdinaryDiffEqCore.StiffInitDt
14+
@test OrdinaryDiffEqCore.initdt_alg(ABDF2()) isa OrdinaryDiffEqCore.StiffInitDt
15+
@test OrdinaryDiffEqCore.initdt_alg(ImplicitEuler()) isa OrdinaryDiffEqCore.StiffInitDt
1616
end
1717

1818
@testset "Simple exponential decay (in-place)" begin

lib/OrdinaryDiffEqCore/src/initdt.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,12 @@ end
394394
tspan = prob.tspan
395395
tdist = abs(tspan[2] - tspan[1])
396396

397-
# DAE guard: use IDA-style h = 0.001 * tdist for mass-matrix DAEs
398-
# Must be before f₀ evaluation to avoid type issues with AD (ForwardDiff Duals)
397+
# DAE guard: use conservative small dt for mass-matrix DAEs.
398+
# Must be before f₀ evaluation to avoid type issues with AD (ForwardDiff Duals).
399+
# Uses smalldt (≈1e-6) rather than 0.001*tdist which can be too large for
400+
# DAE initialization with inconsistent initial conditions.
399401
if integrator.isdae
400-
h = convert(_tType, 1 // 1000) * tdist * oneunit_tType
401-
h = clamp(h, dtmin, tdir * dtmax)
402-
return tdir * h
402+
return tdir * max(smalldt, dtmin)
403403
end
404404

405405
# Fall back to DefaultInitDt for non-Array types (GPU arrays need broadcast)
@@ -521,6 +521,9 @@ end
521521
end
522522
end
523523

524+
# Ensure hub_inv stays as _tType (avoid promotion from BigFloat u0 elements)
525+
hub_inv = convert(_tType, hub_inv)
526+
524527
hub = convert(_tType, 0.1) * tdist * oneunit_tType
525528
if hub * hub_inv > 1
526529
hub = 1 / hub_inv
@@ -559,7 +562,7 @@ end
559562
end
560563

561564
# Evaluate f at stepped point
562-
f(f₁, u₁, p, t + hgs)
565+
f(f₁, u₁, p, t + convert(_tType, hgs))
563566
integrator.stats.nf += 1
564567

565568
# Handle mass matrix
@@ -630,11 +633,11 @@ end
630633
yddnrm += (ydd_i * ewt_i)^2
631634
end
632635
end
633-
yddnrm = sqrt(yddnrm / N)
636+
yddnrm = convert(_tType, sqrt(yddnrm / N))
634637

635638
# Compute new step proposal
636639
if yddnrm * hub^2 > 2
637-
hnew = sqrt(2 / yddnrm)
640+
hnew = sqrt(convert(_tType, 2) / yddnrm)
638641
else
639642
hnew = sqrt(hg * hub)
640643
end
@@ -685,12 +688,10 @@ end
685688
tspan = prob.tspan
686689
tdist = abs(tspan[2] - tspan[1])
687690

688-
# DAE guard: use IDA-style h = 0.001 * tdist for mass-matrix DAEs
689-
# Must be before f₀ evaluation to avoid type issues with AD (ForwardDiff Duals)
691+
# DAE guard: use conservative small dt for mass-matrix DAEs.
692+
# Must be before f₀ evaluation to avoid type issues with AD (ForwardDiff Duals).
690693
if integrator.isdae
691-
h = convert(_tType, 1 // 1000) * tdist * oneunit_tType
692-
h = clamp(h, dtmin, tdir * dtmax)
693-
return tdir * h
694+
return tdir * max(smalldt, dtmin)
694695
end
695696

696697
# Fall back to DefaultInitDt for non-Array types (GPU arrays need broadcast)
@@ -736,6 +737,9 @@ end
736737
end
737738
end
738739

740+
# Ensure hub_inv stays as _tType (avoid promotion from BigFloat u0 elements)
741+
hub_inv = convert(_tType, hub_inv)
742+
739743
hub = convert(_tType, 0.1) * tdist * oneunit_tType
740744
if hub * hub_inv > 1
741745
hub = 1 / hub_inv
@@ -762,7 +766,7 @@ end
762766
hgs = hg * tdir
763767

764768
u₁ = @.. broadcast = false u0 + hgs * f₀
765-
f₁ = f(u₁, p, t + hgs)
769+
f₁ = f(u₁, p, t + convert(_tType, hgs))
766770
integrator.stats.nf += 1
767771

768772
ydd_ok = !any(x -> any(!isfinite, x), f₁)
@@ -787,7 +791,7 @@ end
787791

788792
hgs = hg * tdir
789793
u₁ = @.. broadcast = false u0 + hgs * f₀
790-
f₁ = f(u₁, p, t + hgs)
794+
f₁ = f(u₁, p, t + convert(_tType, hgs))
791795
integrator.stats.nf += 1
792796

793797
yddnrm = zero(_tType)
@@ -799,10 +803,10 @@ end
799803
ydd_i = (f₁[i] - f₀[i]) / hg * oneunit_tType
800804
yddnrm += (ydd_i * ewt_i)^2
801805
end
802-
yddnrm = sqrt(yddnrm / N)
806+
yddnrm = convert(_tType, sqrt(yddnrm / N))
803807

804808
if yddnrm * hub^2 > 2
805-
hnew = sqrt(2 / yddnrm)
809+
hnew = sqrt(convert(_tType, 2) / yddnrm)
806810
else
807811
hnew = sqrt(hg * hub)
808812
end

0 commit comments

Comments
 (0)