Skip to content

Commit 887a34f

Browse files
Harsh SinghHarsh Singh
authored andcommitted
Fix init_χ! to be type-correct and GPU-compatible
- Add Random to [deps] so rand! is explicitly available so it works correctly for non-Float64 arrays and GPU arrays
1 parent 157e080 commit 887a34f

3 files changed

Lines changed: 5 additions & 4 deletions

File tree

lib/StochasticDiffEqROCK/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "1.1.0"
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
10+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
1112
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1213
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -31,6 +32,7 @@ LinearAlgebra = "1.10"
3132
MuladdMacro = "0.2.4"
3233
OrdinaryDiffEqCore = "3.22"
3334
Pkg = "1"
35+
Random = "<0.0.1, 1"
3436
RecursiveArrayTools = "2, 3"
3537
Reexport = "0.2, 1.0"
3638
SciMLBase = "2.146"

lib/StochasticDiffEqROCK/src/StochasticDiffEqROCK.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import MuladdMacro: @muladd
2020
import SciMLBase
2121

2222
using LinearAlgebra
23+
using Random: rand!
2324
using StaticArrays
2425
using RecursiveArrayTools
2526

lib/StochasticDiffEqROCK/src/perform_step/SROCK_perform_step.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,10 +1847,8 @@ end
18471847
end
18481848

18491849
function init_χ!(vec_χ, W)
1850-
r = rng(W)
1851-
for i in eachindex(vec_χ)
1852-
vec_χ[i] = 2 * floor(rand(r) + 0.5) - 1
1853-
end
1850+
rand!(rng(W), vec_χ)
1851+
@.. vec_χ = 2 * floor(vec_χ + 1 // 2) - 1
18541852
end
18551853

18561854
rng(W) = hasfield(typeof(W), :rng) ? W.rng : W.source.rng

0 commit comments

Comments
 (0)