Skip to content
8 changes: 6 additions & 2 deletions lib/StochasticDiffEqROCK/src/SROCK_utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# This function calculates the largest eigenvalue
# (absolute value wise) by power iteration.
function maxeig!(integrator, cache::StochasticDiffEqConstantCache)
isfirst = integrator.iter == 1 || integrator.derivative_discontinuity
isfirst = integrator.iter == 1 ||
(hasfield(typeof(integrator), :derivative_discontinuity) &&
integrator.derivative_discontinuity)
(; t, dt, uprev, u, p) = integrator
maxiter = 50
safe = 1.2
Expand Down Expand Up @@ -72,7 +74,9 @@ function maxeig!(integrator, cache::StochasticDiffEqConstantCache)
end

function maxeig!(integrator, cache::StochasticDiffEqMutableCache)
isfirst = integrator.iter == 1 || integrator.derivative_discontinuity
isfirst = integrator.iter == 1 ||
(hasfield(typeof(integrator), :derivative_discontinuity) &&
integrator.derivative_discontinuity)
(; t, dt, uprev, u, p) = integrator
fz, z, fsalfirst = cache.atmp, cache.tmp, cache.fsalfirst
integrator.f(fsalfirst, uprev, p, t)
Expand Down
1 change: 1 addition & 0 deletions lib/StochasticDiffEqROCK/src/StochasticDiffEqROCK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import MuladdMacro: @muladd
import SciMLBase

using LinearAlgebra
using Random: rand!
using StaticArrays
using RecursiveArrayTools

Expand Down
6 changes: 3 additions & 3 deletions lib/StochasticDiffEqROCK/src/caches/SROCK_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ function alg_cache(
uᵢ₋₁ = zero(u)
uᵢ₋₂ = zero(u)
Gₛ = zero(noise_rate_prototype)
if (!alg.strong_order_1 || is_diagonal_noise(prob) || ΔW isa Number || length(ΔW) == 1)
if (!alg.strong_order_1 || is_diagonal_noise(prob) || ΔW isa Number)
Gₛ₁ = Gₛ
else
Gₛ₁ = zero(noise_rate_prototype)
Expand Down Expand Up @@ -295,7 +295,7 @@ function alg_cache(
uᵢ₋₁ = zero(u)
uᵢ₋₂ = zero(u)
Gₛ = zero(noise_rate_prototype)
if ΔW isa Number || length(ΔW) == 1 || is_diagonal_noise(prob)
if ΔW isa Number || is_diagonal_noise(prob)
Gₛ₁ = Gₛ
else
Gₛ₁ = zero(noise_rate_prototype)
Expand Down Expand Up @@ -377,7 +377,7 @@ function alg_cache(
Xₛ₋₃ = zero(noise_rate_prototype)
vec_χ = false .* vec(ΔW)
WikRange = false .* vec(ΔW)
if ΔW isa Number || length(ΔW) == 1 || is_diagonal_noise(prob)
if ΔW isa Number || is_diagonal_noise(prob)
Gₛ = Xₛ₋₁
SXₛ₋₁ = utmp
SXₛ₋₂ = utmp
Expand Down
81 changes: 47 additions & 34 deletions lib/StochasticDiffEqROCK/src/perform_step/SROCK_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ end
(; recf, recf2, mα, mσ, mτ) = cache

gen_prob = !(
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number) ||
(length(W.dW) == 1)
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number)
)
gen_prob && (vec_χ = 2 .* floor.(false .* W.dW .+ 1 // 2 .+ oftype(W.dW, rand(W.rng, length(W.dW)))) .- true)
if gen_prob
vec_χ = similar(W.dW)
init_χ!(vec_χ, W)
end

alg = unwrap_alg(integrator, true)
alg.eigen_est === nothing ? maxeig!(integrator, cache) : alg.eigen_est(integrator)
Expand Down Expand Up @@ -265,7 +267,7 @@ end
# Now uᵢ₋₂ = uₛ₋₂, uᵢ₋₁ = uₛ₋₁, uᵢ = uₛ
# Similarly tᵢ₋₂ = tₛ₋₂, tᵢ₋₁ = tₛ₋₁, tᵢ = tₛ

if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
Gₛ = integrator.f.g(uᵢ₋₁, p, tᵢ₋₁)
u += Gₛ .* W.dW
Gₛ = integrator.f.g(uᵢ, p, tᵢ)
Expand Down Expand Up @@ -300,7 +302,9 @@ end
for i in 1:length(W.dW)
WikJ = W.dW[i]
WikJ2 = vec_χ[i]
WikRange = 1 // 2 .* (W.dW .* WikJ .- (1:length(W.dW) .== i) .* abs(dt)) #.- (1:length(W.dW) .> i) .* dt .* vec_χ .+ (1:length(W.dW) .< i) .* dt .* WikJ2)
WikRange = 1 // 2 .* (W.dW .* WikJ .- (1:length(W.dW) .== i) .* abs(dt) .-
(1:length(W.dW) .> i) .* abs(dt) .* vec_χ .+
(1:length(W.dW) .< i) .* abs(dt) .* WikJ2)
uₓ = Gₛ * WikRange
WikRange = 1 // 2 .* (1:length(W.dW) .== i)
uᵢ₋₂ = uᵢ + uₓ
Expand Down Expand Up @@ -332,8 +336,7 @@ end
(; recf, recf2, mα, mσ, mτ) = cache.constantcache
ccache = cache.constantcache
gen_prob = !(
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number) ||
(length(W.dW) == 1)
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number)
)

