Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 40 additions & 42 deletions lib/OrdinaryDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
@@ -1,71 +1,69 @@
name = "OrdinaryDiffEqFIRK"
uuid = "5960d6e9-dd7a-4743-88e7-cf307b64f125"
authors = ["ParamThakkar123 <paramthakkar864@gmail.com>"]
version = "1.26.0"
authors = ["ParamThakkar123 <paramthakkar864@gmail.com>"]

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b"
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Runic = "62bfec6d-59d7-401d-8490-b29ee721c001"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"

[extras]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"

[sources]
DiffEqBase = {path = "../DiffEqBase"}
DiffEqDevTools = {path = "../DiffEqDevTools"}
OrdinaryDiffEqCore = {path = "../OrdinaryDiffEqCore"}
OrdinaryDiffEqDifferentiation = {path = "../OrdinaryDiffEqDifferentiation"}
OrdinaryDiffEqNonlinearSolve = {path = "../OrdinaryDiffEqNonlinearSolve"}

[compat]
Pkg = "1"
Test = "<0.0.1, 1"
FastBroadcast = "1.3"
Random = "<0.0.1, 1"
ADTypes = "1.16"
DiffEqBase = "7"
DiffEqDevTools = "3"
FastBroadcast = "1.3"
FastGaussQuadrature = "1.0.2"
MuladdMacro = "0.2"
LinearSolve = "3.46"
LinearAlgebra = "1.10"
OrdinaryDiffEqDifferentiation = "2"
SciMLBase = "3"
OrdinaryDiffEqCore = "4"
GenericSchur = "0.5"
julia = "1.10"
ADTypes = "1.16"
RecursiveArrayTools = "4"
FastPower = "1.1"
GenericSchur = "0.5"
LinearAlgebra = "1.10"
LinearSolve = "3.46"
MuladdMacro = "0.2"
ODEProblemLibrary = "1"
OrdinaryDiffEqCore = "4"
OrdinaryDiffEqDifferentiation = "2"
OrdinaryDiffEqNonlinearSolve = "1.16.0"
DiffEqBase = "7"
Pkg = "1"
Polyester = "0.7"
Random = "<0.0.1, 1"
RecursiveArrayTools = "4"
Reexport = "1.2"
Runic = "1.7.0"
SafeTestsets = "0.1.0"
SciMLBase = "3"
SciMLOperators = "1.4"
Test = "<0.0.1, 1"
julia = "1.10"

[extras]
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DiffEqDevTools", "GenericSchur", "Random", "SafeTestsets", "Test", "ODEProblemLibrary", "Pkg"]

[sources.OrdinaryDiffEqDifferentiation]
path = "../OrdinaryDiffEqDifferentiation"

[sources.OrdinaryDiffEqNonlinearSolve]
path = "../OrdinaryDiffEqNonlinearSolve"

[sources.OrdinaryDiffEqCore]
path = "../OrdinaryDiffEqCore"
15 changes: 8 additions & 7 deletions lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module OrdinaryDiffEqFIRK

import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
initialize!, perform_step!, unwrap_alg,
calculate_residuals, default_controller, PredictiveController,
calculate_residuals, default_controller, PredictiveController, PIController,
OrdinaryDiffEqAlgorithm, OrdinaryDiffEqNewtonAdaptiveAlgorithm,
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
OrdinaryDiffEqAdaptiveAlgorithm, CompiledFloats, uses_uprev,
Expand All @@ -14,15 +14,16 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
qmax_default, alg_adaptive_order,
stepsize_controller!, step_accept_controller!,
step_reject_controller!,
PredictiveController, alg_can_repeat_jac, NewtonAlgorithm,
alg_can_repeat_jac, NewtonAlgorithm,
fac_default_gamma,
get_current_adaptive_order, get_fsalfirstlast,
isfirk, generic_solver_docstring, _ad_chunksize_int, _ad_fdtype, _fixup_ad,
get_current_alg_order,
isfirk, generic_solver_docstring,
_ad_chunksize_int, _ad_fdtype, _fixup_ad,
LinearAliasSpecifier
using MuladdMacro, DiffEqBase, RecursiveArrayTools
isfirk, generic_solver_docstring
using MuladdMacro, DiffEqBase, RecursiveArrayTools, Polyester
using SciMLOperators: AbstractSciMLOperator
using LinearAlgebra: I, UniformScaling, mul!, lu
using LinearAlgebra: I, UniformScaling, mul!, lu, dot, eigvals
import LinearSolve
import FastBroadcast: @..
import OrdinaryDiffEqCore
Expand Down Expand Up @@ -62,6 +63,6 @@ include("firk_interpolants.jl")
include("firk_addsteps.jl")
include("integrator_interface.jl")

export RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau
export RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau, GaussLegendre

end
16 changes: 16 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,19 @@ get_current_adaptive_order(alg::AdaptiveRadau, cache) = cache.num_stages
function has_stiff_interpolation(::Union{RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau})
return true
end

qmax_default(alg::GaussLegendre) = 8

alg_order(alg::GaussLegendre) = 2 * alg.num_stages

default_controller(QT, alg::GaussLegendre) = PIController(QT, alg)

isfirk(alg::GaussLegendre) = true

# Richardson step-doubling controller
isadaptive(alg::GaussLegendre) = alg.num_stages >= 2
alg_adaptive_order(alg::GaussLegendre) = 2 * alg.num_stages
has_stiff_interpolation(::GaussLegendre) = true

get_current_alg_order(alg::GaussLegendre, cache) = 2 * alg.num_stages
get_current_adaptive_order(alg::GaussLegendre, cache) = 2 * alg.num_stages
68 changes: 68 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,71 @@ function AdaptiveRadau(;

)
end


