@@ -56,3 +56,58 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
5656 SciMLBase. solve (cache, alg, args... ; kwargs... )
5757 end
5858end
59+
60+ function init_cacheval (alg:: Nothing , A, b, u)
61+ if A isa DiffEqArrayOperator
62+ A = A. A
63+ end
64+
65+ # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
66+ # it makes sense according to the benchmarks, which is dependent on
67+ # whether MKL or OpenBLAS is being used
68+ if A isa Matrix
69+ if eltype (A) <: Union{Float32,Float64,ComplexF32,ComplexF64} &&
70+ ArrayInterface. can_setindex (b) && (size (A,1 ) <= 100 ||
71+ (isopenblas () && size (A,1 ) <= 500 )
72+ )
73+ alg = RFLUFactorization ()
74+ init_cacheval (alg, A, b, u)
75+ else
76+ alg = LUFactorization ()
77+ init_cacheval (alg, A, b, u)
78+ end
79+
80+ # These few cases ensure the choice is optimal without the
81+ # dynamic dispatching of factorize
82+ elseif A isa Tridiagonal
83+ alg = GenericFactorization (;fact_alg= lu!)
84+ init_cacheval (alg, A, b, u)
85+ elseif A isa SymTridiagonal
86+ alg = GenericFactorization (;fact_alg= ldlt!)
87+ init_cacheval (alg, A, b, u)
88+ elseif A isa SparseMatrixCSC
89+ alg = UMFPACKFactorization ()
90+ init_cacheval (alg, A, b, u)
91+
92+ # This catches the cases where a factorization overload could exist
93+ # For example, BlockBandedMatrix
94+ elseif ArrayInterface. isstructured (A)
95+ alg = GenericFactorization ()
96+ init_cacheval (alg, A, b, u)
97+
98+ # This catches the case where A is a CuMatrix
99+ # Which does not have LU fully defined
100+ elseif ! (A isa AbstractDiffEqOperator)
101+ alg = QRFactorization ()
102+ init_cacheval (alg, A, b, u)
103+
104+ # Not factorizable operator, default to only using A*x
105+ # IterativeSolvers is faster on CPU but not GPU-compatible
106+ elseif cache. u isa Array
107+ alg = IterativeSolversJL_GMRES ()
108+ init_cacheval (alg, A, b, u)
109+ else
110+ alg = KrylovJL_GMRES ()
111+ init_cacheval (alg, A, b, u)
112+ end
113+ end
0 commit comments