@@ -100,12 +100,15 @@ mutable struct Model <: DiffOpt.AbstractModel
100100 # sensitivity input cache using MOI like sparse format
101101 input_cache:: DiffOpt.InputCache
102102
103+ # linear solving function to use
104+ linear_solver:: Any
105+
103106 x:: Vector{Float64} # Primal
104107 λ:: Vector{Float64} # Dual of inequalities
105108 ν:: Vector{Float64} # Dual of equalities
106109end
107110function Model ()
108- return Model (Form {Float64} (), nothing , nothing , nothing , DiffOpt. InputCache (), Float64[], Float64[], Float64[])
111+ return Model (Form {Float64} (), nothing , nothing , nothing , DiffOpt. InputCache (), nothing , Float64[], Float64[], Float64[])
109112end
110113
111114function MOI. is_empty (model:: Model )
@@ -281,11 +284,9 @@ function DiffOpt.reverse_differentiate!(model::Model)
281284
282285 nv = length (model. x)
283286 Q = view (LHS, 1 : nv, 1 : nv)
284- partial_grads = if norm (Q) ≈ 0
285- - IterativeSolvers. lsqr (LHS, RHS)
286- else
287- - LHS \ RHS
288- end
287+ iterative = norm (Q) ≈ 0
288+ solver = model. linear_solver
289+ partial_grads = - solve_system (solver, LHS, RHS, iterative)
289290
290291 dz = partial_grads[1 : nv]
291292 dλ = partial_grads[nv+ 1 : nv+ nineq]
@@ -302,10 +303,6 @@ function DiffOpt.reverse_differentiate!(model::Model)
302303 # todo, check MOI signs for dA and dG
303304end
304305
305- _linsolve (A, b) = A \ b
306- # See https://github.com/JuliaLang/julia/issues/32668
307- _linsolve (A, b:: SparseVector ) = A \ Vector (b)
308-
309306# Just a hack that will be removed once we use `MOI.Utilities.MatrixOfConstraints`
310307struct _QPSets end
311308MOI. Utilities. rows (:: _QPSets , ci:: MOI.ConstraintIndex ) = ci. value
@@ -353,13 +350,9 @@ function DiffOpt.forward_differentiate!(model::Model)
353350 ]
354351
355352 Q = view (LHS, 1 : nv, 1 : nv)
356- partial_grads = if norm (Q) ≈ 0
357- - IterativeSolvers. lsqr (LHS' , RHS)
358- else
359- - _linsolve (LHS' , RHS)
360- end
361-
362-
353+ iterative = norm (Q) ≈ 0
354+ solver = model. linear_solver
355+ partial_grads = - solve_system (solver, LHS' , RHS, iterative)
363356 dz = partial_grads[1 : nv]
364357 dλ = partial_grads[nv+ 1 : nv+ length (λ)]
365358 dν = partial_grads[nv+ length (λ)+ 1 : end ]
@@ -395,4 +388,31 @@ function DiffOpt._get_dA(model::Model, ci::LE)
395388 return DiffOpt. lazy_combination (+ , l * dλ[i], model. x, l, dz)
396389end
397390
391+ """
392+ LinearAlgebraSolver
393+
394+ Optimizer attribute for the solver to use for the linear algebra operations.
395+ Each solver must implement: `solve_system(solver, LHS, RHS, iterative::Bool)`.
396+ """
397+ struct LinearAlgebraSolver <: MOI.AbstractOptimizerAttribute end
398+
399+ """
400+ Default `solve_system` call uses IterativeSolvers or the default linear solve
401+ """
402+ function solve_system (:: Any , LHS, RHS, iterative)
403+ if iterative
404+ IterativeSolvers. lsqr (LHS, RHS)
405+ else
406+ LHS \ RHS
407+ end
408+ end
409+ # See https://github.com/JuliaLang/julia/issues/32668
410+ solve_system (:: Nothing , LHS, RHS:: SparseVector , iterative) = solve_system (nothing , LHS, Vector (RHS), iterative)
411+
412+ MOI. supports (:: Model , :: LinearAlgebraSolver ) = true
413+ MOI. get (model:: Model , :: LinearAlgebraSolver ) = model. linear_solver
414+ function MOI. set (model:: Model , :: LinearAlgebraSolver , linear_solver)
415+ model. linear_solver = linear_solver
416+ end
417+
398418end
0 commit comments