diff --git a/ext/SciMLBaseForwardDiffExt.jl b/ext/SciMLBaseForwardDiffExt.jl index 0f9589e34..55803167a 100644 --- a/ext/SciMLBaseForwardDiffExt.jl +++ b/ext/SciMLBaseForwardDiffExt.jl @@ -8,7 +8,7 @@ import SciMLBase: AbstractTimeseriesSolution, NonlinearProblem, NonlinearLeastSquaresProblem, ODEProblem, SDEProblem, RODEProblem, DDEProblem, PDEProblem, DAEProblem, RecursiveArrayTools, totallength, sse, anyeltypedual, reduce_tup, - unitfulvalue + unitfulvalue, _promote_jac_p! eltypedual(x) = eltype(x) <: ForwardDiff.Dual isdualtype(::Type{<:ForwardDiff.Dual}) = true @@ -460,4 +460,16 @@ function SciMLBase.totallength(x::ForwardDiff.Dual) sum(SciMLBase.totallength, ForwardDiff.partials(x)) end +function _promote_jac_p!(ff::SciMLBase.UJacobianWrapper, u::AbstractArray{<:ForwardDiff.Dual}) + p = ff.p + p isa AbstractArray{<:ForwardDiff.Dual} || return p + hasfield(typeof(ff.f), :f) && getfield(ff.f, :f) isa SciMLBase.FunctionWrappersWrappers.FunctionWrappersWrapper && return p + DualU = eltype(u) + eltype(p) <: DualU && return p + cached = ff._promoted_p + cached isa AbstractArray{DualU} && return cached + ff._promoted_p = DualU.(p) + return ff._promoted_p end + +end \ No newline at end of file diff --git a/src/function_wrappers.jl b/src/function_wrappers.jl index 0da377da1..153c656e7 100644 --- a/src/function_wrappers.jl +++ b/src/function_wrappers.jl @@ -75,14 +75,16 @@ mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractWrappedFunction f::fType t::tType p::P + _promoted_p::Any end function UJacobianWrapper{iip}(f::F, t, p) where {F, iip} - return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p) + return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p, nothing) end UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p) -(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t) +(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, _promote_jac_p!(ff, uprev), ff.t) +_promote_jac_p!(ff, u) = ff.p function (ff::UJacobianWrapper{true})(uprev) (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1) end @@ -236,4 +238,4 @@ JacobianWrapper(f::F, p) where {F} = JacobianWrapper{isinplace(f, 3)}(f, p) (uf::JacobianWrapper{false})(u) = uf.f(u, uf.p) (uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p))) -(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p) +(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p) \ No newline at end of file