Skip to content

Commit d6ca482

Browse files
Fix Unitful, zero-length vector, and mixed-unit ArrayPartition compatibility in CVHin
- Add early return for zero-length u0 vectors (dt ≈ 1e-6) - Introduce _fType = typeof(real(one(_tType))) for dimensionless scalar constants - Replace convert(_tType, constant) with convert(_fType, constant) to avoid adding time units to dimensionless values (0.1, 0.2, 0.5) - Use abs(u0[i]) and sk[i] instead of internalnorm(u0[i]) in IIP scalar loop to preserve physical units for component-wise hub_inv computation - Use eps(_fType) .* oneunit.(denoms) for unit-aware division guard in broadcast path, supporting mixed-unit ArrayPartition where eltype is abstract Quantity - Compare hub * hub_inv > oneunit_tType (not > 1) for dimensional correctness Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 01f2e90 commit d6ca482

1 file changed

Lines changed: 37 additions & 26 deletions

File tree

lib/OrdinaryDiffEqCore/src/initdt.jl

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,14 @@
269269
# iteration f calls are tracked individually inside the loop.
270270
integrator.stats.nf -= 1
271271

272+
# Zero-length vectors: no state to evolve, use default small dt
273+
if length(u0) == 0
274+
return tdir * max(smalldt, dtmin)
275+
end
276+
277+
# Dimensionless float type for scalar constants (handles Unitful)
278+
_fType = typeof(real(one(_tType)))
279+
272280
# NaN check via d₁ = norm(f₀/sk)
273281
if u0 isa Array
274282
@inbounds @simd ivdep for i in eachindex(u0)
@@ -290,33 +298,29 @@
290298
# CVHin Step 1: Compute lower and upper bounds on |h|
291299
tspan = prob.tspan
292300
tdist = abs(tspan[2] - tspan[1])
293-
eps_tType = eps(_tType)
294-
hlb = convert(_tType, 100 * eps_tType * oneunit_tType)
301+
hlb = 100 * eps(_fType) * oneunit_tType
295302

296303
# Upper bound: most restrictive component of |f₀| / (0.1*|u0| + tol)
297-
hub_inv = zero(_tType)
304+
hub_inv = zero(_fType)
298305
if u0 isa Array
299306
@inbounds for i in eachindex(u0)
300-
atol_i = abstol isa Number ? abstol : abstol[i]
301-
rtol_i = reltol isa Number ? reltol : reltol[i]
302-
tol_i = rtol_i * internalnorm(u0[i], t) + atol_i
303-
denom = convert(_tType, 0.1) * internalnorm(u0[i], t) + tol_i
304-
numer = internalnorm(f₀[i], t) * oneunit_tType
305-
if denom > 0
306-
hub_inv = max(hub_inv, numer / denom)
307+
denom_i = convert(_fType, 0.1) * abs(u0[i]) + sk[i]
308+
numer_i = abs(f₀[i]) * oneunit_tType
309+
if denom_i > zero(denom_i)
310+
hub_inv = max(hub_inv, DiffEqBase.value(numer_i / denom_i))
307311
end
308312
end
309313
else
310314
# GPU-compatible: use abs/max broadcasts instead of scalar indexing
311-
denoms = @.. broadcast = false convert(_tType, 0.1) * abs(u0) + sk
315+
denoms = @.. broadcast = false convert(_fType, 0.1) * abs(u0) + sk
312316
numers = @.. broadcast = false abs(f₀) * oneunit_tType
313-
hub_inv = maximum(numers ./ max.(denoms, eps(eltype(denoms))))
317+
hub_inv = maximum(numers ./ max.(denoms, eps(_fType) .* oneunit.(denoms)))
314318
end
315319
# Strip ForwardDiff.Dual tracking — step size bounds don't need AD
316320
hub_inv = DiffEqBase.value(hub_inv)
317321

