diff --git a/lib/DiffEqBase/ext/DiffEqBaseDynamicQuantitiesExt.jl b/lib/DiffEqBase/ext/DiffEqBaseDynamicQuantitiesExt.jl index de56a4e0fc..f02891c2c6 100644 --- a/lib/DiffEqBase/ext/DiffEqBaseDynamicQuantitiesExt.jl +++ b/lib/DiffEqBase/ext/DiffEqBaseDynamicQuantitiesExt.jl @@ -10,6 +10,8 @@ import DiffEqBase: default_factorize return real(abs2(ustrip(x))) end +DiffEqBase.stripunits(x::UnionAbstractQuantity) = ustrip(x) + DiffEqBase._rate_prototype(u, t::UnionAbstractQuantity, onet) = u / oneunit(t) DiffEqBase.timedepentdtmin(t::UnionAbstractQuantity, dtmin) = abs(ustrip(dtmin / oneunit(t)) * oneunit(t)) diff --git a/lib/DiffEqBase/ext/DiffEqBaseFlexUnitsExt.jl b/lib/DiffEqBase/ext/DiffEqBaseFlexUnitsExt.jl index d14b4ca96c..266c718e6a 100644 --- a/lib/DiffEqBase/ext/DiffEqBaseFlexUnitsExt.jl +++ b/lib/DiffEqBase/ext/DiffEqBaseFlexUnitsExt.jl @@ -11,6 +11,8 @@ value(x::Quantity{T, U}) where {T, U} = dstrip(x) unitfulvalue(::Type{T}) where {T <: Quantity} = T unitfulvalue(x::Quantity) = x +DiffEqBase.stripunits(x::Quantity) = dstrip(x) + @inline function DiffEqBase.ODE_DEFAULT_NORM( u::AbstractArray{ <:Quantity, diff --git a/lib/DiffEqBase/ext/DiffEqBaseUnitfulExt.jl b/lib/DiffEqBase/ext/DiffEqBaseUnitfulExt.jl index 6f7ae8c73e..ded17045f0 100644 --- a/lib/DiffEqBase/ext/DiffEqBaseUnitfulExt.jl +++ b/lib/DiffEqBase/ext/DiffEqBaseUnitfulExt.jl @@ -11,6 +11,8 @@ value(x::Unitful.AbstractQuantity) = x.val unitfulvalue(x::Type{T}) where {T <: Unitful.AbstractQuantity} = T unitfulvalue(x::Unitful.AbstractQuantity) = x +DiffEqBase.stripunits(x::Unitful.AbstractQuantity) = Unitful.ustrip(x) + @inline function DiffEqBase.ODE_DEFAULT_NORM( u::AbstractArray{ <:Unitful.AbstractQuantity, diff --git a/lib/DiffEqBase/src/utils.jl b/lib/DiffEqBase/src/utils.jl index 24ae5d11f1..7d024005dd 100644 --- a/lib/DiffEqBase/src/utils.jl +++ b/lib/DiffEqBase/src/utils.jl @@ -52,6 +52,12 @@ end # for the non-unitful case the correct type is just u _rate_prototype(u, t::T, onet::T) where {T} = u +# Strip only the unit wrapper, leaving AD/uncertainty wrappers (Dual, Measurement, +# Tracker, etc.) intact. Extensions for Unitful, DynamicQuantities, and FlexUnits +# override this to return the underlying numeric value. +# Complementary to `value` (strips everything) and `unitfulvalue` (strips AD, keeps units). +stripunits(x) = x + # Nonlinear Solve functionality @inline __fast_scalar_indexing(args...) = all(ArrayInterface.fast_scalar_indexing, args) diff --git a/lib/OrdinaryDiffEqCore/src/initdt.jl b/lib/OrdinaryDiffEqCore/src/initdt.jl index 0aa9f4dd85..1850135bae 100644 --- a/lib/OrdinaryDiffEqCore/src/initdt.jl +++ b/lib/OrdinaryDiffEqCore/src/initdt.jl @@ -13,10 +13,10 @@ _tType = eltype(t) f = prob.f p = integrator.p - oneunit_tType = oneunit(_tType) + oneunit_tType = oneunit(t) dtmax_tdir = tdir * dtmax - dtmin = nextfloat(max(integrator.opts.dtmin, eps(t))) + dtmin = nextfloat(max(integrator.opts.dtmin, convert(_tType, oneunit_tType * eps(DiffEqBase.value(t))))) smalldt = max(dtmin, convert(_tType, oneunit_tType * 1 // 10^(6))) if integrator.isdae @@ -314,10 +314,10 @@ end _tType = eltype(t) f = prob.f p = prob.p - oneunit_tType = oneunit(_tType) + oneunit_tType = oneunit(t) dtmax_tdir = tdir * dtmax - dtmin = nextfloat(max(integrator.opts.dtmin, eps(t))) + dtmin = nextfloat(max(integrator.opts.dtmin, convert(_tType, oneunit_tType * eps(DiffEqBase.value(t))))) smalldt = max(dtmin, convert(_tType, oneunit_tType * 1 // 10^(6))) if integrator.isdae diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index d50714a07d..b9a515af0c 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -66,8 +66,10 @@ Base.@constprop :aggressive function _ode_init( calck = (callback !== nothing && callback !== CallbackSet()) || (dense) || !isempty(saveat), # and no dense output dt = nothing, - dtmin = eltype(prob.tspan)(0), - dtmax = eltype(prob.tspan)((prob.tspan[end] - prob.tspan[1])), + # For runtime-unit quantities (DynamicQuantities), eltype(prob.tspan)(0) would + # drop units; use a value-based zero to preserve units. + dtmin = zero(prob.tspan[1]), + dtmax = (prob.tspan[end] - prob.tspan[1]), force_dtmin = false, adaptive = anyadaptive(alg), abstol = nothing, @@ -274,21 +276,17 @@ Base.@constprop :aggressive function _ode_init( uBottomEltypeNoUnits = recursive_unitless_bottom_eltype(u) uEltypeNoUnits = recursive_unitless_eltype(u) - tTypeNoUnits = typeof(one(tType)) + tTypeNoUnits = typeof(DiffEqBase.stripunits(oneunit(first(tspan)))) + + scalar_type_tol = + uBottomEltypeNoUnits == uBottomEltype && + uBottomEltype <: Union{Real, Complex} if prob isa SciMLBase.AbstractDiscreteProblem && abstol === nothing abstol_internal = false elseif abstol === nothing - if uBottomEltypeNoUnits == uBottomEltype - abstol_internal = unitfulvalue( - real( - convert( - uBottomEltype, - oneunit(uBottomEltype) * - 1 // 10^6 - ) - ) - ) + if scalar_type_tol + abstol_internal = unitfulvalue(real(convert(uBottomEltype, oneunit(uBottomEltype) * 1 // 10^6))) else abstol_internal = unitfulvalue.(real.(oneunit.(u) .* 1 // 10^6)) end @@ -299,15 +297,8 @@ Base.@constprop :aggressive function _ode_init( if prob isa SciMLBase.AbstractDiscreteProblem && reltol === nothing reltol_internal = false elseif reltol === nothing - if uBottomEltypeNoUnits == uBottomEltype - reltol_internal = unitfulvalue( - real( - convert( - uBottomEltype, - oneunit(uBottomEltype) * 1 // 10^3 - ) - ) - ) + if scalar_type_tol + reltol_internal = unitfulvalue(real(convert(uBottomEltype, oneunit(uBottomEltype) * 1 // 10^3))) else reltol_internal = unitfulvalue.(real.(oneunit.(u) .* 1 // 10^3)) end @@ -333,7 +324,7 @@ Base.@constprop :aggressive function _ode_init( eltype(u) <: Enum rate_prototype = u else # has units! - rate_prototype = u / oneunit(tType) + rate_prototype = u / oneunit(first(tspan)) end end rateType = typeof(rate_prototype) ## Can be different if united diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 4e04394e9e..85cd0b4fe0 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,6 +1,7 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" @@ -18,6 +19,7 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" [compat] ADTypes = "1.16" DifferentiationInterface = "0.6.54, 0.7" +DynamicQuantities = "1.8" Enzyme = "0.13" FiniteDiff = "2.27" Measurements = "2.9" diff --git a/test/downstream/dynamicquantities_measurements.jl b/test/downstream/dynamicquantities_measurements.jl new file mode 100644 index 0000000000..5e00142337 --- /dev/null +++ b/test/downstream/dynamicquantities_measurements.jl @@ -0,0 +1,21 @@ +using OrdinaryDiffEq +using DynamicQuantities +using Measurements +using Test + +@testset "DynamicQuantities units + Measurements uncertainty" begin + u0 = (1.0 ± 0.1) * (1.0u"m") + tspan = (0.0u"s", 1.0u"s") + + f(u, p, t) = u / (1u"s") + prob = ODEProblem(f, u0, tspan) + + sol = solve(prob, Tsit5(); abstol = 1.0e-9, reltol = 1.0e-9) + + @test sol.u[end] isa typeof(u0) + @test eltype(sol.u) == typeof(u0) + + uend_m = ustrip(u"m", sol.u[end]) + @test isapprox(Measurements.value(uend_m), exp(1); atol = 1.0e-6) + @test Measurements.uncertainty(uend_m) > 0 +end diff --git a/test/runtests.jl b/test/runtests.jl index 7ce6999c9d..596d6c8830 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -208,6 +208,7 @@ end activate_downstream_env() @time @safetestset "Measurements Tests" include("downstream/measurements.jl") @time @safetestset "Time derivative Tests" include("downstream/time_derivative_test.jl") + @time @safetestset "DynamicQuantities + Measurements Tests" include("downstream/dynamicquantities_measurements.jl") end # AD tests - Enzyme/Zygote only on Julia <= 1.11 (see https://github.com/EnzymeAD/Enzyme.jl/issues/2699)