Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b"
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Expand Down Expand Up @@ -53,6 +54,7 @@ RecursiveArrayTools = "4"
FastPower = "1.1"
ODEProblemLibrary = "1"
OrdinaryDiffEqNonlinearSolve = "1.16.0"
Polyester = "0.7"
DiffEqBase = "7"
Reexport = "1.2"
SafeTestsets = "0.1.0"
Expand Down
13 changes: 7 additions & 6 deletions lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
qmax_default, alg_adaptive_order,
stepsize_controller!, step_accept_controller!,
step_reject_controller!,
PredictiveController, alg_can_repeat_jac, NewtonAlgorithm,
alg_can_repeat_jac, NewtonAlgorithm,
fac_default_gamma,
get_current_adaptive_order, get_fsalfirstlast,
isfirk, generic_solver_docstring, _ad_chunksize_int, _ad_fdtype, _fixup_ad,
get_current_alg_order,
isfirk, generic_solver_docstring,
_ad_chunksize_int, _ad_fdtype, _fixup_ad,
LinearAliasSpecifier
using MuladdMacro, DiffEqBase, RecursiveArrayTools
isfirk, generic_solver_docstring
using MuladdMacro, DiffEqBase, RecursiveArrayTools, Polyester
using SciMLOperators: AbstractSciMLOperator
using LinearAlgebra: I, UniformScaling, mul!, lu
using LinearAlgebra: I, UniformScaling, mul!, lu, dot, eigvals
import LinearSolve
import FastBroadcast: @..
import OrdinaryDiffEqCore
Expand Down Expand Up @@ -62,6 +63,6 @@ include("firk_interpolants.jl")
include("firk_addsteps.jl")
include("integrator_interface.jl")

export RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau
export RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau, GaussLegendre

end
14 changes: 14 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,17 @@ get_current_adaptive_order(alg::AdaptiveRadau, cache) = cache.num_stages
function has_stiff_interpolation(::Union{RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau})
return true
end

qmax_default(alg::GaussLegendre) = 8

alg_order(alg::GaussLegendre) = 2 * alg.num_stages

isfirk(alg::GaussLegendre) = true

# Richardson step-doubling controller
isadaptive(alg::GaussLegendre) = alg.num_stages >= 2
alg_adaptive_order(alg::GaussLegendre) = 2 * alg.num_stages
has_stiff_interpolation(::GaussLegendre) = true

get_current_alg_order(alg::GaussLegendre, cache) = 2 * alg.num_stages
get_current_adaptive_order(alg::GaussLegendre, cache) = 2 * alg.num_stages
78 changes: 78 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,81 @@ function AdaptiveRadau(;

)
end


gauss_legendre_docstring = """@article{butcher2008numerical,
title={Numerical Methods for Ordinary Differential Equations},
author={Butcher, John Charles},
year={2008},
publisher={Wiley}}"""

@doc differentiation_rk_docstring(
"A symplectic fully implicit Runge-Kutta method based on Gauss-Legendre quadrature.
With s stages, the method has order 2s. Symplectic and A-stable, making it suitable
for Hamiltonian systems and problems requiring long-time geometric integration.

!!! warning \"Experimental\"
`GaussLegendre` is experimental. Adaptive stepping currently uses Richardson
step-doubling (roughly 3× the work per accepted step) and requires `num_stages ≥ 2`;
Details may change as the implementation is refined.",
"GaussLegendre",
"Fully-Implicit Runge-Kutta Method.";
references = gauss_legendre_docstring,
extra_keyword_description = extra_keyword_description,
extra_keyword_default = extra_keyword_default
)
Comment thread
Sreeram-Shankar marked this conversation as resolved.
struct GaussLegendre{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
Comment thread
oscardssmith marked this conversation as resolved.
Outdated
OrdinaryDiffEqNewtonAdaptiveAlgorithm
linsolve::F
precs::P
smooth_est::Bool
extrapolant::Symbol
κ::Tol
maxiters::Int
fast_convergence_cutoff::C1
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
num_stages::Int
autodiff::AD
concrete_jac::CJ
end

@inline function _process_AD_choice(autodiff, chunk_size, diff_type)
return _fixup_ad(autodiff), chunk_size, diff_type
end

function GaussLegendre(;
num_stages = 2,
chunk_size = Val{0}(), autodiff = AutoForwardDiff(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}(),
linsolve = nothing, precs = nothing,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
controller = :Predictive, κ = nothing, maxiters = 10, smooth_est = true,
Comment thread
oscardssmith marked this conversation as resolved.
Outdated
step_limiter! = trivial_limiter!
)
AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type)

