Skip to content

Commit ab3b340

Browse files
Add promote_eltype for CuArray / ROCArray / MtlArray
`ArrayInterface.promote_eltype` only had a method for plain `Array{T, N}` (with an explicit "no generic fallback is given" note in the docstring). Downstream packages that pass GPU array types through `promote_eltype` therefore hit a `MethodError` — for example, SciML/NonlinearSolve.jl#910 tripped this on `test/cuda_tests.jl:33 "GeneralizedFirstOrderAlgorithm"` when deriving a Dual-eltype wrapper-signature array type for `CuArray{Float32}`: MethodError: no method matching promote_eltype( ::Type{CuArray{Float32, 1, CUDA.DeviceMemory}}, ::Type{ForwardDiff.Dual{Tag{NonlinearSolveBase.NonlinearSolveTag, Float32}, Float32, 1}}) Adds the obvious eltype-swapping method in each GPU extension, preserving the non-eltype type parameters (`M` for `CuArray` memory kind, `B` for `ROCArray` buffer type, `S` for `MtlArray` storage mode): ArrayInterface.promote_eltype( ::Type{<:CuArray{T, N, M}}, ::Type{T2} ) where {T, N, M, T2} = CuArray{promote_type(T, T2), N, M} ArrayInterface.promote_eltype( ::Type{<:ROCArray{T, N, B}}, ::Type{T2} ) where {T, N, B, T2} = ROCArray{promote_type(T, T2), N, B} ArrayInterface.promote_eltype( ::Type{<:MtlArray{T, N, S}}, ::Type{T2} ) where {T, N, S, T2} = MtlArray{promote_type(T, T2), N, S} Bumps patch version 7.23.0 → 7.24.0 so downstream packages can compat-bound the new method. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 192a325 commit ab3b340

File tree

4 files changed

+19
-1
lines changed

4 files changed

+19
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "7.23.0"
3+
version = "7.24.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/ArrayInterfaceAMDGPUExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,10 @@ end
1212

1313
ArrayInterface.device(::Type{<:AMDGPU.ROCArray}) = ArrayInterface.GPU()
1414

15+
function ArrayInterface.promote_eltype(
16+
::Type{<:AMDGPU.ROCArray{T, N, B}}, ::Type{T2}
17+
) where {T, N, B, T2}
18+
return AMDGPU.ROCArray{promote_type(T, T2), N, B}
19+
end
20+
1521
end # module

ext/ArrayInterfaceCUDAExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,10 @@ end
1313

1414
ArrayInterface.device(::Type{<:CUDA.CuArray}) = ArrayInterface.GPU()
1515

16+
function ArrayInterface.promote_eltype(
17+
::Type{<:CUDA.CuArray{T, N, M}}, ::Type{T2}
18+
) where {T, N, M, T2}
19+
return CUDA.CuArray{promote_type(T, T2), N, M}
20+
end
21+
1622
end # module

ext/ArrayInterfaceMetalExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,10 @@ end
1212

1313
ArrayInterface.device(::Type{<:Metal.MtlArray}) = ArrayInterface.GPU()
1414

15+
function ArrayInterface.promote_eltype(
16+
::Type{<:Metal.MtlArray{T, N, S}}, ::Type{T2}
17+
) where {T, N, S, T2}
18+
return Metal.MtlArray{promote_type(T, T2), N, S}
19+
end
20+
1521
end # module

0 commit comments

Comments
 (0)