Skip to content

Commit aa7f12e

Browse files
Merge pull request #834 from hexaeder/hw/cusparse_defaults
fix default solver selection for CuSparseMatrix
2 parents 49d76ed + 3d0f557 commit aa7f12e

File tree

7 files changed

+93
-9
lines changed

7 files changed

+93
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ ArrayInterface = "7.17"
8686
BandedMatrices = "1.8"
8787
BlockDiagonals = "0.2"
8888
CUDA = "5.5"
89-
CUDSS = "0.4, 0.6.1"
89+
CUDSS = "0.6.3"
9090
CUSOLVERRF = "0.2.6"
9191
ChainRulesCore = "1.25"
9292
CliqueTrees = "1.11.0"

ext/LinearSolveCUDAExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ function LinearSolve.is_cusparse(A::Union{
1717
CUDA.CUSPARSE.CuSparseMatrixCSR, CUDA.CUSPARSE.CuSparseMatrixCSC})
1818
true
1919
end
20+
LinearSolve.is_cusparse_csr(::CUDA.CUSPARSE.CuSparseMatrixCSR) = true
21+
LinearSolve.is_cusparse_csc(::CUDA.CUSPARSE.CuSparseMatrixCSC) = true
2022

2123
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
2224
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
@@ -31,6 +33,16 @@ function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
3133
end
3234
end
3335

36+
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSC, b,
37+
assump::OperatorAssumptions{Bool})
38+
if LinearSolve.cudss_loaded(A)
39+
@warn("CUDSS.jl does not support CuSparseMatrixCSC for LU Factorizations, consider using CuSparseMatrixCSR instead. Falling back to Krylov", maxlog=1)
40+
else
41+
@warn("CuSparseMatrixCSC does not support LU Factorization falling back to Krylov. Consider using CUDSS.jl together with CuSparseMatrixCSR", maxlog=1)
42+
end
43+
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
44+
end
45+
3446
function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
3547
if !LinearSolve.cudss_loaded(A)
3648
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.")

ext/LinearSolveCUSOLVERRFExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LinearSolveCUSOLVERRFExt
22

3-
using LinearSolve: LinearSolve, @get_cacheval, pattern_changed, OperatorAssumptions
3+
using LinearSolve: LinearSolve, @get_cacheval, pattern_changed, OperatorAssumptions, LinearVerbosity
44
using CUSOLVERRF: CUSOLVERRF, RFLU, CUDA
55
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz
66
using CUSOLVERRF.CUDA.CUSPARSE: CuSparseMatrixCSR

