diff --git a/lib/OrdinaryDiffEqRKN/src/OrdinaryDiffEqRKN.jl b/lib/OrdinaryDiffEqRKN/src/OrdinaryDiffEqRKN.jl index bea847d3ca2..1e6dff67e3d 100644 --- a/lib/OrdinaryDiffEqRKN/src/OrdinaryDiffEqRKN.jl +++ b/lib/OrdinaryDiffEqRKN/src/OrdinaryDiffEqRKN.jl @@ -27,6 +27,8 @@ include("rkn_caches.jl") include("interp_func.jl") include("interpolants.jl") include("rkn_perform_step.jl") +include("generic_rkn_vi_perform_step.jl") +include("generic_rkn_vd_perform_step.jl") export Nystrom4, FineRKN4, FineRKN5, Nystrom4VelocityIndependent, Nystrom5VelocityIndependent, diff --git a/lib/OrdinaryDiffEqRKN/src/generic_rkn_vd_perform_step.jl b/lib/OrdinaryDiffEqRKN/src/generic_rkn_vd_perform_step.jl new file mode 100644 index 00000000000..f55fbbf6c6a --- /dev/null +++ b/lib/OrdinaryDiffEqRKN/src/generic_rkn_vd_perform_step.jl @@ -0,0 +1,177 @@ +## Generic Nyström velocity-DEPENDENT perform_step! +## Solves: y'' = f(t, y, y') where f depends on both position and velocity +## kᵢ = f1(duprev + dt*Σⱼ abar[i,j]*kⱼ, uprev + dt*c[i]*duprev + dt²*Σⱼ a[i,j]*kⱼ, p, t+c[i]*dt) +## y₁ = y₀ + h*y'₀ + h²*Σᵢ bᵢ*kᵢ +## y'₁ = y'₀ + h*Σᵢ bpᵢ*kᵢ + +function initialize!(integrator, cache::NystromVDConstantCache) + integrator.kshortsize = 2 + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) + + duprev, uprev = integrator.uprev.x + kdu = integrator.f.f1(duprev, uprev, integrator.p, integrator.t) + ku = integrator.f.f2(duprev, uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + integrator.stats.nf2 += 1 + integrator.fsalfirst = ArrayPartition((kdu, ku)) + integrator.fsallast = zero(integrator.fsalfirst) + integrator.k[1] = integrator.fsalfirst + return integrator.k[2] = integrator.fsallast +end + +function initialize!(integrator, cache::NystromVDCache) + integrator.kshortsize = 2 + resize!(integrator.k, integrator.kshortsize) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + duprev, uprev = integrator.uprev.x + integrator.f.f1(integrator.fsalfirst.x[1], duprev, uprev, integrator.p, integrator.t) + integrator.f.f2(integrator.fsalfirst.x[2], duprev, uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + return integrator.stats.nf2 += 1 +end + +@muladd function perform_step!( + integrator, cache::NystromVDConstantCache, repeat_step = false + ) + (; t, dt, f, p) = integrator + duprev, uprev = integrator.uprev.x + (; tab) = cache + (; a, abar, b, bp, btilde, bptilde, c) = tab + k1 = integrator.fsalfirst.x[1] + nstages = length(b) + dtsq = dt^2 + + # Compute intermediate stages k2..knstages + ks = Vector{typeof(k1)}(undef, nstages) + ks[1] = k1 + for i in 2:nstages + ku = uprev + dt * c[i - 1] * duprev + kdu = duprev + for j in 1:(i - 1) + if !iszero(a[i, j]) + ku = ku + dtsq * a[i, j] * ks[j] + end + if !iszero(abar[i, j]) + kdu = kdu + dt * abar[i, j] * ks[j] + end + end + ks[i] = f.f1(kdu, ku, p, t + dt * c[i - 1]) + end + + # Position and velocity updates + u = uprev + dt * duprev + for i in 1:nstages + if !iszero(b[i]) + u = u + dtsq * b[i] * ks[i] + end + end + du = duprev + for i in 1:nstages + if !iszero(bp[i]) + du = du + dt * bp[i] * ks[i] + end + end + + integrator.u = ArrayPartition((du, u)) + integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, tab.nf_per_step) + integrator.stats.nf2 += 1 + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + + if integrator.opts.adaptive && !isempty(btilde) + uhat = zero(uprev) + duhat = zero(duprev) + for i in 1:nstages + if !iszero(btilde[i]) + uhat = uhat + dtsq * btilde[i] * ks[i] + end + if !isempty(bptilde) && !iszero(bptilde[i]) + duhat = duhat + dt * bptilde[i] * ks[i] + end + end + utilde = ArrayPartition((duhat, uhat)) + atmp = calculate_residuals( + utilde, integrator.uprev, integrator.u, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t + ) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + end +end + +@muladd function perform_step!( + integrator, cache::NystromVDCache, repeat_step = false + ) + (; t, dt, f, p) = integrator + du, u = integrator.u.x + duprev, uprev = integrator.uprev.x + (; ks, k, utilde, tmp, atmp, tab) = cache + (; a, abar, b, bp, btilde, bptilde, c) = tab + ku = tmp.x[2] + kdu = tmp.x[1] + k1 = integrator.fsalfirst.x[1] + nstages = length(b) + dtsq = dt^2 + + # Compute intermediate stages k2..knstages, stored in ks[1..nstages-1] + for i in 2:nstages + @.. broadcast = false ku = uprev + dt * c[i - 1] * duprev + @.. broadcast = false kdu = duprev + for j in 1:(i - 1) + kj = (j == 1) ? k1 : ks[j - 1] + if !iszero(a[i, j]) + @.. broadcast = false ku = ku + dtsq * a[i, j] * kj + end + if !iszero(abar[i, j]) + @.. broadcast = false kdu = kdu + dt * abar[i, j] * kj + end + end + f.f1(ks[i - 1], kdu, ku, p, t + dt * c[i - 1]) + end + + # Position update: u = uprev + dt*duprev + dt^2 * sum(b[i]*ki) + @.. broadcast = false u = uprev + dt * duprev + for i in 1:nstages + if !iszero(b[i]) + ki = (i == 1) ? k1 : ks[i - 1] + @.. broadcast = false u = u + dtsq * b[i] * ki + end + end + + # Velocity update: du = duprev + dt * sum(bp[i]*ki) + @.. broadcast = false du = duprev + for i in 1:nstages + if !iszero(bp[i]) + ki = (i == 1) ? k1 : ks[i - 1] + @.. broadcast = false du = du + dt * bp[i] * ki + end + end + + f.f1(k.x[1], du, u, p, t + dt) + f.f2(k.x[2], du, u, p, t + dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, tab.nf_per_step) + integrator.stats.nf2 += 1 + + if integrator.opts.adaptive && !isempty(btilde) + duhat, uhat = utilde.x + @.. broadcast = false uhat = zero(uhat) + @.. broadcast = false duhat = zero(duhat) + for i in 1:nstages + ki = (i == 1) ? k1 : ks[i - 1] + if !iszero(btilde[i]) + @.. broadcast = false uhat = uhat + dtsq * btilde[i] * ki + end + if !isempty(bptilde) && !iszero(bptilde[i]) + @.. broadcast = false duhat = duhat + dt * bptilde[i] * ki + end + end + calculate_residuals!( + atmp, utilde, integrator.uprev, integrator.u, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t + ) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + end +end diff --git a/lib/OrdinaryDiffEqRKN/src/generic_rkn_vi_perform_step.jl b/lib/OrdinaryDiffEqRKN/src/generic_rkn_vi_perform_step.jl new file mode 100644 index 00000000000..5df400324a7 --- /dev/null +++ b/lib/OrdinaryDiffEqRKN/src/generic_rkn_vi_perform_step.jl @@ -0,0 +1,196 @@ +## Generic Nyström velocity-independent perform_step! +## Solves: y'' = f(t, y) where f is velocity-independent +## kᵢ = f1(duprev, yᵢ, p, t + cᵢ*dt) (duprev constant throughout) +## yᵢ = y₀ + cᵢ*h*y'₀ + h²*Σⱼ<ᵢ aᵢⱼ*kⱼ +## y₁ = y₀ + h*y'₀ + h²*Σᵢ bᵢ*kᵢ +## y'₁ = y'₀ + h*Σᵢ bpᵢ*kᵢ + +function initialize!(integrator, cache::NystromVIConstantCache) + integrator.kshortsize = 2 + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) + duprev, uprev = integrator.uprev.x + kdu = integrator.f.f1(duprev, uprev, integrator.p, integrator.t) + ku = integrator.f.f2(duprev, uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + integrator.stats.nf2 += 1 + integrator.fsalfirst = ArrayPartition((kdu, ku)) + integrator.fsallast = zero(integrator.fsalfirst) + integrator.k[1] = integrator.fsalfirst + return integrator.k[2] = integrator.fsallast +end + +function initialize!(integrator, cache::NystromVICache) + integrator.kshortsize = 2 + resize!(integrator.k, integrator.kshortsize) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + duprev, uprev = integrator.uprev.x + integrator.f.f1(integrator.fsalfirst.x[1], duprev, uprev, integrator.p, integrator.t) + integrator.f.f2(integrator.fsalfirst.x[2], duprev, uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + return integrator.stats.nf2 += 1 +end + +@muladd function perform_step!( + integrator, cache::NystromVIConstantCache, repeat_step = false + ) + (; t, dt, f, p) = integrator + duprev, uprev = integrator.uprev.x + (; tab) = cache + (; a, b, bp, btilde, bptilde, c, pos_only_error) = tab + k1 = integrator.fsalfirst.x[1] + nstages = length(b) + dtsq = dt^2 + + # Compute intermediate stages + ks = Vector{typeof(k1)}(undef, nstages) + ks[1] = k1 + for i in 2:nstages + ku = uprev + dt * c[i - 1] * duprev + for j in 1:(i - 1) + if !iszero(a[i, j]) + ku = ku + dtsq * a[i, j] * ks[j] + end + end + ks[i] = f.f1(duprev, ku, p, t + dt * c[i - 1]) + end + + # Position and velocity updates + u = uprev + dt * duprev + for i in 1:nstages + if !iszero(b[i]) + u = u + dtsq * b[i] * ks[i] + end + end + du = duprev + for i in 1:nstages + if !iszero(bp[i]) + du = du + dt * bp[i] * ks[i] + end + end + + integrator.u = ArrayPartition((du, u)) + integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, nstages) + integrator.stats.nf2 += 1 + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + + if integrator.opts.adaptive && !isempty(btilde) + uhat = zero(uprev) + for i in 1:nstages + if !iszero(btilde[i]) + uhat = uhat + dtsq * btilde[i] * ks[i] + end + end + if pos_only_error + atmp = calculate_residuals( + uhat, integrator.uprev.x[2], integrator.u.x[2], + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t + ) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + else + duhat = zero(duprev) + for i in 1:nstages + if !isempty(bptilde) && !iszero(bptilde[i]) + duhat = duhat + dt * bptilde[i] * ks[i] + end + end + utilde = ArrayPartition((duhat, uhat)) + atmp = calculate_residuals( + utilde, integrator.uprev, integrator.u, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t + ) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + end + end +end + +@muladd function perform_step!( + integrator, cache::NystromVICache, repeat_step = false + ) + (; t, dt, f, p) = integrator + du, u = integrator.u.x + duprev, uprev = integrator.uprev.x + (; ks, k, utilde, tmp, atmp, tab) = cache + (; a, b, bp, btilde, bptilde, c, pos_only_error) = tab + ku = tmp.x[2] + k1 = integrator.fsalfirst.x[1] + nstages = length(b) + dtsq = dt^2 + + # Compute intermediate stages k2..knstages, stored in ks[1..nstages-1] + for i in 2:nstages + @.. broadcast = false ku = uprev + dt * c[i - 1] * duprev + for j in 1:(i - 1) + if !iszero(a[i, j]) + kj = (j == 1) ? k1 : ks[j - 1] + @.. broadcast = false ku = ku + dtsq * a[i, j] * kj + end + end + f.f1(ks[i - 1], duprev, ku, p, t + dt * c[i - 1]) + end + + # Position update: u = uprev + dt*duprev + dt^2 * sum(b[i]*ki) + @.. broadcast = false u = uprev + dt * duprev + for i in 1:nstages + if !iszero(b[i]) + ki = (i == 1) ? k1 : ks[i - 1] + @.. broadcast = false u = u + dtsq * b[i] * ki + end + end + + # Velocity update: du = duprev + dt * sum(bp[i]*ki) + @.. broadcast = false du = duprev + for i in 1:nstages + if !iszero(bp[i]) + ki = (i == 1) ? k1 : ks[i - 1] + @.. broadcast = false du = du + dt * bp[i] * ki + end + end + + f.f1(k.x[1], du, u, p, t + dt) + f.f2(k.x[2], du, u, p, t + dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, nstages) + integrator.stats.nf2 += 1 + + if integrator.opts.adaptive && !isempty(btilde) + if pos_only_error + uhat = utilde.x[2] + @.. broadcast = false uhat = zero(uhat) + for i in 1:nstages + if !iszero(btilde[i]) + ki = (i == 1) ? k1 : ks[i - 1] + @.. broadcast = false uhat = uhat + dtsq * btilde[i] * ki + end + end + calculate_residuals!( + atmp.x[2], uhat, integrator.uprev.x[2], integrator.u.x[2], + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t + ) + integrator.EEst = integrator.opts.internalnorm(atmp.x[2], t) + else + duhat, uhat = utilde.x + @.. broadcast = false uhat = zero(uhat) + @.. broadcast = false duhat = zero(duhat) + for i in 1:nstages + ki = (i == 1) ? k1 : ks[i - 1] + if !iszero(btilde[i]) + @.. broadcast = false uhat = uhat + dtsq * btilde[i] * ki + end + if !isempty(bptilde) && !iszero(bptilde[i]) + @.. broadcast = false duhat = duhat + dt * bptilde[i] * ki + end + end + calculate_residuals!( + atmp, utilde, integrator.uprev, integrator.u, + integrator.opts.abstol, integrator.opts.reltol, + integrator.opts.internalnorm, t + ) + integrator.EEst = integrator.opts.internalnorm(atmp, t) + end + end +end diff --git a/lib/OrdinaryDiffEqRKN/src/rkn_caches.jl b/lib/OrdinaryDiffEqRKN/src/rkn_caches.jl index 27e4317b986..fde820592c6 100644 --- a/lib/OrdinaryDiffEqRKN/src/rkn_caches.jl +++ b/lib/OrdinaryDiffEqRKN/src/rkn_caches.jl @@ -1,18 +1,43 @@ abstract type NystromMutableCache <: OrdinaryDiffEqMutableCache end get_fsalfirstlast(cache::NystromMutableCache, u) = (cache.fsalfirst, cache.k) -@cache struct Nystrom4Cache{uType, rateType, reducedRateType} <: NystromMutableCache +## Generic velocity-independent Nyström caches + +struct NystromVIConstantCache{T, T2} <: NystromConstantCache + tab::NystromVITableau{T, T2} +end + +@cache struct NystromVICache{uType, rateType, reducedRateType, uNoUnitsType, T, T2} <: + NystromMutableCache u::uType uprev::uType fsalfirst::rateType - k₂::reducedRateType - k₃::reducedRateType - k₄::reducedRateType + ks::Vector{reducedRateType} # stage derivatives k2..kN (length nstages-1) k::rateType + utilde::uType tmp::uType + atmp::uNoUnitsType + tab::NystromVITableau{T, T2} +end + +## Generic velocity-dependent Nyström caches + +struct NystromVDConstantCache{T, T2} <: NystromConstantCache + tab::NystromVDTableau{T, T2} end -# struct Nystrom4ConstantCache <: NystromConstantCache end +@cache struct NystromVDCache{uType, rateType, reducedRateType, uNoUnitsType, T, T2} <: + NystromMutableCache + u::uType + uprev::uType + fsalfirst::rateType + ks::Vector{reducedRateType} # stage derivatives k2..kN (length nstages-1) + k::rateType + utilde::uType + tmp::uType + atmp::uNoUnitsType + tab::NystromVDTableau{T, T2} +end function alg_cache( alg::Nystrom4, u, rate_prototype, ::Type{uEltypeNoUnits}, @@ -21,42 +46,27 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - k₁ = zero(rate_prototype) - k₂ = zero(reduced_rate_prototype) - k₃ = zero(reduced_rate_prototype) - k₄ = zero(reduced_rate_prototype) + tab = Nystrom4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) + k1 = zero(rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) + utilde = zero(u) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) tmp = zero(u) - return Nystrom4Cache(u, uprev, k₁, k₂, k₃, k₄, k, tmp) + return NystromVDCache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end -struct Nystrom4ConstantCache <: NystromConstantCache end - function alg_cache( alg::Nystrom4, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return Nystrom4ConstantCache() -end - -# alg_cache(alg::Nystrom4,u,rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = Nystrom4ConstantCache(constvalue(uBottomEltypeNoUnits),constvalue(tTypeNoUnits)) - -@cache struct FineRKN4Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k5::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType + return NystromVDConstantCache( + Nystrom4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end function alg_cache( @@ -66,18 +76,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = FineRKN4ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = FineRKN4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) - k5 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return FineRKN4Cache(u, uprev, k1, k2, k3, k4, k5, k, utilde, tmp, atmp, tab) + return NystromVDCache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -86,25 +94,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return FineRKN4ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) -end - -@cache struct FineRKN5Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k5::reducedRateType - k6::reducedRateType - k7::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType + return NystromVDConstantCache( + FineRKN4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end function alg_cache( @@ -114,20 +106,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = FineRKN5ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = FineRKN5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) - k5 = zero(reduced_rate_prototype) - k6 = zero(reduced_rate_prototype) - k7 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return FineRKN5Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k, utilde, tmp, atmp, tab) + return NystromVDCache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -136,7 +124,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return FineRKN5ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + return NystromVDConstantCache( + FineRKN5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end @cache struct Nystrom4VelocityIndependentCache{uType, rateType, reducedRateType} <: @@ -157,12 +147,19 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - k₁ = zero(rate_prototype) - k₂ = zero(reduced_rate_prototype) - k₃ = zero(reduced_rate_prototype) + tab = Nystrom4VelocityIndependentTableau( + constvalue(uBottomEltypeNoUnits), + constvalue(tTypeNoUnits) + ) + nstages = length(tab.b) + k1 = zero(rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) + utilde = zero(u) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) tmp = zero(u) - return Nystrom4VelocityIndependentCache(u, uprev, k₁, k₂, k₃, k, tmp) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end struct Nystrom4VelocityIndependentConstantCache <: NystromConstantCache end @@ -173,7 +170,11 @@ function alg_cache( ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return Nystrom4VelocityIndependentConstantCache() + tab = Nystrom4VelocityIndependentTableau( + constvalue(uBottomEltypeNoUnits), + constvalue(tTypeNoUnits) + ) + return NystromVIConstantCache(tab) end @cache struct IRKN3Cache{uType, rateType, TabType} <: NystromMutableCache @@ -260,19 +261,6 @@ function alg_cache( return IRKN4ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) end -@cache struct Nystrom5VelocityIndependentCache{uType, rateType, reducedRateType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k₂::reducedRateType - k₃::reducedRateType - k₄::reducedRateType - k::rateType - tmp::uType - tab::TabType -end - function alg_cache( alg::Nystrom5VelocityIndependent, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, @@ -280,17 +268,19 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - k₁ = zero(rate_prototype) - k₂ = zero(reduced_rate_prototype) - k₃ = zero(reduced_rate_prototype) - k₄ = zero(reduced_rate_prototype) - k = zero(rate_prototype) - tmp = zero(u) - tab = Nystrom5VelocityIndependentConstantCache( + tab = Nystrom5VelocityIndependentTableau( constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits) ) - return Nystrom5VelocityIndependentCache(u, uprev, k₁, k₂, k₃, k₄, k, tmp, tab) + nstages = length(tab.b) + k1 = zero(rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] + k = zero(rate_prototype) + utilde = zero(u) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) + tmp = zero(u) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -299,25 +289,11 @@ function alg_cache( ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return Nystrom5VelocityIndependentConstantCache( + tab = Nystrom5VelocityIndependentTableau( constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits) ) -end - -struct DPRKN4Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType + return NystromVIConstantCache(tab) end function alg_cache( @@ -327,17 +303,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = DPRKN4ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = DPRKN4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return DPRKN4Cache(u, uprev, k1, k2, k3, k4, k, utilde, tmp, atmp, tab) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -346,24 +321,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return DPRKN4ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) -end - -@cache struct DPRKN5Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k5::reducedRateType - k6::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType + return NystromVIConstantCache( + DPRKN4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end function alg_cache( @@ -373,19 +333,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = DPRKN5ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = DPRKN5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) - k5 = zero(reduced_rate_prototype) - k6 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return DPRKN5Cache(u, uprev, k1, k2, k3, k4, k5, k6, k, utilde, tmp, atmp, tab) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -394,7 +351,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return DPRKN5ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + return NystromVIConstantCache( + DPRKN5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end @cache struct DPRKN6Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: @@ -445,23 +404,6 @@ function alg_cache( return DPRKN6ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) end -@cache struct DPRKN6FMCache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k5::reducedRateType - k6::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType -end - function alg_cache( alg::DPRKN6FM, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, @@ -469,19 +411,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = DPRKN6FMConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = DPRKN6FMTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) - k5 = zero(reduced_rate_prototype) - k6 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return DPRKN6FMCache(u, uprev, k1, k2, k3, k4, k5, k6, k, utilde, tmp, atmp, tab) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -490,27 +429,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return DPRKN6FMConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) -end - -@cache struct DPRKN8Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k5::reducedRateType - k6::reducedRateType - k7::reducedRateType - k8::reducedRateType - k9::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType + return NystromVIConstantCache( + DPRKN6FMTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end function alg_cache( @@ -520,22 +441,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = DPRKN8ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = DPRKN8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) - k5 = zero(reduced_rate_prototype) - k6 = zero(reduced_rate_prototype) - k7 = zero(reduced_rate_prototype) - k8 = zero(reduced_rate_prototype) - k9 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return DPRKN8Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k, utilde, tmp, atmp, tab) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -544,35 +459,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return DPRKN8ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) -end - -@cache struct DPRKN12Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k5::reducedRateType - k6::reducedRateType - k7::reducedRateType - k8::reducedRateType - k9::reducedRateType - k10::reducedRateType - k11::reducedRateType - k12::reducedRateType - k13::reducedRateType - k14::reducedRateType - k15::reducedRateType - k16::reducedRateType - k17::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType + return NystromVIConstantCache( + DPRKN8Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end function alg_cache( @@ -582,33 +471,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = DPRKN12ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = DPRKN12Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) - k5 = zero(reduced_rate_prototype) - k6 = zero(reduced_rate_prototype) - k7 = zero(reduced_rate_prototype) - k8 = zero(reduced_rate_prototype) - k9 = zero(reduced_rate_prototype) - k10 = zero(reduced_rate_prototype) - k11 = zero(reduced_rate_prototype) - k12 = zero(reduced_rate_prototype) - k13 = zero(reduced_rate_prototype) - k14 = zero(reduced_rate_prototype) - k15 = zero(reduced_rate_prototype) - k16 = zero(reduced_rate_prototype) - k17 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return DPRKN12Cache( - u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, - k16, k17, k, utilde, tmp, atmp, tab - ) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -617,22 +489,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return DPRKN12ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) -end - -@cache struct ERKN4Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType + return NystromVIConstantCache( + DPRKN12Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end function alg_cache( @@ -642,17 +501,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = ERKN4ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = ERKN4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return ERKN4Cache(u, uprev, k1, k2, k3, k4, k, utilde, tmp, atmp, tab) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -661,22 +519,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return ERKN4ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) -end - -@cache struct ERKN5Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType + return NystromVIConstantCache( + ERKN4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end function alg_cache( @@ -686,17 +531,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = ERKN5ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = ERKN5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return ERKN5Cache(u, uprev, k1, k2, k3, k4, k, utilde, tmp, atmp, tab) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -705,25 +549,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return ERKN5ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) -end - -@cache struct ERKN7Cache{uType, rateType, reducedRateType, uNoUnitsType, TabType} <: - NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k2::reducedRateType - k3::reducedRateType - k4::reducedRateType - k5::reducedRateType - k6::reducedRateType - k7::reducedRateType - k::rateType - utilde::uType - tmp::uType - atmp::uNoUnitsType - tab::TabType + return NystromVIConstantCache( + ERKN5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end function alg_cache( @@ -733,20 +561,16 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - tab = ERKN7ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + tab = ERKN7Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) k1 = zero(rate_prototype) - k2 = zero(reduced_rate_prototype) - k3 = zero(reduced_rate_prototype) - k4 = zero(reduced_rate_prototype) - k5 = zero(reduced_rate_prototype) - k6 = zero(reduced_rate_prototype) - k7 = zero(reduced_rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) utilde = zero(u) atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) tmp = zero(u) - return ERKN7Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k, utilde, tmp, atmp, tab) + return NystromVICache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end function alg_cache( @@ -755,17 +579,9 @@ function alg_cache( dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return ERKN7ConstantCache(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) -end - -@cache struct RKN4Cache{uType, rateType, reducedRateType} <: NystromMutableCache - u::uType - uprev::uType - fsalfirst::rateType - k₂::reducedRateType - k₃::reducedRateType - k::rateType - tmp::uType + return NystromVIConstantCache( + ERKN7Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end function alg_cache( @@ -775,21 +591,25 @@ function alg_cache( ::Val{true}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} reduced_rate_prototype = rate_prototype.x[2] - k₁ = zero(rate_prototype) - k₂ = zero(reduced_rate_prototype) - k₃ = zero(reduced_rate_prototype) + tab = RKN4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + nstages = length(tab.b) + k1 = zero(rate_prototype) + ks = [zero(reduced_rate_prototype) for _ in 2:nstages] k = zero(rate_prototype) + utilde = zero(u) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) tmp = zero(u) - return RKN4Cache(u, uprev, k₁, k₂, k₃, k, tmp) + return NystromVDCache(u, uprev, k1, ks, k, utilde, tmp, atmp, tab) end -struct RKN4ConstantCache <: NystromConstantCache end - function alg_cache( alg::RKN4, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}, verbose ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} - return RKN4ConstantCache() + return NystromVDConstantCache( + RKN4Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits)) + ) end diff --git a/lib/OrdinaryDiffEqRKN/src/rkn_perform_step.jl b/lib/OrdinaryDiffEqRKN/src/rkn_perform_step.jl index 7d36f0fd95a..c2346895fa3 100644 --- a/lib/OrdinaryDiffEqRKN/src/rkn_perform_step.jl +++ b/lib/OrdinaryDiffEqRKN/src/rkn_perform_step.jl @@ -5,15 +5,8 @@ ## y'₁ = y'₀ + h∑bᵢk'ᵢ const NystromCCDefaultInitialization = Union{ - Nystrom4ConstantCache, FineRKN4ConstantCache, - FineRKN5ConstantCache, Nystrom4VelocityIndependentConstantCache, - Nystrom5VelocityIndependentConstantCache, IRKN3ConstantCache, IRKN4ConstantCache, - DPRKN4ConstantCache, DPRKN5ConstantCache, - DPRKN6FMConstantCache, DPRKN8ConstantCache, - DPRKN12ConstantCache, ERKN4ConstantCache, - ERKN5ConstantCache, ERKN7ConstantCache, RKN4ConstantCache, } function initialize!(integrator, cache::NystromCCDefaultInitialization) @@ -29,14 +22,8 @@ function initialize!(integrator, cache::NystromCCDefaultInitialization) end const NystromDefaultInitialization = Union{ - Nystrom4Cache, FineRKN4Cache, FineRKN5Cache, Nystrom4VelocityIndependentCache, - Nystrom5VelocityIndependentCache, IRKN3Cache, IRKN4Cache, - DPRKN4Cache, DPRKN5Cache, - DPRKN6FMCache, DPRKN8Cache, - DPRKN12Cache, ERKN4Cache, - ERKN5Cache, ERKN7Cache, } function initialize!(integrator, cache::NystromDefaultInitialization) @@ -52,359 +39,7 @@ function initialize!(integrator, cache::NystromDefaultInitialization) return integrator.stats.nf2 += 1 end -@muladd function perform_step!( - integrator, cache::Nystrom4ConstantCache, - repeat_step = false - ) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - k₁ = integrator.fsalfirst.x[1] - halfdt = dt / 2 - dtsq = dt^2 - eighth_dtsq = dtsq / 8 - half_dtsq = dtsq / 2 - ttmp = t + halfdt - - ## y₁ = y₀ + hy'₀ + h²∑b̄ᵢk'ᵢ - ku = uprev + halfdt * duprev + eighth_dtsq * k₁ - ## y'₁ = y'₀ + h∑bᵢk'ᵢ - kdu = duprev + halfdt * k₁ - - k₂ = f.f1(kdu, ku, p, ttmp) - ku = uprev + halfdt * duprev + eighth_dtsq * k₁ - kdu = duprev + halfdt * k₂ - - k₃ = f.f1(kdu, ku, p, ttmp) - ku = uprev + dt * duprev + half_dtsq * k₃ - kdu = duprev + dt * k₃ - - k₄ = f.f1(kdu, ku, p, t + dt) - u = uprev + (dtsq / 6) * (k₁ + k₂ + k₃) + dt * duprev - du = duprev + (dt / 6) * (k₁ + k₄ + 2 * (k₂ + k₃)) - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast -end - -@muladd function perform_step!(integrator, cache::Nystrom4Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, fsalfirst, k₂, k₃, k₄, k) = cache - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - k₁ = integrator.fsalfirst.x[1] - halfdt = dt / 2 - dtsq = dt^2 - eighth_dtsq = dtsq / 8 - half_dtsq = dtsq / 2 - ttmp = t + halfdt - - ## y₁ = y₀ + hy'₀ + h²∑b̄ᵢk'ᵢ - @.. broadcast = false ku = uprev + halfdt * duprev + eighth_dtsq * k₁ - ## y'₁ = y'₀ + h∑bᵢk'ᵢ - @.. broadcast = false kdu = duprev + halfdt * k₁ - - f.f1(k₂, kdu, ku, p, ttmp) - @.. broadcast = false ku = uprev + halfdt * duprev + eighth_dtsq * k₁ - @.. broadcast = false kdu = duprev + halfdt * k₂ - - f.f1(k₃, kdu, ku, p, ttmp) - @.. broadcast = false ku = uprev + dt * duprev + half_dtsq * k₃ - @.. broadcast = false kdu = duprev + dt * k₃ - - f.f1(k₄, kdu, ku, p, t + dt) - @.. broadcast = false u = uprev + (dtsq / 6) * (k₁ + k₂ + k₃) + dt * duprev - @.. broadcast = false du = duprev + (dt / 6) * (k₁ + k₄ + 2 * (k₂ + k₃)) - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 -end - -@muladd function perform_step!( - integrator, cache::FineRKN4ConstantCache, - repeat_step = false - ) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; - c2, c3, c4, c5, a21, a31, a32, a41, a43, a51, - a52, a53, a54, abar21, abar31, abar32, abar41, abar42, abar43, abar51, - abar52, abar53, abar54, b1, b3, b4, b5, bbar1, bbar3, bbar4, bbar5, btilde1, btilde3, btilde4, btilde5, bptilde1, - bptilde3, bptilde4, bptilde5, - ) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c2 * duprev + dt * (a21 * k1)) - kdu = duprev + dt * (abar21 * k1) - - k2 = f.f1(kdu, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a31 * k1 + a32 * k2)) - kdu = duprev + dt * (abar31 * k1 + abar32 * k2) - - k3 = f.f1(kdu, ku, p, t + dt * c3) - ku = uprev + dt * (c4 * duprev + dt * (a41 * k1 + a43 * k3)) # a42 = 0 - kdu = duprev + dt * (abar41 * k1 + abar42 * k2 + abar43 * k3) - - k4 = f.f1(kdu, ku, p, t + dt * c4) - ku = uprev + dt * (c5 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - kdu = duprev + dt * (abar51 * k1 + abar52 * k2 + abar53 * k3 + abar54 * k4) - - k5 = f.f1(kdu, ku, p, t + dt * c5) - - u = uprev + dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) # b2 = 0 - du = duprev + dt * (bbar1 * k1 + bbar3 * k3 + bbar4 * k4 + bbar5 * k5) # bbar2 = 0 - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * (btilde1 * k1 + btilde3 * k3 + btilde4 * k4 + btilde5 * k5) # btilde2 = 0 - duhat = dt * (bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 + bptilde5 * k5) # bptilde2 = 0 - utilde = ArrayPartition((duhat, uhat)) - atmp = calculate_residuals( - utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::FineRKN4Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k5, k, utilde) = cache - (; - c2, c3, c4, c5, a21, a31, a32, a41, a43, a51, - a52, a53, a54, abar21, abar31, abar32, abar41, abar42, abar43, abar51, - abar52, abar53, abar54, b1, b3, b4, b5, bbar1, bbar3, bbar4, bbar5, btilde1, btilde3, btilde4, btilde5, bptilde1, - bptilde3, bptilde4, bptilde5, - ) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a21 * k1)) - @.. broadcast = false kdu = duprev + dt * (abar21 * k1) - - f.f1(k2, kdu, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + dt * (c3 * duprev + dt * (a31 * k1 + a32 * k2)) - @.. broadcast = false kdu = duprev + dt * (abar31 * k1 + abar32 * k2) - - f.f1(k3, kdu, ku, p, t + dt * c3) - @.. broadcast = false ku = uprev + - dt * (c4 * duprev + dt * (a41 * k1 + a43 * k3)) # a42 = 0 - @.. broadcast = false kdu = duprev + dt * (abar41 * k1 + abar42 * k2 + abar43 * k3) - - f.f1(k4, kdu, ku, p, t + dt * c4) - @.. broadcast = false ku = uprev + - dt * - (c5 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - @.. broadcast = false kdu = duprev + - dt * (abar51 * k1 + abar52 * k2 + abar53 * k3 + abar54 * k4) - - f.f1(k5, kdu, ku, p, t + dt * c5) - @.. broadcast = false u = uprev + - dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) # b2 = 0 - @.. broadcast = false du = duprev + - dt * - (bbar1 * k1 + bbar3 * k3 + bbar4 * k4 + bbar5 * k5) # bbar2 = 0 - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @.. broadcast = false uhat = dtsq * - ( - btilde1 * k1 + btilde3 * k3 + btilde4 * k4 + - btilde5 * k5 - ) # btilde2 = 0 - @.. broadcast = false duhat = dt * - ( - bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 + - bptilde5 * k5 - ) # bptilde2 = 0 - - calculate_residuals!( - atmp, utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!( - integrator, cache::FineRKN5ConstantCache, - repeat_step = false - ) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; c2, c3, c4, c5, c6, c7, a21, a31, a32, a41, a43, a51, a52, a53, a54, a61, a62, a63, a64, a71, a73, a74, a75, abar21, abar31, abar32, abar41, abar42, abar43, abar51, abar52, abar53, abar54, abar61, abar62, abar63, abar64, abar65, abar71, abar73, abar74, abar75, abar76, b1, b3, b4, b5, bbar1, bbar3, bbar4, bbar5, bbar6, btilde1, btilde3, btilde4, btilde5, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6, bptilde7) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c2 * duprev + dt * (a21 * k1)) - kdu = duprev + dt * (abar21 * k1) - - k2 = f.f1(kdu, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a31 * k1 + a32 * k2)) - kdu = duprev + dt * (abar31 * k1 + abar32 * k2) - - k3 = f.f1(kdu, ku, p, t + dt * c3) - ku = uprev + dt * (c4 * duprev + dt * (a41 * k1 + a43 * k3)) # a42 = 0 - kdu = duprev + dt * (abar41 * k1 + abar42 * k2 + abar43 * k3) - - k4 = f.f1(kdu, ku, p, t + dt * c4) - ku = uprev + dt * (c5 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - kdu = duprev + dt * (abar51 * k1 + abar52 * k2 + abar53 * k3 + abar54 * k4) - - k5 = f.f1(kdu, ku, p, t + dt * c5) - ku = uprev + - dt * (c6 * duprev + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4)) # a65 = 0 - kdu = duprev + - dt * (abar61 * k1 + abar62 * k2 + abar63 * k3 + abar64 * k4 + abar65 * k5) - - k6 = f.f1(kdu, ku, p, t + dt * c6) - ku = uprev + - dt * ( - c7 * duprev + - dt * (a71 * k1 + a73 * k3 + a74 * k4 + a75 * k5) - ) # a72 = a76 = 0 - kdu = duprev + - dt * ( - abar71 * k1 + abar73 * k3 + abar74 * k4 + abar75 * k5 + - abar76 * k6 - ) # abar72 = 0 - - k7 = f.f1(kdu, ku, p, t + dt * c7) - u = uprev + dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) # no b6, b7 - du = duprev + dt * (bbar1 * k1 + bbar3 * k3 + bbar4 * k4 + bbar5 * k5 + bbar6 * k6) # no b2, b7 - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 7) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * (btilde1 * k1 + btilde3 * k3 + btilde4 * k4 + btilde5 * k5) - duhat = dt * ( - bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 + bptilde5 * k5 + - bptilde6 * k6 + bptilde7 * k7 - ) - utilde = ArrayPartition((duhat, uhat)) - atmp = calculate_residuals( - utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::FineRKN5Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k5, k6, k7, k, utilde) = cache - (; c1, c2, c3, c4, c5, c6, c7, a21, a31, a32, a41, a43, a51, a52, a53, a54, a61, a62, a63, a64, a71, a73, a74, a75, abar21, abar31, abar32, abar41, abar42, abar43, abar51, abar52, abar53, abar54, abar61, abar62, abar63, abar64, abar65, abar71, abar73, abar74, abar75, abar76, b1, b3, b4, b5, bbar1, bbar3, bbar4, bbar5, bbar6, btilde1, btilde3, btilde4, btilde5, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6, bptilde7) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a21 * k1)) - @.. broadcast = false kdu = duprev + dt * (abar21 * k1) - - f.f1(k2, kdu, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + dt * (c3 * duprev + dt * (a31 * k1 + a32 * k2)) - @.. broadcast = false kdu = duprev + dt * (abar31 * k1 + abar32 * k2) - - f.f1(k3, kdu, ku, p, t + dt * c3) - @.. broadcast = false ku = uprev + - dt * (c4 * duprev + dt * (a41 * k1 + a43 * k3)) # a42 = 0 - @.. broadcast = false kdu = duprev + dt * (abar41 * k1 + abar42 * k2 + abar43 * k3) - - f.f1(k4, kdu, ku, p, t + dt * c4) - @.. broadcast = false ku = uprev + - dt * - (c5 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - @.. broadcast = false kdu = duprev + - dt * (abar51 * k1 + abar52 * k2 + abar53 * k3 + abar54 * k4) - - f.f1(k5, kdu, ku, p, t + dt * c5) - @.. broadcast = false ku = uprev + - dt * ( - c6 * duprev + - dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4) - ) # a65 = 0 - @.. broadcast = false kdu = duprev + - dt * ( - abar61 * k1 + abar62 * k2 + abar63 * k3 + abar64 * k4 + - abar65 * k5 - ) - - f.f1(k6, kdu, ku, p, t + dt * c6) - @.. broadcast = false ku = uprev + - dt * ( - c7 * duprev + - dt * (a71 * k1 + a73 * k3 + a74 * k4 + a75 * k5) - ) # a72 = a76 = 0 - @.. broadcast = false kdu = duprev + - dt * ( - abar71 * k1 + abar73 * k3 + abar74 * k4 + - abar75 * k5 + abar76 * k6 - ) # abar72 = 0 - - f.f1(k7, kdu, ku, p, t + dt * c7) - @.. broadcast = false u = uprev + - dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) - @.. broadcast = false du = duprev + - dt * - (bbar1 * k1 + bbar3 * k3 + bbar4 * k4 + bbar5 * k5 + bbar6 * k6) - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 7) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @.. broadcast = false uhat = dtsq * - ( - btilde1 * k1 + btilde3 * k3 + btilde4 * k4 + - btilde5 * k5 - ) - @.. broadcast = false duhat = dt * - ( - bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 + - bptilde5 * k5 + bptilde6 * k6 + bptilde7 * k7 - ) - - calculate_residuals!( - atmp, utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end +## Nystrom4VelocityIndependent perform_step! kept for IRKN3/IRKN4 bootstrap use @muladd function perform_step!( integrator, cache::Nystrom4VelocityIndependentConstantCache, @@ -419,7 +54,6 @@ end half_dtsq = dtsq / 2 ttmp = t + halfdt - ## y₁ = y₀ + hy'₀ + h²∑b̄ᵢk'ᵢ ku = uprev + halfdt * duprev + eighth_dtsq * k₁ k₂ = f.f1(duprev, ku, p, ttmp) @@ -453,7 +87,6 @@ end half_dtsq = dtsq / 2 ttmp = t + halfdt - ## y₁ = y₀ + hy'₀ + h²∑b̄ᵢk'ᵢ @.. broadcast = false ku = uprev + halfdt * duprev + eighth_dtsq * k₁ f.f1(k₂, duprev, ku, p, ttmp) @@ -607,78 +240,29 @@ end end # end if end -@muladd function perform_step!( - integrator, cache::Nystrom5VelocityIndependentConstantCache, - repeat_step = false - ) - (; t, dt, f, p) = integrator +function initialize!(integrator, cache::DPRKN6ConstantCache) duprev, uprev = integrator.uprev.x - (; c1, c2, a21, a31, a32, a41, a42, a43, bbar1, bbar2, bbar3, b1, b2, b3, b4) = cache - k₁ = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c1 * duprev + dt * a21 * k₁) - - k₂ = f.f1(duprev, ku, p, t + c1 * dt) - ku = uprev + dt * (c2 * duprev + dt * (a31 * k₁ + a32 * k₂)) - - k₃ = f.f1(duprev, ku, p, t + c2 * dt) - ku = uprev + dt * (duprev + dt * (a41 * k₁ + a42 * k₂ + a43 * k₃)) - - k₄ = f.f1(duprev, ku, p, t + dt) - u = uprev + dt * (duprev + dt * (bbar1 * k₁ + bbar2 * k₂ + bbar3 * k₃)) - du = duprev + dt * (b1 * k₁ + b2 * k₂ + b3 * k₃ + b4 * k₄) + integrator.kshortsize = 3 + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) + kdu = integrator.f.f1(duprev, uprev, integrator.p, integrator.t) + ku = integrator.f.f2(duprev, uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast -end - -@muladd function perform_step!( - integrator, cache::Nystrom5VelocityIndependentCache, - repeat_step = false - ) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - uidx = eachindex(integrator.uprev.x[1]) - (; tmp, fsalfirst, k₂, k₃, k₄, k) = cache - (; c1, c2, a21, a31, a32, a41, a42, a43, bbar1, bbar2, bbar3, b1, b2, b3, b4) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - k₁ = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c1 * duprev + dt * a21 * k₁) - - f.f1(k₂, du, ku, p, t + c1 * dt) - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a31 * k₁ + a32 * k₂)) - - f.f1(k₃, du, ku, p, t + c2 * dt) - #@tight_loop_macros for i in uidx - # @inbounds ku[i] = uprev[i] + dt*(duprev[i] + dt*(a41*k₁[i] + a42*k₂[i] + a43*k₃[i])) - #end - @.. broadcast = false ku = uprev + dt * (duprev + dt * (a41 * k₁ + a42 * k₂ + a43 * k₃)) + integrator.fsalfirst = ArrayPartition((kdu, ku)) + integrator.fsallast = zero(integrator.fsalfirst) - f.f1(k₄, duprev, ku, p, t + dt) - #@tight_loop_macros for i in uidx - # @inbounds u[i] = uprev[i] + dt*(duprev[i] + dt*(bbar1*k₁[i] + bbar2*k₂[i] + bbar3*k₃[i])) - # @inbounds du[i] = duprev[i] + dt*(b1*k₁[i] + b2*k₂[i] + b3*k₃[i] + b4*k₄[i]) - #end - @.. broadcast = false u = uprev + - dt * (duprev + dt * (bbar1 * k₁ + bbar2 * k₂ + bbar3 * k₃)) - @.. broadcast = false du = duprev + dt * (b1 * k₁ + b2 * k₂ + b3 * k₃ + b4 * k₄) - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - return nothing + integrator.k[1] = integrator.fsalfirst + @inbounds for i in 2:(integrator.kshortsize - 1) + integrator.k[i] = zero(integrator.fsalfirst) + end + return integrator.k[integrator.kshortsize] = integrator.fsallast end -@muladd function perform_step!(integrator, cache::DPRKN4ConstantCache, repeat_step = false) +@muladd function perform_step!(integrator, cache::DPRKN6ConstantCache, repeat_step = false) (; t, dt, f, p) = integrator duprev, uprev = integrator.uprev.x - (; c1, c2, c3, a21, a31, a32, a41, a42, a43, b1, b2, b3, bp1, bp2, bp3, bp4, btilde1, btilde2, btilde3, btilde4, bptilde1, bptilde2, bptilde3, bptilde4) = cache + (; c1, c2, c3, c4, c5, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, b1, b3, b4, b5, bp1, bp3, bp4, bp5, bp6, btilde1, btilde2, btilde3, btilde4, btilde5, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6) = cache k1 = integrator.fsalfirst.x[1] ku = uprev + dt * (c1 * duprev + dt * a21 * k1) @@ -690,275 +274,21 @@ end ku = uprev + dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) k4 = f.f1(duprev, ku, p, t + dt * c3) + ku = uprev + dt * (c4 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - u = uprev + dt * (duprev + dt * (b1 * k1 + b2 * k2 + b3 * k3)) - du = duprev + dt * (bp1 * k1 + bp2 * k2 + bp3 * k3 + bp4 * k4) + k5 = f.f1(duprev, ku, p, t + dt * c4) + ku = uprev + dt * (c5 * duprev + dt * (a61 * k1 + a63 * k3 + a64 * k4 + a65 * k5)) # no a62 + + k6 = f.f1(duprev, ku, p, t + dt * c5) + u = uprev + dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) # b1 -- b5, no b2 + du = duprev + dt * (bp1 * k1 + bp3 * k3 + bp4 * k4 + bp5 * k5 + bp6 * k6) # bp1 -- bp6, no bp2 integrator.u = ArrayPartition((du, u)) integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * (btilde1 * k1 + btilde2 * k2 + btilde3 * k3 + btilde4 * k4) - duhat = dt * (bptilde1 * k1 + bptilde2 * k2 + bptilde3 * k3 + bptilde4 * k4) - utilde = ArrayPartition((duhat, uhat)) - atmp = calculate_residuals( - utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::DPRKN4Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k, utilde) = cache - (; c1, c2, c3, a21, a31, a32, a41, a42, a43, b1, b2, b3, bp1, bp2, bp3, bp4, btilde1, btilde2, btilde3, btilde4, bptilde1, bptilde2, bptilde3, bptilde4) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - f.f1(k2, duprev, ku, p, t + dt * c1) - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - f.f1(k3, duprev, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + - dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - f.f1(k4, duprev, ku, p, t + dt * c3) - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + - dt * ( - duprev[i] + - dt * (b1 * k1[i] + b2 * k2[i] + b3 * k3[i]) - ) - @inbounds du[i] = duprev[i] + - dt * (bp1 * k1[i] + bp2 * k2[i] + bp3 * k3[i] + bp4 * k4[i]) - end - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @tight_loop_macros for i in uidx - @inbounds uhat[i] = dtsq * - ( - btilde1 * k1[i] + btilde2 * k2[i] + btilde3 * k3[i] + - btilde4 * k4[i] - ) - @inbounds duhat[i] = dt * - ( - bptilde1 * k1[i] + bptilde2 * k2[i] + bptilde3 * k3[i] + - bptilde4 * k4[i] - ) - end - calculate_residuals!( - atmp, utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::DPRKN5ConstantCache, repeat_step = false) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; c1, c2, c3, c4, c5, a21, a31, a32, a41, a43, a51, a53, a54, a61, a63, a64, a65, b1, b3, b4, b5, bp1, bp3, bp4, bp5, bp6, btilde1, btilde3, btilde4, btilde5, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - k2 = f.f1(duprev, ku, p, t + dt * c1) - ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - k3 = f.f1(duprev, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a41 * k1 + a43 * k3)) - - k4 = f.f1(duprev, ku, p, t + dt * c3) - ku = uprev + dt * (c4 * duprev + dt * (a51 * k1 + a53 * k3 + a54 * k4)) - - k5 = f.f1(duprev, ku, p, t + dt * c4) - ku = uprev + - dt * (c5 * duprev + dt * (a61 * k1 + a63 * k3 + a64 * k4 + a65 * k5)) - - k6 = f.f1(duprev, ku, p, t + dt * c5) - u = uprev + - dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) - du = duprev + - dt * (bp1 * k1 + bp3 * k3 + bp4 * k4 + bp5 * k5 + bp6 * k6) - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 6) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * (btilde1 * k1 + btilde3 * k3 + btilde4 * k4 + btilde5 * k5) - duhat = dt * ( - bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 + bptilde5 * k5 + - bptilde6 * k6 - ) - utilde = ArrayPartition((duhat, uhat)) - atmp = calculate_residuals( - utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::DPRKN5Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k5, k6, k, utilde) = cache - (; c1, c2, c3, c4, c5, a21, a31, a32, a41, a43, a51, a53, a54, a61, a63, a64, a65, b1, b3, b4, b5, bp1, bp3, bp4, bp5, bp6, btilde1, btilde3, btilde4, btilde5, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - f.f1(k2, duprev, ku, p, t + dt * c1) - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - f.f1(k3, duprev, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + - dt * (c3 * duprev + dt * (a41 * k1 + a43 * k3)) - - f.f1(k4, duprev, ku, p, t + dt * c3) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c4 * duprev[i] + - dt * (a51 * k1[i] + a53 * k3[i] + a54 * k4[i]) - ) - end - - f.f1(k5, duprev, ku, p, t + dt * c4) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c5 * duprev[i] + - dt * (a61 * k1[i] + a63 * k3[i] + a64 * k4[i] + a65 * k5[i]) - ) - end - - f.f1(k6, duprev, ku, p, t + dt * c5) - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + - dt * ( - duprev[i] + - dt * (b1 * k1[i] + b3 * k3[i] + b4 * k4[i] + b5 * k5[i]) - ) - @inbounds du[i] = duprev[i] + - dt * ( - bp1 * k1[i] + bp3 * k3[i] + bp4 * k4[i] + bp5 * k5[i] + - bp6 * k6[i] - ) - end - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 6) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @tight_loop_macros for i in uidx - @inbounds uhat[i] = dtsq * - ( - btilde1 * k1[i] + btilde3 * k3[i] + btilde4 * k4[i] + - btilde5 * k5[i] - ) - @inbounds duhat[i] = dt * - ( - bptilde1 * k1[i] + bptilde3 * k3[i] + bptilde4 * k4[i] + - bptilde5 * k5[i] + bptilde6 * k6[i] - ) - end - calculate_residuals!( - atmp, utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -function initialize!(integrator, cache::DPRKN6ConstantCache) - duprev, uprev = integrator.uprev.x - integrator.kshortsize = 3 - integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) - - kdu = integrator.f.f1(duprev, uprev, integrator.p, integrator.t) - ku = integrator.f.f2(duprev, uprev, integrator.p, integrator.t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - integrator.stats.nf2 += 1 - integrator.fsalfirst = ArrayPartition((kdu, ku)) - integrator.fsallast = zero(integrator.fsalfirst) - - integrator.k[1] = integrator.fsalfirst - @inbounds for i in 2:(integrator.kshortsize - 1) - integrator.k[i] = zero(integrator.fsalfirst) - end - return integrator.k[integrator.kshortsize] = integrator.fsallast -end - -@muladd function perform_step!(integrator, cache::DPRKN6ConstantCache, repeat_step = false) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; c1, c2, c3, c4, c5, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a63, a64, a65, b1, b3, b4, b5, bp1, bp3, bp4, bp5, bp6, btilde1, btilde2, btilde3, btilde4, btilde5, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - k2 = f.f1(duprev, ku, p, t + dt * c1) - ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - k3 = f.f1(duprev, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - k4 = f.f1(duprev, ku, p, t + dt * c3) - ku = uprev + dt * (c4 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - - k5 = f.f1(duprev, ku, p, t + dt * c4) - ku = uprev + dt * (c5 * duprev + dt * (a61 * k1 + a63 * k3 + a64 * k4 + a65 * k5)) # no a62 - - k6 = f.f1(duprev, ku, p, t + dt * c5) - u = uprev + dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5)) # b1 -- b5, no b2 - du = duprev + dt * (bp1 * k1 + bp3 * k3 + bp4 * k4 + bp5 * k5 + bp6 * k6) # bp1 -- bp6, no bp2 - - #= - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + dt*(duprev[i] + dt*(bhat1*k1.x[2][i] + bhat2*k2.x[2][i] + bhat3*k3.x[2][i])) - @inbounds du[i] = duprev[i]+ dt*(bphat1*k1.x[2][i] + bphat3*k3.x[2][i] + bphat4*k4.x[2][i] + bphat5*k5.x[2][i] + bphat6*k6.x[2][i]) - end - =# - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - integrator.k[1] = ArrayPartition(integrator.fsalfirst.x[1], k2) - integrator.k[2] = ArrayPartition(k3, k4) - integrator.k[3] = ArrayPartition(k5, k6) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 6) + integrator.k[1] = ArrayPartition(integrator.fsalfirst.x[1], k2) + integrator.k[2] = ArrayPartition(k3, k4) + integrator.k[3] = ArrayPartition(k5, k6) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 6) integrator.stats.nf2 += 1 if integrator.opts.adaptive @@ -1029,13 +359,6 @@ end @.. broadcast = false du = duprev + dt * (bp1 * k1 + bp3 * k3 + bp4 * k4 + bp5 * k5 + bp6 * k6) # bp1 -- bp6, no bp2 - #= - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + dt*(duprev[i] + dt*(bhat1*k1.x[2][i] + bhat2*k2.x[2][i] + bhat3*k3.x[2][i])) - @inbounds du[i] = duprev[i]+ dt*(bphat1*k1.x[2][i] + bphat3*k3.x[2][i] + bphat4*k4.x[2][i] + bphat5*k5.x[2][i] + bphat6*k6.x[2][i]) - end - =# - f.f1(k.x[1], du, u, p, t + dt) f.f2(k.x[2], du, u, p, t + dt) OrdinaryDiffEqCore.increment_nf!(integrator.stats, 6) @@ -1059,1150 +382,3 @@ end integrator.EEst = integrator.opts.internalnorm(atmp, t) end end - -@muladd function perform_step!( - integrator, cache::DPRKN6FMConstantCache, - repeat_step = false - ) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; c1, c2, c3, c4, c5, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, b1, b2, b3, b4, b5, bp1, bp2, bp3, bp4, bp5, bp6, btilde1, btilde2, btilde3, btilde4, btilde5, bptilde1, bptilde2, bptilde3, bptilde4, bptilde5) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - k2 = f.f1(duprev, ku, p, t + dt * c1) - ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - k3 = f.f1(duprev, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - k4 = f.f1(duprev, ku, p, t + dt * c3) - ku = uprev + dt * (c4 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - - k5 = f.f1(duprev, ku, p, t + dt * c4) - ku = uprev + - dt * (c5 * duprev + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5)) - - k6 = f.f1(duprev, ku, p, t + dt * c5) - u = uprev + - dt * (duprev + dt * (b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 + b5 * k5)) - du = duprev + - dt * (bp1 * k1 + bp2 * k2 + bp3 * k3 + bp4 * k4 + bp5 * k5 + bp6 * k6) - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 6) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * - (btilde1 * k1 + btilde2 * k2 + btilde3 * k3 + btilde4 * k4 + btilde5 * k5) - duhat = dt * ( - bptilde1 * k1 + bptilde2 * k2 + bptilde3 * k3 + bptilde4 * k4 + - bptilde5 * k5 - ) - utilde = ArrayPartition((duhat, uhat)) - atmp = calculate_residuals( - utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::DPRKN6FMCache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k5, k6, k, utilde) = cache - (; c1, c2, c3, c4, c5, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, b1, b2, b3, b4, b5, bp1, bp2, bp3, bp4, bp5, bp6, btilde1, btilde2, btilde3, btilde4, btilde5, bptilde1, bptilde2, bptilde3, bptilde4, bptilde5) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - f.f1(k2, duprev, ku, p, t + dt * c1) - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - f.f1(k3, duprev, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + - dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - f.f1(k4, duprev, ku, p, t + dt * c3) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c4 * duprev[i] + - dt * (a51 * k1[i] + a52 * k2[i] + a53 * k3[i] + a54 * k4[i]) - ) - end - - f.f1(k5, duprev, ku, p, t + dt * c4) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c5 * duprev[i] + - dt * ( - a61 * k1[i] + a62 * k2[i] + a63 * k3[i] + a64 * k4[i] + - a65 * k5[i] - ) - ) - end - - f.f1(k6, duprev, ku, p, t + dt * c5) - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + - dt * ( - duprev[i] + - dt * - (b1 * k1[i] + b2 * k2[i] + b3 * k3[i] + b4 * k4[i] + b5 * k5[i]) - ) - @inbounds du[i] = duprev[i] + - dt * ( - bp1 * k1[i] + bp2 * k2[i] + bp3 * k3[i] + bp4 * k4[i] + - bp5 * k5[i] + bp6 * k6[i] - ) - end - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 6) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @tight_loop_macros for i in uidx - @inbounds uhat[i] = dtsq * - ( - btilde1 * k1[i] + btilde2 * k2[i] + btilde3 * k3[i] + - btilde4 * k4[i] + btilde5 * k5[i] - ) - @inbounds duhat[i] = dt * - ( - bptilde1 * k1[i] + bptilde2 * k2[i] + bptilde3 * k3[i] + - bptilde4 * k4[i] + bptilde5 * k5[i] - ) - end - calculate_residuals!( - atmp, utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::DPRKN8ConstantCache, repeat_step = false) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; c1, c2, c3, c4, c5, c6, c7, c8, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a81, a82, a83, a84, a85, a86, a87, a91, a93, a94, a95, a96, a97, b1, b3, b4, b5, b6, b7, bp1, bp3, bp4, bp5, bp6, bp7, bp8, btilde1, btilde3, btilde4, btilde5, btilde6, btilde7, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6, bptilde7, bptilde8, bptilde9) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - k2 = f.f1(duprev, ku, p, t + dt * c1) - ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - k3 = f.f1(duprev, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - k4 = f.f1(duprev, ku, p, t + dt * c3) - ku = uprev + dt * (c4 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - - k5 = f.f1(duprev, ku, p, t + dt * c4) - ku = uprev + - dt * (c5 * duprev + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5)) - - k6 = f.f1(duprev, ku, p, t + dt * c5) - ku = uprev + - dt * ( - c6 * duprev + - dt * (a71 * k1 + a72 * k2 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6) - ) - - k7 = f.f1(duprev, ku, p, t + dt * c6) - ku = uprev + - dt * ( - c7 * duprev + - dt * (a81 * k1 + a82 * k2 + a83 * k3 + a84 * k4 + a85 * k5 + a86 * k6 + a87 * k7) - ) - - k8 = f.f1(duprev, ku, p, t + dt * c7) - ku = uprev + - dt * ( - c8 * duprev + - dt * (a91 * k1 + a93 * k3 + a94 * k4 + a95 * k5 + a96 * k6 + a97 * k7) - ) # no a92 & a98 - - k9 = f.f1(duprev, ku, p, t + dt * c8) - u = uprev + - dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6 + b7 * k7)) # b1 -- b7, no b2 - du = duprev + - dt * (bp1 * k1 + bp3 * k3 + bp4 * k4 + bp5 * k5 + bp6 * k6 + bp7 * k7 + bp8 * k8) # bp1 -- bp8, no bp2 - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 9) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * - ( - btilde1 * k1 + btilde3 * k3 + btilde4 * k4 + btilde5 * k5 + btilde6 * k6 + - btilde7 * k7 - ) - duhat = dt * ( - bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 + bptilde5 * k5 + - bptilde6 * k6 + bptilde7 * k7 + bptilde8 * k8 + bptilde9 * k9 - ) - utilde = ArrayPartition((duhat, uhat)) - atmp = calculate_residuals( - utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::DPRKN8Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k5, k6, k7, k8, k9, k, utilde) = cache - (; c1, c2, c3, c4, c5, c6, c7, c8, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a72, a73, a74, a75, a76, a81, a82, a83, a84, a85, a86, a87, a91, a93, a94, a95, a96, a97, b1, b3, b4, b5, b6, b7, bp1, bp3, bp4, bp5, bp6, bp7, bp8, btilde1, btilde3, btilde4, btilde5, btilde6, btilde7, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6, bptilde7, bptilde8, bptilde9) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - f.f1(k2, duprev, ku, p, t + dt * c1) - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - f.f1(k3, duprev, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + - dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - f.f1(k4, duprev, ku, p, t + dt * c3) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c4 * duprev[i] + - dt * (a51 * k1[i] + a52 * k2[i] + a53 * k3[i] + a54 * k4[i]) - ) - end - - f.f1(k5, duprev, ku, p, t + dt * c4) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c5 * duprev[i] + - dt * ( - a61 * k1[i] + a62 * k2[i] + a63 * k3[i] + a64 * k4[i] + - a65 * k5[i] - ) - ) - end - - f.f1(k6, duprev, ku, p, t + dt * c5) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c6 * duprev[i] + - dt * ( - a71 * k1[i] + a72 * k2[i] + a73 * k3[i] + a74 * k4[i] + - a75 * k5[i] + a76 * k6[i] - ) - ) - end - - f.f1(k7, duprev, ku, p, t + dt * c6) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c7 * duprev[i] + - dt * ( - a81 * k1[i] + a82 * k2[i] + a83 * k3[i] + a84 * k4[i] + - a85 * k5[i] + a86 * k6[i] + a87 * k7[i] - ) - ) - end - - f.f1(k8, duprev, ku, p, t + dt * c7) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c8 * duprev[i] + - dt * ( - a91 * k1[i] + a93 * k3[i] + a94 * k4[i] + a95 * k5[i] + - a96 * k6[i] + a97 * k7[i] - ) - ) # no a92 & a98 - end - - f.f1(k9, duprev, ku, p, t + dt * c8) - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + - dt * ( - duprev[i] + - dt * - ( - b1 * k1[i] + b3 * k3[i] + b4 * k4[i] + b5 * k5[i] + b6 * k6[i] + - b7 * k7[i] - ) - ) # b1 -- b7, no b2 - @inbounds du[i] = duprev[i] + - dt * ( - bp1 * k1[i] + bp3 * k3[i] + bp4 * k4[i] + bp5 * k5[i] + - bp6 * k6[i] + bp7 * k7[i] + bp8 * k8[i] - ) # bp1 -- bp8, no bp2 - end - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 9) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @tight_loop_macros for i in uidx - @inbounds uhat[i] = dtsq * - ( - btilde1 * k1[i] + btilde3 * k3[i] + btilde4 * k4[i] + - btilde5 * k5[i] + btilde6 * k6[i] + btilde7 * k7[i] - ) - @inbounds duhat[i] = dt * - ( - bptilde1 * k1[i] + bptilde3 * k3[i] + bptilde4 * k4[i] + - bptilde5 * k5[i] + bptilde6 * k6[i] + bptilde7 * k7[i] + - bptilde8 * k8[i] + bptilde9 * k9[i] - ) - end - calculate_residuals!( - atmp, utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::DPRKN12ConstantCache, repeat_step = false) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, a21, a31, a32, a41, a42, a43, a51, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, a81, a84, a85, a86, a87, a91, a93, a94, a95, a96, a97, a98, a101, a103, a104, a105, a106, a107, a108, a109, a111, a113, a114, a115, a116, a117, a118, a119, a1110, a121, a123, a124, a125, a126, a127, a128, a129, a1210, a1211, a131, a133, a134, a135, a136, a137, a138, a139, a1310, a1311, a1312, a141, a143, a144, a145, a146, a147, a148, a149, a1410, a1411, a1412, a1413, a151, a153, a154, a155, a156, a157, a158, a159, a1510, a1511, a1512, a1513, a1514, a161, a163, a164, a165, a166, a167, a168, a169, a1610, a1611, a1612, a1613, a1614, a1615, a171, a173, a174, a175, a176, a177, a178, a179, a1710, a1711, a1712, a1713, a1714, a1715, b1, b7, b8, b9, b10, b11, b12, b13, b14, b15, bp1, bp7, bp8, bp9, bp10, bp11, bp12, bp13, bp14, bp15, bp16, bp17, btilde1, btilde7, btilde8, btilde9, btilde10, btilde11, btilde12, btilde13, btilde14, btilde15, bptilde1, bptilde7, bptilde8, bptilde9, bptilde10, bptilde11, bptilde12, bptilde13, bptilde14, bptilde15, bptilde16, bptilde17) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - k2 = f.f1(duprev, ku, p, t + dt * c1) - ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - k3 = f.f1(duprev, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - k4 = f.f1(duprev, ku, p, t + dt * c3) - ku = uprev + dt * (c4 * duprev + dt * (a51 * k1 + a53 * k3 + a54 * k4)) # no a52 - - k5 = f.f1(duprev, ku, p, t + dt * c4) - ku = uprev + dt * (c5 * duprev + dt * (a61 * k1 + a63 * k3 + a64 * k4 + a65 * k5)) # no a62 - - k6 = f.f1(duprev, ku, p, t + dt * c5) - ku = uprev + - dt * (c6 * duprev + dt * (a71 * k1 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6)) # no a72 - - k7 = f.f1(duprev, ku, p, t + dt * c6) - ku = uprev + - dt * (c7 * duprev + dt * (a81 * k1 + a84 * k4 + a85 * k5 + a86 * k6 + a87 * k7)) # no a82, a83 - - k8 = f.f1(duprev, ku, p, t + dt * c7) - ku = uprev + - dt * ( - c8 * duprev + - dt * (a91 * k1 + a93 * k3 + a94 * k4 + a95 * k5 + a96 * k6 + a97 * k7 + a98 * k8) - ) # no a92 - - k9 = f.f1(duprev, ku, p, t + dt * c8) - ku = uprev + - dt * ( - c9 * duprev + - dt * ( - a101 * k1 + a103 * k3 + a104 * k4 + a105 * k5 + a106 * k6 + a107 * k7 + - a108 * k8 + a109 * k9 - ) - ) # no a102 - - k10 = f.f1(duprev, ku, p, t + dt * c9) - ku = uprev + - dt * ( - c10 * duprev + - dt * ( - a111 * k1 + a113 * k3 + a114 * k4 + a115 * k5 + a116 * k6 + a117 * k7 + - a118 * k8 + a119 * k9 + a1110 * k10 - ) - ) # no a112 - - k11 = f.f1(duprev, ku, p, t + dt * c10) - ku = uprev + - dt * ( - c11 * duprev + - dt * ( - a121 * k1 + a123 * k3 + a124 * k4 + a125 * k5 + a126 * k6 + a127 * k7 + - a128 * k8 + a129 * k9 + a1210 * k10 + a1211 * k11 - ) - ) # no a122 - - k12 = f.f1(duprev, ku, p, t + dt * c11) - ku = uprev + - dt * ( - c12 * duprev + - dt * ( - a131 * k1 + a133 * k3 + a134 * k4 + a135 * k5 + a136 * k6 + a137 * k7 + - a138 * k8 + a139 * k9 + a1310 * k10 + a1311 * k11 + a1312 * k12 - ) - ) # no a132 - - k13 = f.f1(duprev, ku, p, t + dt * c12) - ku = uprev + - dt * ( - c13 * duprev + - dt * ( - a141 * k1 + a143 * k3 + a144 * k4 + a145 * k5 + a146 * k6 + a147 * k7 + - a148 * k8 + a149 * k9 + a1410 * k10 + a1411 * k11 + a1412 * k12 + a1413 * k13 - ) - ) # no a142 - - k14 = f.f1(duprev, ku, p, t + dt * c13) - ku = uprev + - dt * ( - c14 * duprev + - dt * ( - a151 * k1 + a153 * k3 + a154 * k4 + a155 * k5 + a156 * k6 + a157 * k7 + - a158 * k8 + a159 * k9 + a1510 * k10 + a1511 * k11 + a1512 * k12 + a1513 * k13 + - a1514 * k14 - ) - ) # no a152 - - k15 = f.f1(duprev, ku, p, t + dt * c14) - ku = uprev + - dt * ( - c15 * duprev + - dt * ( - a161 * k1 + a163 * k3 + a164 * k4 + a165 * k5 + a166 * k6 + a167 * k7 + - a168 * k8 + a169 * k9 + a1610 * k10 + a1611 * k11 + a1612 * k12 + a1613 * k13 + - a1614 * k14 + a1615 * k15 - ) - ) # no a162 - - k16 = f.f1(duprev, ku, p, t + dt * c15) - ku = uprev + - dt * ( - c16 * duprev + - dt * ( - a171 * k1 + a173 * k3 + a174 * k4 + a175 * k5 + a176 * k6 + a177 * k7 + - a178 * k8 + a179 * k9 + a1710 * k10 + a1711 * k11 + a1712 * k12 + a1713 * k13 + - a1714 * k14 + a1715 * k15 - ) - ) # no a172, a1716 - - k17 = f.f1(duprev, ku, p, t + dt * c16) - u = uprev + - dt * ( - duprev + - dt * ( - b1 * k1 + b7 * k7 + b8 * k8 + b9 * k9 + b10 * k10 + b11 * k11 + b12 * k12 + - b13 * k13 + b14 * k14 + b15 * k15 - ) - ) # b1 & b7 -- b15 - du = duprev + - dt * - ( - bp1 * k1 + bp7 * k7 + bp8 * k8 + bp9 * k9 + bp10 * k10 + bp11 * k11 + bp12 * k12 + - bp13 * k13 + bp14 * k14 + bp15 * k15 + bp16 * k16 + bp17 * k17 - ) # bp1 & bp7 -- bp17 - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 17) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * - ( - btilde1 * k1 + btilde7 * k7 + btilde8 * k8 + btilde9 * k9 + btilde10 * k10 + - btilde11 * k11 + btilde12 * k12 + btilde13 * k13 + btilde14 * k14 + - btilde15 * k15 - ) # btilde1 & btilde7 -- btilde15 - duhat = dt * ( - bptilde1 * k1 + bptilde7 * k7 + bptilde8 * k8 + bptilde9 * k9 + - bptilde10 * k10 + bptilde11 * k11 + bptilde12 * k12 + bptilde13 * k13 + - bptilde14 * k14 + bptilde15 * k15 + bptilde16 * k16 + bptilde17 * k17 - ) # bptilde1 & bptilde7 -- bptilde17 - utilde = ArrayPartition((duhat, uhat)) - atmp = calculate_residuals( - utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::DPRKN12Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, k17, k, utilde) = cache - (; c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, a21, a31, a32, a41, a42, a43, a51, a53, a54, a61, a63, a64, a65, a71, a73, a74, a75, a76, a81, a84, a85, a86, a87, a91, a93, a94, a95, a96, a97, a98, a101, a103, a104, a105, a106, a107, a108, a109, a111, a113, a114, a115, a116, a117, a118, a119, a1110, a121, a123, a124, a125, a126, a127, a128, a129, a1210, a1211, a131, a133, a134, a135, a136, a137, a138, a139, a1310, a1311, a1312, a141, a143, a144, a145, a146, a147, a148, a149, a1410, a1411, a1412, a1413, a151, a153, a154, a155, a156, a157, a158, a159, a1510, a1511, a1512, a1513, a1514, a161, a163, a164, a165, a166, a167, a168, a169, a1610, a1611, a1612, a1613, a1614, a1615, a171, a173, a174, a175, a176, a177, a178, a179, a1710, a1711, a1712, a1713, a1714, a1715, b1, b7, b8, b9, b10, b11, b12, b13, b14, b15, bp1, bp7, bp8, bp9, bp10, bp11, bp12, bp13, bp14, bp15, bp16, bp17, btilde1, btilde7, btilde8, btilde9, btilde10, btilde11, btilde12, btilde13, btilde14, btilde15, bptilde1, bptilde7, bptilde8, bptilde9, bptilde10, bptilde11, bptilde12, bptilde13, bptilde14, bptilde15, bptilde16, bptilde17) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - f.f1(k2, duprev, ku, p, t + dt * c1) - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - f.f1(k3, duprev, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + - dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - f.f1(k4, duprev, ku, p, t + dt * c3) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * - (c4 * duprev[i] + dt * (a51 * k1[i] + a53 * k3[i] + a54 * k4[i])) # no a52 - end - - f.f1(k5, duprev, ku, p, t + dt * c4) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c5 * duprev[i] + - dt * (a61 * k1[i] + a63 * k3[i] + a64 * k4[i] + a65 * k5[i]) - ) # no a62 - end - - f.f1(k6, duprev, ku, p, t + dt * c5) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c6 * duprev[i] + - dt * ( - a71 * k1[i] + a73 * k3[i] + a74 * k4[i] + a75 * k5[i] + - a76 * k6[i] - ) - ) # no a72 - end - - f.f1(k7, duprev, ku, p, t + dt * c6) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c7 * duprev[i] + - dt * ( - a81 * k1[i] + a84 * k4[i] + a85 * k5[i] + a86 * k6[i] + - a87 * k7[i] - ) - ) # no a82, a83 - end - - f.f1(k8, duprev, ku, p, t + dt * c7) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c8 * duprev[i] + - dt * ( - a91 * k1[i] + a93 * k3[i] + a94 * k4[i] + a95 * k5[i] + - a96 * k6[i] + a97 * k7[i] + a98 * k8[i] - ) - ) # no a92 - end - - f.f1(k9, duprev, ku, p, t + dt * c8) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c9 * duprev[i] + - dt * - ( - a101 * k1[i] + a103 * k3[i] + a104 * k4[i] + a105 * k5[i] + - a106 * k6[i] + a107 * k7[i] + a108 * k8[i] + a109 * k9[i] - ) - ) # no a102 - end - - f.f1(k10, duprev, ku, p, t + dt * c9) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c10 * duprev[i] + - dt * - ( - a111 * k1[i] + a113 * k3[i] + a114 * k4[i] + a115 * k5[i] + - a116 * k6[i] + a117 * k7[i] + a118 * k8[i] + a119 * k9[i] + - a1110 * k10[i] - ) - ) # no a112 - end - - f.f1(k11, duprev, ku, p, t + dt * c10) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c11 * duprev[i] + - dt * - ( - a121 * k1[i] + a123 * k3[i] + a124 * k4[i] + a125 * k5[i] + - a126 * k6[i] + a127 * k7[i] + a128 * k8[i] + a129 * k9[i] + - a1210 * k10[i] + a1211 * k11[i] - ) - ) # no a122 - end - - f.f1(k12, duprev, ku, p, t + dt * c11) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c12 * duprev[i] + - dt * - ( - a131 * k1[i] + a133 * k3[i] + a134 * k4[i] + a135 * k5[i] + - a136 * k6[i] + a137 * k7[i] + a138 * k8[i] + a139 * k9[i] + - a1310 * k10[i] + a1311 * k11[i] + a1312 * k12[i] - ) - ) # no a132 - end - - f.f1(k13, duprev, ku, p, t + dt * c12) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c13 * duprev[i] + - dt * - ( - a141 * k1[i] + a143 * k3[i] + a144 * k4[i] + a145 * k5[i] + - a146 * k6[i] + a147 * k7[i] + a148 * k8[i] + a149 * k9[i] + - a1410 * k10[i] + a1411 * k11[i] + a1412 * k12[i] + - a1413 * k13[i] - ) - ) # no a142 - end - - f.f1(k14, duprev, ku, p, t + dt * c13) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c14 * duprev[i] + - dt * - ( - a151 * k1[i] + a153 * k3[i] + a154 * k4[i] + a155 * k5[i] + - a156 * k6[i] + a157 * k7[i] + a158 * k8[i] + a159 * k9[i] + - a1510 * k10[i] + a1511 * k11[i] + a1512 * k12[i] + - a1513 * k13[i] + a1514 * k14[i] - ) - ) # no a152 - end - - f.f1(k15, duprev, ku, p, t + dt * c14) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c15 * duprev[i] + - dt * - ( - a161 * k1[i] + a163 * k3[i] + a164 * k4[i] + a165 * k5[i] + - a166 * k6[i] + a167 * k7[i] + a168 * k8[i] + a169 * k9[i] + - a1610 * k10[i] + a1611 * k11[i] + a1612 * k12[i] + - a1613 * k13[i] + a1614 * k14[i] + a1615 * k15[i] - ) - ) # no a162 - end - - f.f1(k16, duprev, ku, p, t + dt * c15) - @tight_loop_macros for i in uidx - @inbounds ku[i] = uprev[i] + - dt * ( - c16 * duprev[i] + - dt * - ( - a171 * k1[i] + a173 * k3[i] + a174 * k4[i] + a175 * k5[i] + - a176 * k6[i] + a177 * k7[i] + a178 * k8[i] + a179 * k9[i] + - a1710 * k10[i] + a1711 * k11[i] + a1712 * k12[i] + - a1713 * k13[i] + a1714 * k14[i] + a1715 * k15[i] - ) - ) # no a172, a1716 - end - - f.f1(k17, duprev, ku, p, t + dt * c16) - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + - dt * ( - duprev[i] + - dt * ( - b1 * k1[i] + b7 * k7[i] + b8 * k8[i] + b9 * k9[i] + - b10 * k10[i] + b11 * k11[i] + b12 * k12[i] + b13 * k13[i] + - b14 * k14[i] + b15 * k15[i] - ) - ) # b1 & b7 -- b15 - @inbounds du[i] = duprev[i] + - dt * ( - bp1 * k1[i] + bp7 * k7[i] + bp8 * k8[i] + bp9 * k9[i] + - bp10 * k10[i] + bp11 * k11[i] + bp12 * k12[i] + bp13 * k13[i] + - bp14 * k14[i] + bp15 * k15[i] + bp16 * k16[i] + bp17 * k17[i] - ) # bp1 & bp7 -- bp17 - end - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 17) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @tight_loop_macros for i in uidx - @inbounds uhat[i] = dtsq * - ( - btilde1 * k1[i] + btilde7 * k7[i] + btilde8 * k8[i] + - btilde9 * k9[i] + btilde10 * k10[i] + btilde11 * k11[i] + - btilde12 * k12[i] + btilde13 * k13[i] + btilde14 * k14[i] + - btilde15 * k15[i] - ) # btilde1 & btilde7 -- btilde15 - @inbounds duhat[i] = dt * - ( - bptilde1 * k1[i] + bptilde7 * k7[i] + bptilde8 * k8[i] + - bptilde9 * k9[i] + bptilde10 * k10[i] + - bptilde11 * k11[i] + bptilde12 * k12[i] + - bptilde13 * k13[i] + bptilde14 * k14[i] + - bptilde15 * k15[i] + bptilde16 * k16[i] + - bptilde17 * k17[i] - ) # bptilde1 & bptilde7 -- bptilde17 - end - calculate_residuals!( - atmp, utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::ERKN4ConstantCache, repeat_step = false) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; c1, c2, c3, a21, a31, a32, a41, a42, a43, b1, b2, b3, b4, bp1, bp2, bp3, bp4, btilde1, btilde2, btilde3, btilde4, bptilde1, bptilde2, bptilde3, bptilde4) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - k2 = f.f1(duprev, ku, p, t + dt * c1) - ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - k3 = f.f1(duprev, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - k4 = f.f1(duprev, ku, p, t + dt * c3) - u = uprev + dt * (duprev + dt * (b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4)) - du = duprev + dt * (bp1 * k1 + bp2 * k2 + bp3 * k3 + bp4 * k4) - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * (btilde1 * k1 + btilde2 * k2 + btilde3 * k3 + btilde4 * k4) - duhat = dt * (bptilde1 * k1 + bptilde2 * k2 + bptilde3 * k3 + bptilde4 * k4) - utilde = ArrayPartition((duhat, uhat)) - atmp = calculate_residuals( - utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::ERKN4Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k, utilde) = cache - (; c1, c2, c3, a21, a31, a32, a41, a42, a43, b1, b2, b3, b4, bp1, bp2, bp3, bp4, btilde1, btilde2, btilde3, btilde4, bptilde1, bptilde2, bptilde3, bptilde4) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - f.f1(k2, duprev, ku, p, t + dt * c1) - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - f.f1(k3, duprev, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + - dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - f.f1(k4, duprev, ku, p, t + dt * c3) - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + - dt * ( - duprev[i] + - dt * (b1 * k1[i] + b2 * k2[i] + b3 * k3[i] + b4 * k4[i]) - ) - @inbounds du[i] = duprev[i] + - dt * (bp1 * k1[i] + bp2 * k2[i] + bp3 * k3[i] + bp4 * k4[i]) - end - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @tight_loop_macros for i in uidx - @inbounds uhat[i] = dtsq * - ( - btilde1 * k1[i] + btilde2 * k2[i] + btilde3 * k3[i] + - btilde4 * k4[i] - ) - @inbounds duhat[i] = dt * - ( - bptilde1 * k1[i] + bptilde2 * k2[i] + bptilde3 * k3[i] + - bptilde4 * k4[i] - ) - end - calculate_residuals!( - atmp, utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::ERKN5ConstantCache, repeat_step = false) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; c1, c2, c3, a21, a31, a32, a41, a42, a43, b1, b2, b3, b4, bp1, bp2, bp3, bp4, btilde1, btilde2, btilde3, btilde4) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - k2 = f.f1(duprev, ku, p, t + dt * c1) - ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - k3 = f.f1(duprev, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - k4 = f.f1(duprev, ku, p, t + dt * c3) - u = uprev + dt * (duprev + dt * (b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4)) - du = duprev + dt * (bp1 * k1 + bp2 * k2 + bp3 * k3 + bp4 * k4) - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * (btilde1 * k1 + btilde2 * k2 + btilde3 * k3 + btilde4 * k4) - atmp = calculate_residuals( - uhat, integrator.uprev.x[2], integrator.u.x[2], - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::ERKN5Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k, utilde) = cache - (; c1, c2, c3, a21, a31, a32, a41, a42, a43, b1, b2, b3, b4, bp1, bp2, bp3, bp4, btilde1, btilde2, btilde3, btilde4) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - f.f1(k2, duprev, ku, p, t + dt * c1) - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - f.f1(k3, duprev, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + - dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - f.f1(k4, duprev, ku, p, t + dt * c3) - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + - dt * ( - duprev[i] + - dt * (b1 * k1[i] + b2 * k2[i] + b3 * k3[i] + b4 * k4[i]) - ) - @inbounds du[i] = duprev[i] + - dt * (bp1 * k1[i] + bp2 * k2[i] + bp3 * k3[i] + bp4 * k4[i]) - end - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @tight_loop_macros for i in uidx - @inbounds uhat[i] = dtsq * - ( - btilde1 * k1[i] + btilde2 * k2[i] + btilde3 * k3[i] + - btilde4 * k4[i] - ) - end - calculate_residuals!( - atmp.x[2], uhat, integrator.uprev.x[2], integrator.u.x[2], - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp.x[2], t) - end -end - -@muladd function perform_step!(integrator, cache::ERKN7ConstantCache, repeat_step = false) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - (; c1, c2, c3, c4, c5, c6, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a73, a74, a75, a76, b1, b3, b4, b5, b6, bp1, bp3, bp4, bp5, bp6, bp7, btilde1, btilde3, btilde4, btilde5, btilde6, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6, bptilde7) = cache - k1 = integrator.fsalfirst.x[1] - - ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - k2 = f.f1(duprev, ku, p, t + dt * c1) - ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - k3 = f.f1(duprev, ku, p, t + dt * c2) - ku = uprev + dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - k4 = f.f1(duprev, ku, p, t + dt * c3) - ku = uprev + dt * (c4 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - - k5 = f.f1(duprev, ku, p, t + dt * c4) - ku = uprev + - dt * (c5 * duprev + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5)) - - k6 = f.f1(duprev, ku, p, t + dt * c5) - ku = uprev + - dt * (c6 * duprev + dt * (a71 * k1 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6)) - - k7 = f.f1(duprev, ku, p, t + dt * c6) - u = uprev + dt * (duprev + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6)) - du = duprev + dt * (bp1 * k1 + bp3 * k3 + bp4 * k4 + bp5 * k5 + bp6 * k6 + bp7 * k7) - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - - if integrator.opts.adaptive - dtsq = dt^2 - uhat = dtsq * - (btilde1 * k1 + btilde3 * k3 + btilde4 * k4 + btilde5 * k5 + btilde6 * k6) - duhat = dt * ( - bptilde1 * k1 + bptilde3 * k3 + bptilde4 * k4 + bptilde5 * k5 + - bptilde6 * k6 + bptilde7 * k7 - ) - utilde = ArrayPartition((duhat, uhat)) - atmp = calculate_residuals( - utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -@muladd function perform_step!(integrator, cache::ERKN7Cache, repeat_step = false) - (; t, dt, f, p) = integrator - du, u = integrator.u.x - duprev, uprev = integrator.uprev.x - (; tmp, atmp, fsalfirst, k2, k3, k4, k5, k6, k7, k, utilde) = cache - (; c1, c2, c3, c4, c5, c6, a21, a31, a32, a41, a42, a43, a51, a52, a53, a54, a61, a62, a63, a64, a65, a71, a73, a74, a75, a76, b1, b3, b4, b5, b6, bp1, bp3, bp4, bp5, bp6, bp7, btilde1, btilde3, btilde4, btilde5, btilde6, bptilde1, bptilde3, bptilde4, bptilde5, bptilde6, bptilde7) = cache.tab - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - uidx = eachindex(integrator.uprev.x[2]) - k1 = integrator.fsalfirst.x[1] - - @.. broadcast = false ku = uprev + dt * (c1 * duprev + dt * a21 * k1) - - f.f1(k2, duprev, ku, p, t + dt * c1) - @.. broadcast = false ku = uprev + dt * (c2 * duprev + dt * (a31 * k1 + a32 * k2)) - - f.f1(k3, duprev, ku, p, t + dt * c2) - @.. broadcast = false ku = uprev + - dt * (c3 * duprev + dt * (a41 * k1 + a42 * k2 + a43 * k3)) - - f.f1(k4, duprev, ku, p, t + dt * c3) - @.. broadcast = false ku = uprev + - dt * - (c4 * duprev + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4)) - - f.f1(k5, duprev, ku, p, t + dt * c4) - @.. broadcast = false ku = uprev + - dt * ( - c5 * duprev + - dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5) - ) - - f.f1(k6, duprev, ku, p, t + dt * c5) - @.. broadcast = false ku = uprev + - dt * ( - c6 * duprev + - dt * (a71 * k1 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6) - ) - - f.f1(k7, duprev, ku, p, t + dt * c6) - @tight_loop_macros for i in uidx - @inbounds u[i] = uprev[i] + - dt * ( - duprev[i] + - dt * - (b1 * k1[i] + b3 * k3[i] + b4 * k4[i] + b5 * k5[i] + b6 * k6[i]) - ) - @inbounds du[i] = duprev[i] + - dt * ( - bp1 * k1[i] + bp3 * k3[i] + bp4 * k4[i] + bp5 * k5[i] + - bp6 * k6[i] + bp7 * k7[i] - ) - end - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 4) - integrator.stats.nf2 += 1 - if integrator.opts.adaptive - duhat, uhat = utilde.x - dtsq = dt^2 - @tight_loop_macros for i in uidx - @inbounds uhat[i] = dtsq * - ( - btilde1 * k1[i] + btilde3 * k3[i] + btilde4 * k4[i] + - btilde5 * k5[i] + btilde6 * k6[i] - ) - @inbounds duhat[i] = dt * - ( - bptilde1 * k1[i] + bptilde3 * k3[i] + bptilde4 * k4[i] + - bptilde5 * k5[i] + bptilde6 * k6[i] + bptilde7 * k7[i] - ) - end - calculate_residuals!( - atmp, utilde, integrator.uprev, integrator.u, - integrator.opts.abstol, integrator.opts.reltol, - integrator.opts.internalnorm, t - ) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - end -end - -function initialize!(integrator, cache::RKN4Cache) - (; fsalfirst, k) = cache - duprev, uprev = integrator.uprev.x - integrator.kshortsize = 2 - resize!(integrator.k, integrator.kshortsize) - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast - integrator.f.f1(integrator.k[1].x[1], duprev, uprev, integrator.p, integrator.t) - integrator.f.f2(integrator.k[1].x[2], duprev, uprev, integrator.p, integrator.t) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) - return integrator.stats.nf2 += 1 -end - -@muladd function perform_step!(integrator, cache::RKN4ConstantCache, repeat_step = false) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - u, du = integrator.u.x - #define dt values - halfdt = dt / 2 - dtsq = dt^2 - eightdtsq = dtsq / 8 - halfdtsq = dtsq / 2 - sixthdtsq = dtsq / 6 - sixthdt = dt / 6 - ttmp = t + halfdt - - #perform operations to find k values - k₁ = integrator.fsalfirst.x[1] - ku = uprev + halfdt * duprev + eightdtsq * k₁ - kdu = duprev + halfdt * k₁ - - k₂ = f.f1(kdu, ku, p, ttmp) - ku = uprev + dt * duprev + halfdtsq * k₂ - kdu = duprev + dt * k₂ - - k₃ = f.f1(kdu, ku, p, t + dt) - - #perform final calculations to determine new y and y'. - u = uprev + sixthdtsq * (1 * k₁ + 2 * k₂ + 0 * k₃) + dt * duprev - du = duprev + sixthdt * (1 * k₁ + 4 * k₂ + 1 * k₃) - - integrator.u = ArrayPartition((du, u)) - integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt))) - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2) - integrator.stats.nf2 += 1 - integrator.k[1] = integrator.fsalfirst - integrator.k[2] = integrator.fsallast -end - -@muladd function perform_step!(integrator, cache::RKN4Cache, repeat_step = false) - (; t, dt, f, p) = integrator - duprev, uprev = integrator.uprev.x - du, u = integrator.u.x - (; tmp, fsalfirst, k₂, k₃, k) = cache - kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2] - - #define dt values - halfdt = dt / 2 - dtsq = dt^2 - eightdtsq = dtsq / 8 - halfdtsq = dtsq / 2 - sixthdtsq = dtsq / 6 - sixthdt = dt / 6 - ttmp = t + halfdt - - #perform operations to find k values - k₁ = integrator.fsalfirst.x[1] - @.. broadcast = false ku = uprev + halfdt * duprev + eightdtsq * k₁ - @.. broadcast = false kdu = duprev + halfdt * k₁ - - f.f1(k₂, kdu, ku, p, ttmp) - @.. broadcast = false ku = uprev + dt * duprev + halfdtsq * k₂ - @.. broadcast = false kdu = duprev + dt * k₂ - - f.f1(k₃, kdu, ku, p, t + dt) - - #perform final calculations to determine new y and y'. - @.. broadcast = false u = uprev + sixthdtsq * (1 * k₁ + 2 * k₂ + 0 * k₃) + dt * duprev - @.. broadcast = false du = duprev + sixthdt * (1 * k₁ + 4 * k₂ + 1 * k₃) - - f.f1(k.x[1], du, u, p, t + dt) - f.f2(k.x[2], du, u, p, t + dt) - - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2) - integrator.stats.nf2 += 1 -end diff --git a/lib/OrdinaryDiffEqRKN/src/rkn_tableaus.jl b/lib/OrdinaryDiffEqRKN/src/rkn_tableaus.jl index da3968a02e5..8848034385e 100644 --- a/lib/OrdinaryDiffEqRKN/src/rkn_tableaus.jl +++ b/lib/OrdinaryDiffEqRKN/src/rkn_tableaus.jl @@ -1,4 +1,360 @@ abstract type NystromConstantCache <: OrdinaryDiffEqConstantCache end + +""" + NystromVITableau{T, T2} + +Tableau for velocity-independent Nyström methods. +Fields: +- `a`: nstages × nstages lower-triangular position coupling matrix +- `b`: position update weights (length nstages; b[end] may be 0 if last stage not in position update) +- `bp`: velocity update weights (length nstages) +- `btilde`: embedded position error weights (empty if non-adaptive) +- `bptilde`: embedded velocity error weights (empty if non-adaptive) +- `c`: time nodes for stages 2..nstages (length nstages-1); c[i] is node for stage i+1 +- `pos_only_error`: if true, error estimate uses only position components (ERKN5 behaviour) +""" +struct NystromVITableau{T, T2} + a::Matrix{T} + b::Vector{T} + bp::Vector{T} + btilde::Vector{T} + bptilde::Vector{T} + c::Vector{T2} + pos_only_error::Bool +end + +function DPRKN4Tableau(T::Type, T2::Type) + tab = DPRKN4ConstantCache(T, T2) + # 4 stages: k1 = fsalfirst, k2, k3, k4 + nstages = 4 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 2] = tab.a42; a[4, 3] = tab.a43 + # DPRKN4: u = uprev + dt*(duprev + dt*(b1*k1 + b2*k2 + b3*k3)) (b4=0) + b = [tab.b1, tab.b2, tab.b3, zero(T)] + bp = [tab.bp1, tab.bp2, tab.bp3, tab.bp4] + btilde = [tab.btilde1, tab.btilde2, tab.btilde3, tab.btilde4] + bptilde = [tab.bptilde1, tab.bptilde2, tab.bptilde3, tab.bptilde4] + c = [tab.c1, tab.c2, tab.c3] # c for stages 2, 3, 4 + return NystromVITableau(a, b, bp, btilde, bptilde, c, false) +end + +function DPRKN5Tableau(T::Type, T2::Type) + tab = DPRKN5ConstantCache(T, T2) + # 6 stages: k1..k6; c1..c5 for stages 2..6 + nstages = 6 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 3] = tab.a43 # a42=0 + a[5, 1] = tab.a51; a[5, 3] = tab.a53; a[5, 4] = tab.a54 # a52=0 + a[6, 1] = tab.a61; a[6, 3] = tab.a63; a[6, 4] = tab.a64; a[6, 5] = tab.a65 # a62=0 + # u = uprev + dt*(duprev + dt*(b1*k1 + b3*k3 + b4*k4 + b5*k5)) b2=b6=0 + b = [tab.b1, zero(T), tab.b3, tab.b4, tab.b5, zero(T)] + bp = [tab.bp1, zero(T), tab.bp3, tab.bp4, tab.bp5, tab.bp6] + btilde = [tab.btilde1, zero(T), tab.btilde3, tab.btilde4, tab.btilde5, zero(T)] + bptilde = [tab.bptilde1, zero(T), tab.bptilde3, tab.bptilde4, tab.bptilde5, tab.bptilde6] + c = [tab.c1, tab.c2, tab.c3, tab.c4, tab.c5] # c for stages 2..6 + return NystromVITableau(a, b, bp, btilde, bptilde, c, false) +end + +function DPRKN6FMTableau(T::Type, T2::Type) + tab = DPRKN6FMConstantCache(T, T2) + # 6 stages: k1..k6; c1..c5 for stages 2..6 + nstages = 6 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 2] = tab.a42; a[4, 3] = tab.a43 + a[5, 1] = tab.a51; a[5, 2] = tab.a52; a[5, 3] = tab.a53; a[5, 4] = tab.a54 + a[6, 1] = tab.a61; a[6, 2] = tab.a62; a[6, 3] = tab.a63; a[6, 4] = tab.a64; a[6, 5] = tab.a65 + # u = uprev + dt*(duprev + dt*(b1*k1+b2*k2+b3*k3+b4*k4+b5*k5)) b6=0 + b = [tab.b1, tab.b2, tab.b3, tab.b4, tab.b5, zero(T)] + bp = [tab.bp1, tab.bp2, tab.bp3, tab.bp4, tab.bp5, tab.bp6] + btilde = [tab.btilde1, tab.btilde2, tab.btilde3, tab.btilde4, tab.btilde5, zero(T)] + bptilde = [tab.bptilde1, tab.bptilde2, tab.bptilde3, tab.bptilde4, tab.bptilde5, zero(T)] + c = [tab.c1, tab.c2, tab.c3, tab.c4, tab.c5] + return NystromVITableau(a, b, bp, btilde, bptilde, c, false) +end + +function DPRKN8Tableau(T::Type, T2::Type) + tab = DPRKN8ConstantCache(T, T2) + # 9 stages: k1..k9; c1..c8 for stages 2..9 + nstages = 9 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 2] = tab.a42; a[4, 3] = tab.a43 + a[5, 1] = tab.a51; a[5, 2] = tab.a52; a[5, 3] = tab.a53; a[5, 4] = tab.a54 + a[6, 1] = tab.a61; a[6, 2] = tab.a62; a[6, 3] = tab.a63; a[6, 4] = tab.a64; a[6, 5] = tab.a65 + a[7, 1] = tab.a71; a[7, 2] = tab.a72; a[7, 3] = tab.a73; a[7, 4] = tab.a74; a[7, 5] = tab.a75; a[7, 6] = tab.a76 + a[8, 1] = tab.a81; a[8, 2] = tab.a82; a[8, 3] = tab.a83; a[8, 4] = tab.a84; a[8, 5] = tab.a85; a[8, 6] = tab.a86; a[8, 7] = tab.a87 + a[9, 1] = tab.a91; a[9, 3] = tab.a93; a[9, 4] = tab.a94; a[9, 5] = tab.a95; a[9, 6] = tab.a96; a[9, 7] = tab.a97 # a92=a98=0 + # u uses b1,b3..b7 (b2=b8=b9=0) + b = [tab.b1, zero(T), tab.b3, tab.b4, tab.b5, tab.b6, tab.b7, zero(T), zero(T)] + bp = [tab.bp1, zero(T), tab.bp3, tab.bp4, tab.bp5, tab.bp6, tab.bp7, tab.bp8, zero(T)] + btilde = [tab.btilde1, zero(T), tab.btilde3, tab.btilde4, tab.btilde5, tab.btilde6, tab.btilde7, zero(T), zero(T)] + bptilde = [tab.bptilde1, zero(T), tab.bptilde3, tab.bptilde4, tab.bptilde5, tab.bptilde6, tab.bptilde7, tab.bptilde8, tab.bptilde9] + c = [tab.c1, tab.c2, tab.c3, tab.c4, tab.c5, tab.c6, tab.c7, tab.c8] + return NystromVITableau(a, b, bp, btilde, bptilde, c, false) +end + +function DPRKN12Tableau(T::Type, T2::Type) + tab = DPRKN12ConstantCache(T, T2) + # 17 stages: k1..k17; c1..c16 for stages 2..17 + nstages = 17 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 2] = tab.a42; a[4, 3] = tab.a43 + a[5, 1] = tab.a51; a[5, 3] = tab.a53; a[5, 4] = tab.a54 # a52=0 + a[6, 1] = tab.a61; a[6, 3] = tab.a63; a[6, 4] = tab.a64; a[6, 5] = tab.a65 # a62=0 + a[7, 1] = tab.a71; a[7, 3] = tab.a73; a[7, 4] = tab.a74; a[7, 5] = tab.a75; a[7, 6] = tab.a76 # a72=0 + a[8, 1] = tab.a81; a[8, 4] = tab.a84; a[8, 5] = tab.a85; a[8, 6] = tab.a86; a[8, 7] = tab.a87 # a82=a83=0 + a[9, 1] = tab.a91; a[9, 3] = tab.a93; a[9, 4] = tab.a94; a[9, 5] = tab.a95; a[9, 6] = tab.a96; a[9, 7] = tab.a97; a[9, 8] = tab.a98 # a92=0 + a[10, 1] = tab.a101; a[10, 3] = tab.a103; a[10, 4] = tab.a104; a[10, 5] = tab.a105; a[10, 6] = tab.a106; a[10, 7] = tab.a107; a[10, 8] = tab.a108; a[10, 9] = tab.a109 # a102=0 + a[11, 1] = tab.a111; a[11, 3] = tab.a113; a[11, 4] = tab.a114; a[11, 5] = tab.a115; a[11, 6] = tab.a116; a[11, 7] = tab.a117; a[11, 8] = tab.a118; a[11, 9] = tab.a119; a[11, 10] = tab.a1110 # a112=0 + a[12, 1] = tab.a121; a[12, 3] = tab.a123; a[12, 4] = tab.a124; a[12, 5] = tab.a125; a[12, 6] = tab.a126; a[12, 7] = tab.a127; a[12, 8] = tab.a128; a[12, 9] = tab.a129; a[12, 10] = tab.a1210; a[12, 11] = tab.a1211 # a122=0 + a[13, 1] = tab.a131; a[13, 3] = tab.a133; a[13, 4] = tab.a134; a[13, 5] = tab.a135; a[13, 6] = tab.a136; a[13, 7] = tab.a137; a[13, 8] = tab.a138; a[13, 9] = tab.a139; a[13, 10] = tab.a1310; a[13, 11] = tab.a1311; a[13, 12] = tab.a1312 # a132=0 + a[14, 1] = tab.a141; a[14, 3] = tab.a143; a[14, 4] = tab.a144; a[14, 5] = tab.a145; a[14, 6] = tab.a146; a[14, 7] = tab.a147; a[14, 8] = tab.a148; a[14, 9] = tab.a149; a[14, 10] = tab.a1410; a[14, 11] = tab.a1411; a[14, 12] = tab.a1412; a[14, 13] = tab.a1413 # a142=0 + a[15, 1] = tab.a151; a[15, 3] = tab.a153; a[15, 4] = tab.a154; a[15, 5] = tab.a155; a[15, 6] = tab.a156; a[15, 7] = tab.a157; a[15, 8] = tab.a158; a[15, 9] = tab.a159; a[15, 10] = tab.a1510; a[15, 11] = tab.a1511; a[15, 12] = tab.a1512; a[15, 13] = tab.a1513; a[15, 14] = tab.a1514 # a152=0 + a[16, 1] = tab.a161; a[16, 3] = tab.a163; a[16, 4] = tab.a164; a[16, 5] = tab.a165; a[16, 6] = tab.a166; a[16, 7] = tab.a167; a[16, 8] = tab.a168; a[16, 9] = tab.a169; a[16, 10] = tab.a1610; a[16, 11] = tab.a1611; a[16, 12] = tab.a1612; a[16, 13] = tab.a1613; a[16, 14] = tab.a1614; a[16, 15] = tab.a1615 # a162=0 + a[17, 1] = tab.a171; a[17, 3] = tab.a173; a[17, 4] = tab.a174; a[17, 5] = tab.a175; a[17, 6] = tab.a176; a[17, 7] = tab.a177; a[17, 8] = tab.a178; a[17, 9] = tab.a179; a[17, 10] = tab.a1710; a[17, 11] = tab.a1711; a[17, 12] = tab.a1712; a[17, 13] = tab.a1713; a[17, 14] = tab.a1714; a[17, 15] = tab.a1715 # a172=a1716=0 + # u uses b1, b7..b15 (b2..b6=b16=b17=0) + b = [tab.b1, zero(T), zero(T), zero(T), zero(T), zero(T), tab.b7, tab.b8, tab.b9, tab.b10, tab.b11, tab.b12, tab.b13, tab.b14, tab.b15, zero(T), zero(T)] + bp = [tab.bp1, zero(T), zero(T), zero(T), zero(T), zero(T), tab.bp7, tab.bp8, tab.bp9, tab.bp10, tab.bp11, tab.bp12, tab.bp13, tab.bp14, tab.bp15, tab.bp16, tab.bp17] + btilde = [tab.btilde1, zero(T), zero(T), zero(T), zero(T), zero(T), tab.btilde7, tab.btilde8, tab.btilde9, tab.btilde10, tab.btilde11, tab.btilde12, tab.btilde13, tab.btilde14, tab.btilde15, zero(T), zero(T)] + bptilde = [tab.bptilde1, zero(T), zero(T), zero(T), zero(T), zero(T), tab.bptilde7, tab.bptilde8, tab.bptilde9, tab.bptilde10, tab.bptilde11, tab.bptilde12, tab.bptilde13, tab.bptilde14, tab.bptilde15, tab.bptilde16, tab.bptilde17] + c = [tab.c1, tab.c2, tab.c3, tab.c4, tab.c5, tab.c6, tab.c7, tab.c8, tab.c9, tab.c10, tab.c11, tab.c12, tab.c13, tab.c14, tab.c15, tab.c16] + return NystromVITableau(a, b, bp, btilde, bptilde, c, false) +end + +function ERKN4Tableau(T::Type, T2::Type) + tab = ERKN4ConstantCache(T, T2) + # 4 stages: k1..k4; c1,c2,c3 for stages 2,3,4 (c3=1 stored) + nstages = 4 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 2] = tab.a42; a[4, 3] = tab.a43 + b = [tab.b1, tab.b2, tab.b3, tab.b4] + bp = [tab.bp1, tab.bp2, tab.bp3, tab.bp4] + btilde = [tab.btilde1, tab.btilde2, tab.btilde3, tab.btilde4] + bptilde = [tab.bptilde1, tab.bptilde2, tab.bptilde3, tab.bptilde4] + c = [tab.c1, tab.c2, tab.c3] + return NystromVITableau(a, b, bp, btilde, bptilde, c, false) +end + +function ERKN5Tableau(T::Type, T2::Type) + tab = ERKN5ConstantCache(T, T2) + # 4 stages: k1..k4; c1,c2,c3 for stages 2,3,4 + # Note: ERKN5 has no bptilde coefficients (velocity error not tracked) + nstages = 4 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 2] = tab.a42; a[4, 3] = tab.a43 + b = [tab.b1, tab.b2, tab.b3, tab.b4] + bp = [tab.bp1, tab.bp2, tab.bp3, tab.bp4] + btilde = [tab.btilde1, tab.btilde2, tab.btilde3, tab.btilde4] + bptilde = T[] # no velocity error for ERKN5 + c = [tab.c1, tab.c2, tab.c3] + return NystromVITableau(a, b, bp, btilde, bptilde, c, true) # pos_only_error = true +end + +function ERKN7Tableau(T::Type, T2::Type) + tab = ERKN7ConstantCache(T, T2) + # 7 stages: k1..k7; c1..c6 for stages 2..7 + # Note: b2=0, bp2=0, btilde2=0, bptilde2=0 + nstages = 7 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 2] = tab.a42; a[4, 3] = tab.a43 + a[5, 1] = tab.a51; a[5, 2] = tab.a52; a[5, 3] = tab.a53; a[5, 4] = tab.a54 + a[6, 1] = tab.a61; a[6, 2] = tab.a62; a[6, 3] = tab.a63; a[6, 4] = tab.a64; a[6, 5] = tab.a65 + a[7, 1] = tab.a71; a[7, 3] = tab.a73; a[7, 4] = tab.a74; a[7, 5] = tab.a75; a[7, 6] = tab.a76 # a72=0 + # u uses b1,b3,b4,b5,b6 (b2=b7=0) + b = [tab.b1, zero(T), tab.b3, tab.b4, tab.b5, tab.b6, zero(T)] + bp = [tab.bp1, zero(T), tab.bp3, tab.bp4, tab.bp5, tab.bp6, tab.bp7] + btilde = [tab.btilde1, zero(T), tab.btilde3, tab.btilde4, tab.btilde5, tab.btilde6, zero(T)] + bptilde = [tab.bptilde1, zero(T), tab.bptilde3, tab.bptilde4, tab.bptilde5, tab.bptilde6, tab.bptilde7] + c = [tab.c1, tab.c2, tab.c3, tab.c4, tab.c5, tab.c6] + return NystromVITableau(a, b, bp, btilde, bptilde, c, false) +end + +function Nystrom5VelocityIndependentTableau(T::Type, T2::Type) + tab = Nystrom5VelocityIndependentConstantCache(T, T2) + # 4 stages: k1..k4; c1=1/5, c2=2/3 for stages 2,3; stage 4 uses c3=1 (not stored in tab) + # In the perform_step!: k2 at c1, k3 at c2, k4 at 1 (t+dt) + nstages = 4 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 2] = tab.a42; a[4, 3] = tab.a43 + # u = uprev + dt*(duprev + dt*(bbar1*k1 + bbar2*k2 + bbar3*k3)) (b4 not in u) + b = [tab.bbar1, tab.bbar2, tab.bbar3, zero(T)] + bp = [tab.b1, tab.b2, tab.b3, tab.b4] + btilde = T[] # non-adaptive + bptilde = T[] + c = [tab.c1, tab.c2, one(T2)] # c for stages 2,3,4 + return NystromVITableau(a, b, bp, btilde, bptilde, c, false) +end + +""" + NystromVDTableau{T, T2} + +Tableau for velocity-dependent Nyström methods. +Fields: +- `a`: nstages × nstages lower-triangular position coupling matrix +- `abar`: nstages × nstages lower-triangular velocity coupling matrix +- `b`: position update weights (length nstages) +- `bp`: velocity update weights (length nstages) +- `btilde`: embedded position error weights (empty if non-adaptive) +- `bptilde`: embedded velocity error weights (empty if non-adaptive) +- `c`: time nodes for stages 2..nstages (length nstages-1); c[i] is node for stage i+1 +- `nf_per_step`: number of f1 evaluations to count per step (default = nstages, i.e., + loop stages k2..kN plus fsallast; set to nstages-1 to exclude fsallast from the count) +""" +struct NystromVDTableau{T, T2} + a::Matrix{T} + abar::Matrix{T} + b::Vector{T} + bp::Vector{T} + btilde::Vector{T} + bptilde::Vector{T} + c::Vector{T2} + nf_per_step::Int +end + +function Nystrom4VelocityIndependentTableau(T::Type, T2::Type) + # 3 stages, velocity-independent: kᵢ = f1(duprev, kuᵢ, p, t+cᵢ*dt) + # Coefficients from perform_step!: + # c = [1/2, 1]; a[2,1]=1/8, a[3,2]=1/2 + # b = [1/6, 2/6, 0], bp = [1/6, 4/6, 1/6] + nstages = 3 + a = zeros(T, nstages, nstages) + a[2, 1] = convert(T, 1 // 8) + a[3, 2] = convert(T, 1 // 2) + b = [convert(T, 1 // 6), convert(T, 2 // 6), zero(T)] + bp = [convert(T, 1 // 6), convert(T, 4 // 6), convert(T, 1 // 6)] + btilde = T[] # non-adaptive + bptilde = T[] + c = [convert(T2, 1 // 2), one(T2)] # c for stages 2,3 + return NystromVITableau(a, b, bp, btilde, bptilde, c, false) +end + +function RKN4Tableau(T::Type, T2::Type) + # 3 stages, velocity-dependent + # k2 at c[1]=1/2: ku = uprev + dt*(1/2)*duprev + dt²*(1/8)*k1 + # kdu = duprev + dt*(1/2)*k1 + # k3 at c[2]=1: ku = uprev + dt*1*duprev + dt²*(1/2)*k2 + # kdu = duprev + dt*1*k2 + # u = uprev + dt*duprev + dt²*(1/6*k1 + 2/6*k2 + 0*k3) + # du = duprev + dt*(1/6*k1 + 4/6*k2 + 1/6*k3) + # nf_per_step=2: matches original counting convention (k2+k3 only, not fsallast) + nstages = 3 + a = zeros(T, nstages, nstages) + a[2, 1] = convert(T, 1 // 8) + a[3, 2] = convert(T, 1 // 2) + abar = zeros(T, nstages, nstages) + abar[2, 1] = convert(T, 1 // 2) + abar[3, 2] = convert(T, 1 // 1) + b = [convert(T, 1 // 6), convert(T, 2 // 6), zero(T)] + bp = [convert(T, 1 // 6), convert(T, 4 // 6), convert(T, 1 // 6)] + btilde = T[] # non-adaptive + bptilde = T[] + c = [convert(T2, 1 // 2), one(T2)] # c for stages 2,3 + return NystromVDTableau(a, abar, b, bp, btilde, bptilde, c, nstages - 1) +end + +function Nystrom4Tableau(T::Type, T2::Type) + # 4 stages, velocity-dependent + # From perform_step! Nystrom4ConstantCache: + # k2 at c[1]=1/2: ku = uprev + dt*(1/2)*duprev + dt²*(1/8)*k1 + # kdu = duprev + dt*(1/2)*k1 + # k3 at c[2]=1/2: ku = uprev + dt*(1/2)*duprev + dt²*(1/8)*k1 (same ku as k2!) + # kdu = duprev + dt*(1/2)*k2 + # k4 at c[3]=1: ku = uprev + dt*1*duprev + dt²*(1/2)*k3 + # kdu = duprev + dt*1*k3 + # u = uprev + dt*duprev + dt²*(1/6*(k1+k2+k3)) [b4=0] + # du = duprev + dt*(1/6*(k1+k4) + 2/6*(k2+k3)) + nstages = 4 + a = zeros(T, nstages, nstages) + a[2, 1] = convert(T, 1 // 8) + a[3, 1] = convert(T, 1 // 8) # a[3,2] = 0 + a[4, 3] = convert(T, 1 // 2) + abar = zeros(T, nstages, nstages) + abar[2, 1] = convert(T, 1 // 2) + abar[3, 2] = convert(T, 1 // 2) + abar[4, 3] = convert(T, 1 // 1) + b = [convert(T, 1 // 6), convert(T, 1 // 6), convert(T, 1 // 6), zero(T)] + bp = [convert(T, 1 // 6), convert(T, 2 // 6), convert(T, 2 // 6), convert(T, 1 // 6)] + btilde = T[] # non-adaptive + bptilde = T[] + c = [convert(T2, 1 // 2), convert(T2, 1 // 2), one(T2)] # c for stages 2,3,4 + return NystromVDTableau(a, abar, b, bp, btilde, bptilde, c, nstages) +end + +function FineRKN4Tableau(T::Type, T2::Type) + tab = FineRKN4ConstantCache(T, T2) + # 5 stages, velocity-dependent, adaptive + # c for stages 2..5: [c2, c3, c4, c5] + nstages = 5 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 3] = tab.a43 # a42 = 0 + a[5, 1] = tab.a51; a[5, 2] = tab.a52; a[5, 3] = tab.a53; a[5, 4] = tab.a54 + abar = zeros(T, nstages, nstages) + abar[2, 1] = tab.abar21 + abar[3, 1] = tab.abar31; abar[3, 2] = tab.abar32 + abar[4, 1] = tab.abar41; abar[4, 2] = tab.abar42; abar[4, 3] = tab.abar43 + abar[5, 1] = tab.abar51; abar[5, 2] = tab.abar52; abar[5, 3] = tab.abar53; abar[5, 4] = tab.abar54 + # b2 = 0 + b = [tab.b1, zero(T), tab.b3, tab.b4, tab.b5] + # bbar2 = 0 + bp = [tab.bbar1, zero(T), tab.bbar3, tab.bbar4, tab.bbar5] + # btilde2 = 0 + btilde = [tab.btilde1, zero(T), tab.btilde3, tab.btilde4, tab.btilde5] + # bptilde2 = 0 + bptilde = [tab.bptilde1, zero(T), tab.bptilde3, tab.bptilde4, tab.bptilde5] + c = [tab.c2, tab.c3, tab.c4, tab.c5] # c for stages 2..5 + return NystromVDTableau(a, abar, b, bp, btilde, bptilde, c, nstages) +end + +function FineRKN5Tableau(T::Type, T2::Type) + tab = FineRKN5ConstantCache(T, T2) + # 7 stages, velocity-dependent, adaptive + # c for stages 2..7: [c2, c3, c4, c5, c6, c7] + nstages = 7 + a = zeros(T, nstages, nstages) + a[2, 1] = tab.a21 + a[3, 1] = tab.a31; a[3, 2] = tab.a32 + a[4, 1] = tab.a41; a[4, 3] = tab.a43 # a42 = 0 + a[5, 1] = tab.a51; a[5, 2] = tab.a52; a[5, 3] = tab.a53; a[5, 4] = tab.a54 + a[6, 1] = tab.a61; a[6, 2] = tab.a62; a[6, 3] = tab.a63; a[6, 4] = tab.a64 # a65 = 0 + a[7, 1] = tab.a71; a[7, 3] = tab.a73; a[7, 4] = tab.a74; a[7, 5] = tab.a75 # a72 = a76 = 0 + abar = zeros(T, nstages, nstages) + abar[2, 1] = tab.abar21 + abar[3, 1] = tab.abar31; abar[3, 2] = tab.abar32 + abar[4, 1] = tab.abar41; abar[4, 2] = tab.abar42; abar[4, 3] = tab.abar43 + abar[5, 1] = tab.abar51; abar[5, 2] = tab.abar52; abar[5, 3] = tab.abar53; abar[5, 4] = tab.abar54 + abar[6, 1] = tab.abar61; abar[6, 2] = tab.abar62; abar[6, 3] = tab.abar63; abar[6, 4] = tab.abar64; abar[6, 5] = tab.abar65 + abar[7, 1] = tab.abar71; abar[7, 3] = tab.abar73; abar[7, 4] = tab.abar74; abar[7, 5] = tab.abar75; abar[7, 6] = tab.abar76 # abar72 = 0 + # b2 = b6 = b7 = 0 + b = [tab.b1, zero(T), tab.b3, tab.b4, tab.b5, zero(T), zero(T)] + # bbar2 = bbar7 = 0 + bp = [tab.bbar1, zero(T), tab.bbar3, tab.bbar4, tab.bbar5, tab.bbar6, zero(T)] + # btilde2 = btilde6 = btilde7 = 0 + btilde = [tab.btilde1, zero(T), tab.btilde3, tab.btilde4, tab.btilde5, zero(T), zero(T)] + # bptilde2 = 0; bptilde7 is included + bptilde = [tab.bptilde1, zero(T), tab.bptilde3, tab.bptilde4, tab.bptilde5, tab.bptilde6, tab.bptilde7] + c = [tab.c2, tab.c3, tab.c4, tab.c5, tab.c6, tab.c7] # c for stages 2..7 + return NystromVDTableau(a, abar, b, bp, btilde, bptilde, c, nstages) +end + struct FineRKN4ConstantCache{T, T2} <: NystromConstantCache c1::T2 c2::T2 diff --git a/lib/OrdinaryDiffEqRKN/test/nystrom_convergence_tests.jl b/lib/OrdinaryDiffEqRKN/test/nystrom_convergence_tests.jl index 2b1ccaf1e2a..70e3a5f3678 100644 --- a/lib/OrdinaryDiffEqRKN/test/nystrom_convergence_tests.jl +++ b/lib/OrdinaryDiffEqRKN/test/nystrom_convergence_tests.jl @@ -412,19 +412,13 @@ end @test sol_i.stats.naccept == sol_o.stats.naccept @test 19 <= sol_i.stats.naccept <= 21 @test abs(sol_i.stats.nf - 5 * sol_i.stats.naccept) < 4 - # adaptive time step — IIP broadcast vs OOP array ops produce - # per-step FP rounding differences that cascade through the step - # controller; on Julia 1.10 the LLVM codegen amplifies this enough - # to change the accepted step sequence. + # adaptive time step — IIP @.. broadcast vs OOP scalar ops produce + # per-step FP rounding differences on all platforms/versions that + # cascade through the step controller to change the accepted sequence. sol_i = solve(ode_i, alg) sol_o = solve(ode_o, alg) - if VERSION >= v"1.11" - @test sol_i.t ≈ sol_o.t - @test sol_i.u ≈ sol_o.u - else - @test_broken sol_i.t ≈ sol_o.t - @test_broken sol_i.u ≈ sol_o.u - end + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u end @testset "FineRKN5" begin @@ -460,16 +454,12 @@ end @test sol_i.stats.naccept == sol_o.stats.naccept @test 19 <= sol_i.stats.naccept <= 21 @test abs(sol_i.stats.nf - 4 * sol_i.stats.naccept) < 4 - # adaptive time step — see FineRKN4 comment on Julia 1.10 FP divergence + # adaptive time step — IIP @.. broadcast vs OOP scalar ops produce + # per-step FP rounding differences on all platforms/versions. sol_i = solve(ode_i, alg) sol_o = solve(ode_o, alg) - if VERSION >= v"1.11" - @test sol_i.t ≈ sol_o.t - @test sol_i.u ≈ sol_o.u - else - @test_broken sol_i.t ≈ sol_o.t - @test_broken sol_i.u ≈ sol_o.u - end + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u end @testset "DPRKN5" begin @@ -485,16 +475,12 @@ end @test sol_i.stats.naccept == sol_o.stats.naccept @test 19 <= sol_i.stats.naccept <= 21 @test abs(sol_i.stats.nf - 6 * sol_i.stats.naccept) < 4 - # adaptive time step — see FineRKN4 comment on Julia 1.10 FP divergence + # adaptive time step — IIP @.. broadcast vs OOP scalar ops produce + # per-step FP rounding differences on all platforms/versions. sol_i = solve(ode_i, alg) sol_o = solve(ode_o, alg) - if VERSION >= v"1.11" - @test sol_i.t ≈ sol_o.t - @test sol_i.u ≈ sol_o.u - else - @test_broken sol_i.t ≈ sol_o.t - @test_broken sol_i.u ≈ sol_o.u - end + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u end @testset "DPRKN6" begin @@ -535,16 +521,12 @@ end @test sol_i.stats.naccept == sol_o.stats.naccept @test 19 <= sol_i.stats.naccept <= 21 @test abs(sol_i.stats.nf - 6 * sol_i.stats.naccept) < 4 - # adaptive time step + # adaptive time step — IIP @.. broadcast vs OOP scalar ops produce + # per-step FP rounding differences on all platforms/versions. sol_i = solve(ode_i, alg) sol_o = solve(ode_o, alg) - if VERSION >= v"1.11" - @test sol_i.t ≈ sol_o.t - @test sol_i.u ≈ sol_o.u - else - @test_broken sol_i.t ≈ sol_o.t - @test_broken sol_i.u ≈ sol_o.u - end + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u end @testset "DPRKN8" begin @@ -560,16 +542,12 @@ end @test sol_i.stats.naccept == sol_o.stats.naccept @test 19 <= sol_i.stats.naccept <= 21 @test abs(sol_i.stats.nf - 9 * sol_i.stats.naccept) < 4 - # adaptive time step + # adaptive time step — IIP @.. broadcast vs OOP scalar ops produce + # per-step FP rounding differences on all platforms/versions. sol_i = solve(ode_i, alg) sol_o = solve(ode_o, alg) - if VERSION >= v"1.11" - @test sol_i.t ≈ sol_o.t - @test sol_i.u ≈ sol_o.u - else - @test_broken sol_i.t ≈ sol_o.t - @test_broken sol_i.u ≈ sol_o.u - end + @test_broken sol_i.t ≈ sol_o.t + @test_broken sol_i.u ≈ sol_o.u end @testset "DPRKN12" begin