alg = unwrap_alg(integrator, true)
Expand Down Expand Up @@ -362,8 +365,7 @@ end

sqrt_dt = sqrt(abs(dt))
if gen_prob
vec_χ .= 1 // 2 .+ oftype(W.dW, rand(W.rng, length(W.dW)))
@.. vec_χ = 2 * floor(vec_χ) - 1
init_χ!(vec_χ, W)
end

μ = recf[start] # here κ = 0
Expand Down Expand Up @@ -418,7 +420,7 @@ end
# Now uᵢ₋₂ = uₛ₋₂, uᵢ₋₁ = uₛ₋₁, uᵢ = uₛ
# Similarly tᵢ₋₂ = tₛ₋₂, tᵢ₋₁ = tₛ₋₁, tᵢ = tₛ

if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this removed?

integrator.f.g(Gₛ, uᵢ₋₁, p, tᵢ₋₁)
@.. u += Gₛ * W.dW
integrator.f.g(Gₛ, uᵢ, p, tᵢ)
Expand Down Expand Up @@ -458,7 +460,10 @@ end
WikJ2 = vec_χ[i]
dwrange = 1:length(W.dW)
abs_dt = abs(dt)
@.. WikRange = 1 // 2 * (W.dW * WikJ - (dwrange == i) * abs_dt) #+ (dwrange < i) * dt * WikJ2 - (dwrange > i) * dt * vec_χ)
@.. WikRange = 1 // 2 *
(W.dW * WikJ - (dwrange == i) * abs_dt -
(dwrange > i) * abs_dt * vec_χ +
(dwrange < i) * abs_dt * WikJ2)
mul!(uₓ, Gₛ, WikRange)
@.. uᵢ₋₂ = uᵢ + uₓ
@.. WikRange = 1 // 2 * (dwrange == i)
Expand Down Expand Up @@ -542,14 +547,14 @@ end
end

Gₛ = integrator.f.g(u, p, tᵢ)
if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
u += Gₛ .* W.dW
else
u += Gₛ * W.dW
end

if integrator.alg.strong_order_1
if (W.dW isa Number) || (length(W.dW) == 1) ||
if (W.dW isa Number) ||
(is_diagonal_noise(integrator.sol.prob))
uᵢ₋₂ = @. 1 // 2 * Gₛ * (W.dW^2 - abs(dt))
tmp = @. u + uᵢ₋₂
Expand Down Expand Up @@ -633,15 +638,15 @@ end
end

