22
33using LinearAlgebra
44using LinearAlgebra: BlasComplex, BlasFloat, BlasReal
5- using .. CUBLAS: CublasFloat
5+ using .. CUBLAS: CublasFloat, trsm!
66
77function copy_cublasfloat (As... )
88 eltypes = eltype .(As)
@@ -20,17 +20,49 @@ _copywitheltype(::Type{T}, As...) where {T} = map(A -> copyto!(similar(A, T), A)
2020
2121# matrix division
2222
23- const CuMatOrAdj{T} = Union{CuMatrix,
23+ const CuMatOrAdj{T} = Union{CuMatrix{T} ,
2424 LinearAlgebra. Adjoint{T, <: CuMatrix{T} },
2525 LinearAlgebra. Transpose{T, <: CuMatrix{T} }}
26- const CuOrAdj{T} = Union{CuVecOrMat,
26+ const CuOrAdj{T} = Union{CuVecOrMat{T} ,
2727 LinearAlgebra. Adjoint{T, <: CuVecOrMat{T} },
2828 LinearAlgebra. Transpose{T, <: CuVecOrMat{T} }}
2929
3030function Base.:\ (_A:: CuMatOrAdj , _B:: CuOrAdj )
3131 A, B = copy_cublasfloat (_A, _B)
32- A, ipiv = CUSOLVER. getrf! (A)
33- return CUSOLVER. getrs! (' N' , A, ipiv, B)
32+ T = eltype (A)
33+ n,m = size (A)
34+ if n < m
35+ # LQ decomposition
36+ At = CuMatrix (A' )
37+ F, tau = CUSOLVER. geqrf! (At) # A = RᴴQᴴ
38+ if B isa CuVector{T}
39+ CUBLAS. trsv! (' U' , ' C' , ' N' , view (F,1 : n,1 : n), B)
40+ X = CUDA. zeros (T, m)
41+ view (X, 1 : n) .= B
42+ else
43+ CUBLAS. trsm! (' L' , ' U' , ' C' , ' N' , one (T), view (F,1 : n,1 : n), B)
44+ p = size (B, 2 )
45+ X = CUDA. zeros (T, m, p)
46+ view (X, 1 : n, :) .= B
47+ end
48+ CUSOLVER. ormqr! (' L' , ' N' , F, tau, X)
49+ elseif n == m
50+ # LU decomposition with partial pivoting
51+ F, p, info = CUSOLVER. getrf! (A) # PA = LU
52+ X = CUSOLVER. getrs! (' N' , F, p, B)
53+ else
54+ # QR decomposition
55+ F, tau = CUSOLVER. geqrf! (A) # A = QR
56+ CUSOLVER. ormqr! (' L' , T <: Real ? ' T' : ' C' , F, tau, B)
57+ if B isa CuVector{T}
58+ X = B[1 : m]
59+ CUBLAS. trsv! (' U' , ' N' , ' N' , view (F,1 : m,1 : m), X)
60+ else
61+ X = B[1 : m,:]
62+ CUBLAS. trsm! (' L' , ' U' , ' N' , ' N' , one (T), view (F,1 : m,1 : m), X)
63+ end
64+ end
65+ return X
3466end
3567
3668# patch JuliaLang/julia#40899 to create a CuArray
0 commit comments