From 949f6a5fd885716bfcc89b71990b369a17c90cc1 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 13 May 2026 12:11:25 +0200 Subject: [PATCH 1/2] Use a pass-through for gemm --- ext/StridedGPUArraysExt.jl | 5 +++++ src/linalg.jl | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index 6bfb21a..e44da7b 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -1,6 +1,7 @@ module StridedGPUArraysExt using Strided, GPUArrays, LinearAlgebra +import Strided: _gemm! using GPUArrays: Adapt, KernelAbstractions using GPUArrays.KernelAbstractions: @kernel, @index using StridedViews: ParentIndex @@ -20,6 +21,10 @@ function Base.Array(a::GPUStridedView) return Array(b) end +function Strided._gemm!(opA::Char, opB::Char, α, A::TA, B::TB, β, C::TC) where {TA <: GPUStridedView, TB <: GPUStridedView, TC <: GPUStridedView} + GPUArrays.generic_matmatmul!(C, LinearAlgebra.wrap(A, opA), LinearAlgebra.wrap(B, opB), α, β) +end + # ---------- GPU mapreduce support ---------- @inline _gpu_init_acc(::Nothing, current_val) = current_val diff --git a/src/linalg.jl b/src/linalg.jl index febc573..ac73822 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -2,6 +2,8 @@ LinearAlgebra.rmul!(dst::StridedView, α::Number) = mul!(dst, dst, α) LinearAlgebra.lmul!(α::Number, dst::StridedView) = mul!(dst, α, dst) +_gemm!(args...) = LinearAlgebra.BLAS.gemm!(args...) + function LinearAlgebra.mul!( dst::StridedView{<:Number, N}, α::Number, src::StridedView{<:Number, N} @@ -117,7 +119,7 @@ function _threaded_blas_mul!( return if nthreads == 1 || m * n < 1024 A2, CA = getblasmatrix(A) B2, CB = getblasmatrix(B) - LinearAlgebra.BLAS.gemm!(CA, CB, convert(T, α), A2, B2, convert(T, β), C) + _gemm!(CA, CB, convert(T, α), A2, B2, convert(T, β), C) else if m > n m2 = round(Int, m / 16) * 8 From 7531b2296b97f74fd9425ce6ebdeb79109048bdf Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 13 May 2026 12:16:31 +0200 Subject: [PATCH 2/2] Formatter --- ext/StridedGPUArraysExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index e44da7b..152b032 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -22,7 +22,7 @@ function Base.Array(a::GPUStridedView) end function Strided._gemm!(opA::Char, opB::Char, α, A::TA, B::TB, β, C::TC) where {TA <: GPUStridedView, TB <: GPUStridedView, TC <: GPUStridedView} - GPUArrays.generic_matmatmul!(C, LinearAlgebra.wrap(A, opA), LinearAlgebra.wrap(B, opB), α, β) + return GPUArrays.generic_matmatmul!(C, LinearAlgebra.wrap(A, opA), LinearAlgebra.wrap(B, opB), α, β) end # ---------- GPU mapreduce support ----------