From da35919fbcc6798066f82ad7fe89754f9ae88222 Mon Sep 17 00:00:00 2001 From: AdityaPandeyCN Date: Sun, 12 Apr 2026 09:38:17 +0530 Subject: [PATCH 1/3] Fix nested ForwardDiff tag mismatch in UJacobianWrapper Signed-off-by: AdityaPandeyCN --- ext/SciMLBaseForwardDiffExt.jl | 11 ++++++++++- src/function_wrappers.jl | 3 ++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ext/SciMLBaseForwardDiffExt.jl b/ext/SciMLBaseForwardDiffExt.jl index 0f9589e344..89ccae82e2 100644 --- a/ext/SciMLBaseForwardDiffExt.jl +++ b/ext/SciMLBaseForwardDiffExt.jl @@ -8,7 +8,7 @@ import SciMLBase: AbstractTimeseriesSolution, NonlinearProblem, NonlinearLeastSquaresProblem, ODEProblem, SDEProblem, RODEProblem, DDEProblem, PDEProblem, DAEProblem, RecursiveArrayTools, totallength, sse, anyeltypedual, reduce_tup, - unitfulvalue + unitfulvalue, _promote_jac_p eltypedual(x) = eltype(x) <: ForwardDiff.Dual isdualtype(::Type{<:ForwardDiff.Dual}) = true @@ -460,4 +460,13 @@ function SciMLBase.totallength(x::ForwardDiff.Dual) sum(SciMLBase.totallength, ForwardDiff.partials(x)) end +function _promote_jac_p(p::AbstractArray{<:ForwardDiff.Dual}, u::AbstractArray{<:ForwardDiff.Dual}) + DualU = eltype(u) + DualP = eltype(p) + if !(DualP <: DualU) + return DualU.(p) + end + return p +end + end diff --git a/src/function_wrappers.jl b/src/function_wrappers.jl index 0da377da1c..31416718c1 100644 --- a/src/function_wrappers.jl +++ b/src/function_wrappers.jl @@ -82,7 +82,8 @@ function UJacobianWrapper{iip}(f::F, t, p) where {F, iip} end UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p) -(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t) +(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, _promote_jac_p(ff.p, uprev), ff.t) +_promote_jac_p(p, u) = p function (ff::UJacobianWrapper{true})(uprev) (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1) end From 5008f5fda5656e90aea6aa0d91cf4b409443fae9 Mon Sep 17 00:00:00 2001 From: AdityaPandeyCN Date: Sun, 12 Apr 2026 15:51:48 +0530 Subject: [PATCH 2/3] Gate _promote_jac_p to skip FWW-wrapped functions for AutoSpecialize compat Signed-off-by: AdityaPandeyCN --- ext/SciMLBaseForwardDiffExt.jl | 5 ++++- src/function_wrappers.jl | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ext/SciMLBaseForwardDiffExt.jl b/ext/SciMLBaseForwardDiffExt.jl index 89ccae82e2..2372ff8356 100644 --- a/ext/SciMLBaseForwardDiffExt.jl +++ b/ext/SciMLBaseForwardDiffExt.jl @@ -460,7 +460,10 @@ function SciMLBase.totallength(x::ForwardDiff.Dual) sum(SciMLBase.totallength, ForwardDiff.partials(x)) end -function _promote_jac_p(p::AbstractArray{<:ForwardDiff.Dual}, u::AbstractArray{<:ForwardDiff.Dual}) +function _promote_jac_p(p::AbstractArray{<:ForwardDiff.Dual}, u::AbstractArray{<:ForwardDiff.Dual}, f) + if hasfield(typeof(f), :f) && getfield(f, :f) isa SciMLBase.FunctionWrappersWrappers.FunctionWrappersWrapper + return p + end DualU = eltype(u) DualP = eltype(p) if !(DualP <: DualU) diff --git a/src/function_wrappers.jl b/src/function_wrappers.jl index 31416718c1..f699fac063 100644 --- a/src/function_wrappers.jl +++ b/src/function_wrappers.jl @@ -82,8 +82,8 @@ function UJacobianWrapper{iip}(f::F, t, p) where {F, iip} end UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p) -(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, _promote_jac_p(ff.p, uprev), ff.t) -_promote_jac_p(p, u) = p +(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, _promote_jac_p(ff.p, uprev, ff.f), ff.t) +_promote_jac_p(p, u, f) = p function (ff::UJacobianWrapper{true})(uprev) (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1) end From 03be015906d96ca022b0a9c2eb740b5697f264af Mon Sep 17 00:00:00 2001 From: AdityaPandeyCN Date: Sun, 12 Apr 2026 17:04:34 +0530 Subject: [PATCH 3/3] remove per call allocation Signed-off-by: AdityaPandeyCN --- ext/SciMLBaseForwardDiffExt.jl | 22 +++++++++++----------- src/function_wrappers.jl | 9 +++++---- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/ext/SciMLBaseForwardDiffExt.jl b/ext/SciMLBaseForwardDiffExt.jl index 2372ff8356..55803167ac 100644 --- a/ext/SciMLBaseForwardDiffExt.jl +++ b/ext/SciMLBaseForwardDiffExt.jl @@ -8,7 +8,7 @@ import SciMLBase: AbstractTimeseriesSolution, NonlinearProblem, NonlinearLeastSquaresProblem, ODEProblem, SDEProblem, RODEProblem, DDEProblem, PDEProblem, DAEProblem, RecursiveArrayTools, totallength, sse, anyeltypedual, reduce_tup, - unitfulvalue, _promote_jac_p + unitfulvalue, _promote_jac_p! eltypedual(x) = eltype(x) <: ForwardDiff.Dual isdualtype(::Type{<:ForwardDiff.Dual}) = true @@ -460,16 +460,16 @@ function SciMLBase.totallength(x::ForwardDiff.Dual) sum(SciMLBase.totallength, ForwardDiff.partials(x)) end -function _promote_jac_p(p::AbstractArray{<:ForwardDiff.Dual}, u::AbstractArray{<:ForwardDiff.Dual}, f) - if hasfield(typeof(f), :f) && getfield(f, :f) isa SciMLBase.FunctionWrappersWrappers.FunctionWrappersWrapper - return p - end +function _promote_jac_p!(ff::SciMLBase.UJacobianWrapper, u::AbstractArray{<:ForwardDiff.Dual}) + p = ff.p + p isa AbstractArray{<:ForwardDiff.Dual} || return p + hasfield(typeof(ff.f), :f) && getfield(ff.f, :f) isa SciMLBase.FunctionWrappersWrappers.FunctionWrappersWrapper && return p DualU = eltype(u) - DualP = eltype(p) - if !(DualP <: DualU) - return DualU.(p) - end - return p + eltype(p) <: DualU && return p + cached = ff._promoted_p + cached isa AbstractArray{DualU} && return cached + ff._promoted_p = DualU.(p) + return ff._promoted_p end -end +end \ No newline at end of file diff --git a/src/function_wrappers.jl b/src/function_wrappers.jl index f699fac063..153c656e7e 100644 --- a/src/function_wrappers.jl +++ b/src/function_wrappers.jl @@ -75,15 +75,16 @@ mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractWrappedFunction f::fType t::tType p::P + _promoted_p::Any end function UJacobianWrapper{iip}(f::F, t, p) where {F, iip} - return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p) + return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p, nothing) end UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p) -(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, _promote_jac_p(ff.p, uprev, ff.f), ff.t) -_promote_jac_p(p, u, f) = p +(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, _promote_jac_p!(ff, uprev), ff.t) +_promote_jac_p!(ff, u) = ff.p function (ff::UJacobianWrapper{true})(uprev) (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1) end @@ -237,4 +238,4 @@ JacobianWrapper(f::F, p) where {F} = JacobianWrapper{isinplace(f, 3)}(f, p) (uf::JacobianWrapper{false})(u) = uf.f(u, uf.p) (uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p))) -(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p) +(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p) \ No newline at end of file