@@ -396,89 +396,3 @@ for op in (:(+), :(-))
396396 @eval Base.$ op (A:: $TypeA , B:: $TypeB ) where {T <: CublasFloat } = geam ($ transa (T), $ transb (T), one (T), $ (unwrapa (:A )), $ (op)(one (T)), $ (unwrapb (:B )))
397397 end
398398end
399-
400- # Kronecker product
401- function LinearAlgebra. kron! (C:: CuMatrix{TC} , A:: CuMatrix{TA} , B:: CuMatrix{TB} ) where {TA,TB,TC}
402-
403- function _kron_mat_kernelA! (C, A, B, m, n, p, q)
404- index_i = (blockIdx (). x - 1 ) * blockDim (). x + threadIdx (). x
405- index_j = (blockIdx (). y - 1 ) * blockDim (). y + threadIdx (). y
406-
407- stride_i = blockDim (). x * gridDim (). x
408- stride_j = blockDim (). y * gridDim (). y
409-
410- index_i > m && return
411- index_j > n && return
412-
413- for i in index_i: stride_i: m
414- for j in index_j: stride_j: n
415- for k in 1 : p
416- for l in 1 : q
417- @inbounds C[(i- 1 )* p+ k, (j- 1 )* q+ l] = A[i,j] * B[k,l]
418- end
419- end
420- end
421- end
422- return nothing
423- end
424-
425- function _kron_mat_kernelB! (C, A, B, m, n, p, q)
426- index_p = (blockIdx (). x - 1 ) * blockDim (). x + threadIdx (). x
427- index_q = (blockIdx (). y - 1 ) * blockDim (). y + threadIdx (). y
428-
429- stride_p = blockDim (). x * gridDim (). x
430- stride_q = blockDim (). y * gridDim (). y
431-
432- index_p > p && return
433- index_q > q && return
434-
435- for i in 1 : m
436- for j in 1 : n
437- for k in index_p: stride_p: p
438- for l in index_q: stride_q: q
439- @inbounds C[(i- 1 )* p+ k, (j- 1 )* q+ l] = A[i,j] * B[k,l]
440- end
441- end
442- end
443- end
444- return nothing
445- end
446-
447- m, n = size (A)
448- p, q = size (B)
449-
450- # Use different kernels depending on the size of the matrices
451- # choosing to parallelize the matrix with the largest number of elements
452- m* n >= p* q ? (kernel = @cuda launch= false _kron_mat_kernelA! (C, A, B, m, n, p, q)) :
453- (kernel = @cuda launch= false _kron_mat_kernelB! (C, A, B, m, n, p, q))
454-
455- m* n >= p* q ? (sizes = (m, n)) : (sizes = (p, q))
456-
457- config = launch_configuration (kernel. fun)
458- dim_ratio = sizes[1 ] / sizes[2 ]
459- max_threads_i = max (1 , floor (Int, sqrt (config. threads * dim_ratio)))
460- max_threads_j = max (1 , floor (Int, sqrt (config. threads / dim_ratio)))
461- max_blocks_i = max (1 , floor (Int, sqrt (config. blocks * dim_ratio)))
462- max_blocks_j = max (1 , floor (Int, sqrt (config. blocks / dim_ratio)))
463-
464- threads_i = min (sizes[1 ], max_threads_i)
465- threads_j = min (sizes[2 ], max_threads_j)
466- threads = (threads_i, threads_j)
467- blocks_i = min (cld (sizes[1 ], threads_i), max_blocks_i)
468- blocks_j = min (cld (sizes[2 ], threads_j), max_blocks_j)
469- blocks = (blocks_i, blocks_j)
470-
471- kernel (C, A, B, m, n, p, q; threads= threads, blocks= blocks)
472-
473- return C
474- end
475-
476- function LinearAlgebra. kron (A:: CuMatrix{TA} , B:: CuMatrix{TB} ) where {TA,TB}
477- m, n = size (A)
478- p, q = size (B)
479-
480- T = promote_type (TA, TB)
481- C = similar (A, T, m* p, n* q)
482-
483- kron! (C, A, B)
484- end
0 commit comments