Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ext/StridedGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module StridedGPUArraysExt

using Strided, GPUArrays, LinearAlgebra
import Strided: _gemm!
using GPUArrays: Adapt, KernelAbstractions
using GPUArrays.KernelAbstractions: @kernel, @index
using StridedViews: ParentIndex
Expand All @@ -20,6 +21,10 @@ function Base.Array(a::GPUStridedView)
return Array(b)
end

function Strided._gemm!(opA::Char, opB::Char, α, A::TA, B::TB, β, C::TC) where {TA <: GPUStridedView, TB <: GPUStridedView, TC <: GPUStridedView}
return GPUArrays.generic_matmatmul!(C, LinearAlgebra.wrap(A, opA), LinearAlgebra.wrap(B, opB), α, β)
end

# ---------- GPU mapreduce support ----------

@inline _gpu_init_acc(::Nothing, current_val) = current_val
Expand Down
4 changes: 3 additions & 1 deletion src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
LinearAlgebra.rmul!(dst::StridedView, α::Number) = mul!(dst, dst, α)
LinearAlgebra.lmul!(α::Number, dst::StridedView) = mul!(dst, α, dst)

_gemm!(args...) = LinearAlgebra.BLAS.gemm!(args...)

function LinearAlgebra.mul!(
dst::StridedView{<:Number, N}, α::Number,
src::StridedView{<:Number, N}
Expand Down Expand Up @@ -117,7 +119,7 @@ function _threaded_blas_mul!(
return if nthreads == 1 || m * n < 1024
A2, CA = getblasmatrix(A)
B2, CB = getblasmatrix(B)
LinearAlgebra.BLAS.gemm!(CA, CB, convert(T, α), A2, B2, convert(T, β), C)
_gemm!(CA, CB, convert(T, α), A2, B2, convert(T, β), C)
else
if m > n
m2 = round(Int, m / 16) * 8
Expand Down
Loading