Skip to content

Commit 4134cbe

Browse files
Harsh SinghHarsh Singh
authored andcommitted
Fix SROCK2 NoiseWrapper and non-square noise (#3188, #3170)
- Add init_χ! helper and use correct RNG dispatch for NoiseWrapper - Remove broken commutative noise opt-out (length=1) for non-diagonal noise - This ensures matrix noise shapes correctly route to generalized noise path - Add DiffEqNoiseProcess to deps
1 parent a2c092a commit 4134cbe

2 files changed

Lines changed: 37 additions & 28 deletions

File tree

lib/StochasticDiffEqROCK/src/caches/SROCK_caches.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ function alg_cache(
161161
uᵢ₋₁ = zero(u)
162162
uᵢ₋₂ = zero(u)
163163
Gₛ = zero(noise_rate_prototype)
164-
if (!alg.strong_order_1 || is_diagonal_noise(prob) || ΔW isa Number || length(ΔW) == 1)
164+
if (!alg.strong_order_1 || is_diagonal_noise(prob) || ΔW isa Number )
165165
Gₛ₁ = Gₛ
166166
else
167167
Gₛ₁ = zero(noise_rate_prototype)
@@ -295,7 +295,7 @@ function alg_cache(
295295
uᵢ₋₁ = zero(u)
296296
uᵢ₋₂ = zero(u)
297297
Gₛ = zero(noise_rate_prototype)
298-
if ΔW isa Number || length(ΔW) == 1 || is_diagonal_noise(prob)
298+
if ΔW isa Number || is_diagonal_noise(prob)
299299
Gₛ₁ = Gₛ
300300
else
301301
Gₛ₁ = zero(noise_rate_prototype)
@@ -377,7 +377,7 @@ function alg_cache(
377377
Xₛ₋₃ = zero(noise_rate_prototype)
378378
vec_χ = false .* vec(ΔW)
379379
WikRange = false .* vec(ΔW)
380-
if ΔW isa Number || length(ΔW) == 1 || is_diagonal_noise(prob)
380+
if ΔW isa Number || is_diagonal_noise(prob)
381381
Gₛ = Xₛ₋₁
382382
SXₛ₋₁ = utmp
383383
SXₛ₋₂ = utmp

lib/StochasticDiffEqROCK/src/perform_step/SROCK_perform_step.jl

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,12 @@ end
194194
(; recf, recf2, mα, mσ, mτ) = cache
195195

196196
gen_prob = !(
197-
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number) ||
198-
(length(W.dW) == 1)
197+
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number)
199198
)
200-
gen_prob && (vec_χ = 2 .* floor.(false .* W.dW .+ 1 // 2 .+ oftype(W.dW, rand(W.rng, length(W.dW)))) .- true)
199+
if gen_prob
200+
vec_χ = similar(W.dW)
201+
init_χ!(vec_χ, W)
202+
end
201203

202204
alg = unwrap_alg(integrator, true)
203205
alg.eigen_est === nothing ? maxeig!(integrator, cache) : alg.eigen_est(integrator)
@@ -265,7 +267,7 @@ end
265267
# Now uᵢ₋₂ = uₛ₋₂, uᵢ₋₁ = uₛ₋₁, uᵢ = uₛ
266268
# Similarly tᵢ₋₂ = tₛ₋₂, tᵢ₋₁ = tₛ₋₁, tᵢ = tₛ
267269

268-
if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
270+
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
269271
Gₛ = integrator.f.g(uᵢ₋₁, p, tᵢ₋₁)
270272
u += Gₛ .* W.dW
271273
Gₛ = integrator.f.g(uᵢ, p, tᵢ)
@@ -332,8 +334,7 @@ end
332334
(; recf, recf2, mα, mσ, mτ) = cache.constantcache
333335
ccache = cache.constantcache
334336
gen_prob = !(
335-
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number) ||
336-
(length(W.dW) == 1)
337+
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number)
337338
)
338339

339340
alg = unwrap_alg(integrator, true)
@@ -362,8 +363,7 @@ end
362363

363364
sqrt_dt = sqrt(abs(dt))
364365
if gen_prob
365-
vec_χ .= 1 // 2 .+ oftype(W.dW, rand(W.rng, length(W.dW)))
366-
@.. vec_χ = 2 * floor(vec_χ) - 1
366+
init_χ!(vec_χ, W)
367367
end
368368

369369
μ = recf[start] # here κ = 0
@@ -418,7 +418,7 @@ end
418418
# Now uᵢ₋₂ = uₛ₋₂, uᵢ₋₁ = uₛ₋₁, uᵢ = uₛ
419419
# Similarly tᵢ₋₂ = tₛ₋₂, tᵢ₋₁ = tₛ₋₁, tᵢ = tₛ
420420

421-
if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
421+
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
422422
integrator.f.g(Gₛ, uᵢ₋₁, p, tᵢ₋₁)
423423
@.. u += Gₛ * W.dW
424424
integrator.f.g(Gₛ, uᵢ, p, tᵢ)
@@ -542,14 +542,14 @@ end
542542
end
543543

544544
Gₛ = integrator.f.g(u, p, tᵢ)
545-
if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
545+
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
546546
u += Gₛ .* W.dW
547547
else
548548
u += Gₛ * W.dW
549549
end
550550

551551
if integrator.alg.strong_order_1
552-
if (W.dW isa Number) || (length(W.dW) == 1) ||
552+
if (W.dW isa Number) ||
553553
(is_diagonal_noise(integrator.sol.prob))
554554
uᵢ₋₂ = @. 1 // 2 * Gₛ * (W.dW^2 - abs(dt))
555555
tmp = @. u + uᵢ₋₂
@@ -633,15 +633,15 @@ end
633633
end
634634

635635
integrator.f.g(Gₛ, u, p, tᵢ)
636-
if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
636+
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
637637
@.. u += Gₛ * W.dW
638638
else
639639
mul!(uᵢ₋₁, Gₛ, W.dW)
640640
u += uᵢ₋₁
641641
end
642642

643643
if integrator.alg.strong_order_1
644-
if (W.dW isa Number) || (length(W.dW) == 1) ||
644+
if (W.dW isa Number) ||
645645
(is_diagonal_noise(integrator.sol.prob))
646646
@.. uᵢ₋₂ = 1 // 2 * Gₛ * (W.dW^2 - abs(dt))
647647
@.. tmp = u + uᵢ₋₂
@@ -982,7 +982,7 @@ end
982982
end
983983
end
984984

985-
if (W.dW isa Number) || (length(W.dW) == 1)
985+
if (W.dW isa Number)
986986
Gₛ = integrator.f.g(Û₁, p, t̂₁)
987987
uₓ += Gₛ * W.dW
988988

@@ -1168,7 +1168,7 @@ end
11681168
end
11691169
end
11701170

1171-
if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
1171+
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
11721172
integrator.f.g(Gₛ, Û₁, p, t̂₁)
11731173
@.. uₓ += Gₛ * W.dW
11741174

@@ -1227,8 +1227,7 @@ end
12271227
(; recf, mσ, mτ, mδ) = cache
12281228

12291229
gen_prob = !(
1230-
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number) ||
1231-
(length(W.dW) == 1)
1230+
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number)
12321231
)
12331232

