Skip to content

Commit e0fc7af

Browse files
matbesanconblegat
andauthored
Abstract linear solving method instead of \ (#229)
* add LinearSolve * remove LinearSolve for abstract method * fix constructor * fix constructor * fix issue with sparse ldiv * Update src/QuadraticProgram/QuadraticProgram.jl Co-authored-by: Benoît Legat <benoit.legat@gmail.com> * test attribute setting * docstring * remove ambiguity * dont reset on empty * Apply suggestions from code review --------- Co-authored-by: Benoît Legat <benoit.legat@gmail.com>
1 parent 47f06f0 commit e0fc7af

File tree

3 files changed

+56
-18
lines changed

3 files changed

+56
-18
lines changed

docs/src/examples/custom-relu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ dataset = repeated((train_X, train_Y), epochs);
9999

100100
# training loss function, Flux optimizer
101101
custom_loss(x, y) = Flux.crossentropy(m(x), y)
102-
opt = Flux.ADAM()
102+
opt = Flux.Adam()
103103
evalcb = () -> @show(custom_loss(train_X, train_Y))
104104

105105
# Train to optimize network parameters

src/QuadraticProgram/QuadraticProgram.jl

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,15 @@ mutable struct Model <: DiffOpt.AbstractModel
100100
# sensitivity input cache using MOI like sparse format
101101
input_cache::DiffOpt.InputCache
102102

103+
# linear solving function to use
104+
linear_solver::Any
105+
103106
x::Vector{Float64} # Primal
104107
λ::Vector{Float64} # Dual of inequalities
105108
ν::Vector{Float64} # Dual of equalities
106109
end
107110
function Model()
108-
return Model(Form{Float64}(), nothing, nothing, nothing, DiffOpt.InputCache(), Float64[], Float64[], Float64[])
111+
return Model(Form{Float64}(), nothing, nothing, nothing, DiffOpt.InputCache(), nothing, Float64[], Float64[], Float64[])
109112
end
110113

111114
function MOI.is_empty(model::Model)
@@ -281,11 +284,9 @@ function DiffOpt.reverse_differentiate!(model::Model)
281284

282285
nv = length(model.x)
283286
Q = view(LHS, 1:nv, 1:nv)
284-
partial_grads = if norm(Q) 0
285-
-IterativeSolvers.lsqr(LHS, RHS)
286-
else
287-
-LHS \ RHS
288-
end
287+
iterative = norm(Q) 0
288+
solver = model.linear_solver
289+
partial_grads = -solve_system(solver, LHS, RHS, iterative)
289290

290291
dz = partial_grads[1:nv]
291292
= partial_grads[nv+1:nv+nineq]
@@ -302,10 +303,6 @@ function DiffOpt.reverse_differentiate!(model::Model)
302303
# todo, check MOI signs for dA and dG
303304
end
304305

305-
_linsolve(A, b) = A \ b
306-
# See https://github.com/JuliaLang/julia/issues/32668
307-
_linsolve(A, b::SparseVector) = A \ Vector(b)
308-
309306
# Just a hack that will be removed once we use `MOI.Utilities.MatrixOfConstraints`
310307
struct _QPSets end
311308
MOI.Utilities.rows(::_QPSets, ci::MOI.ConstraintIndex) = ci.value
@@ -353,13 +350,9 @@ function DiffOpt.forward_differentiate!(model::Model)
353350
]
354351

355352
Q = view(LHS, 1:nv, 1:nv)
356-
partial_grads = if norm(Q) 0
357-
-IterativeSolvers.lsqr(LHS', RHS)
358-
else
359-
-_linsolve(LHS', RHS)
360-
end
361-
362-
353+
iterative = norm(Q) 0
354+
solver = model.linear_solver
355+
partial_grads = -solve_system(solver, LHS', RHS, iterative)
363356
dz = partial_grads[1:nv]
364357
= partial_grads[nv+1:nv+length(λ)]
365358
= partial_grads[nv+length(λ)+1:end]
@@ -395,4 +388,31 @@ function DiffOpt._get_dA(model::Model, ci::LE)
395388
return DiffOpt.lazy_combination(+, l * dλ[i], model.x, l, dz)
396389
end
397390

391+
"""
392+
LinearAlgebraSolver
393+
394+
Optimizer attribute for the solver to use for the linear algebra operations.
395+
Each solver must implement: `solve_system(solver, LHS, RHS, iterative::Bool)`.
396+
"""
397+
struct LinearAlgebraSolver <: MOI.AbstractOptimizerAttribute end
398+
399+
"""
400+
Default `solve_system` call uses IterativeSolvers or the default linear solve
401+
"""
402+
function solve_system(::Any, LHS, RHS, iterative)
403+
if iterative
404+
IterativeSolvers.lsqr(LHS, RHS)
405+
else
406+
LHS \ RHS
407+
end
408+
end
409+
# See https://github.com/JuliaLang/julia/issues/32668
410+
solve_system(::Nothing, LHS, RHS::SparseVector, iterative) = solve_system(nothing, LHS, Vector(RHS), iterative)
411+
412+
MOI.supports(::Model, ::LinearAlgebraSolver) = true
413+
MOI.get(model::Model, ::LinearAlgebraSolver) = model.linear_solver
414+
function MOI.set(model::Model, ::LinearAlgebraSolver, linear_solver)
415+
model.linear_solver = linear_solver
416+
end
417+
398418
end

test/solver_interface.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,21 @@ end
3232
@test_throws ErrorException DiffOpt.forward_differentiate!(model)
3333
@test_throws ErrorException DiffOpt.reverse_differentiate!(model)
3434
end
35+
36+
struct TestSolver
37+
end
38+
39+
# always use IterativeSolvers
40+
function DiffOpt.QuadraticProgram.solve_system(::TestSolver, LHS, RHS, iterative::Bool)
41+
IterativeSolvers.lsqr(LHS, RHS)
42+
end
43+
44+
@testset "Setting the linear solver in the quadratic solver" begin
45+
model = DiffOpt.QuadraticProgram.Model()
46+
@test MOI.supports(model, DiffOpt.QuadraticProgram.LinearAlgebraSolver())
47+
@test MOI.get(model, DiffOpt.QuadraticProgram.LinearAlgebraSolver()) === nothing
48+
MOI.set(model, DiffOpt.QuadraticProgram.LinearAlgebraSolver(), TestSolver())
49+
@test MOI.get(model, DiffOpt.QuadraticProgram.LinearAlgebraSolver()) == TestSolver()
50+
MOI.empty!(model)
51+
@test MOI.get(model, DiffOpt.QuadraticProgram.LinearAlgebraSolver()) == TestSolver()
52+
end

0 commit comments

Comments
 (0)