Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/DiffEqBase/ext/DiffEqBaseDynamicQuantitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import DiffEqBase: default_factorize
return real(abs2(ustrip(x)))
end

DiffEqBase.stripunits(x::UnionAbstractQuantity) = ustrip(x)

DiffEqBase._rate_prototype(u, t::UnionAbstractQuantity, onet) = u / oneunit(t)
DiffEqBase.timedepentdtmin(t::UnionAbstractQuantity, dtmin) =
abs(ustrip(dtmin / oneunit(t)) * oneunit(t))
Expand Down
2 changes: 2 additions & 0 deletions lib/DiffEqBase/ext/DiffEqBaseFlexUnitsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ value(x::Quantity{T, U}) where {T, U} = dstrip(x)
unitfulvalue(::Type{T}) where {T <: Quantity} = T
unitfulvalue(x::Quantity) = x

DiffEqBase.stripunits(x::Quantity) = dstrip(x)

@inline function DiffEqBase.ODE_DEFAULT_NORM(
u::AbstractArray{
<:Quantity,
Expand Down
2 changes: 2 additions & 0 deletions lib/DiffEqBase/ext/DiffEqBaseUnitfulExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ value(x::Unitful.AbstractQuantity) = x.val
unitfulvalue(x::Type{T}) where {T <: Unitful.AbstractQuantity} = T
unitfulvalue(x::Unitful.AbstractQuantity) = x

DiffEqBase.stripunits(x::Unitful.AbstractQuantity) = Unitful.ustrip(x)

@inline function DiffEqBase.ODE_DEFAULT_NORM(
u::AbstractArray{
<:Unitful.AbstractQuantity,
Expand Down
6 changes: 6 additions & 0 deletions lib/DiffEqBase/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ end
# for the non-unitful case the correct type is just u
_rate_prototype(u, t::T, onet::T) where {T} = u

# Strip only the unit wrapper, leaving AD/uncertainty wrappers (Dual, Measurement,
# Tracker, etc.) intact. Extensions for Unitful, DynamicQuantities, and FlexUnits
# override this to return the underlying numeric value.
# Complementary to `value` (strips everything) and `unitfulvalue` (strips AD, keeps units).
stripunits(x) = x

# Nonlinear Solve functionality
@inline __fast_scalar_indexing(args...) = all(ArrayInterface.fast_scalar_indexing, args)

Expand Down
8 changes: 4 additions & 4 deletions lib/OrdinaryDiffEqCore/src/initdt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
_tType = eltype(t)
f = prob.f
p = integrator.p
oneunit_tType = oneunit(_tType)
oneunit_tType = oneunit(t)
dtmax_tdir = tdir * dtmax

dtmin = nextfloat(max(integrator.opts.dtmin, eps(t)))
dtmin = nextfloat(max(integrator.opts.dtmin, convert(_tType, oneunit_tType * eps(DiffEqBase.value(t)))))
smalldt = max(dtmin, convert(_tType, oneunit_tType * 1 // 10^(6)))

if integrator.isdae
Expand Down Expand Up @@ -314,10 +314,10 @@ end
_tType = eltype(t)
f = prob.f
p = prob.p
oneunit_tType = oneunit(_tType)
oneunit_tType = oneunit(t)
dtmax_tdir = tdir * dtmax

dtmin = nextfloat(max(integrator.opts.dtmin, eps(t)))
dtmin = nextfloat(max(integrator.opts.dtmin, convert(_tType, oneunit_tType * eps(DiffEqBase.value(t)))))
smalldt = max(dtmin, convert(_tType, oneunit_tType * 1 // 10^(6)))

if integrator.isdae
Expand Down
37 changes: 14 additions & 23 deletions lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ Base.@constprop :aggressive function _ode_init(
calck = (callback !== nothing && callback !== CallbackSet()) ||
(dense) || !isempty(saveat), # and no dense output
dt = nothing,
dtmin = eltype(prob.tspan)(0),
dtmax = eltype(prob.tspan)((prob.tspan[end] - prob.tspan[1])),
# For runtime-unit quantities (DynamicQuantities), eltype(prob.tspan)(0) would
# drop units; use a value-based zero to preserve units.
dtmin = zero(prob.tspan[1]),
dtmax = (prob.tspan[end] - prob.tspan[1]),
force_dtmin = false,
adaptive = anyadaptive(alg),
abstol = nothing,
Expand Down Expand Up @@ -274,21 +276,17 @@ Base.@constprop :aggressive function _ode_init(
uBottomEltypeNoUnits = recursive_unitless_bottom_eltype(u)

uEltypeNoUnits = recursive_unitless_eltype(u)
tTypeNoUnits = typeof(one(tType))
tTypeNoUnits = typeof(DiffEqBase.stripunits(oneunit(first(tspan))))

scalar_type_tol =
uBottomEltypeNoUnits == uBottomEltype &&
uBottomEltype <: Union{Real, Complex}

if prob isa SciMLBase.AbstractDiscreteProblem && abstol === nothing
abstol_internal = false
elseif abstol === nothing
if uBottomEltypeNoUnits == uBottomEltype
abstol_internal = unitfulvalue(
real(
convert(
uBottomEltype,
oneunit(uBottomEltype) *
1 // 10^6
)
)
)
if scalar_type_tol
abstol_internal = unitfulvalue(real(convert(uBottomEltype, oneunit(uBottomEltype) * 1 // 10^6)))
else
abstol_internal = unitfulvalue.(real.(oneunit.(u) .* 1 // 10^6))
end
Expand All @@ -299,15 +297,8 @@ Base.@constprop :aggressive function _ode_init(
if prob isa SciMLBase.AbstractDiscreteProblem && reltol === nothing
reltol_internal = false
elseif reltol === nothing
if uBottomEltypeNoUnits == uBottomEltype
reltol_internal = unitfulvalue(
real(
convert(
uBottomEltype,
oneunit(uBottomEltype) * 1 // 10^3
)
)
)
if scalar_type_tol
reltol_internal = unitfulvalue(real(convert(uBottomEltype, oneunit(uBottomEltype) * 1 // 10^3)))
else
reltol_internal = unitfulvalue.(real.(oneunit.(u) .* 1 // 10^3))
end
Expand All @@ -333,7 +324,7 @@ Base.@constprop :aggressive function _ode_init(
eltype(u) <: Enum
rate_prototype = u
else # has units!
rate_prototype = u / oneunit(tType)
rate_prototype = u / oneunit(first(tspan))
end
end
rateType = typeof(rate_prototype) ## Can be different if united
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Expand All @@ -18,6 +19,7 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
[compat]
ADTypes = "1.16"
DifferentiationInterface = "0.6.54, 0.7"
DynamicQuantities = "1.8"
Enzyme = "0.13"
FiniteDiff = "2.27"
Measurements = "2.9"
Expand Down
21 changes: 21 additions & 0 deletions test/downstream/dynamicquantities_measurements.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using OrdinaryDiffEq
using DynamicQuantities
using Measurements
using Test

@testset "DynamicQuantities units + Measurements uncertainty" begin
u0 = (1.0 ± 0.1) * (1.0u"m")
tspan = (0.0u"s", 1.0u"s")

f(u, p, t) = u / (1u"s")
prob = ODEProblem(f, u0, tspan)

sol = solve(prob, Tsit5(); abstol = 1.0e-9, reltol = 1.0e-9)

@test sol.u[end] isa typeof(u0)
@test eltype(sol.u) == typeof(u0)

uend_m = ustrip(u"m", sol.u[end])
@test isapprox(Measurements.value(uend_m), exp(1); atol = 1.0e-6)
@test Measurements.uncertainty(uend_m) > 0
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ end
activate_downstream_env()
@time @safetestset "Measurements Tests" include("downstream/measurements.jl")
@time @safetestset "Time derivative Tests" include("downstream/time_derivative_test.jl")
@time @safetestset "DynamicQuantities + Measurements Tests" include("downstream/dynamicquantities_measurements.jl")
end

# AD tests - Enzyme/Zygote only on Julia <= 1.11 (see https://github.com/EnzymeAD/Enzyme.jl/issues/2699)
Expand Down
Loading