12341233
alg = unwrap_alg(integrator, true)
@@ -1245,7 +1244,10 @@ end
12451244
τ = mτ[deg_index]
12461245

12471246
sqrt_dt = sqrt(abs(dt))
1248-
(gen_prob) && (vec_χ = 2 .* floor.(1 // 2 .+ false .* W.dW .+ rand(length(W.dW))) .- 1)
1247+
if gen_prob
1248+
vec_χ = similar(W.dW)
1249+
init_χ!(vec_χ, W)
1250+
end
12491251

12501252
tᵢ₋₂ = t
12511253
uᵢ₋₂ = uprev
@@ -1289,7 +1291,7 @@ end
12891291
tᵢ₋₁ += θₛ₋₃ * (tᵢ₋₁ - tᵢ₋₂)
12901292
tᵢ₋₂ = ttmp
12911293

1292-
if W.dW isa Number || length(W.dW) == 1 || is_diagonal_noise(integrator.sol.prob)
1294+
if W.dW isa Number || is_diagonal_noise(integrator.sol.prob)
12931295
# stage s-3
12941296
yₛ₋₃ = integrator.f(uᵢ₋₁, p, tᵢ₋₁)
12951297
utmp = uᵢ₋₁ + μₛ₋₃ * yₛ₋₃
@@ -1431,8 +1433,7 @@ end
14311433

14321434
ccache = cache.constantcache
14331435
gen_prob = !(
1434-
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number) ||
1435-
(length(W.dW) == 1)
1436+
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number)
14361437
)
14371438

14381439
alg = unwrap_alg(integrator, true)
@@ -1459,7 +1460,7 @@ end
14591460
τ = mτ[deg_index]
14601461

14611462
sqrt_dt = sqrt(abs(dt))
1462-
(gen_prob) && (vec_χ .= 2 .* floor.(1 // 2 .+ false .* vec_χ .+ rand(length(vec_χ))) .- 1)
1463+
if gen_prob; init_χ!(vec_χ, W); end
14631464

14641465
tᵢ₋₂ = t
14651466
@.. uᵢ₋₂ = uprev
@@ -1502,7 +1503,7 @@ end
15021503
tᵢ₋₁ += θₛ₋₃ * (tᵢ₋₁ - tᵢ₋₂)
15031504
tᵢ₋₂ = ttmp
15041505

1505-
if W.dW isa Number || length(W.dW) == 1 || is_diagonal_noise(integrator.sol.prob)
1506+
if W.dW isa Number || is_diagonal_noise(integrator.sol.prob)
15061507
# stage s-3
15071508
integrator.f(yₛ₋₃, uᵢ₋₁, p, tᵢ₋₁)
15081509
@.. utmp = uᵢ₋₁ + μₛ₋₃ * yₛ₋₃
@@ -1713,7 +1714,7 @@ end
17131714
uᵢ₋₂ = integrator.f(uᵢ₋₂, p, tᵢ₋₂)
17141715
u += dt *+ τ) * uᵢ₋₂
17151716

1716-
if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
1717+
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
17171718
Gₛ = integrator.f.g(uᵢ₋₁, p, tᵢ₋₁)
17181719
u += Gₛ .* W.dW
17191720

@@ -1808,7 +1809,7 @@ end
18081809
integrator.f(k, uᵢ₋₂, p, tᵢ₋₂)
18091810
@.. u += dt *+ τ) * k
18101811

1811-
if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
1812+
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
18121813
integrator.f.g(Gₛ, uᵢ₋₁, p, tᵢ₋₁)
18131814
@.. u += Gₛ * W.dW
18141815

@@ -1842,3 +1843,11 @@ end
18421843

18431844
integrator.u = u
18441845
end
1846+
1847+
function init_χ!(vec_χ, W)
1848+
rand!(rng(W), vec_χ)
1849+
@.. vec_χ = 2 * floor(vec_χ + 1 // 2) - 1
1850+
end
1851+
1852+
rng(W::DiffEqNoiseProcess.AbstractNoiseProcess) = W.rng
1853+
rng(W::DiffEqNoiseProcess.NoiseWrapper) = W.source.rng

0 commit comments

Comments
 (0)