diff --git a/lib/OrdinaryDiffEqFIRK/Project.toml b/lib/OrdinaryDiffEqFIRK/Project.toml index 662531b6de..bca70b709d 100644 --- a/lib/OrdinaryDiffEqFIRK/Project.toml +++ b/lib/OrdinaryDiffEqFIRK/Project.toml @@ -1,71 +1,69 @@ name = "OrdinaryDiffEqFIRK" uuid = "5960d6e9-dd7a-4743-88e7-cf307b64f125" -authors = ["ParamThakkar123 "] version = "1.26.0" +authors = ["ParamThakkar123 "] [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" -MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b" +OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b" OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8" -DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Runic = "62bfec6d-59d7-401d-8490-b29ee721c001" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" -[extras] -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" -ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" -GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" - [sources] DiffEqBase = {path = "../DiffEqBase"} DiffEqDevTools = {path = "../DiffEqDevTools"} +OrdinaryDiffEqCore = {path = "../OrdinaryDiffEqCore"} +OrdinaryDiffEqDifferentiation = {path = "../OrdinaryDiffEqDifferentiation"} +OrdinaryDiffEqNonlinearSolve = {path = "../OrdinaryDiffEqNonlinearSolve"} [compat] -Pkg = "1" -Test = "<0.0.1, 1" -FastBroadcast = "1.3" -Random = "<0.0.1, 1" +ADTypes = "1.16" +DiffEqBase = "7" DiffEqDevTools = "3" +FastBroadcast = "1.3" FastGaussQuadrature = "1.0.2" -MuladdMacro = "0.2" -LinearSolve = "3.46" -LinearAlgebra = "1.10" -OrdinaryDiffEqDifferentiation = "2" -SciMLBase = "3" -OrdinaryDiffEqCore = "4" -GenericSchur = "0.5" -julia = "1.10" -ADTypes = "1.16" -RecursiveArrayTools = "4" FastPower = "1.1" +GenericSchur = "0.5" +LinearAlgebra = "1.10" +LinearSolve = "3.46" +MuladdMacro = "0.2" ODEProblemLibrary = "1" +OrdinaryDiffEqCore = "4" +OrdinaryDiffEqDifferentiation = "2" OrdinaryDiffEqNonlinearSolve = "1.16.0" -DiffEqBase = "7" +Pkg = "1" +Polyester = "0.7" +Random = "<0.0.1, 1" +RecursiveArrayTools = "4" Reexport = "1.2" +Runic = "1.7.0" SafeTestsets = "0.1.0" +SciMLBase = "3" SciMLOperators = "1.4" +Test = "<0.0.1, 1" +julia = "1.10" + +[extras] +DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" +GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e" +ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["DiffEqDevTools", "GenericSchur", "Random", "SafeTestsets", "Test", "ODEProblemLibrary", "Pkg"] - -[sources.OrdinaryDiffEqDifferentiation] -path = "../OrdinaryDiffEqDifferentiation" - -[sources.OrdinaryDiffEqNonlinearSolve] -path = "../OrdinaryDiffEqNonlinearSolve" - -[sources.OrdinaryDiffEqCore] -path = "../OrdinaryDiffEqCore" diff --git a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl index 1946d90f3e..1181d3777f 100644 --- a/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl +++ b/lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl @@ -2,7 +2,7 @@ module OrdinaryDiffEqFIRK import OrdinaryDiffEqCore: alg_order, calculate_residuals!, initialize!, perform_step!, unwrap_alg, - calculate_residuals, default_controller, PredictiveController, + calculate_residuals, default_controller, PredictiveController, PIController, OrdinaryDiffEqAlgorithm, OrdinaryDiffEqNewtonAdaptiveAlgorithm, OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache, OrdinaryDiffEqAdaptiveAlgorithm, CompiledFloats, uses_uprev, @@ -14,15 +14,16 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, qmax_default, alg_adaptive_order, stepsize_controller!, step_accept_controller!, step_reject_controller!, - PredictiveController, alg_can_repeat_jac, NewtonAlgorithm, + alg_can_repeat_jac, NewtonAlgorithm, fac_default_gamma, get_current_adaptive_order, get_fsalfirstlast, - isfirk, generic_solver_docstring, _ad_chunksize_int, _ad_fdtype, _fixup_ad, + get_current_alg_order, + isfirk, generic_solver_docstring, + _ad_chunksize_int, _ad_fdtype, _fixup_ad, LinearAliasSpecifier -using MuladdMacro, DiffEqBase, RecursiveArrayTools -isfirk, generic_solver_docstring +using MuladdMacro, DiffEqBase, RecursiveArrayTools, Polyester using SciMLOperators: AbstractSciMLOperator -using LinearAlgebra: I, UniformScaling, mul!, lu +using LinearAlgebra: I, UniformScaling, mul!, lu, dot, eigvals import LinearSolve import FastBroadcast: @.. import OrdinaryDiffEqCore @@ -62,6 +63,6 @@ include("firk_interpolants.jl") include("firk_addsteps.jl") include("integrator_interface.jl") -export RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau +export RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau, GaussLegendre end diff --git a/lib/OrdinaryDiffEqFIRK/src/alg_utils.jl b/lib/OrdinaryDiffEqFIRK/src/alg_utils.jl index 622cca5e43..fcb65a6a11 100644 --- a/lib/OrdinaryDiffEqFIRK/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqFIRK/src/alg_utils.jl @@ -25,3 +25,19 @@ get_current_adaptive_order(alg::AdaptiveRadau, cache) = cache.num_stages function has_stiff_interpolation(::Union{RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau}) return true end + +qmax_default(alg::GaussLegendre) = 8 + +alg_order(alg::GaussLegendre) = 2 * alg.num_stages + +default_controller(QT, alg::GaussLegendre) = PIController(QT, alg) + +isfirk(alg::GaussLegendre) = true + +# Richardson step-doubling controller +isadaptive(alg::GaussLegendre) = alg.num_stages >= 2 +alg_adaptive_order(alg::GaussLegendre) = 2 * alg.num_stages +has_stiff_interpolation(::GaussLegendre) = true + +get_current_alg_order(alg::GaussLegendre, cache) = 2 * alg.num_stages +get_current_adaptive_order(alg::GaussLegendre, cache) = 2 * alg.num_stages diff --git a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl index 1f64bd41b1..ff7e8c40e4 100644 --- a/lib/OrdinaryDiffEqFIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqFIRK/src/algorithms.jl @@ -220,3 +220,71 @@ function AdaptiveRadau(; ) end + + +gauss_legendre_docstring = """@article{butcher2008numerical, +title={Numerical Methods for Ordinary Differential Equations}, +author={Butcher, John Charles}, +year={2008}, +publisher={Wiley}}""" + +@doc differentiation_rk_docstring( + "A symplectic fully implicit Runge-Kutta method based on Gauss-Legendre quadrature. +With s stages, the method has order 2s. Symplectic and A-stable, making it suitable +for Hamiltonian systems and problems requiring long-time geometric integration. + +!!! warning \"Experimental\" + `GaussLegendre` is experimental. Adaptive stepping currently uses Richardson + step-doubling (roughly 3× the work per accepted step) and requires `num_stages ≥ 2`; + Details may change as the implementation is refined.", + "GaussLegendre", + "Fully-Implicit Runge-Kutta Method."; + references = gauss_legendre_docstring, + extra_keyword_description = extra_keyword_description, + extra_keyword_default = extra_keyword_default +) +struct GaussLegendre{AD, F, P, Tol, C1, C2, StepLimiter, CJ} <: + OrdinaryDiffEqNewtonAdaptiveAlgorithm + linsolve::F + precs::P + smooth_est::Bool + extrapolant::Symbol + κ::Tol + maxiters::Int + fast_convergence_cutoff::C1 + new_W_γdt_cutoff::C2 + controller::Symbol + step_limiter!::StepLimiter + num_stages::Int + autodiff::AD + concrete_jac::CJ +end + +function GaussLegendre(; + num_stages = 2, + autodiff = AutoForwardDiff(), + concrete_jac = nothing, + linsolve = nothing, precs = nothing, + extrapolant = :dense, fast_convergence_cutoff = 1 // 5, + new_W_γdt_cutoff = 1 // 5, + controller = :PI, κ = nothing, maxiters = 10, smooth_est = true, + step_limiter! = trivial_limiter! + ) + autodiff = _fixup_ad(autodiff) + + return GaussLegendre( + linsolve, + precs, + smooth_est, + extrapolant, + κ, + maxiters, + fast_convergence_cutoff, + new_W_γdt_cutoff, + controller, + step_limiter!, + num_stages, + autodiff, + _unwrap_val(concrete_jac) + ) +end diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl b/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl index cf3301a7c1..8d7a134a62 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_caches.jl @@ -723,3 +723,142 @@ function alg_cache( Convergence, alg.step_limiter!, num_stages, 1, 0.0, index ) end + +mutable struct GaussLegendreConstantCache{F, Tab, Tol, Dt, U, JType} <: + OrdinaryDiffEqConstantCache + uf::F + tab::Tab + κ::Tol + ηold::Tol + iter::Int + cont::Vector{U} + dtprev::Dt + W_γdt::Dt + status::NLStatus + J::JType + num_stages::Int +end + +function alg_cache( + alg::GaussLegendre, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, + ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{false}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + uf = UDerivativeWrapper(f, t, p) + uToltype = constvalue(uBottomEltypeNoUnits) + num_stages = alg.num_stages + tab = GaussLegendreTableau(uToltype, constvalue(tTypeNoUnits), num_stages) + κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100) + J = false .* _vec(rate_prototype) .* _vec(rate_prototype)' + cont = Vector{typeof(u)}(undef, num_stages) + for i in 1:num_stages + cont[i] = zero(u) + end + return GaussLegendreConstantCache( + uf, tab, κ, one(uToltype), 10000, cont, dt, dt, + Convergence, J, num_stages + ) +end + + +mutable struct GaussLegendreCache{ + uType, uNoUnitsType, rateType, JType, WType, Buff, + UF, JC, F1, Tab, Tol, Dt, rTol, aTol, StepLimiter, + } <: FIRKMutableCache + u::uType + uprev::uType + z::Vector{uType} + z_last::Vector{uType} + w::Vector{uType} + dw::Vector{uType} + ubuff::Buff + u_full::uType + u_half::uType + du1::rateType + fsalfirst::rateType + k::rateType + ks::Vector{rateType} + fw::Vector{rateType} + J::JType + W::WType + uf::UF + tab::Tab + κ::Tol + ηold::Tol + iter::Int + tmp::uType + atmp::uNoUnitsType + jac_config::JC + linsolve::F1 + rtol::rTol + atol::aTol + dtprev::Dt + W_γdt::Dt + status::NLStatus + step_limiter!::StepLimiter + num_stages::Int +end + +function alg_cache( + alg::GaussLegendre, u, rate_prototype, ::Type{uEltypeNoUnits}, + ::Type{uBottomEltypeNoUnits}, + ::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, + ::Val{true}, verbose + ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} + uf = UJacobianWrapper(f, t, p) + uToltype = constvalue(uBottomEltypeNoUnits) + num_stages = alg.num_stages + tab = GaussLegendreTableau(uToltype, constvalue(tTypeNoUnits), num_stages) + κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100) + + z = [zero(u) for _ in 1:num_stages] + z_last = [zero(u) for _ in 1:num_stages] + w = [zero(u) for _ in 1:num_stages] + dw = [zero(u) for _ in 1:num_stages] + n = length(_vec(u)) + ubuff = similar(_vec(u), num_stages * n) + recursivefill!(ubuff, false) + u_full = zero(u) + u_half = zero(u) + + fsalfirst = zero(rate_prototype) + k = zero(rate_prototype) + ks = [zero(rate_prototype) for _ in 1:num_stages] + fw = [zero(rate_prototype) for _ in 1:num_stages] + du1 = zero(rate_prototype) + + tmp = zero(u) + atmp = similar(u, uEltypeNoUnits) + recursivefill!(atmp, false) + jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw[1]) + + J, _ = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true)) + if J isa AbstractSciMLOperator + error("Non-concrete Jacobian not yet supported by GaussLegendre.") + end + + W = similar(J, num_stages * n, num_stages * n) + recursivefill!(W, false) + linu0 = similar(_vec(ubuff)) + recursivefill!(linu0, false) + linprob = LinearProblem(W, _vec(ubuff); u0 = linu0) + linsolve = init( + linprob, alg.linsolve, + alias = LinearAliasSpecifier(alias_A = true, alias_b = true), + assumptions = LinearSolve.OperatorAssumptions(true), + verbose = verbose.linear_verbosity + ) + + rtol = reltol isa Number ? reltol : zero(reltol) + atol = reltol isa Number ? reltol : zero(reltol) + + return GaussLegendreCache( + u, uprev, z, z_last, w, dw, ubuff, u_full, u_half, + du1, fsalfirst, k, ks, fw, + J, W, + uf, tab, κ, one(uToltype), 10000, + tmp, atmp, jac_config, linsolve, rtol, atol, dt, dt, + Convergence, alg.step_limiter!, num_stages + ) +end diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl index 74dc8eec2f..85139b7862 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl @@ -2241,3 +2241,429 @@ end integrator.stats.nf += 1 return end + + +function initialize!(integrator, cache::GaussLegendreConstantCache) + integrator.kshortsize = integrator.alg.num_stages + 2 + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) + integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + integrator.fsallast = zero(integrator.fsalfirst) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + for i in 3:(integrator.alg.num_stages + 2) + integrator.k[i] = zero(integrator.fsalfirst) + end + + # adaptive Richardson controller requires num_stages >= 2 + if integrator.opts.adaptive && integrator.alg.num_stages < 2 + throw( + ArgumentError( + "GaussLegendre with num_stages = $(integrator.alg.num_stages) " * + "does not support adaptive stepping (Richardson controller " * + "requires num_stages ≥ 2). Use num_stages ≥ 2 or pass adaptive = false." + ) + ) + end + return nothing +end + +function initialize!(integrator, cache::GaussLegendreCache) + integrator.kshortsize = integrator.alg.num_stages + 2 + resize!(integrator.k, integrator.kshortsize) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + for i in 3:(integrator.alg.num_stages + 2) + integrator.k[i] = similar(integrator.fsallast) + end + integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + if integrator.opts.adaptive + if cache.num_stages < 2 + throw( + ArgumentError( + "GaussLegendre with num_stages = $(cache.num_stages) " * + "does not support adaptive stepping (Richardson controller " * + "requires num_stages ≥ 2). Use num_stages ≥ 2 or pass adaptive = false." + ) + ) + end + (; abstol, reltol) = integrator.opts + if reltol isa Number + cache.rtol = reltol^((cache.num_stages + 1) / (2 * cache.num_stages)) / 10 + cache.atol = cache.rtol * (abstol / reltol) + else + @.. cache.rtol = reltol^((cache.num_stages + 1) / (2 * cache.num_stages)) / 10 + @.. cache.atol = cache.rtol * (abstol / reltol) + end + end + return nothing +end + +# Newton iteration helper for a single GaussLegendre sub-step + +@muladd function _gausslegendre_substep_constant( + integrator, cache::GaussLegendreConstantCache, alg, + uprev_local, t_local, dt_local, J_local, + atol, rtol + ) + (; tab, κ, num_stages) = cache + (; A, b, c) = tab + (; internalnorm) = integrator.opts + (; maxiters) = alg + f, p = integrator.f, integrator.p + + n = length(uprev_local) + W_full = kron(I(num_stages), I(n)) - dt_local * kron(A, J_local) + LU_full = lu(W_full) + integrator.stats.nw += 1 + + z = [map(zero, uprev_local) for _ in 1:num_stages] + + ndw = one(eltype(uprev_local)) + η = max(cache.ηold, eps(eltype(integrator.opts.reltol)))^(0.8) + success = false + iter = 0 + local ff + + while iter < maxiters + iter += 1 + integrator.stats.nnonliniter += 1 + + ff = [f(uprev_local + z[i], p, t_local + c[i] * dt_local) for i in 1:num_stages] + OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages) + + r = zeros(eltype(uprev_local), num_stages * n) + for i in 1:num_stages + acc = zero(uprev_local) + for j in 1:num_stages + acc = @.. acc + A[i, j] * ff[j] + end + + ri = ((i - 1) * n + 1):(i * n) + @views r[ri] .= @.. z[i] - dt_local * acc + end + + dw_flat = LU_full \ r + integrator.stats.nsolve += 1 + + dw = if n == 1 + [dw_flat[i] for i in 1:num_stages] + else + [dw_flat[((i - 1) * n + 1):(i * n)] for i in 1:num_stages] + end + + ndwprev = ndw + ndw = sum( + internalnorm( + calculate_residuals(dw[i], uprev_local, uprev_local, atol, rtol, internalnorm, t_local), + t_local + ) + for i in 1:num_stages + ) + + if iter > 1 + θ = ndw / ndwprev + (diverge = θ > 1) && (cache.status = Divergence) + (veryslowconvergence = ndw * θ^(maxiters - iter) > κ * (1 - θ)) && + (cache.status = VerySlowConvergence) + if diverge || veryslowconvergence + break + end + η = θ / (1 - θ) + end + + for i in 1:num_stages + z[i] = @.. z[i] - dw[i] + end + + if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter)) + cache.status = η < alg.fast_convergence_cutoff ? FastConvergence : Convergence + success = true + break + end + end + + cache.ηold = η + cache.iter = iter + + if !success + return (uprev_local, z, false) + end + + ff_final = [f(uprev_local + z[i], p, t_local + c[i] * dt_local) for i in 1:num_stages] + OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages) + u_out = copy(uprev_local) + for i in 1:num_stages + u_out = @.. u_out + dt_local * b[i] * ff_final[i] + end + + return (u_out, z, true) +end + +@muladd function perform_step!( + integrator, cache::GaussLegendreConstantCache, + repeat_step = false + ) + (; t, dt, uprev, f, p) = integrator + (; num_stages) = cache + (; internalnorm, abstol, reltol, adaptive) = integrator.opts + alg = unwrap_alg(integrator, true) + + rtol = @.. reltol^((num_stages + 1) / (num_stages * 2)) / 10 + atol = @.. rtol * (abstol / reltol) + + J = calc_J(integrator, cache) + + if adaptive && num_stages >= 2 + # Richardson step-doubling: one full step at dt, two successive half-steps. + u_H, z_full, ok1 = _gausslegendre_substep_constant( + integrator, cache, alg, uprev, t, dt, J, atol, rtol + ) + if !ok1 + integrator.force_stepfail = true + integrator.stats.nnonlinconvfail += 1 + return + end + + half_dt = dt / 2 + u_h1, _, ok2 = _gausslegendre_substep_constant( + integrator, cache, alg, uprev, t, half_dt, J, atol, rtol + ) + if !ok2 + integrator.force_stepfail = true + integrator.stats.nnonlinconvfail += 1 + return + end + + u_h, _, ok3 = _gausslegendre_substep_constant( + integrator, cache, alg, u_h1, t + half_dt, half_dt, J, atol, rtol + ) + if !ok3 + integrator.force_stepfail = true + integrator.stats.nnonlinconvfail += 1 + return + end + + p_order = 2 * num_stages + utilde = @.. (u_h - u_H) / (2^p_order - 1) + OrdinaryDiffEqCore.set_EEst!( + integrator, + internalnorm(calculate_residuals(utilde, uprev, u_h, atol, rtol, internalnorm, t), t), + ) + + u = u_h + z_for_fsal = z_full + else + u_out, z_out, ok = _gausslegendre_substep_constant( + integrator, cache, alg, uprev, t, dt, J, atol, rtol + ) + if !ok + integrator.force_stepfail = true + integrator.stats.nnonlinconvfail += 1 + return + end + u = u_out + z_for_fsal = z_out + end + + if OrdinaryDiffEqCore.get_EEst(integrator) <= oneunit(OrdinaryDiffEqCore.get_EEst(integrator)) + cache.dtprev = dt + if alg.extrapolant != :constant + for i in 1:num_stages + integrator.k[i + 2] = z_for_fsal[i] + end + end + end + + integrator.fsallast = f(u, p, t + dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + integrator.k[1] = integrator.fsalfirst + integrator.k[2] = integrator.fsallast + integrator.u = u + return +end + +# Mutable-cache Newton iteration helper + +@muladd function _gausslegendre_substep!( + u_dest, uprev_local, t_local, dt_local, J_local, + cache::GaussLegendreCache, integrator, alg + ) + (; + tab, κ, z, dw, ks, W, ubuff, tmp, atmp, linsolve, + rtol, atol, num_stages, + ) = cache + (; A, b, c) = tab + (; internalnorm) = integrator.opts + (; maxiters) = alg + f, p = integrator.f, integrator.p + n = length(uprev_local) + + W .= kron(I(num_stages), I(n)) .- dt_local .* kron(A, J_local) + integrator.stats.nw += 1 + + for i in 1:num_stages + @.. z[i] = zero(eltype(uprev_local)) + end + + ndw = one(eltype(uprev_local)) + η = max(cache.ηold, eps(eltype(integrator.opts.reltol)))^(0.8) + success = false + iter = 0 + + while iter < maxiters + iter += 1 + integrator.stats.nnonliniter += 1 + + for i in 1:num_stages + @.. tmp = uprev_local + z[i] + f(ks[i], tmp, p, t_local + c[i] * dt_local) + end + OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages) + + for i in 1:num_stages + acc = zero(uprev_local) + for j in 1:num_stages + @.. acc = acc + A[i, j] * ks[j] + end + ubuff[((i - 1) * n + 1):(i * n)] .= _vec(z[i]) .- dt_local .* _vec(acc) + end + + needfactor = iter == 1 + linsolve.b = ubuff + linsolve.u = ubuff + if needfactor + LinearSolve.reinit!(linsolve; A = W) + end + linres = LinearSolve.solve!(linsolve; reltol = integrator.opts.reltol) + cache.linsolve = linres.cache + integrator.stats.nsolve += 1 + + for i in 1:num_stages + @views copyto!(_vec(dw[i]), ubuff[((i - 1) * n + 1):(i * n)]) + end + + ndwprev = ndw + ndw = sum( + begin + calculate_residuals!(atmp, dw[i], uprev_local, uprev_local, atol, rtol, internalnorm, t_local) + internalnorm(atmp, t_local) + end for i in 1:num_stages + ) + + if iter > 1 + θ = ndw / ndwprev + (diverge = θ > 1) && (cache.status = Divergence) + (veryslowconvergence = ndw * θ^(maxiters - iter) > κ * (1 - θ)) && + (cache.status = VerySlowConvergence) + if diverge || veryslowconvergence + break + end + η = θ / (1 - θ) + end + + for i in 1:num_stages + @.. z[i] = z[i] - dw[i] + end + + if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter)) + cache.status = η < alg.fast_convergence_cutoff ? FastConvergence : Convergence + success = true + break + end + end + + cache.ηold = η + cache.iter = iter + + if !success + return false + end + + @.. u_dest = uprev_local + for i in 1:num_stages + @.. tmp = uprev_local + z[i] + f(ks[i], tmp, p, t_local + c[i] * dt_local) + @.. u_dest = u_dest + dt_local * b[i] * ks[i] + end + OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages) + + return true +end + +@muladd function perform_step!(integrator, cache::GaussLegendreCache, repeat_step = false) + (; t, dt, uprev, u, f, p, fsallast) = integrator + (; + atmp, J, z, z_last, u_full, u_half, rtol, atol, + step_limiter!, num_stages, + ) = cache + (; internalnorm, adaptive) = integrator.opts + alg = unwrap_alg(integrator, true) + + new_jac = do_newJ(integrator, alg, cache, repeat_step) + new_jac && (calc_J!(J, integrator, cache); cache.W_γdt = dt) + + if adaptive && num_stages >= 2 + # fll step at dt + ok1 = _gausslegendre_substep!(u_full, uprev, t, dt, J, cache, integrator, alg) + if !ok1 + integrator.force_stepfail = true + integrator.stats.nnonlinconvfail += 1 + return + end + for i in 1:num_stages + @.. z_last[i] = z[i] + end + + half_dt = dt / 2 + + # first half step at dt/2 + ok2 = _gausslegendre_substep!(u_half, uprev, t, half_dt, J, cache, integrator, alg) + if !ok2 + integrator.force_stepfail = true + integrator.stats.nnonlinconvfail += 1 + return + end + + # second half step at dt/2 + ok3 = _gausslegendre_substep!(u, u_half, t + half_dt, half_dt, J, cache, integrator, alg) + if !ok3 + integrator.force_stepfail = true + integrator.stats.nnonlinconvfail += 1 + return + end + + step_limiter!(u, integrator, p, t + dt) + + p_order = 2 * num_stages + denom = 2^p_order - 1 + @.. u_full = (u - u_full) / denom + calculate_residuals!(atmp, u_full, uprev, u, atol, rtol, internalnorm, t) + OrdinaryDiffEqCore.set_EEst!(integrator, internalnorm(atmp, t)) + else + ok = _gausslegendre_substep!(u, uprev, t, dt, J, cache, integrator, alg) + if !ok + integrator.force_stepfail = true + integrator.stats.nnonlinconvfail += 1 + return + end + for i in 1:num_stages + @.. z_last[i] = z[i] + end + step_limiter!(u, integrator, p, t + dt) + end + + if OrdinaryDiffEqCore.get_EEst(integrator) <= oneunit(OrdinaryDiffEqCore.get_EEst(integrator)) + cache.dtprev = dt + if alg.extrapolant != :constant + for i in 1:num_stages + integrator.k[i + 2] .= z_last[i] + end + end + end + + f(fsallast, u, p, t + dt) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) + return +end diff --git a/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl b/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl index fe7d05f8b9..93cc9fc02b 100644 --- a/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl +++ b/lib/OrdinaryDiffEqFIRK/src/firk_tableaus.jl @@ -349,3 +349,82 @@ const RadauIIATableauCache = Dict{ (Float64, Float64, 5) => generateRadauTableau(Float64, Float64, 5), (Float64, Float64, 7) => generateRadauTableau(Float64, Float64, 7) ) + + +struct GaussLegendreTableau{T1, T2} + A::Matrix{T1} + b::Vector{T1} + c::Vector{T2} + e::Vector{T1} +end + +import FastGaussQuadrature: gausslegendre + +function GaussLegendreTableau(T1, T2, num_stages::Int) + tab = get(GaussLegendreTableauCache, (T1, T2, num_stages)) do + tab = generateGaussLegendreTableau(T1, T2, num_stages) + GaussLegendreTableauCache[(T1, T2, num_stages)] = tab + tab + end + return GaussLegendreTableau{T1, T2}(tab.A, tab.b, tab.c, tab.e) +end + +# TODO: embedded error coefficients use s-1 GL rule interpolated to s nodes +# a proper derivation following Hairer Vol I would improve adaptive performance + +# TODO: add the symplectic integrator stage decoupling from Antonan et al to increase efficiency + +function generateGaussLegendreTableau(T1, T2, num_stages::Int) + x, w = gausslegendre(num_stages) + c = T2.((x .+ 1) ./ 2) + b = T1.(w ./ 2) + + A = Matrix{T1}(undef, num_stages, num_stages) + for i in 1:num_stages + for j in 1:num_stages + nodes, weights = gausslegendre(2 * num_stages) + nodes_ij = T1.((nodes .+ 1) ./ 2 .* c[i]) + weights_ij = T1.(weights ./ 2 .* c[i]) + Lj = ones(T1, length(nodes_ij)) + for k in 1:num_stages + if k != j + Lj .*= (nodes_ij .- c[k]) ./ (c[j] - c[k]) + end + end + A[i, j] = dot(weights_ij, Lj) + end + end + + + # error estimate coefficients: embed using s-1 GL weights via interpolation + if num_stages > 1 + x_low, w_low = gausslegendre(num_stages - 1) + c_low = T1.((x_low .+ 1) ./ 2) + b_low = zeros(T1, num_stages) + for i in 1:num_stages + for j in 1:(num_stages - 1) + Lj = one(T1) + for k in 1:(num_stages - 1) + if k != j + Lj *= (c[i] - c_low[k]) / (c_low[j] - c_low[k]) + end + end + b_low[i] += T1(w_low[j] / 2) * Lj + end + end + e = b .- b_low + else + e = b + end + + return GaussLegendreTableau{T1, T2}(A, b, c, e) +end + +const GaussLegendreTableauCache = Dict{ + Tuple{Type, Type, Int}, GaussLegendreTableau{T1, T2} where {T1, T2}, +}( + (Float64, Float64, 2) => generateGaussLegendreTableau(Float64, Float64, 2), + (Float64, Float64, 3) => generateGaussLegendreTableau(Float64, Float64, 3), + (Float64, Float64, 4) => generateGaussLegendreTableau(Float64, Float64, 4), + (Float64, Float64, 5) => generateGaussLegendreTableau(Float64, Float64, 5), +) diff --git a/lib/OrdinaryDiffEqFIRK/src/integrator_interface.jl b/lib/OrdinaryDiffEqFIRK/src/integrator_interface.jl index ca2ecde909..ab38b8bc4b 100644 --- a/lib/OrdinaryDiffEqFIRK/src/integrator_interface.jl +++ b/lib/OrdinaryDiffEqFIRK/src/integrator_interface.jl @@ -1,5 +1,6 @@ @inline function SciMLBase.get_tmp_cache( - integrator, alg::Union{RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau}, + integrator, + alg::Union{RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau, GaussLegendre}, cache::OrdinaryDiffEqMutableCache ) return (cache.tmp, cache.atmp) diff --git a/lib/OrdinaryDiffEqFIRK/test/gausslegendre_adaptive_tests.jl b/lib/OrdinaryDiffEqFIRK/test/gausslegendre_adaptive_tests.jl new file mode 100644 index 0000000000..2e379c9701 --- /dev/null +++ b/lib/OrdinaryDiffEqFIRK/test/gausslegendre_adaptive_tests.jl @@ -0,0 +1,93 @@ +using OrdinaryDiffEqFIRK, DiffEqDevTools, Test, LinearAlgebra +import OrdinaryDiffEqCore +import ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear + +testTol = 0.6 + +# test orders 2 and 3 for convergence with fixed dt, anything higher is too sensitive to floating point precision +@testset "GaussLegendre: fixed-dt empirical order (s = 2, 3)" begin + dts = Float64.(1 ./ 2 .^ (5:-1:2)) # 1/32 … 1/4 + for s in 2:3 + alg = GaussLegendre(num_stages = s; maxiters = 100) + sim = test_convergence( + dts, prob_ode_linear, alg; + dense_errors = false, + abstol = 1.0e-12, reltol = 1.0e-12 + ) + @test sim.𝒪est[:final] ≈ 2 * s atol = testTol + end +end + +# test accuracy of high-order fixed-dt method (s = 4, order 8) empirically +@testset "GaussLegendre: fixed-dt accuracy (s = 4, order 8)" begin + s = 4 + alg = GaussLegendre(num_stages = s; maxiters = 100) + sol = solve( + prob_ode_linear, alg; adaptive = false, dt = 1 // 256, + abstol = 1.0e-14, reltol = 1.0e-14 + ) + @test sol.retcode == ReturnCode.Success + exact = prob_ode_linear.u0 * exp(1.01 * (sol.t[end] - sol.t[1])) + @test isapprox(sol.u[end], exact; rtol = 1.0e-9, atol = 1.0e-12) +end + +# test adaptive stepping with tolerance matching +@testset "GaussLegendre: adaptive run matches tolerance" begin + for s in 2:4 + reltol = 1.0e-6 + abstol = 1.0e-9 + sol = solve( + prob_ode_linear, GaussLegendre(num_stages = s); + reltol = reltol, abstol = abstol + ) + @test sol.retcode == ReturnCode.Success + exact = prob_ode_linear.u0 * exp(1.01 * (sol.t[end] - sol.t[1])) + @test isapprox(sol.u[end], exact; rtol = 1.0e-3, atol = 1.0e-6) + end +end + +@testset "GaussLegendre: adaptive controller defaults to PI" begin + alg = GaussLegendre(num_stages = 3) + integrator = init( + prob_ode_linear, alg; + reltol = 1.0e-6, abstol = 1.0e-9 + ) + + controller = integrator.controller_cache.controller + order = 2 * alg.num_stages + @test alg.controller === :PI + @test integrator.controller_cache isa OrdinaryDiffEqCore.PIControllerCache + @test controller.beta1 ≈ 7 / (10 * order) + @test controller.beta2 ≈ 2 / (5 * order) +end + +# test that Richardson step-doubling tightens step count with tolerance +@testset "GaussLegendre: PI/Richardson tightens step count when tol tightens" begin + s = 3 + sol_loose = solve( + prob_ode_linear, GaussLegendre(num_stages = s); + reltol = 1.0e-3, abstol = 1.0e-6 + ) + sol_tight = solve( + prob_ode_linear, GaussLegendre(num_stages = s); + reltol = 1.0e-8, abstol = 1.0e-10 + ) + @test length(sol_tight.t) >= length(sol_loose.t) +end + +# test that num_stages = 1 with adaptive throws +@testset "GaussLegendre: num_stages = 1 with adaptive throws" begin + @test_throws ArgumentError solve( + prob_ode_linear, GaussLegendre(num_stages = 1); + reltol = 1.0e-4, abstol = 1.0e-7 + ) +end + +# test that num_stages = 1 with adaptive = false runs +@testset "GaussLegendre: num_stages = 1 with adaptive = false runs" begin + sol = solve( + prob_ode_linear, GaussLegendre(num_stages = 1); + adaptive = false, dt = 1 // 32 + ) + @test sol.retcode == ReturnCode.Success +end diff --git a/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl b/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl index 2c34c20e4e..1a3334e949 100644 --- a/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl +++ b/lib/OrdinaryDiffEqFIRK/test/ode_firk_tests.jl @@ -106,6 +106,37 @@ for prob in [prob_ode_linear, prob_ode_2Dlinear] @test sim.𝒪est[:L2] ≈ 3 atol = 0.25 end +# GL4 on the convergence tests +for prob in [prob_ode_linear, prob_ode_2Dlinear] + dts = Float64.(1 ./ 2 .^ (5:-1:2)) + sim = test_convergence( + dts, + prob, + GaussLegendre(num_stages = 2; maxiters = 100); + dense_errors = false, + abstol = 1.0e-12, + reltol = 1.0e-12, + ) + @test sim.𝒪est[:final] ≈ 4 atol = testTol +end + +# GL6 on the 2D linear problem only due to scalar log–log slope being noisier at high order +dts = Float64.(1 ./ 2 .^ (5:-1:2)) +sim_gl3 = test_convergence( + dts, + prob_ode_2Dlinear, + GaussLegendre(num_stages = 3; maxiters = 100); + dense_errors = false, + abstol = 1.0e-12, + reltol = 1.0e-12, +) +@test sim_gl3.𝒪est[:final] ≈ 6 atol = testTol + +for prob in [prob_ode_linear, prob_ode_2Dlinear] + sol = solve(prob, GaussLegendre(num_stages = 3); reltol = 1.0e-5, abstol = 1.0e-8) + @test SciMLBase.successful_retcode(sol) +end + # test adaptivity for iip in (true, false) vanstiff = ODEProblem{iip}(vanderpol_firk, [sqrt(3), 0], (0.0, 1.0), [1.0e6]) diff --git a/lib/OrdinaryDiffEqFIRK/test/qa/allocation_tests.jl b/lib/OrdinaryDiffEqFIRK/test/qa/allocation_tests.jl index 97295884d4..c23da989e2 100644 --- a/lib/OrdinaryDiffEqFIRK/test/qa/allocation_tests.jl +++ b/lib/OrdinaryDiffEqFIRK/test/qa/allocation_tests.jl @@ -13,14 +13,18 @@ using Test # Use FullSpecialize to avoid FunctionWrappers dynamic dispatch noise prob = ODEProblem{true, FullSpecialize}(simple_system!, [1.0, 1.0], (0.0, 1.0)) - firk_solvers = [RadauIIA3(), RadauIIA5(), RadauIIA9(), AdaptiveRadau()] + firk_solvers = [ + RadauIIA3(), RadauIIA5(), RadauIIA9(), AdaptiveRadau(), + GaussLegendre(num_stages = 2), + ] @testset "FIRK perform_step! Static Analysis" begin for solver in firk_solvers @testset "$(typeof(solver)) perform_step! allocation check" begin integrator = init( prob, solver, dt = 0.1, save_everystep = false, - abstol = 1.0e-6, reltol = 1.0e-6 + abstol = 1.0e-6, reltol = 1.0e-6; + adaptive = !(solver isa GaussLegendre), ) step!(integrator) diff --git a/lib/OrdinaryDiffEqFIRK/test/runtests.jl b/lib/OrdinaryDiffEqFIRK/test/runtests.jl index 48ba6e086c..1fc4b8a590 100644 --- a/lib/OrdinaryDiffEqFIRK/test/runtests.jl +++ b/lib/OrdinaryDiffEqFIRK/test/runtests.jl @@ -11,6 +11,7 @@ end # Run functional tests if TEST_GROUP == "Core" || TEST_GROUP == "ALL" @time @safetestset "FIRK Tests" include("ode_firk_tests.jl") + @time @safetestset "GaussLegendre Adaptive Tests" include("gausslegendre_adaptive_tests.jl") end # Run QA tests (AllocCheck, JET, Aqua) - skip on pre-release Julia