integrator.f.g(Gₛ, u, p, tᵢ)
if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
@.. u += Gₛ * W.dW
else
mul!(uᵢ₋₁, Gₛ, W.dW)
u += uᵢ₋₁
end

if integrator.alg.strong_order_1
if (W.dW isa Number) || (length(W.dW) == 1) ||
if (W.dW isa Number) ||
(is_diagonal_noise(integrator.sol.prob))
@.. uᵢ₋₂ = 1 // 2 * Gₛ * (W.dW^2 - abs(dt))
@.. tmp = u + uᵢ₋₂
Expand Down Expand Up @@ -982,7 +987,7 @@ end
end
end

if (W.dW isa Number) || (length(W.dW) == 1)
if (W.dW isa Number)
Gₛ = integrator.f.g(Û₁, p, t̂₁)
uₓ += Gₛ * W.dW

Expand Down Expand Up @@ -1168,7 +1173,7 @@ end
end
end

if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
integrator.f.g(Gₛ, Û₁, p, t̂₁)
@.. uₓ += Gₛ * W.dW

Expand Down Expand Up @@ -1227,8 +1232,7 @@ end
(; recf, mσ, mτ, mδ) = cache

gen_prob = !(
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number) ||
(length(W.dW) == 1)
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number)
)

alg = unwrap_alg(integrator, true)
Expand All @@ -1245,7 +1249,10 @@ end
τ = mτ[deg_index]

