Skip to content

Commit a695765

Browse files
Harsh Singhclaude
authored andcommitted
refactor(RKN): unify velocity-dependent Nyström methods via NystromVDTableau (Phase 2)
Adds NystromVDTableau{T,T2} struct and generic NystromVDCache/NystromVDConstantCache, eliminating per-method caches and perform_step! for FineRKN4, FineRKN5, RKN4, Nystrom4, and Nystrom4VelocityIndependent. Unchanged: DPRKN6 (kshortsize=3 dense output), IRKN3, IRKN4 (implicit). All nystrom_convergence_tests pass (70 pass, 16 pre-existing broken). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8cce3b5 commit a695765

5 files changed

Lines changed: 391 additions & 533 deletions

File tree

lib/OrdinaryDiffEqRKN/src/OrdinaryDiffEqRKN.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ include("interp_func.jl")
2828
include("interpolants.jl")
2929
include("rkn_perform_step.jl")
3030
include("generic_rkn_vi_perform_step.jl")
31+
include("generic_rkn_vd_perform_step.jl")
3132

3233
export Nystrom4, FineRKN4, FineRKN5, Nystrom4VelocityIndependent,
3334
Nystrom5VelocityIndependent,
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
## Generic Nyström velocity-DEPENDENT perform_step!
2+
## Solves: y'' = f(t, y, y') where f depends on both position and velocity
3+
## kᵢ = f1(duprev + dt*Σⱼ abar[i,j]*kⱼ, uprev + dt*c[i]*duprev + dt²*Σⱼ a[i,j]*kⱼ, p, t+c[i]*dt)
4+
## y₁ = y₀ + h*y'₀ + h²*Σᵢ bᵢ*kᵢ
5+
## y'₁ = y'₀ + h*Σᵢ bpᵢ*kᵢ
6+
7+
function initialize!(integrator, cache::NystromVDConstantCache)
8+
integrator.kshortsize = 2
9+
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
10+
11+
duprev, uprev = integrator.uprev.x
12+
kdu = integrator.f.f1(duprev, uprev, integrator.p, integrator.t)
13+
ku = integrator.f.f2(duprev, uprev, integrator.p, integrator.t)
14+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
15+
integrator.stats.nf2 += 1
16+
integrator.fsalfirst = ArrayPartition((kdu, ku))
17+
integrator.fsallast = zero(integrator.fsalfirst)
18+
integrator.k[1] = integrator.fsalfirst
19+
return integrator.k[2] = integrator.fsallast
20+
end
21+
22+
function initialize!(integrator, cache::NystromVDCache)
23+
integrator.kshortsize = 2
24+
resize!(integrator.k, integrator.kshortsize)
25+
integrator.k[1] = integrator.fsalfirst
26+
integrator.k[2] = integrator.fsallast
27+
duprev, uprev = integrator.uprev.x
28+
integrator.f.f1(integrator.fsalfirst.x[1], duprev, uprev, integrator.p, integrator.t)
29+
integrator.f.f2(integrator.fsalfirst.x[2], duprev, uprev, integrator.p, integrator.t)
30+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
31+
return integrator.stats.nf2 += 1
32+
end
33+
34+
@muladd function perform_step!(
35+
integrator, cache::NystromVDConstantCache, repeat_step = false)
36+
(; t, dt, f, p) = integrator
37+
duprev, uprev = integrator.uprev.x
38+
(; tab) = cache
39+
(; a, abar, b, bp, btilde, bptilde, c) = tab
40+
k1 = integrator.fsalfirst.x[1]
41+
nstages = length(b)
42+
dtsq = dt^2
43+
44+
# Compute intermediate stages k2..knstages
45+
ks = Vector{typeof(k1)}(undef, nstages)
46+
ks[1] = k1
47+
for i in 2:nstages
48+
ku = uprev + dt * c[i - 1] * duprev
49+
kdu = duprev
50+
for j in 1:(i - 1)
51+
if !iszero(a[i, j])
52+
ku = ku + dtsq * a[i, j] * ks[j]
53+
end
54+
if !iszero(abar[i, j])
55+
kdu = kdu + dt * abar[i, j] * ks[j]
56+
end
57+
end
58+
ks[i] = f.f1(kdu, ku, p, t + dt * c[i - 1])
59+
end
60+
61+
# Position and velocity updates
62+
u = uprev + dt * duprev
63+
for i in 1:nstages
64+
if !iszero(b[i])
65+
u = u + dtsq * b[i] * ks[i]
66+
end
67+
end
68+
du = duprev
69+
for i in 1:nstages
70+
if !iszero(bp[i])
71+
du = du + dt * bp[i] * ks[i]
72+
end
73+
end
74+
75+
integrator.u = ArrayPartition((du, u))
76+
integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt)))
77+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, tab.nf_per_step)
78+
integrator.stats.nf2 += 1
79+
integrator.k[1] = integrator.fsalfirst
80+
integrator.k[2] = integrator.fsallast
81+
82+
if integrator.opts.adaptive && !isempty(btilde)
83+
uhat = zero(uprev)
84+
duhat = zero(duprev)
85+
for i in 1:nstages
86+
if !iszero(btilde[i])
87+
uhat = uhat + dtsq * btilde[i] * ks[i]
88+
end
89+
if !isempty(bptilde) && !iszero(bptilde[i])
90+
duhat = duhat + dt * bptilde[i] * ks[i]
91+
end
92+
end
93+
utilde = ArrayPartition((duhat, uhat))
94+
atmp = calculate_residuals(utilde, integrator.uprev, integrator.u,
95+
integrator.opts.abstol, integrator.opts.reltol,
96+
integrator.opts.internalnorm, t)
97+
integrator.EEst = integrator.opts.internalnorm(atmp, t)
98+
end
99+
end
100+
101+
@muladd function perform_step!(
102+
integrator, cache::NystromVDCache, repeat_step = false)
103+
(; t, dt, f, p) = integrator
104+
du, u = integrator.u.x
105+
duprev, uprev = integrator.uprev.x
106+
(; ks, k, utilde, tmp, atmp, tab) = cache
107+
(; a, abar, b, bp, btilde, bptilde, c) = tab
108+
ku = tmp.x[2]
109+
kdu = tmp.x[1]
110+
k1 = integrator.fsalfirst.x[1]
111+
nstages = length(b)
112+
dtsq = dt^2
113+
114+
# Compute intermediate stages k2..knstages, stored in ks[1..nstages-1]
115+
for i in 2:nstages
116+
@.. broadcast=false ku = uprev + dt * c[i - 1] * duprev
117+
@.. broadcast=false kdu = duprev
118+
for j in 1:(i - 1)
119+
kj = (j == 1) ? k1 : ks[j - 1]
120+
if !iszero(a[i, j])
121+
@.. broadcast=false ku = ku + dtsq * a[i, j] * kj
122+
end
123+
if !iszero(abar[i, j])
124+
@.. broadcast=false kdu = kdu + dt * abar[i, j] * kj
125+
end
126+
end
127+
f.f1(ks[i - 1], kdu, ku, p, t + dt * c[i - 1])
128+
end
129+
130+
# Position update: u = uprev + dt*duprev + dt^2 * sum(b[i]*ki)
131+
@.. broadcast=false u = uprev + dt * duprev
132+
for i in 1:nstages
133+
if !iszero(b[i])
134+
ki = (i == 1) ? k1 : ks[i - 1]
135+
@.. broadcast=false u = u + dtsq * b[i] * ki
136+
end
137+
end
138+
139+
# Velocity update: du = duprev + dt * sum(bp[i]*ki)
140+
@.. broadcast=false du = duprev
141+
for i in 1:nstages
142+
if !iszero(bp[i])
143+
ki = (i == 1) ? k1 : ks[i - 1]
144+
@.. broadcast=false du = du + dt * bp[i] * ki
145+
end
146+
end
147+
148+
f.f1(k.x[1], du, u, p, t + dt)
149+
f.f2(k.x[2], du, u, p, t + dt)
150+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, tab.nf_per_step)
151+
integrator.stats.nf2 += 1
152+
153+
if integrator.opts.adaptive && !isempty(btilde)
154+
duhat, uhat = utilde.x
155+
@.. broadcast=false uhat = zero(uhat)
156+
@.. broadcast=false duhat = zero(duhat)
157+
for i in 1:nstages
158+
ki = (i == 1) ? k1 : ks[i - 1]
159+
if !iszero(btilde[i])
160+
@.. broadcast=false uhat = uhat + dtsq * btilde[i] * ki
161+
end
162+
if !isempty(bptilde) && !iszero(bptilde[i])
163+
@.. broadcast=false duhat = duhat + dt * bptilde[i] * ki
164+
end
165+
end
166+
calculate_residuals!(atmp, utilde, integrator.uprev, integrator.u,
167+
integrator.opts.abstol, integrator.opts.reltol,
168+
integrator.opts.internalnorm, t)
169+
integrator.EEst = integrator.opts.internalnorm(atmp, t)
170+
end
171+
end

0 commit comments

Comments
 (0)