11# generic APIs
22
33export gather!, scatter!, axpby!, rot!
4- export vv!, sv!, sm!, gemv, gemm, gemm!, sddmm!
4+ export vv!, sv!, sm!, mv!, mm!, gemv, gemm, gemm!, sddmm!
55export bmm!
66
7+ """
8+ mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)
9+
10+ Performs `Y = alpha * op(A) * X + beta * Y`, where `op` can be nothing (`transa = N`),
11+ tranpose (`transa = T`) or conjugate transpose (`transa = C`).
12+ `X` and `Y` are dense vectors.
13+ """
14+ function mv! end
15+
16+ """
17+ mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix, B::CuMatrix, beta::Number, C::CuMatrix, index::SparseChar)
18+ mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuMatrix, B::Union{CuSparseMatrixCSC,CuSparseMatrixCSR,CuSparseMatrixCOO}, beta::Number, C::CuMatrix, index::SparseChar)
19+
20+ Performs `C = alpha * op(A) * op(B) + beta * C`, where `op` can be nothing (`transa = N`),
21+ tranpose (`transa = T`) or conjugate transpose (`transa = C`).
22+ """
23+ function mm! end
24+
725# # API functions
826
927function sparsetodense (A:: Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}} , index:: SparseChar , algo:: cusparseSparseToDenseAlg_t = CUSPARSE_SPARSETODENSE_ALG_DEFAULT) where {T}
@@ -191,9 +209,11 @@ function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{TA},C
191209 return Y
192210end
193211
194- function mm! (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T} } ,
212+ function mm! (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: CuSparseMatrix{T } ,
195213 B:: DenseCuMatrix{T} , beta:: Number , C:: DenseCuMatrix{T} , index:: SparseChar , algo:: cusparseSpMMAlg_t = CUSPARSE_SPMM_ALG_DEFAULT) where {T}
196214
215+ (A isa CuSparseMatrixBSR) && (CUSPARSE. version () < v " 12.5.1" ) && throw (ErrorException (" This operation is not supported by the current CUDA version." ))
216+
197217 # Support transa = 'C' and `transb = 'C' for real matrices
198218 transa = T <: Real && transa == ' C' ? ' T' : transa
199219 transb = T <: Real && transb == ' C' ? ' T' : transb
@@ -235,10 +255,10 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::Union{CuS
235255 # cusparseCsrSetStridedBatch(obj, batchsize, 0, nnz(A))
236256 # end
237257
238- # Set default buffer for small matrices (10000 chosen arbitrarly)
258+ # Set default buffer for small matrices (1000 chosen arbitrarly)
239259 # Otherwise tries to allocate 120TB of memory (see #2296)
240260 function bufferSize ()
241- out = Ref {Csize_t} (10000 )
261+ out = Ref {Csize_t} (1000 )
242262 cusparseSpMM_bufferSize (
243263 handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
244264 descC, T, algo, out)
@@ -274,7 +294,6 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
274294 throw (ErrorException (" Batched dense-matrix times batched sparse-matrix (bmm!) requires a CUSPARSE version ≥ 11.7.2 (yours: $(CUSPARSE. version ()) )." ))
275295 end
276296
277-
278297 # Support transa = 'C' and `transb = 'C' for real matrices
279298 transa = T <: Real && transa == ' C' ? ' T' : transa
280299 transb = T <: Real && transb == ' C' ? ' T' : transb
@@ -313,10 +332,10 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
313332 strideC = stride (C, 3 )
314333 cusparseDnMatSetStridedBatch (descC, b, strideC)
315334
316- # Set default buffer for small matrices (10000 chosen arbitrarly)
335+ # Set default buffer for small matrices (1000 chosen arbitrarly)
317336 # Otherwise tries to allocate 120TB of memory (see #2296)
318337 function bufferSize ()
319- out = Ref {Csize_t} (10000 )
338+ out = Ref {Csize_t} (1000 )
320339 cusparseSpMM_bufferSize (
321340 handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
322341 descC, T, algo, out)
@@ -337,10 +356,11 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
337356end
338357
339358function mm! (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: DenseCuMatrix{T} ,
340- B:: Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}} ,
341- beta :: Number , C:: DenseCuMatrix{T} , index:: SparseChar , algo:: cusparseSpMMAlg_t = CUSPARSE_SPMM_ALG_DEFAULT) where {T}
359+ B:: Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}} , beta :: Number ,
360+ C:: DenseCuMatrix{T} , index:: SparseChar , algo:: cusparseSpMMAlg_t = CUSPARSE_SPMM_ALG_DEFAULT) where {T}
342361
343362 CUSPARSE. version () < v " 11.7.4" && throw (ErrorException (" This operation is not supported by the current CUDA version." ))
363+
344364 # Support transa = 'C' and `transb = 'C' for real matrices
345365 transa = T <: Real && transa == ' C' ? ' T' : transa
346366 transb = T <: Real && transb == ' C' ? ' T' : transb
@@ -373,10 +393,10 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMa
373393 descB = CuSparseMatrixDescriptor (B, index, transposed= true )
374394 descC = CuDenseMatrixDescriptor (C, transposed= true )
375395
376- # Set default buffer for small matrices (10000 chosen arbitrarly)
396+ # Set default buffer for small matrices (1000 chosen arbitrarly)
377397 # Otherwise tries to allocate 120TB of memory (see #2296)
378398 function bufferSize ()
379- out = Ref {Csize_t} (10000 )
399+ out = Ref {Csize_t} (1000 )
380400 cusparseSpMM_bufferSize (
381401 handle (), transb, transa, Ref {T} (alpha), descB, descA, Ref {T} (beta),
382402 descC, T, algo, out)
@@ -736,9 +756,10 @@ function sm!(transa::SparseChar, transb::SparseChar, uplo::SparseChar, diag::Spa
736756end
737757
738758function sddmm! (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: DenseCuMatrix{T} , B:: DenseCuMatrix{T} ,
739- beta:: Number , C:: CuSparseMatrixCSR{T} , index:: SparseChar , algo:: cusparseSDDMMAlg_t = CUSPARSE_SDDMM_ALG_DEFAULT) where {T}
759+ beta:: Number , C:: Union{ CuSparseMatrixCSR{T},CuSparseMatrixBSR{T} } , index:: SparseChar , algo:: cusparseSDDMMAlg_t = CUSPARSE_SDDMM_ALG_DEFAULT) where {T}
740760
741761 CUSPARSE. version () < v " 11.4.1" && throw (ErrorException (" This operation is not supported by the current CUDA version." ))
762+ (C isa CuSparseMatrixBSR) && (CUSPARSE. version () < v " 12.1.0" ) && throw (ErrorException (" This operation is not supported by the current CUDA version." ))
742763
743764 # Support transa = 'C' and `transb = 'C' for real matrices
744765 transa = T <: Real && transa == ' C' ? ' T' : transa
0 commit comments