Skip to content

Commit 0aa6afd

Browse files
committed
fix CUDA permutedims
1 parent ba699cb commit 0aa6afd

File tree

1 file changed

+0
-17
lines changed

1 file changed

+0
-17
lines changed

src/cuda.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,4 @@ end
2727

2828
function togpu(t::LabeledTensor)
2929
LabeledTensor(CuArray(t.array), t.labels)
30-
end
31-
32-
function GPUArrays.genperm(I::NTuple{N}, perm::NTuple{N}) where N
33-
ntuple(d-> (@inbounds return I[perm[d]]), Val(N))
34-
end
35-
36-
function LinearAlgebra.permutedims!(dest::GPUArrays.AbstractGPUArray, src::GPUArrays.AbstractGPUArray, perm)
37-
perm isa Tuple || (perm = Tuple(perm))
38-
size_dest = size(dest)
39-
size_src = size(src)
40-
CUDA.gpu_call(vec(dest), vec(src), perm; name="permutedims!") do ctx, dest, src, perm
41-
i = @linearidx src
42-
I = l2c(size_src, i)
43-
@inbounds dest[c2l(size_dest, GPUArrays.genperm(I, perm))] = src[i]
44-
return
45-
end
46-
return reshape(dest, size(dest))
4730
end

0 commit comments

Comments
 (0)