Skip to content

Commit 4e612f1

Browse files
Add DiffEqBase.stripunits helper and use it for tTypeNoUnits
Adds a narrow "strip units only" primitive that is orthogonal to the existing `value` (strips everything, including AD) and `unitfulvalue` (strips AD, keeps units) helpers. Defaults to the identity; Unitful, FlexUnits, and DynamicQuantities extensions override it to return the underlying numeric value. Switches the `tTypeNoUnits` derivation in `OrdinaryDiffEqCore/src/solve.jl` from `DiffEqBase.value(oneunit(first(tspan)))` to `DiffEqBase.stripunits(oneunit(first(tspan)))` so that ForwardDiff Duals in t0 are preserved through `tTypeNoUnits`. The previous `value`-based form collapsed Duals to Float64, which then made `NLSolver{true, tTypeNoUnits}` use Float64 fields for ηold etc., and SDIRK solvers (KenCarp4/47/5/58, TRBDF2) that run with a Dual t0 errored with `MethodError: Float64(::Dual)`. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent f84ae68 commit 4e612f1

5 files changed

Lines changed: 13 additions & 1 deletion

File tree

lib/DiffEqBase/ext/DiffEqBaseDynamicQuantitiesExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import DiffEqBase: default_factorize
1010
return real(abs2(ustrip(x)))
1111
end
1212

13+
DiffEqBase.stripunits(x::UnionAbstractQuantity) = ustrip(x)
14+
1315
DiffEqBase._rate_prototype(u, t::UnionAbstractQuantity, onet) = u / oneunit(t)
1416
DiffEqBase.timedepentdtmin(t::UnionAbstractQuantity, dtmin) =
1517
abs(ustrip(dtmin / oneunit(t)) * oneunit(t))

lib/DiffEqBase/ext/DiffEqBaseFlexUnitsExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ value(x::Quantity{T, U}) where {T, U} = dstrip(x)
1111
unitfulvalue(::Type{T}) where {T <: Quantity} = T
1212
unitfulvalue(x::Quantity) = x
1313

14+
DiffEqBase.stripunits(x::Quantity) = dstrip(x)
15+
1416
@inline function DiffEqBase.ODE_DEFAULT_NORM(
1517
u::AbstractArray{
1618
<:Quantity,

lib/DiffEqBase/ext/DiffEqBaseUnitfulExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ value(x::Unitful.AbstractQuantity) = x.val
1111
unitfulvalue(x::Type{T}) where {T <: Unitful.AbstractQuantity} = T
1212
unitfulvalue(x::Unitful.AbstractQuantity) = x
1313

14+
DiffEqBase.stripunits(x::Unitful.AbstractQuantity) = Unitful.ustrip(x)
15+
1416
@inline function DiffEqBase.ODE_DEFAULT_NORM(
1517
u::AbstractArray{
1618
<:Unitful.AbstractQuantity,

lib/DiffEqBase/src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ end
5252
# for the non-unitful case the correct type is just u
5353
_rate_prototype(u, t::T, onet::T) where {T} = u
5454

55+
# Strip only the unit wrapper, leaving AD/uncertainty wrappers (Dual, Measurement,
56+
# Tracker, etc.) intact. Extensions for Unitful, DynamicQuantities, and FlexUnits
57+
# override this to return the underlying numeric value.
58+
# Complementary to `value` (strips everything) and `unitfulvalue` (strips AD, keeps units).
59+
stripunits(x) = x
60+
5561
# Nonlinear Solve functionality
5662
@inline __fast_scalar_indexing(args...) = all(ArrayInterface.fast_scalar_indexing, args)
5763

lib/OrdinaryDiffEqCore/src/solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ Base.@constprop :aggressive function _ode_init(
317317
uBottomEltypeNoUnits = recursive_unitless_bottom_eltype(u)
318318

319319
uEltypeNoUnits = recursive_unitless_eltype(u)
320-
tTypeNoUnits = typeof(DiffEqBase.value(oneunit(first(tspan))))
320+
tTypeNoUnits = typeof(DiffEqBase.stripunits(oneunit(first(tspan))))
321321

322322
scalar_type_tol =
323323
uBottomEltypeNoUnits == uBottomEltype &&

0 commit comments

Comments
 (0)