gauss_legendre_docstring = """@article{butcher2008numerical,
title={Numerical Methods for Ordinary Differential Equations},
author={Butcher, John Charles},
year={2008},
publisher={Wiley}}"""

@doc differentiation_rk_docstring(
"A symplectic fully implicit Runge-Kutta method based on Gauss-Legendre quadrature.
With s stages, the method has order 2s. Symplectic and A-stable, making it suitable
for Hamiltonian systems and problems requiring long-time geometric integration.

!!! warning \"Experimental\"
`GaussLegendre` is experimental. Adaptive stepping currently uses Richardson
step-doubling (roughly 3× the work per accepted step) and requires `num_stages ≥ 2`;
Details may change as the implementation is refined.",
"GaussLegendre",
"Fully-Implicit Runge-Kutta Method.";
references = gauss_legendre_docstring,
extra_keyword_description = extra_keyword_description,
extra_keyword_default = extra_keyword_default
)
Comment thread
Sreeram-Shankar marked this conversation as resolved.
struct GaussLegendre{AD, F, P, Tol, C1, C2, StepLimiter, CJ} <:
OrdinaryDiffEqNewtonAdaptiveAlgorithm
linsolve::F
precs::P
smooth_est::Bool
extrapolant::Symbol
κ::Tol
maxiters::Int
fast_convergence_cutoff::C1
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
num_stages::Int
autodiff::AD
concrete_jac::CJ
end

function GaussLegendre(;
num_stages = 2,
autodiff = AutoForwardDiff(),
concrete_jac = nothing,
linsolve = nothing, precs = nothing,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
controller = :PI, κ = nothing, maxiters = 10, smooth_est = true,
step_limiter! = trivial_limiter!
)
autodiff = _fixup_ad(autodiff)

return GaussLegendre(
linsolve,
precs,
smooth_est,
extrapolant,
κ,
maxiters,
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller,
step_limiter!,
num_stages,
autodiff,
_unwrap_val(concrete_jac)
)
end
139 changes: 139 additions & 0 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,142 @@ function alg_cache(
Convergence, alg.step_limiter!, num_stages, 1, 0.0, index
)
end

mutable struct GaussLegendreConstantCache{F, Tab, Tol, Dt, U, JType} <:
OrdinaryDiffEqConstantCache
uf::F
tab::Tab
κ::Tol
ηold::Tol
iter::Int
cont::Vector{U}
dtprev::Dt
W_γdt::Dt
status::NLStatus
J::JType
num_stages::Int
end

function alg_cache(
alg::GaussLegendre, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{false}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.num_stages
tab = GaussLegendreTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'
cont = Vector{typeof(u)}(undef, num_stages)
for i in 1:num_stages
cont[i] = zero(u)
end
return GaussLegendreConstantCache(
uf, tab, κ, one(uToltype), 10000, cont, dt, dt,
Convergence, J, num_stages
)
end


mutable struct GaussLegendreCache{
uType, uNoUnitsType, rateType, JType, WType, Buff,
UF, JC, F1, Tab, Tol, Dt, rTol, aTol, StepLimiter,
} <: FIRKMutableCache
u::uType
uprev::uType
z::Vector{uType}
z_last::Vector{uType}
w::Vector{uType}
dw::Vector{uType}
ubuff::Buff
u_full::uType
u_half::uType
du1::rateType
fsalfirst::rateType
k::rateType
ks::Vector{rateType}
fw::Vector{rateType}
J::JType
W::WType
uf::UF
tab::Tab
κ::Tol
ηold::Tol
iter::Int
tmp::uType
atmp::uNoUnitsType
jac_config::JC
linsolve::F1
rtol::rTol
atol::aTol
dtprev::Dt
W_γdt::Dt
status::NLStatus
step_limiter!::StepLimiter
num_stages::Int
end

function alg_cache(
alg::GaussLegendre, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits},
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{true}, verbose
) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.num_stages
tab = GaussLegendreTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)

z = [zero(u) for _ in 1:num_stages]
z_last = [zero(u) for _ in 1:num_stages]
w = [zero(u) for _ in 1:num_stages]
dw = [zero(u) for _ in 1:num_stages]
n = length(_vec(u))
ubuff = similar(_vec(u), num_stages * n)
recursivefill!(ubuff, false)
u_full = zero(u)
u_half = zero(u)

fsalfirst = zero(rate_prototype)
k = zero(rate_prototype)
ks = [zero(rate_prototype) for _ in 1:num_stages]
fw = [zero(rate_prototype) for _ in 1:num_stages]
du1 = zero(rate_prototype)

tmp = zero(u)
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw[1])

J, _ = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by GaussLegendre.")
end

W = similar(J, num_stages * n, num_stages * n)
recursivefill!(W, false)
linu0 = similar(_vec(ubuff))
recursivefill!(linu0, false)
linprob = LinearProblem(W, _vec(ubuff); u0 = linu0)
linsolve = init(
linprob, alg.linsolve,
alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
assumptions = LinearSolve.OperatorAssumptions(true),
verbose = verbose.linear_verbosity
)

rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)

return GaussLegendreCache(
u, uprev, z, z_last, w, dw, ubuff, u_full, u_half,
du1, fsalfirst, k, ks, fw,
J, W,
uf, tab, κ, one(uToltype), 10000,
tmp, atmp, jac_config, linsolve, rtol, atol, dt, dt,
Convergence, alg.step_limiter!, num_stages
)
end
Loading
Loading