318-
hub = convert(_tType, 0.1) * tdist
319-
if hub * hub_inv > 1
322+
hub = convert(_fType, 0.1) * tdist
323+
if hub * hub_inv > oneunit_tType
320324
hub = oneunit_tType / hub_inv
321325
end
322326
hub = min(hub, abs(dtmax_tdir))
@@ -372,7 +376,7 @@
372376
hg_ok = true
373377
break
374378
end
375-
hg *= convert(_tType, 0.2)
379+
hg *= convert(_fType, 0.2)
376380
end
377381

378382
if !hg_ok
@@ -409,7 +413,7 @@
409413

410414
count1 == 4 && break
411415
hrat = hnew / hg
412-
if hrat > convert(_tType, 0.5) && hrat < 2
416+
if hrat > convert(_fType, 0.5) && hrat < 2
413417
break
414418
end
415419
if count1 > 1 && hrat > 2
@@ -420,7 +424,7 @@
420424
end
421425

422426
# CVHin Step 3: Apply 0.5 safety factor and bounds
423-
h0 = convert(_tType, 0.5) * hnew
427+
h0 = convert(_fType, 0.5) * hnew
424428
h0 = clamp(h0, hlb, hub)
425429

426430
return tdir * max(dtmin, min(h0, abs(dtmax_tdir)))
@@ -568,6 +572,14 @@ end
568572
# iteration f calls are tracked individually inside the loop.
569573
integrator.stats.nf -= 1
570574

575+
# Zero-length vectors: no state to evolve, use default small dt
576+
if length(u0) == 0
577+
return tdir * max(smalldt, dtmin)
578+
end
579+
580+
# Dimensionless float type for scalar constants (handles Unitful)
581+
_fType = typeof(real(one(_tType)))
582+
571583
# NaN check via d₁ = norm(f₀/sk)
572584
d₁ = internalnorm(f₀ ./ sk .* oneunit_tType, t)
573585
if isnan(d₁)
@@ -581,16 +593,15 @@ end
581593
# CVHin Step 1: Compute lower and upper bounds
582594
tspan = prob.tspan
583595
tdist = abs(tspan[2] - tspan[1])
584-
eps_tType = eps(_tType)
585-
hlb = convert(_tType, 100 * eps_tType * oneunit_tType)
596+
hlb = 100 * eps(_fType) * oneunit_tType
586597

587598
# Upper bound: most restrictive component of |f₀| / (0.1*|u0| + tol)
588-
denoms = @.. broadcast = false convert(_tType, 0.1) * abs(u0) + sk
599+
denoms = @.. broadcast = false convert(_fType, 0.1) * abs(u0) + sk
589600
numers = @.. broadcast = false abs(f₀) * oneunit_tType
590-
hub_inv = DiffEqBase.value(maximum(numers ./ max.(denoms, eps(eltype(denoms)))))
601+
hub_inv = DiffEqBase.value(maximum(numers ./ max.(denoms, eps(_fType) .* oneunit.(denoms))))
591602

592-
hub = convert(_tType, 0.1) * tdist
593-
if hub * hub_inv > 1
603+
hub = convert(_fType, 0.1) * tdist
604+
if hub * hub_inv > oneunit_tType
594605
hub = oneunit_tType / hub_inv
595606
end
596607
hub = min(hub, abs(dtmax_tdir))
@@ -618,7 +629,7 @@ end
618629
hg_ok = true
619630
break
620631
end
621-
hg *= convert(_tType, 0.2)
632+
hg *= convert(_fType, 0.2)
622633
end
623634

624635
if !hg_ok
@@ -650,7 +661,7 @@ end
650661

651662
count1 == 4 && break
652663
hrat = hnew / hg
653-
if hrat > convert(_tType, 0.5) && hrat < 2
664+
if hrat > convert(_fType, 0.5) && hrat < 2
654665
break
655666
end
656667
if count1 > 1 && hrat > 2
@@ -661,7 +672,7 @@ end
661672
end
662673

663674
# CVHin Step 3: Apply 0.5 safety factor and bounds
664-
h0 = convert(_tType, 0.5) * hnew
675+
h0 = convert(_fType, 0.5) * hnew
665676
h0 = clamp(h0, hlb, hub)
666677

667678
return tdir * max(dtmin, min(h0, abs(dtmax_tdir)))

0 commit comments

Comments
 (0)