diff --git a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl index e6e9a07d884..e8cb7cf5e1f 100644 --- a/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl +++ b/lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl @@ -23,7 +23,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, get_fsalfirstlast, generic_solver_docstring, _bool_to_ADType, _process_AD_choice, _ode_interpolant, _ode_interpolant!, has_stiff_interpolation, - _ode_addsteps!, DerivativeOrderNotPossibleError + _ode_addsteps!, DerivativeOrderNotPossibleError, find_algebraic_vars_eqs using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache using TruncatedStacktraces: @truncate_stacktrace diff --git a/lib/OrdinaryDiffEqBDF/src/bdf_caches.jl b/lib/OrdinaryDiffEqBDF/src/bdf_caches.jl index 80b92a6f70c..5d71d3995ea 100644 --- a/lib/OrdinaryDiffEqBDF/src/bdf_caches.jl +++ b/lib/OrdinaryDiffEqBDF/src/bdf_caches.jl @@ -62,7 +62,7 @@ function alg_cache( atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) algebraic_vars = f.mass_matrix === I ? nothing : - [all(iszero, x) for x in eachcol(f.mass_matrix)] + find_algebraic_vars_eqs(f.mass_matrix)[1] eulercache = ImplicitEulerCache( u, uprev, uprev2, fsalfirst, atmp, nlsolver, algebraic_vars, alg.step_limiter! diff --git a/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreSparseArraysExt.jl b/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreSparseArraysExt.jl index 6b57562971f..6a311b0792b 100644 --- a/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreSparseArraysExt.jl +++ b/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreSparseArraysExt.jl @@ -1,7 +1,7 @@ module OrdinaryDiffEqCoreSparseArraysExt using SparseArrays: SparseMatrixCSC -import OrdinaryDiffEqCore: _isdiag +import OrdinaryDiffEqCore: _isdiag, find_algebraic_vars_eqs # Efficient O(nnz) isdiag check for sparse matrices. # Standard isdiag is O(n²) which is prohibitively slow for large sparse matrices. @@ -22,4 +22,28 @@ function _isdiag(A::SparseMatrixCSC) return true end +""" + find_algebraic_vars_eqs(M::SparseMatrixCSC) + +O(nnz) detection of algebraic variables (zero columns) and equations (zero rows). +""" +function find_algebraic_vars_eqs(M::SparseMatrixCSC) + n_cols = size(M, 2) + n_rows = size(M, 1) + + algebraic_vars = fill(true, n_cols) + algebraic_eqs = fill(true, n_rows) + + @inbounds for j in 1:n_cols + for idx in M.colptr[j]:(M.colptr[j + 1] - 1) + if !iszero(M.nzval[idx]) + algebraic_vars[j] = false + algebraic_eqs[M.rowval[idx]] = false + end + end + end + + return algebraic_vars, algebraic_eqs +end + end diff --git a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl index fbfd8a743b7..d774a540d2d 100644 --- a/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl +++ b/lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl @@ -15,7 +15,7 @@ import Logging: @logmsg, LogLevel using MuladdMacro: @muladd -using LinearAlgebra: opnorm, I, UniformScaling, diag, rank, isdiag +using LinearAlgebra: opnorm, I, UniformScaling, diag, rank, isdiag, Diagonal import PrecompileTools diff --git a/lib/OrdinaryDiffEqCore/src/misc_utils.jl b/lib/OrdinaryDiffEqCore/src/misc_utils.jl index a5c672cba95..75c75f3d636 100644 --- a/lib/OrdinaryDiffEqCore/src/misc_utils.jl +++ b/lib/OrdinaryDiffEqCore/src/misc_utils.jl @@ -131,6 +131,30 @@ end # Sparse specialization is provided in OrdinaryDiffEqCoreSparseArraysExt _isdiag(A::AbstractMatrix) = isdiag(A) +""" + find_algebraic_vars_eqs(M) + +Find algebraic variables (zero columns) and algebraic equations (zero rows) from mass matrix. +Returns `(algebraic_vars, algebraic_eqs)` as boolean arrays (true = algebraic). + +Works on CPU and GPU arrays. Sparse specialization (O(nnz)) is provided in +OrdinaryDiffEqCoreSparseArraysExt. +""" +function find_algebraic_vars_eqs(M::Diagonal) + _idxs = map(iszero, diag(M)) + return _idxs, _idxs +end + +function find_algebraic_vars_eqs(M::AbstractMatrix) + algebraic_vars = vec(all(iszero, M, dims = 1)) + algebraic_eqs = vec(all(iszero, M, dims = 2)) + return algebraic_vars, algebraic_eqs +end + +function find_algebraic_vars_eqs(M::AbstractSciMLOperator) + return find_algebraic_vars_eqs(convert(AbstractMatrix, M)) +end + isnewton(::Any) = false function _bool_to_ADType(::Val{true}, ::Val{CS}, _) where {CS} diff --git a/lib/OrdinaryDiffEqNonlinearSolve/test/sparse_algebraic_detection_tests.jl b/lib/OrdinaryDiffEqCore/test/algebraic_vars_detection_tests.jl similarity index 86% rename from lib/OrdinaryDiffEqNonlinearSolve/test/sparse_algebraic_detection_tests.jl rename to lib/OrdinaryDiffEqCore/test/algebraic_vars_detection_tests.jl index 894b87af451..0c88682844c 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/test/sparse_algebraic_detection_tests.jl +++ b/lib/OrdinaryDiffEqCore/test/algebraic_vars_detection_tests.jl @@ -1,6 +1,7 @@ using Test using SparseArrays -using OrdinaryDiffEqNonlinearSolve: find_algebraic_vars_eqs +using OrdinaryDiffEqCore: find_algebraic_vars_eqs +using LinearAlgebra @testset "Sparse Algebraic Detection Performance" begin # Test 1: Correctness - results should match between sparse and dense methods @@ -79,4 +80,14 @@ using OrdinaryDiffEqNonlinearSolve: find_algebraic_vars_eqs @test vars == [false, true] @test eqs == [false, true] end + + # Test 4: Test Diagonal case + @testset "Test Diagonal cast" begin + M_diag = Diagonal([1.0, 0.0, 1.0, 1.0, 0.0]) + vars, eqs = find_algebraic_vars_eqs(M_diag) + @test vars == [false, true, false, false, true] + @test eqs == [false, true, false, false, true] + # compare to dense + @test find_algebraic_vars_eqs(M_diag) == find_algebraic_vars_eqs(collect(M_diag)) + end end diff --git a/lib/OrdinaryDiffEqCore/test/runtests.jl b/lib/OrdinaryDiffEqCore/test/runtests.jl index 2a9ad4e8615..f5a294c3d20 100644 --- a/lib/OrdinaryDiffEqCore/test/runtests.jl +++ b/lib/OrdinaryDiffEqCore/test/runtests.jl @@ -24,4 +24,5 @@ end # Functional tests if TEST_GROUP == "Core" || TEST_GROUP == "ALL" @time @safetestset "Sparse isdiag Performance" include("sparse_isdiag_tests.jl") + @time @safetestset "Algebraic Vars Detection" include("algebraic_vars_detection_tests.jl") end diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 51608029b91..a976e2e28bf 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -14,7 +14,7 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step) tf.p = p alg = unwrap_alg(integrator, true) - autodiff_alg = ADTypes.dense_ad(gpu_safe_autodiff(alg_autodiff(alg), u)) + autodiff_alg = gpu_safe_autodiff(ADTypes.dense_ad(alg_autodiff(alg)), u) # Convert t to eltype(dT) if using ForwardDiff, to make FunctionWrappers work t = autodiff_alg isa AutoForwardDiff ? convert(eltype(dT), t) : t @@ -59,7 +59,7 @@ function calc_tderivative(integrator, cache) tf.u = uprev tf.p = p - autodiff_alg = ADTypes.dense_ad(gpu_safe_autodiff(alg_autodiff(alg), u)) + autodiff_alg = gpu_safe_autodiff(ADTypes.dense_ad(alg_autodiff(alg)), u) if alg_autodiff isa AutoFiniteDiff autodiff_alg = SciMLBase.@set autodiff_alg.dir = diffdir(integrator) @@ -403,6 +403,12 @@ function jacobian2W!( else @.. broadcast = false @view(W[idxs]) = muladd(λ, invdtgamma, @view(J[idxs])) end + elseif is_sparse(W) && !ArrayInterface.fast_scalar_indexing(nonzeros(W)) + # Sparse GPU arrays (e.g. CuSparseMatrixCSC/CSR) don't support broadcasting. + # ArrayInterface.fast_scalar_indexing is not specialized for AbstractGPUSparseArray, + # so we detect them by checking if the underlying nonzeros storage is a GPU array. + # we then fall back to allocating matrix arithmetic + copyto!(W, J - invdtgamma * mass_matrix) else @.. broadcast = false W = muladd(-mass_matrix, invdtgamma, J) end diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl index 60cd480372f..f72e8189de2 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl @@ -50,7 +50,7 @@ using OrdinaryDiffEqCore: resize_nlsolver!, _initialize_dae!, FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence, NLStatus, MethodType, alg_order, error_constant, - alg_extrapolates, resize_J_W!, has_autodiff + alg_extrapolates, resize_J_W!, has_autodiff, find_algebraic_vars_eqs import OrdinaryDiffEqCore: _initialize_dae!, _default_dae_init!, isnewton, get_W, isfirstcall, isfirststage, diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl index bee55d5f6b6..b2c4ba9cc5e 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl @@ -1,46 +1,3 @@ -# Efficient algebraic variable/equation detection for sparse mass matrices. -# O(nnz) instead of O(n²) for sparse matrices. -""" - find_algebraic_vars_eqs(M::SparseMatrixCSC) - -Find algebraic variables (zero columns) and algebraic equations (zero rows) from mass matrix. -Returns (algebraic_vars::Vector{Bool}, algebraic_eqs::Vector{Bool}). - -For sparse matrices, uses O(nnz) traversal of CSC structure instead of O(n²) iteration. -""" -function find_algebraic_vars_eqs(M::SparseMatrixCSC) - n_cols = size(M, 2) - n_rows = size(M, 1) - - # Initialize all as algebraic (true = zero column/row) - algebraic_vars = fill(true, n_cols) - algebraic_eqs = fill(true, n_rows) - - # Mark columns/rows with non-zero values as differential (false) - @inbounds for j in 1:n_cols - for idx in M.colptr[j]:(M.colptr[j + 1] - 1) - if !iszero(M.nzval[idx]) - algebraic_vars[j] = false - algebraic_eqs[M.rowval[idx]] = false - end - end - end - - return algebraic_vars, algebraic_eqs -end - -# Fallback for non-sparse matrices (original behavior) -function find_algebraic_vars_eqs(M::AbstractMatrix) - algebraic_vars = vec(all(iszero, M, dims = 1)) - algebraic_eqs = vec(all(iszero, M, dims = 2)) - return algebraic_vars, algebraic_eqs -end - -# Handle SciMLOperators (e.g., MatrixOperator) by converting to matrix -function find_algebraic_vars_eqs(M::AbstractSciMLOperator) - return find_algebraic_vars_eqs(convert(AbstractMatrix, M)) -end - # Optimized tolerance checking that avoids allocations @inline function check_dae_tolerance(integrator, err, abstol, t, ::Val{true}) if abstol isa Number diff --git a/lib/OrdinaryDiffEqNonlinearSolve/test/runtests.jl b/lib/OrdinaryDiffEqNonlinearSolve/test/runtests.jl index d2f6725597c..0839c157818 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/test/runtests.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/test/runtests.jl @@ -20,7 +20,6 @@ end # Run functional tests if TEST_GROUP ∉ ("QA", "ModelingToolkit") @time @safetestset "Newton Tests" include("newton_tests.jl") - @time @safetestset "Sparse Algebraic Detection" include("sparse_algebraic_detection_tests.jl") @time @safetestset "Sparse DAE Initialization" include("sparse_dae_initialization_tests.jl") @time @safetestset "Linear Nonlinear Solver Tests" include("linear_nonlinear_tests.jl") @time @safetestset "Linear Solver Tests" include("linear_solver_tests.jl") diff --git a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl index 23913546d00..9b269d923e8 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/OrdinaryDiffEqRosenbrock.jl @@ -13,7 +13,8 @@ import OrdinaryDiffEqCore: alg_order, alg_adaptive_order, isWmethod, isfsal, _un calculate_residuals, has_stiff_interpolation, ODEIntegrator, resize_non_user_cache!, _ode_addsteps!, full_cache, DerivativeOrderNotPossibleError, _bool_to_ADType, - _process_AD_choice, LinearAliasSpecifier, copyat_or_push! + _process_AD_choice, LinearAliasSpecifier, copyat_or_push!, + find_algebraic_vars_eqs using MuladdMacro, FastBroadcast, RecursiveArrayTools import MacroTools: namify using MacroTools: @capture diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index d99fd25496c..7db92ffbb8c 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -179,7 +179,7 @@ function alg_cache( ) algebraic_vars = f.mass_matrix === I ? nothing : - [all(iszero, x) for x in eachcol(f.mass_matrix)] + find_algebraic_vars_eqs(f.mass_matrix)[1] return Rosenbrock23Cache( u, uprev, k₁, k₂, k₃, du1, du2, f₁, @@ -239,7 +239,7 @@ function alg_cache( ) algebraic_vars = f.mass_matrix === I ? nothing : - [all(iszero, x) for x in eachcol(f.mass_matrix)] + find_algebraic_vars_eqs(f.mass_matrix)[1] return Rosenbrock32Cache( u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W, diff --git a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl index 37eb64be2c2..c22a23b8bd3 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/OrdinaryDiffEqSDIRK.jl @@ -14,7 +14,8 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!, trivial_limiter!, _ode_interpolant!, isesdirk, issplit, ssp_coefficient, get_fsalfirstlast, generic_solver_docstring, - _bool_to_ADType, _process_AD_choice, current_extrapolant! + _bool_to_ADType, _process_AD_choice, current_extrapolant!, + find_algebraic_vars_eqs using TruncatedStacktraces: @truncate_stacktrace using MuladdMacro, MacroTools, FastBroadcast, RecursiveArrayTools using SciMLBase: SplitFunction diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl index 978120e3b86..70f63fd4708 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl @@ -35,7 +35,7 @@ function alg_cache( recursivefill!(atmp, false) algebraic_vars = f.mass_matrix === I ? nothing : - [all(iszero, x) for x in eachcol(f.mass_matrix)] + find_algebraic_vars_eqs(f.mass_matrix)[1] return ImplicitEulerCache( u, uprev, uprev2, fsalfirst, atmp, nlsolver, algebraic_vars, alg.step_limiter! diff --git a/test/gpu/Project.toml b/test/gpu/Project.toml index aecabf6d1aa..478a27e51ad 100644 --- a/test/gpu/Project.toml +++ b/test/gpu/Project.toml @@ -1,16 +1,19 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" -FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" OrdinaryDiffEqBDF = "6ad6398a-0878-4a85-9266-38940aa047c8" +OrdinaryDiffEqFIRK = "5960d6e9-dd7a-4743-88e7-cf307b64f125" OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8" OrdinaryDiffEqRKIP = "a4daff8c-1d43-4ff3-8eff-f78720aeecdc" OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce" +OrdinaryDiffEqSDIRK = "2d112036-d095-4a1e-ab9a-08536f3ecdbf" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" @@ -23,17 +26,20 @@ OrdinaryDiffEqRosenbrock = {path = "../../lib/OrdinaryDiffEqRosenbrock"} [compat] Adapt = "4" -CUDA = "4, 5" +CUDA = "5" +CUDSS = "0.6.7" ComponentArrays = "0.15" DiffEqBase = "6.182" -FastBroadcast = "0.3" FFTW = "1.8" +FastBroadcast = "0.3" FillArrays = "1" OrdinaryDiffEq = "6" OrdinaryDiffEqBDF = "1" +OrdinaryDiffEqFIRK = "1" OrdinaryDiffEqNonlinearSolve = "1" OrdinaryDiffEqRKIP = "1" OrdinaryDiffEqRosenbrock = "1" +OrdinaryDiffEqSDIRK = "1" RecursiveArrayTools = "3" SciMLBase = "2.99" SciMLOperators = "1.3" diff --git a/test/gpu/simple_dae.jl b/test/gpu/simple_dae.jl index 1b39a42cd1a..44803183385 100644 --- a/test/gpu/simple_dae.jl +++ b/test/gpu/simple_dae.jl @@ -1,10 +1,17 @@ using OrdinaryDiffEqRosenbrock +using OrdinaryDiffEqSDIRK +using OrdinaryDiffEqBDF +using OrdinaryDiffEqFIRK using OrdinaryDiffEqNonlinearSolve using CUDA using LinearAlgebra using Adapt using SparseArrays using Test +using CUDSS +using Printf +using OrdinaryDiffEqNonlinearSolve.LinearSolve: KrylovJL_GMRES + #= du[1] = -u[1] @@ -24,44 +31,266 @@ p = [ -1 1 0 -1 ] -# mass_matrix = [1 0 0 0 -# 0 1 0 0 -# 0 0 0 0 -# 0 0 0 0] mass_matrix = Diagonal([1, 1, 0, 0]) jac_prototype = sparse(map(x -> iszero(x) ? 0.0 : 1.0, p)) u0 = [1.0, 1.0, 0.5, 0.5] # force init -odef = ODEFunction(dae!, mass_matrix = mass_matrix, jac_prototype = jac_prototype) - tspan = (0.0, 5.0) + +# CPU reference solution (Rodas5P) +odef = ODEFunction(dae!, mass_matrix = mass_matrix, jac_prototype = jac_prototype) prob = ODEProblem(odef, u0, tspan, p) -sol = solve(prob, Rodas5P()) - -# gpu version -mass_matrix_d = adapt(CuArray, mass_matrix) - -# TODO: jac_prototype fails -# jac_prototype_d = adapt(CuArray, jac_prototype) -# jac_prototype_d = CUDA.CUSPARSE.CuSparseMatrixCSR(jac_prototype) -jac_prototype_d = nothing - -u0_d = adapt(CuArray, u0) -p_d = adapt(CuArray, p) -odef_d = ODEFunction(dae!, mass_matrix = mass_matrix_d, jac_prototype = jac_prototype_d) -prob_d = ODEProblem(odef_d, u0_d, tspan, p_d) -sol_d = solve(prob_d, Rodas5P()) - -@testset "Test constraints in GPU sol" begin - for t in sol_d.t - u = Vector(sol_d(t)) - @test isapprox(u[1] + u[2], u[3]; atol = 1.0e-6) - @test isapprox(-u[1] + u[2], u[4]; atol = 1.0e-6) - end +sol_ref = solve(prob, Rodas5P()) +sol_ref_krylov = solve(prob, Rodas5P(linsolve = KrylovJL_GMRES())) + +# GPU data -- we use F64 for higher accuracy for comparison +u0_d = adapt(CuArray{Float64}, u0) +p_d = adapt(CuArray{Float64}, p) + +# dense or sparse mass matrix does not work yet! +mass_matrix_d = cu(mass_matrix) + +# ── Jacobian prototype options ──────────────────────────────────────────────── +jac_prots = [ + "none" => nothing, + "CSC" => CUDA.CUSPARSE.CuSparseMatrixCSC(jac_prototype), + "CSR" => CUDA.CUSPARSE.CuSparseMatrixCSR(jac_prototype), +] + +# ── Solver definitions ──────────────────────────────────────────────────────── +# +# Each entry: SolverType => (; tol overrides...) +# method_tol = (; atol, rtol) – cpu solver vs Rodas5P ref (all jac); some methods are poor fits for DAEs +# method_csc_tol = (; atol, rtol) – cpu solver vs Rodas5P ref (CSC/Krylov path only) +# gpu_tol = (; atol, rtol) – gpu vs cpu (none/CSR fallback) +# csc_tol = (; atol, rtol) – gpu vs cpu (CSC jac_prot) +# csr_tol = (; atol, rtol) – gpu vs cpu (CSR jac_prot) +# Only specify fields that deviate from the defaults. + +const DEFAULT_METHOD_TOL = (; atol = 1.0e-2, rtol = 1.0e-2) +const DEFAULT_GPU_TOL = (; atol = 1.0e-4, rtol = 1.0e-4) + +solvers = [ + # ── Rosenbrock 2nd order ── + Rosenbrock23 => (; + csc_tol = (; atol = 2.0e-4, rtol = 5.0e-4), + ), + Rosenbrock32 => (; + method_tol = (; atol = 2.0e-2, rtol = 1.5), + csc_tol = (; atol = Inf, rtol = Inf), + ), + ROS2 => (;), + ROS2PR => (; + csc_tol = (; atol = 3.0e-4, rtol = 4.0e-4), + ), + ROS2S => (; + csc_tol = (; atol = 7.0e-4, rtol = 1.2e-3), + ), + # ── Rosenbrock 3rd order ── + ROS3 => (;), + ROS3PR => (; + method_tol = (; atol = 0.25, rtol = 15.0), + ), + ROS3PRL => (; + csc_tol = (; atol = 2.0e-3, rtol = 4.0e-3), + ), + ROS3PRL2 => (; + csc_tol = (; atol = 3.0e-3, rtol = 4.0e-3), + ), + ROS3P => (; + method_tol = (; atol = 0.25, rtol = 15.0), + ), + Rodas3 => (; + csc_tol = (; atol = 2.0e-3, rtol = 3.0e-3), + ), + # Rodas23W() # scalar indexing, requires large changes to `calculate_interpoldiff!` + # Rodas3P() # scalar indexing + Scholz4_7 => (;), + # ── Rosenbrock 4th order ── + ROS34PW1a => (; + method_tol = (; atol = 0.25, rtol = 5.0), + ), + ROS34PW1b => (; + method_tol = (; atol = 0.25, rtol = 5.0), + ), + ROS34PW2 => (; + csc_tol = (; atol = 1.2e-3, rtol = 2.2e-3), + ), + ROS34PW3 => (;), + ROS34PRw => (; + csc_tol = (; atol = 6.0e-4, rtol = 1.0e-3), + ), + RosShamp4 => (;), + Veldd4 => (;), + Velds4 => (;), + GRK4T => (;), + GRK4A => (;), + Ros4LStab => (;), + Rodas4 => (;), + Rodas42 => (;), + Rodas4P => (;), + Rodas4P2 => (;), + ROK4a => (;), + # ── Rosenbrock 5th order ── + Rodas5 => (;), + Rodas5P => (;), + Rodas5Pe => (;), + Rodas5Pr => (;), + # ── Rosenbrock 6th order ── + Rodas6P => (;), + # ── SDIRK (don't include fixed time step which need explicit dt) ── + ImplicitEuler => (;), + Trapezoid => (; + csc_tol = (; atol = 8.0e-3, rtol = 2.5e-2), + ), + SDIRK2 => (; + csc_tol = (; atol = 1.5e-4, rtol = 3.5e-4), + ), + Cash4 => (; + csc_tol = (; atol = 5.0e-3, rtol = 8.0e-3), + ), + Hairer4 => (; + csc_tol = (; atol = 8.0e-3, rtol = 6.0e-2), + ), + Hairer42 => (; + csc_tol = (; atol = 7.0e-3, rtol = 5.0e-2), + ), + # ── BDF ── + ABDF2 => (; + method_csc_tol = (; atol = Inf, rtol = Inf), # ABDF2 + Krylov diverges vs Rodas5P + Krylov + csc_tol = (; atol = 2.0e-4, rtol = 7.0e-4), + ), + QNDF1 => (;), + QNDF2 => (; + csc_tol = (; atol = 4.0e-3, rtol = 5.0e-3), + ), + # QNDF() # 🔧 DeviceMemory error in LinAlg + QBDF1 => (; + method_tol = (; atol = 2.0e-2, rtol = 0.2), + ), + QBDF2 => (; + gpu_tol = (; atol = 2.0e-3, rtol = 2.0e-3), + csc_tol = (; atol = 4.0e-3, rtol = 7.0e-3), + csr_tol = (; atol = 2.0e-3, rtol = 2.0e-3), + ), + # QBDF() # DeviceMemory error in LinAlg + # FBDF() # scalar indexing -> needs extensive work on reinitFBDF! + # ── FIRK -> all need substantial changes to `perform_step!` for FIRK methods ── + # RadauIIA3() # scalar indexing, ComplexF64 sparse unsupported + # RadauIIA5() # scalar indexing, ComplexF64 sparse unsupported + # RadauIIA9() # scalar indexing, ComplexF64 sparse unsupported + # AdaptiveRadau() # scalar indexing, ComplexF64 sparse unsupported +] + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +function _get_method_tol(overrides, jac_name) + jac_name == "CSC" && hasproperty(overrides, :method_csc_tol) && return overrides.method_csc_tol + hasproperty(overrides, :method_tol) && return overrides.method_tol + return DEFAULT_METHOD_TOL +end + +function _get_tol(overrides, jac_name) + _tol_keys = Dict("none" => :gpu_tol, "CSC" => :csc_tol, "CSR" => :csr_tol) + key = _tol_keys[jac_name] + hasproperty(overrides, key) && return getproperty(overrides, key) + # gpu_tol acts as default for all GPU combos + hasproperty(overrides, :gpu_tol) && return overrides.gpu_tol + return DEFAULT_GPU_TOL end -@testset "Compare GPU to CPU solution" begin +function maxerrs(sol_a, sol_b) + max_abs = 0.0 + max_rel = 0.0 for t in tspan[begin]:0.1:tspan[end] - @test isapprox(Vector(sol_d(t)), sol(t); rtol = 1.0e-4) + a = Vector(sol_a(t)) + b = Vector(sol_b(t)) + diff = abs.(a - b) + ref = abs.(b) + max_abs = max(max_abs, maximum(diff)) + max_rel = max(max_rel, maximum(diff ./ max.(ref, eps()))) end + return max_abs, max_rel +end + +function run_dae_tests() + results = Any[] + + for (sv, overrides) in solvers, (jn, jp) in jac_prots + println("Test $sv with prototype $jn") + sn = string(sv) + gtol = _get_tol(overrides, jn) + mtol = _get_method_tol(overrides, jn) + + # CSC will always fall back to krylov so the ref solution should do to + krylov = (jn == "CSC") + _sol_ref = krylov ? sol_ref_krylov : sol_ref + + # CPU: this solver vs reference + cpu_alg = krylov ? sv(linsolve = KrylovJL_GMRES()) : sv() + sol_cpu = solve(prob, cpu_alg) + cpu_abs, cpu_rel = maxerrs(sol_cpu, _sol_ref) + + # GPU solve + odef_d = ODEFunction(dae!, mass_matrix = mass_matrix_d, jac_prototype = jp) + prob_d = ODEProblem(odef_d, u0_d, tspan, p_d) + sol_d = solve(prob_d, sv()) + + # GPU vs CPU (same solver) + gpu_abs, gpu_rel = maxerrs(sol_d, sol_cpu) + + method_passed = (cpu_abs < mtol.atol) || (cpu_rel < mtol.rtol) + gpu_passed = (gpu_abs < gtol.atol) || (gpu_rel < gtol.rtol) + passed = method_passed && gpu_passed + push!( + results, (; + solver = sn, jac = jn, cpu_abs, cpu_rel, gpu_abs, gpu_rel, + gtol, mtol, method_passed, gpu_passed, passed, error = "", + ) + ) + @test passed + end + + return results +end + +function show_results(results) + function _fmt(val; threshold = 1.0e-3) + isnan(val) && return " --- " + s = @sprintf("%.2e", val) + if val > threshold + return "\e[31m$s\e[0m" + end + return s + end + + println(rpad("Solver / Jac", 30), "cpu_abs cpu_rel gpu_abs gpu_rel status") + println("-"^85) + + for r in results + label = rpad("$(r.solver) / $(r.jac)", 30) + if r.error != "" + printstyled(label, "ERROR: ", r.error, "\n"; color = :yellow) + else + print(label) + print(_fmt(r.cpu_abs, threshold = 1.0e-2), " ", _fmt(r.cpu_rel, threshold = 1.0e-2), " ") + print(_fmt(r.gpu_abs), " ", _fmt(r.gpu_rel), " ") + if r.passed + printstyled("✓\n"; color = :green) + elseif !r.method_passed && !r.gpu_passed + printstyled("✗ method+gpu\n"; color = :red) + elseif !r.method_passed + printstyled("✗ method\n"; color = :red) + else + printstyled("✗ gpu\n"; color = :red) + end + end + end + return println("-"^85) +end + +@testset "GPU DAE solver compatibility" begin + global results + results = run_dae_tests() end +show_results(results)