return GaussLegendre{
_unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve),
typeof(precs), diff_type, _unwrap_val(standardtag), typeof(_unwrap_val(concrete_jac)),
typeof(κ), typeof(fast_convergence_cutoff),
typeof(new_W_γdt_cutoff), typeof(step_limiter!),
}(
linsolve,
precs,
smooth_est,
extrapolant,
κ,
maxiters,
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller,
step_limiter!,
num_stages,
AD_choice,
_unwrap_val(concrete_jac)
)
end
139 changes: 139 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,142 @@ function alg_cache(
Convergence, alg.step_limiter!, num_stages, 1, 0.0, index
)
end

mutable struct GaussLegendreConstantCache{F, Tab, Tol, Dt, U, JType} <:
OrdinaryDiffEqConstantCache
uf::F
tab::Tab
κ::Tol
ηold::Tol
iter::Int
cont::Vector{U}
dtprev::Dt
W_γdt::Dt
status::NLStatus
J::JType
num_stages::Int
end

function alg_cache(
alg::GaussLegendre, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{false}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.num_stages
tab = GaussLegendreTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'
cont = Vector{typeof(u)}(undef, num_stages)
for i in 1:num_stages
cont[i] = zero(u)
end
return GaussLegendreConstantCache(
uf, tab, κ, one(uToltype), 10000, cont, dt, dt,
Convergence, J, num_stages
)
end


mutable struct GaussLegendreCache{
uType, uNoUnitsType, rateType, JType, WType, Buff,
UF, JC, F1, Tab, Tol, Dt, rTol, aTol, StepLimiter,
} <: FIRKMutableCache
u::uType
uprev::uType
z::Vector{uType}
z_last::Vector{uType}
w::Vector{uType}
dw::Vector{uType}
ubuff::Buff
u_full::uType
u_half::uType
du1::rateType
fsalfirst::rateType
k::rateType
ks::Vector{rateType}
fw::Vector{rateType}
J::JType
W::WType
uf::UF
tab::Tab
κ::Tol
ηold::Tol
iter::Int
tmp::uType
atmp::uNoUnitsType
jac_config::JC
linsolve::F1
rtol::rTol
atol::aTol
dtprev::Dt
W_γdt::Dt
status::NLStatus
step_limiter!::StepLimiter
num_stages::Int
end

function alg_cache(
alg::GaussLegendre, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{true}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.num_stages
tab = GaussLegendreTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)

z = [zero(u) for _ in 1:num_stages]
z_last = [zero(u) for _ in 1:num_stages]
w = [zero(u) for _ in 1:num_stages]
dw = [zero(u) for _ in 1:num_stages]
n = length(_vec(u))
ubuff = similar(_vec(u), num_stages * n)
recursivefill!(ubuff, false)
u_full = zero(u)
u_half = zero(u)

fsalfirst = zero(rate_prototype)
k = zero(rate_prototype)
ks = [zero(rate_prototype) for _ in 1:num_stages]
fw = [zero(rate_prototype) for _ in 1:num_stages]
du1 = zero(rate_prototype)

tmp = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw[1])

J, _ = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by GaussLegendre.")
end

W = similar(J, num_stages * n, num_stages * n)
recursivefill!(W, false)
linu0 = similar(_vec(ubuff))
recursivefill!(linu0, false)
linprob = LinearProblem(W, _vec(ubuff); u0 = linu0)
linsolve = init(
linprob, alg.linsolve,
alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
assumptions = LinearSolve.OperatorAssumptions(true),
verbose = verbose.linear_verbosity
)

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)

return GaussLegendreCache(
u, uprev, z, z_last, w, dw, ubuff, u_full, u_half,
du1, fsalfirst, k, ks, fw,
J, W,
uf, tab, κ, one(uToltype), 10000,
tmp, atmp, jac_config, linsolve, rtol, atol, dt, dt,
Convergence, alg.step_limiter!, num_stages
)
end
Loading
Loading