sqrt_dt = sqrt(abs(dt))
(gen_prob) && (vec_χ = 2 .* floor.(1 // 2 .+ false .* W.dW .+ rand(length(W.dW))) .- 1)
if gen_prob
vec_χ = similar(W.dW)
init_χ!(vec_χ, W)
end

tᵢ₋₂ = t
uᵢ₋₂ = uprev
Expand Down Expand Up @@ -1289,7 +1296,7 @@ end
tᵢ₋₁ += θₛ₋₃ * (tᵢ₋₁ - tᵢ₋₂)
tᵢ₋₂ = ttmp

if W.dW isa Number || length(W.dW) == 1 || is_diagonal_noise(integrator.sol.prob)
if W.dW isa Number || is_diagonal_noise(integrator.sol.prob)
# stage s-3
yₛ₋₃ = integrator.f(uᵢ₋₁, p, tᵢ₋₁)
utmp = uᵢ₋₁ + μₛ₋₃ * yₛ₋₃
Expand Down Expand Up @@ -1424,15 +1431,14 @@ end
@muladd function perform_step!(integrator, cache::KomBurSROCK2Cache)
(;
utmp, uᵢ₋₁, uᵢ₋₂, k, yₛ₋₁, yₛ₋₂, yₛ₋₃, SXₛ₋₁, SXₛ₋₂,
SXₛ₋₃, Gₛ, Xₛ₋₁, Xₛ₋₂, Xₛ₋₃, vec_χ,
SXₛ₋₃, Gₛ, Xₛ₋₁, Xₛ₋₂, Xₛ₋₃, vec_χ, WikRange,
) = cache
(; t, dt, uprev, u, W, p, f) = integrator
(; recf, mσ, mτ, mδ) = cache.constantcache

ccache = cache.constantcache
gen_prob = !(
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number) ||
(length(W.dW) == 1)
(is_diagonal_noise(integrator.sol.prob)) || (W.dW isa Number)
)

alg = unwrap_alg(integrator, true)
Expand All @@ -1459,7 +1465,9 @@ end
τ = mτ[deg_index]

sqrt_dt = sqrt(abs(dt))
(gen_prob) && (vec_χ .= 2 .* floor.(1 // 2 .+ false .* vec_χ .+ rand(length(vec_χ))) .- 1)
if gen_prob
init_χ!(vec_χ, W)
end

tᵢ₋₂ = t
@.. uᵢ₋₂ = uprev
Expand Down Expand Up @@ -1502,7 +1510,7 @@ end
tᵢ₋₁ += θₛ₋₃ * (tᵢ₋₁ - tᵢ₋₂)
tᵢ₋₂ = ttmp

if W.dW isa Number || length(W.dW) == 1 || is_diagonal_noise(integrator.sol.prob)
if W.dW isa Number || is_diagonal_noise(integrator.sol.prob)
# stage s-3
integrator.f(yₛ₋₃, uᵢ₋₁, p, tᵢ₋₁)
@.. utmp = uᵢ₋₁ + μₛ₋₃ * yₛ₋₃
Expand Down Expand Up @@ -1563,8 +1571,7 @@ end
ttmp = tᵢ₋₂ + C₁
integrator.f.g(Gₛ, utmp, p, ttmp)
WikRange .= 1 .* (1:length(W.dW) .== i)
# @.. @view(Xₛ₋₂[:,i]) = @view(Gₛ[:,i])
@.. Xₛ₋₂ = Gₛ * W.dW
@view(Xₛ₋₂[:, i]) .= @view(Gₛ[:, i])
end
mul!(SXₛ₋₂, Xₛ₋₂, W.dW)
@.. u += μₛ₋₂ * yₛ₋₂ + 3 // 8 * SXₛ₋₂
Expand All @@ -1578,14 +1585,13 @@ end
# @.. utmp = uᵢ₋₁ + μₛ₋₃*yₛ₋₃ + δ₁*yₛ₋₂ - 1//6*W.dW[i]*@view(Xₛ₋₃[:,i]) - 1//2*W.dW[i]*@view(Xₛ₋₂[:,i]) + 1//4*SXₛ₋₃ + 3//4*SXₛ₋₂
@.. utmp = uᵢ₋₁ + μₛ₋₃ * yₛ₋₃ + δ₁ * yₛ₋₂ + 1 // 4 * SXₛ₋₃ + 3 // 4 * SXₛ₋₂
mul!(SXₛ₋₁, Xₛ₋₃, WikRange)
@.. utmp += 1 // 6 * SXₛ₋₁
@.. utmp -= 1 // 6 * SXₛ₋₁
mul!(SXₛ₋₁, Xₛ₋₂, WikRange)
@.. utmp += 1 // 2 * SXₛ₋₁
@.. utmp -= 1 // 2 * SXₛ₋₁
ttmp = tᵢ₋₁ + μₛ₋₃ + δ₁
integrator.f.g(Gₛ, utmp, p, ttmp)
WikRange .= 1 .* (1:length(W.dW) .== i)
# @.. @view(Xₛ₋₁[:,i]) = @view(Gₛ[:,i])
@.. Xₛ₋₁ = Gₛ * WikRange
@view(Xₛ₋₁[:, i]) .= @view(Gₛ[:, i])
end
mul!(SXₛ₋₁, Xₛ₋₁, W.dW)
@.. u += (σ - τ) * dt * yₛ₋₁ + 3 // 8 * SXₛ₋₁
Expand Down Expand Up @@ -1713,7 +1719,7 @@ end
uᵢ₋₂ = integrator.f(uᵢ₋₂, p, tᵢ₋₂)
u += dt * (σ + τ) * uᵢ₋₂

if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
Gₛ = integrator.f.g(uᵢ₋₁, p, tᵢ₋₁)
u += Gₛ .* W.dW

Expand Down Expand Up @@ -1808,7 +1814,7 @@ end
integrator.f(k, uᵢ₋₂, p, tᵢ₋₂)
@.. u += dt * (σ + τ) * k

if (W.dW isa Number) || (length(W.dW) == 1) || is_diagonal_noise(integrator.sol.prob)
if (W.dW isa Number) || is_diagonal_noise(integrator.sol.prob)
integrator.f.g(Gₛ, uᵢ₋₁, p, tᵢ₋₁)
@.. u += Gₛ * W.dW

Expand Down Expand Up @@ -1842,3 +1848,10 @@ end

integrator.u = u
end

function init_χ!(vec_χ, W)
rand!(rng(W), vec_χ)
@.. vec_χ = 2 * floor(vec_χ + 1 // 2) - 1
end

rng(W) = hasfield(typeof(W), :rng) ? W.rng : W.source.rng
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using StochasticDiffEqROCK, DiffEqDevTools, DiffEqNoiseProcess, Test, Random

# Non-diagonal SDE: 2-component system driven by 1 Brownian motion (n=2, m=1)
# du_i = μ u_i dt + σ u_i dW (same scalar Brownian motion for both components)
# Weak solution: E[u_i(T)] = u_i(0) exp(μ T)

const μ_test = -0.5
const σ_test = 0.1

f_nd_oop(u, p, t) = μ_test .* u
g_nd_oop(u, p, t) = σ_test .* reshape(u, length(u), 1)

f_nd_iip!(du, u, p, t) = (du .= μ_test .* u)
function g_nd_iip!(du, u, p, t)
for i in axes(du, 1)
du[i, 1] = σ_test * u[i]
end
end

# Analytic (Itô): u_i(t) = u_i(0) exp((μ - σ²/2)t + σ W(t))
analytic_nd(u0, p, t, W) = u0 .* exp.((μ_test - σ_test^2 / 2) * t .+ σ_test .* W[1])

u0 = [1.0, 1.0]
tspan = (0.0, 1.0)

prob_oop = SDEProblem(
SDEFunction(f_nd_oop, g_nd_oop; analytic = analytic_nd),
u0, tspan;
noise_rate_prototype = zeros(2, 1)
)
prob_iip = SDEProblem(
SDEFunction(f_nd_iip!, g_nd_iip!; analytic = analytic_nd),
u0, tspan;
noise_rate_prototype = zeros(2, 1)
)

Random.seed!(100)
dts = 1 .// 2 .^ (6:-1:2)

@testset "SROCK2 non-diagonal OOP weak convergence (order ~2)" begin
sim = test_convergence(dts, prob_oop, SROCK2(), trajectories = Int(5e4),
save_everystep = false, weak_timeseries_errors = false)
@test abs(sim.𝒪est[:weak_final] - 2.0) < 0.5
end

@testset "SROCK2 non-diagonal IIP weak convergence (order ~2)" begin
sim = test_convergence(dts, prob_iip, SROCK2(), trajectories = Int(5e4),
save_everystep = false, weak_timeseries_errors = false)
@test abs(sim.𝒪est[:weak_final] - 2.0) < 0.5
end

# KomBurSROCK2 smoke test: non-diagonal noise must not crash with DimensionMismatch
@testset "KomBurSROCK2 non-diagonal IIP does not crash" begin
prob_smoke = SDEProblem(f_nd_iip!, g_nd_iip!, u0, tspan;
noise_rate_prototype = zeros(2, 1))
sol = solve(prob_smoke, KomBurSROCK2(), dt = 0.01)
@test sol.retcode == ReturnCode.Success
end

# Smoke test: NoiseWrapper must not crash
@testset "SROCK2 NoiseWrapper does not crash" begin
prob_base = SDEProblem(f_nd_iip!, g_nd_iip!, u0, tspan;
noise_rate_prototype = zeros(2, 1))
sol_base = solve(prob_base, SROCK2(), dt = 0.01, save_noise = true)
W_wrap = NoiseWrapper(sol_base.W)
prob_wrap = SDEProblem(f_nd_iip!, g_nd_iip!, u0, tspan;
noise = W_wrap, noise_rate_prototype = zeros(2, 1))
sol_wrap = solve(prob_wrap, SROCK2(), dt = 0.01)
@test sol_wrap.retcode == ReturnCode.Success
end
6 changes: 6 additions & 0 deletions lib/StochasticDiffEqROCK/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ if TEST_GROUP == "ALL" || TEST_GROUP == "SROCKC2WeakConvergence"
include("weak_convergence/weak_srockc2.jl")
end
end

if TEST_GROUP == "ALL" || TEST_GROUP == "SROCK2NonDiagonalConvergence"
@time @safetestset "Non-Diagonal Noise Convergence Tests (SROCK2, #3188, #3170)" begin
include("convergence/nondiagonal_convergence.jl")
end
end
Loading
Loading