@@ -1213,10 +1213,27 @@ end
12131213
12141214# create a batch of pointers in device memory from a strided device array
12151215@inline function unsafe_strided_batch (strided:: DenseCuArray{T} ) where {T}
1216- batchsize = last (size (strided))
1217- stride = prod (size (strided)[1 : end - 1 ])
1218- ptrs = [pointer (strided, (i- 1 )* stride + 1 ) for i in 1 : batchsize]
1219- return CuArray (ptrs)
1216+ batch_size = last (size (strided))
1217+ batch_stride = prod (size (strided)[1 : end - 1 ])
1218+ # ptrs = [pointer(strided, (i-1)*batch_stride + 1) for i in 1:batch_size]
1219+ # fill the array on the GPU to avoid synchronous copies and support larger batch sizes
1220+ ptrs = CuArray {CuPtr{T}} (undef, batch_size)
1221+ function compute_pointers ()
1222+ i = (blockIdx (). x - 1 i32) * blockDim (). x + threadIdx (). x
1223+ grid_stride = gridDim (). x * blockDim (). x
1224+ while i <= length (ptrs)
1225+ @inbounds ptrs[i] =
1226+ reinterpret (CuPtr{T}, pointer (strided, (i - 1 i32) * batch_stride + 1 i32))
1227+ i += grid_stride
1228+ end
1229+ return
1230+ end
1231+ kernel = @cuda launch = false compute_pointers ()
1232+ config = launch_configuration (kernel. fun)
1233+ threads = min (config. threads, batch_size)
1234+ blocks = min (config. blocks, cld (batch_size, threads))
1235+ @cuda threads blocks compute_pointers ()
1236+ return ptrs
12201237end
12211238
12221239# # (GE) general matrix-matrix multiplication grouped batched
0 commit comments