Skip to content

Commit 5184395

Browse files
Merge pull request #483 from ChrisRackauckas-Claude/add-promote-eltype-gpu
Add promote_eltype for CuArray / ROCArray / MtlArray
2 parents 192a325 + ab3b340 commit 5184395

File tree

4 files changed

+19
-1
lines changed

4 files changed

+19
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "7.23.0"
3+
version = "7.24.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/ArrayInterfaceAMDGPUExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,10 @@ end
1212

1313
ArrayInterface.device(::Type{<:AMDGPU.ROCArray}) = ArrayInterface.GPU()
1414

15+
function ArrayInterface.promote_eltype(
16+
::Type{<:AMDGPU.ROCArray{T, N, B}}, ::Type{T2}
17+
) where {T, N, B, T2}
18+
return AMDGPU.ROCArray{promote_type(T, T2), N, B}
19+
end
20+
1521
end # module

ext/ArrayInterfaceCUDAExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,10 @@ end
1313

1414
ArrayInterface.device(::Type{<:CUDA.CuArray}) = ArrayInterface.GPU()
1515

16+
function ArrayInterface.promote_eltype(
17+
::Type{<:CUDA.CuArray{T, N, M}}, ::Type{T2}
18+
) where {T, N, M, T2}
19+
return CUDA.CuArray{promote_type(T, T2), N, M}
20+
end
21+
1622
end # module

ext/ArrayInterfaceMetalExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,10 @@ end
1212

1313
ArrayInterface.device(::Type{<:Metal.MtlArray}) = ArrayInterface.GPU()
1414

15+
function ArrayInterface.promote_eltype(
16+
::Type{<:Metal.MtlArray{T, N, S}}, ::Type{T2}
17+
) where {T, N, S, T2}
18+
return Metal.MtlArray{promote_type(T, T2), N, S}
19+
end
20+
1521
end # module

0 commit comments

Comments
 (0)