Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions lib/OrdinaryDiffEqCore/src/initdt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
f = prob.f
p = integrator.p
oneunit_tType = oneunit(_tType)
# a number that's about 1 that we multiply by to prevent numerical coincedences
arbitrary_const = oneunit_tType*(93//83)
dtmax_tdir = tdir * dtmax

dtmin = nextfloat(max(integrator.opts.dtmin, eps(t)))
smalldt = max(dtmin, convert(_tType, oneunit_tType * 1 // 10^(6)))
smalldt = max(dtmin, convert(_tType, arbitrary_const * 1 // 10^(6)))

if integrator.isdae
result_dt = tdir * max(smalldt, dtmin)
Expand Down Expand Up @@ -132,10 +134,10 @@

if u0 isa Array
@inbounds @simd ivdep for i in eachindex(u0)
tmp[i] = f₀[i] / sk[i] * oneunit_tType
tmp[i] = f₀[i] / sk[i] * arbitrary_const
end
else
@.. broadcast = false tmp = f₀ / sk * oneunit_tType
@.. broadcast = false tmp = f₀ / sk * arbitrary_const
end

d₁ = internalnorm(tmp, t)
Expand All @@ -156,20 +158,15 @@
(d₁ < 1 // 10^(5)), smalldt,
convert(
_tType,
oneunit_tType * DiffEqBase.value(
arbitrary_const * DiffEqBase.value(
(d₀ / d₁) /
100
)
)
)
# if d₀ < 1//10^(5) || d₁ < 1//10^(5)
# dt₀ = smalldt
# else
# dt₀ = convert(_tType,oneunit_tType*(d₀/d₁)/100)
# end
dt₀ = min(dt₀, dtmax_tdir)

if typeof(one(_tType)) <: AbstractFloat && dt₀ < 10eps(_tType) * oneunit(_tType)
if typeof(one(_tType)) <: AbstractFloat && dt₀ < 10eps(_tType) * arbitrary_const
# This catches Andreas' non-singular example
# should act like it's singular
result_dt = tdir * max(smalldt, dtmin)
Expand Down Expand Up @@ -209,22 +206,22 @@

if u0 isa Array
@inbounds @simd ivdep for i in eachindex(u0)
tmp[i] = (f₁[i] - f₀[i]) / sk[i] * oneunit_tType
tmp[i] = (f₁[i] - f₀[i]) / sk[i] * arbitrary_const
end
else
@.. broadcast = false tmp = (f₁ - f₀) / sk * oneunit_tType
@.. tmp = (f₁ - f₀) / sk * arbitrary_const
end

d₂ = internalnorm(tmp, t) / dt₀ * oneunit_tType
d₂ = internalnorm(tmp, t) / dt₀ * arbitrary_const
# Hairer has d₂ = sqrt(sum(abs2,tmp))/dt₀, note the lack of norm correction

max_d₁d₂ = max(d₁, d₂)
if max_d₁d₂ <= 1 // Int64(10)^(15)
dt₁ = max(convert(_tType, oneunit_tType * 1 // 10^(6)), dt₀ * 1 // 10^(3))
dt₁ = max(convert(_tType, arbitrary_const * 1 // 10^(6)), dt₀ * 1 // 10^(3))
else
dt₁ = convert(
_tType,
oneunit_tType *
arbitrary_const *
DiffEqBase.value(
10.0^(
-(2 + log10(max_d₁d₂)) /
Expand Down Expand Up @@ -277,10 +274,12 @@ end
f = prob.f
p = prob.p
oneunit_tType = oneunit(_tType)
# a number that's about 1 that we multiply by to prevent numerical coincedences
arbitrary_const = oneunit_tType*(93//83)
dtmax_tdir = tdir * dtmax

dtmin = nextfloat(max(integrator.opts.dtmin, eps(t)))
smalldt = max(dtmin, convert(_tType, oneunit_tType * 1 // 10^(6)))
smalldt = max(dtmin, convert(_tType, arbitrary_const * 1 // 10^(6)))

if integrator.isdae
return tdir * max(smalldt, dtmin)
Expand All @@ -303,12 +302,12 @@ end
throw(TypeNotConstantError(inferredtype, typeof(f₀)))
end

d₁ = internalnorm(f₀ ./ sk .* oneunit_tType, t)
d₁ = internalnorm(f₀ ./ sk .* arbitrary_const, t)

if d₀ < 1 // 10^(5) || d₁ < 1 // 10^(5)
dt₀ = smalldt
else
dt₀ = convert(_tType, oneunit_tType * DiffEqBase.value((d₀ / d₁) / 100))
dt₀ = convert(_tType, arbitrary_const * DiffEqBase.value((d₀ / d₁) / 100))
end
dt₀ = min(dt₀, dtmax_tdir)
dt₀_tdir = tdir * dt₀
Expand All @@ -321,16 +320,16 @@ end
# Avoids AD issues
f₀ == f₁ && return tdir * max(dtmin, 100dt₀)

d₂ = internalnorm((f₁ .- f₀) ./ sk .* oneunit_tType, t) / dt₀ * oneunit_tType
d₂ = internalnorm((f₁ .- f₀) ./ sk .* arbitrary_const, t) / dt₀ * arbitrary_const

max_d₁d₂ = max(d₁, d₂)
if max_d₁d₂ <= 1 // Int64(10)^(15)
dt₁ = max(smalldt, dt₀ * 1 // 10^(3))
dt₁ = max(smalldt, dt₀ * 1 // 10^3)
else
dt₁ = _tType(
oneunit_tType *
arbitrary_const *
DiffEqBase.value(
10^(
exp10(
-(2 + log10(max_d₁d₂)) /
get_current_alg_order(
integrator.alg,
Expand All @@ -354,6 +353,6 @@ end
_tType = eltype(tType)
tspan = prob.tspan
init_dt = abs(tspan[2] - tspan[1])
init_dt = isfinite(init_dt) ? init_dt : oneunit(_tType)
return convert(_tType, init_dt * 1 // 10^(6))
init_dt = isfinite(init_dt) ? init_dt : oneunit(_tType) * 84 // 83 * 1 // 10^6
return convert(_tType, init_dt)
end
Loading