Skip to content

Commit d114d04

Browse files
Fix nf stat tracking in initdt and remove unnecessary DAE guard from DefaultInitDt
- Track integrator.stats.nf directly in each initdt function instead of hardcoding increment_nf!(stats, 2) in auto_dt_reset!. This correctly tracks the variable number of f calls in StiffInitDt (which makes 1+N calls depending on iteration count) vs DefaultInitDt (always 2 calls). Fixes InterfaceIII stats_tests failures. - Remove DAE guard from DefaultInitDt (both in-place and out-of-place). DefaultInitDt is only dispatched for explicit algorithms which never have isdae=true. The guard was causing issues for DAE problems that could end up in DefaultInitDt through fallback paths, returning h=0.001*tdist which was too large for stiff problems like ROBER. StiffInitDt retains its DAE guard since implicit/DAE algorithms dispatch there. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 141aade commit d114d04

2 files changed

Lines changed: 10 additions & 19 deletions

File tree

lib/OrdinaryDiffEqCore/src/initdt.jl

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,6 @@ end
2828
dtmin = nextfloat(max(integrator.opts.dtmin, eps(t)))
2929
smalldt = max(dtmin, convert(_tType, oneunit_tType * 1 // 10^(6)))
3030

31-
# DAE guard: use IDA-style h = 0.001 * tdist for mass-matrix DAEs
32-
if integrator.isdae
33-
tspan = prob.tspan
34-
tdist = abs(tspan[2] - tspan[1])
35-
h = convert(_tType, 1 // 1000) * tdist * oneunit_tType
36-
h = clamp(h, dtmin, tdir * dtmax)
37-
return tdir * h
38-
end
39-
4031
if eltype(u0) <: Number && !(integrator.alg isa CompositeAlgorithm)
4132
cache = get_tmp_cache(integrator)
4233
sk = first(cache)
@@ -74,6 +65,7 @@ end
7465
end
7566
f(f₀, u0, p, t)
7667
end
68+
integrator.stats.nf += 1
7769

7870
# TODO: use more caches
7971
#tmp = cache[2]
@@ -207,6 +199,7 @@ end
207199
end
208200
f₁ = zero(f₀)
209201
f(f₁, u₁, p, t + dt₀_tdir)
202+
integrator.stats.nf += 1
210203

211204
if prob.f.mass_matrix != I && (
212205
!(prob.f isa DynamicalODEFunction) ||
@@ -310,19 +303,11 @@ end
310303
dtmin = nextfloat(max(integrator.opts.dtmin, eps(t)))
311304
smalldt = max(dtmin, convert(_tType, oneunit_tType * 1 // 10^(6)))
312305

313-
# DAE guard: use IDA-style h = 0.001 * tdist for mass-matrix DAEs
314-
if integrator.isdae
315-
tspan = prob.tspan
316-
tdist = abs(tspan[2] - tspan[1])
317-
h = convert(_tType, 1 // 1000) * tdist * oneunit_tType
318-
h = clamp(h, dtmin, tdir * dtmax)
319-
return tdir * h
320-
end
321-
322306
sk = @.. broadcast = false abstol + internalnorm(u0, t) * reltol
323307
d₀ = internalnorm(u0 ./ sk, t)
324308

325309
f₀ = f(u0, p, t)
310+
integrator.stats.nf += 1
326311

327312
if any(x -> any(isnan, x), f₀)
328313
@SciMLMessage(
@@ -348,6 +333,7 @@ end
348333

349334
u₁ = @.. broadcast = false u0 + dt₀_tdir * f₀
350335
f₁ = f(u₁, p, t + dt₀_tdir)
336+
integrator.stats.nf += 1
351337

352338
# Constant zone before callback
353339
# Just return first guess
@@ -440,6 +426,7 @@ end
440426
end
441427
f(f₀, u0, p, t)
442428
end
429+
integrator.stats.nf += 1
443430

444431
# Handle mass matrix
445432
ftmp = nothing
@@ -573,6 +560,7 @@ end
573560

574561
# Evaluate f at stepped point
575562
f(f₁, u₁, p, t + hgs)
563+
integrator.stats.nf += 1
576564

577565
# Handle mass matrix
578566
if prob.f.mass_matrix != I && (
@@ -715,6 +703,7 @@ end
715703
end
716704

717705
f₀ = f(u0, p, t)
706+
integrator.stats.nf += 1
718707

719708
if any(x -> any(isnan, x), f₀)
720709
@SciMLMessage(
@@ -774,6 +763,7 @@ end
774763

775764
u₁ = @.. broadcast = false u0 + hgs * f₀
776765
f₁ = f(u₁, p, t + hgs)
766+
integrator.stats.nf += 1
777767

778768
ydd_ok = !any(x -> any(!isfinite, x), f₁)
779769

@@ -798,6 +788,7 @@ end
798788
hgs = hg * tdir
799789
u₁ = @.. broadcast = false u0 + hgs * f₀
800790
f₁ = f(u₁, p, t + hgs)
791+
integrator.stats.nf += 1
801792

802793
yddnrm = zero(_tType)
803794
N = length(u0)

lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ function SciMLBase.auto_dt_reset!(integrator::ODEIntegrator)
516516
integrator
517517
)
518518
integrator.dtpropose = integrator.dt
519-
return increment_nf!(integrator.stats, 2)
519+
return nothing
520520
end
521521

522522
function increment_nf!(stats, amt = 1)

0 commit comments

Comments
 (0)