ext/LinearSolveSparseArraysExt.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function LinearSolve.init_cacheval(
129129
maxiters::Int, abstol, reltol,
130130
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
131131
if LinearSolve.is_cusparse(A)
132-
ArrayInterface.lu_instance(A)
132+
LinearSolve.cudss_loaded(A) ? ArrayInterface.lu_instance(A) : nothing
133133
else
134134
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(
135135
zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
@@ -141,7 +141,7 @@ function LinearSolve.init_cacheval(
141141
maxiters::Int, abstol, reltol,
142142
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
143143
if LinearSolve.is_cusparse(A)
144-
ArrayInterface.lu_instance(A)
144+
LinearSolve.cudss_loaded(A) ? ArrayInterface.lu_instance(A) : nothing
145145
else
146146
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(
147147
zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
@@ -344,7 +344,13 @@ function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
344344
Symmetric{T, <:AbstractSparseArray{T}}}, b, u, Pl, Pr,
345345
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
346346
assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
347-
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
347+
if LinearSolve.is_cusparse_csc(A)
348+
nothing
349+
elseif LinearSolve.is_cusparse_csr(A) && !LinearSolve.cudss_loaded(A)
350+
nothing
351+
else
352+
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
353+
end
348354
end
349355

350356
# Specialize QR for the non-square case

src/LinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ ALREADY_WARNED_CUDSS = Ref{Bool}(false)
478478
error_no_cudss_lu(A) = nothing
479479
cudss_loaded(A) = false
480480
is_cusparse(A) = false
481+
is_cusparse_csr(A) = false
482+
is_cusparse_csc(A) = false
481483

482484
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
483485
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, ButterflyFactorization,

src/factorization.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,13 @@ end
395395
function init_cacheval(
396396
alg::CholeskyFactorization, A::AbstractArray{<:BLASELTYPES}, b, u, Pl, Pr,
397397
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
398-
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
398+
if LinearSolve.is_cusparse_csc(A)
399+
nothing
400+
elseif LinearSolve.is_cusparse_csr(A) && !LinearSolve.cudss_loaded(A)
401+
nothing
402+
else
403+
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
404+
end
399405
end
400406

401407
const PREALLOCATED_CHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot())

test/gpu/cuda.jl

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,65 @@
11
using LinearSolve, CUDA, LinearAlgebra, SparseArrays, StableRNGs
2-
using CUDA.CUSPARSE, CUDSS
2+
using CUDA.CUSPARSE
33
using Test
44

5+
@testset "Test default solver choice for CuSparse" begin
6+
b = Float64[1, 2, 3, 4]
7+
b_gpu = CUDA.adapt(CuArray, b)
8+
9+
A = Float64[1 1 0 0
10+
0 1 1 0
11+
0 0 3 1
12+
0 0 0 4]
13+
A_gpu_csr = CUDA.CUSPARSE.CuSparseMatrixCSR(sparse(A))
14+
A_gpu_csc = CUDA.CUSPARSE.CuSparseMatrixCSC(sparse(A))
15+
prob_csr = LinearProblem(A_gpu_csr, b_gpu)
16+
prob_csc = LinearProblem(A_gpu_csc, b_gpu)
17+
18+
A_sym = Float64[1 1 0 0
19+
1 0 0 2
20+
0 0 3 0
21+
0 2 0 0]
22+
A_gpu_sym_csr = CUDA.CUSPARSE.CuSparseMatrixCSR(sparse(A_sym))
23+
A_gpu_sym_csc = CUDA.CUSPARSE.CuSparseMatrixCSC(sparse(A_sym))
24+
prob_sym_csr = LinearProblem(A_gpu_sym_csr, b_gpu)
25+
prob_sym_csc = LinearProblem(A_gpu_sym_csc, b_gpu)
26+
27+
@testset "Test without CUDSS loaded" begin
28+
# assert CuDSS is not loaded yet
29+
@test !LinearSolve.cudss_loaded(A_gpu_csr)
30+
# csr fallback to krylov
31+
alg = solve(prob_csr).alg
32+
@test alg.alg == LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES
33+
# csc fallback to krylov
34+
alg = solve(prob_csc).alg
35+
@test alg.alg == LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES
36+
# csr symmetric fallback to krylov
37+
alg = solve(prob_sym_csr).alg
38+
@test alg.alg == LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES
39+
# csc symmetric fallback to krylov
40+
alg = solve(prob_sym_csc).alg
41+
@test alg.alg == LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES
42+
end
43+
44+
using CUDSS
45+
46+
@testset "Test with CUDSS loaded" begin
47+
@test LinearSolve.cudss_loaded(A_gpu_csr)
48+
# csr uses LU
49+
alg = solve(prob_csr).alg
50+
@test alg.alg == LinearSolve.DefaultAlgorithmChoice.LUFactorization
51+
# csc fallback to krylov
52+
alg = solve(prob_csc).alg
53+
@test alg.alg == LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES
54+
# csr symmetric uses LU/cholesky
55+
alg = solve(prob_sym_csr).alg
56+
@test alg.alg == LinearSolve.DefaultAlgorithmChoice.LUFactorization
57+
# csc symmetric fallback to krylov
58+
alg = solve(prob_sym_csc).alg
59+
@test alg.alg == LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES
60+
end
61+
end
62+
563
CUDA.allowscalar(false)
664

765
n = 8
@@ -96,9 +154,9 @@ end
96154
@testset "CUDSS" begin
97155
T = Float32
98156
n = 100
99-
A_cpu = sprand(T, n, n, 0.05) + I
157+
A_cpu = sprand(rng, T, n, n, 0.05) + I
100158
x_cpu = zeros(T, n)
101-
b_cpu = rand(T, n)
159+
b_cpu = rand(rng, T, n)
102160

103161
A_gpu_csr = CuSparseMatrixCSR(A_cpu)
104162
b_gpu = CuVector(b_cpu)

0 commit comments

Comments
 (0)