diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index 6bfb21a..152b032 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -1,6 +1,7 @@ module StridedGPUArraysExt using Strided, GPUArrays, LinearAlgebra +import Strided: _gemm! using GPUArrays: Adapt, KernelAbstractions using GPUArrays.KernelAbstractions: @kernel, @index using StridedViews: ParentIndex @@ -20,6 +21,10 @@ function Base.Array(a::GPUStridedView) return Array(b) end +function Strided._gemm!(opA::Char, opB::Char, α, A::TA, B::TB, β, C::TC) where {TA <: GPUStridedView, TB <: GPUStridedView, TC <: GPUStridedView} + return GPUArrays.generic_matmatmul!(C, LinearAlgebra.wrap(A, opA), LinearAlgebra.wrap(B, opB), α, β) +end + # ---------- GPU mapreduce support ---------- @inline _gpu_init_acc(::Nothing, current_val) = current_val diff --git a/src/linalg.jl b/src/linalg.jl index febc573..ac73822 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -2,6 +2,8 @@ LinearAlgebra.rmul!(dst::StridedView, α::Number) = mul!(dst, dst, α) LinearAlgebra.lmul!(α::Number, dst::StridedView) = mul!(dst, α, dst) +_gemm!(args...) = LinearAlgebra.BLAS.gemm!(args...) + function LinearAlgebra.mul!( dst::StridedView{<:Number, N}, α::Number, src::StridedView{<:Number, N} @@ -117,7 +119,7 @@ function _threaded_blas_mul!( return if nthreads == 1 || m * n < 1024 A2, CA = getblasmatrix(A) B2, CB = getblasmatrix(B) - LinearAlgebra.BLAS.gemm!(CA, CB, convert(T, α), A2, B2, convert(T, β), C) + _gemm!(CA, CB, convert(T, α), A2, B2, convert(T, β), C) else if m > n m2 = round(Int, m / 16) * 8