Skip to content

Commit e884bb5

Browse files
committed
upadtes
1 parent 9d60062 commit e884bb5

File tree

2 files changed

+47
-24
lines changed

2 files changed

+47
-24
lines changed

src/pardiso.jl

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,57 @@
33

44
import Pardiso
55

6-
export PardisoJL, PardisoJLFactorize, PardisoJLIterate
6+
export PardisoJL, MKLPardisoFactorize, MKLPardisoIterate
77

88
Base.@kwdef struct PardisoJL <: SciMLLinearSolveAlgorithm
99
nprocs::Union{Int, Nothing} = nothing
1010
solver_type::Union{Int, Pardiso.Solver, Nothing} = nothing
1111
matrix_type::Union{Int, Pardiso.MatrixType, Nothing} = nothing
12+
fact_phase::Union{Int, Pardiso.Phase, Nothing} = nothing
1213
solve_phase::Union{Int, Pardiso.Phase, Nothing} = nothing
1314
release_phase::Union{Int, Nothing} = nothing
1415
iparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
1516
dparm::Union{Vector{Tuple{Int,Int}}, Nothing} = nothing
1617
end
1718

18-
PardisoJLFactorize(;kwargs...) = PardisoJL(;solver_type=0,
19-
solve_phase=Pardiso.NUM_FACT,
20-
kwargs...)
21-
PardisoJLIterate(;kwargs...) = PardisoJL(;solver_type=1,
22-
solve_phase=Pardiso.SOLVE_ITERATIVE_REFINE,
23-
kwargs...)
19+
MKLPardisoFactorize(;kwargs...) = PardisoJL(;fact_phase=Pardiso.NUM_FACT,
20+
solve_phase=Pardiso.SOLVE_ITERATIVE_REFINE,
21+
kwargs...)
22+
MKLPardisoIterate(;kwargs...) = PardisoJL(;solve_phase=Pardiso.NUM_FACT_SOLVE_REFINE,
23+
kwargs...)
2424

2525
# TODO schur complement functionality
2626

2727
function init_cacheval(alg::PardisoJL, cache::LinearCache)
28-
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
28+
@unpack nprocs, solver_type, matrix_type, fact_phase, solve_phase, iparm, dparm = alg
29+
@unpack A, b, u = cache
30+
31+
if A isa DiffEqArrayOperator
32+
A = A.A
33+
end
2934

3035
solver =
3136
if Pardiso.PARDISO_LOADED[]
32-
Pardiso.PardisoSolver()
37+
solver = Pardiso.PardisoSolver()
3338
solver_type !== nothing && Pardiso.set_solver!(solver, solver_type)
39+
40+
solver
3441
else
35-
Pardiso.MKLPardisoSolver()
42+
solver = Pardiso.MKLPardisoSolver()
3643
nprocs !== nothing && Pardiso.set_nprocs!(solver, nprocs)
44+
45+
solver
3746
end
3847

3948
Pardiso.pardisoinit(solver) # default initialization
4049

41-
@show solver
4250
matrix_type !== nothing && Pardiso.set_matrixtype!(solver, matrix_type)
4351
cache.verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
4452

45-
if iparm !== nothing # pass in vector of tuples like [(iparm, key) ...]
53+
"""
54+
pass in vector of tuples like [(iparm::Int, key::Int) ...]
55+
"""
56+
if iparm !== nothing
4657
for i in length(iparm)
4758
Pardiso.set_iparm!(solver, iparm[i]...)
4859
end
@@ -54,7 +65,21 @@ function init_cacheval(alg::PardisoJL, cache::LinearCache)
5465
end
5566
end
5667

57-
Pardiso.set_phase!(cacheval, Pardiso.ANALYSIS)
68+
if (fact_phase !== nothing) | (solve_phase !== nothing)
69+
# ensure phase is being changed afterwards?
70+
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
71+
Pardiso.pardiso(solver, u, A, b)
72+
end
73+
74+
if fact_phase !== nothing
75+
Pardiso.set_phase!(solver, fact_phase)
76+
Pardiso.pardiso(solver, u, A, b)
77+
end
78+
79+
# ipram/dpram
80+
# abstol = cache.abstol
81+
# reltol = cache.reltol
82+
# kwargs = (abstol=abstol, reltol=reltol)
5883

5984
return solver
6085
end
@@ -70,13 +95,9 @@ function SciMLBase.solve(cache::LinearCache, alg::PardisoJL; kwargs...)
7095
cache = set_cacheval(cache, solver)
7196
end
7297

73-
abstol = cache.abstol
74-
reltol = cache.reltol
75-
kwargs = (abstol=abstol, reltol=reltol)
76-
77-
alg.solve_phase !== nothing && Pardiso.set_phase!(cacheval, alg.solve_phase)
98+
alg.solve_phase !== nothing && Pardiso.set_phase!(cache.cacheval, alg.solve_phase)
7899
Pardiso.pardiso(cache.cacheval, u, A, b)
79-
alg.release_phase !== nothing && Pardiso.set_phase!(cacheval, alg.release_phase)
100+
alg.release_phase !== nothing && Pardiso.set_phase!(cache.cacheval, alg.release_phase)
80101

81102
return cache.u
82103
end

test/runtests.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,9 @@ end
127127
end
128128

129129
@testset "PardisoJL" begin
130-
@test_broken alg = PardisoJL()
130+
@test_throws UndefVarError alg = PardisoJL()
131131

132132
using Pardiso, SparseArrays
133-
verbose = true
134133

135134
A = sparse([ 1. 0 -2 3
136135
0 5 1 2
@@ -140,12 +139,15 @@ end
140139
u = zero(b)
141140

142141
prob = LinearProblem(A, b)
143-
for alg in (PardisoJL(),
144-
PardisoJLFactorize(),
145-
PardisoJLIterate(), # not with MKLPardisoSolver
142+
for alg in (
143+
PardisoJL(),
144+
MKLPardisoFactorize(),
145+
MKLPardisoIterate(),
146146
)
147147

148148
u = solve(prob, alg; verbose=true)
149+
150+
@test A * u b
149151
end
150152

151153
end

0 commit comments

Comments
 (0)