Skip to content

Commit a66d63b

Browse files
Fix backward integration, GPU compatibility, and Runic formatting in CVHin
- Fix backward integration bug: use abs(dtmax_tdir) for magnitude bounds since dtmax_tdir is negative for backward integration (tdir=-1) - Replace scalar indexing loops in non-Array CVHin paths with broadcasts (internalnorm./ifelse./maximum) for GPU array compatibility - Replace scalar isfinite loop in IIP non-Array path with any() broadcast - Fix Runic formatting: max.() args, yddnrm closing paren, continuation indent Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent c1e9972 commit a66d63b

1 file changed

Lines changed: 31 additions & 38 deletions

File tree

lib/OrdinaryDiffEqCore/src/initdt.jl

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,10 @@
235235
g₁ .*= 3
236236
ΔgMax = max.(internalnorm.(g₀ .- g₁, t), internalnorm.(g₀ .+ g₁, t))
237237
d₂ = internalnorm(
238-
max.(internalnorm.(f₁ .- f₀ .+ ΔgMax, t),
239-
internalnorm.(f₁ .- f₀ .- ΔgMax, t)) ./ sk,
238+
max.(
239+
internalnorm.(f₁ .- f₀ .+ ΔgMax, t),
240+
internalnorm.(f₁ .- f₀ .- ΔgMax, t)
241+
) ./ sk,
240242
t
241243
) / dt₀
242244
# Hairer has d₂ = sqrt(sum(abs2,tmp))/dt₀, note the lack of norm correction
@@ -300,23 +302,20 @@
300302
end
301303
end
302304
else
303-
for i in eachindex(u0)
304-
atol_i = abstol isa Number ? abstol : abstol[i]
305-
rtol_i = reltol isa Number ? reltol : reltol[i]
306-
tol_i = rtol_i * internalnorm(u0[i], t) + atol_i
307-
denom = convert(_tType, 0.1) * internalnorm(u0[i], t) + tol_i
308-
numer = internalnorm(f₀[i], t) * oneunit_tType
309-
if denom > 0
310-
hub_inv = max(hub_inv, numer / denom)
311-
end
312-
end
305+
u0_norms = internalnorm.(u0, t)
306+
f₀_norms = internalnorm.(f₀, t)
307+
tols = @.. broadcast = false reltol * u0_norms + abstol
308+
denoms = @.. broadcast = false convert(_tType, 0.1) * u0_norms + tols
309+
numers = @.. broadcast = false f₀_norms * oneunit_tType
310+
hub_inv_vals = ifelse.(denoms .> 0, numers ./ denoms, zero(_tType))
311+
hub_inv = maximum(hub_inv_vals)
313312
end
314313

315314
hub = convert(_tType, 0.1) * tdist
316315
if hub * hub_inv > 1
317316
hub = oneunit_tType / hub_inv
318317
end
319-
hub = min(hub, dtmax_tdir)
318+
hub = min(hub, abs(dtmax_tdir))
320319

321320
if hub < hlb
322321
return tdir * max(dtmin, sqrt(hlb * hub))
@@ -361,12 +360,7 @@
361360
end
362361
end
363362
else
364-
for i in eachindex(f₁)
365-
if !isfinite(f₁[i])
366-
ydd_ok = false
367-
break
368-
end
369-
end
363+
ydd_ok = !any(x -> any(!isfinite, x), f₁)
370364
end
371365

372366
if ydd_ok
@@ -397,7 +391,7 @@
397391

398392
# Order-dependent step proposal: h ~ (2/yddnrm)^(1/(p+1))
399393
if DiffEqBase.value(yddnrm) *
400-
DiffEqBase.value(hub / oneunit_tType)^(p_order + 1) > 2
394+
DiffEqBase.value(hub / oneunit_tType)^(p_order + 1) > 2
401395
hnew = convert(
402396
_tType,
403397
oneunit_tType * DiffEqBase.value(
@@ -424,7 +418,7 @@
424418
h0 = convert(_tType, 0.5) * hnew
425419
h0 = clamp(h0, hlb, hub)
426420

427-
return tdir * max(dtmin, min(h0, dtmax_tdir))
421+
return tdir * max(dtmin, min(h0, abs(dtmax_tdir)))
428422
end
429423
end
430424

@@ -539,8 +533,10 @@ end
539533
g₁ = 3g(u₁, p, t + dt₀_tdir)
540534
ΔgMax = max.(internalnorm.(g₀ .- g₁, t), internalnorm.(g₀ .+ g₁, t))
541535
d₂ = internalnorm(
542-
max.(internalnorm.(f₁ .- f₀ .+ ΔgMax, t),
543-
internalnorm.(f₁ .- f₀ .- ΔgMax, t)) ./ sk,
536+
max.(
537+
internalnorm.(f₁ .- f₀ .+ ΔgMax, t),
538+
internalnorm.(f₁ .- f₀ .- ΔgMax, t)
539+
) ./ sk,
544540
t
545541
) / dt₀
546542

@@ -579,23 +575,19 @@ end
579575
hlb = convert(_tType, 100 * eps_tType * oneunit_tType)
580576

581577
# Upper bound: most restrictive component of |f₀| / (0.1*|u0| + tol)
582-
hub_inv = zero(_tType)
583-
for i in eachindex(u0)
584-
atol_i = abstol isa Number ? abstol : abstol[i]
585-
rtol_i = reltol isa Number ? reltol : reltol[i]
586-
tol_i = rtol_i * internalnorm(u0[i], t) + atol_i
587-
denom = convert(_tType, 0.1) * internalnorm(u0[i], t) + tol_i
588-
numer = internalnorm(f₀[i], t) * oneunit_tType
589-
if denom > 0
590-
hub_inv = max(hub_inv, numer / denom)
591-
end
592-
end
578+
u0_norms = internalnorm.(u0, t)
579+
f₀_norms = internalnorm.(f₀, t)
580+
tols = @.. broadcast = false reltol * u0_norms + abstol
581+
denoms = @.. broadcast = false convert(_tType, 0.1) * u0_norms + tols
582+
numers = @.. broadcast = false f₀_norms * oneunit_tType
583+
hub_inv_vals = ifelse.(denoms .> 0, numers ./ denoms, zero(_tType))
584+
hub_inv = maximum(hub_inv_vals)
593585

594586
hub = convert(_tType, 0.1) * tdist
595587
if hub * hub_inv > 1
596588
hub = oneunit_tType / hub_inv
597589
end
598-
hub = min(hub, dtmax_tdir)
590+
hub = min(hub, abs(dtmax_tdir))
599591

600592
if hub < hlb
601593
return tdir * max(dtmin, sqrt(hlb * hub))
@@ -633,11 +625,12 @@ end
633625

634626
# Second derivative estimate
635627
yddnrm = internalnorm(
636-
(f₁ .- f₀) ./ sk .* oneunit_tType, t) / hg * oneunit_tType
628+
(f₁ .- f₀) ./ sk .* oneunit_tType, t
629+
) / hg * oneunit_tType
637630

638631
# Order-dependent step proposal: h ~ (2/yddnrm)^(1/(p+1))
639632
if DiffEqBase.value(yddnrm) *
640-
DiffEqBase.value(hub / oneunit_tType)^(p_order + 1) > 2
633+
DiffEqBase.value(hub / oneunit_tType)^(p_order + 1) > 2
641634
hnew = convert(
642635
_tType,
643636
oneunit_tType * DiffEqBase.value(
@@ -664,7 +657,7 @@ end
664657
h0 = convert(_tType, 0.5) * hnew
665658
h0 = clamp(h0, hlb, hub)
666659

667-
return tdir * max(dtmin, min(h0, dtmax_tdir))
660+
return tdir * max(dtmin, min(h0, abs(dtmax_tdir)))
668661
end
669662
end
670663

0 commit comments

Comments
 (0)