Skip to content

Commit b4d8c1e

Browse files
committed
implicit kwargs as an option, and verbose=false by default
1 parent 8edaf48 commit b4d8c1e

File tree

3 files changed

+29
-14
lines changed

3 files changed

+29
-14
lines changed

ext/InferOptFrankWolfeExt.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@ using ImplicitDifferentiation: KrylovLinearSolver
88
using InferOpt: InferOpt, RegularizedFrankWolfe
99
using LinearAlgebra: dot
1010

11+
"""
12+
RegularizedFrankWolfe(linear_maximizer; Ω, Ω_grad, frank_wolfe_kwargs=(;), implicit_kwargs=(; linear_solver=KrylovLinearSolver(; verbose=false)))
13+
14+
Construct a `RegularizedFrankWolfe` struct with a linear maximizer and the necessary components for the Frank-Wolfe algorithm.
15+
Set `implicit_kwargs` to `(; linear_solver=KrylovLinearSolver(; verbose=true))` if you want to see the solver potential warnings.
16+
"""
17+
function RegularizedFrankWolfe(
18+
linear_maximizer;
19+
Ω,
20+
Ω_grad,
21+
frank_wolfe_kwargs=NamedTuple(),
22+
implicit_kwargs=(; linear_solver=KrylovLinearSolver(; verbose=false)),
23+
)
24+
return RegularizedFrankWolfe(
25+
linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs, implicit_kwargs
26+
)
27+
end
28+
1129
"""
1230
LinearMaximizationOracleWithKwargs{F,K}
1331
Wraps a linear maximizer as a `FrankWolfe.LinearMinimizationOracle` with a sign switch and predefined keyword arguments.
@@ -42,12 +60,11 @@ function InferOpt.compute_probability_distribution(
4260
regularized::RegularizedFrankWolfe, θ::AbstractArray; kwargs...
4361
)
4462
shape = size(θ)
45-
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized
63+
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs, implicit_kwargs) = regularized
4664
f(y, θ) = Ω(y) - dot(θ, y)
4765
f_grad1(y, θ) = Ω_grad(y) - θ
4866
maximizer(θ; shape, kwargs...) = vec(linear_maximizer(reshape(θ, shape); kwargs...))
4967
lmo = LinearMaximizationOracleWithKwargs(maximizer, (; shape, kwargs...))
50-
implicit_kwargs = (; linear_solver=KrylovLinearSolver())
5168
dfw = DiffFW(f, f_grad1, lmo; implicit_kwargs)
5269
weights, atoms = dfw.implicit(vec(θ); frank_wolfe_kwargs=frank_wolfe_kwargs)
5370
probadist = FixedAtomsProbabilityDistribution(

src/layers/regularized/regularized_frank_wolfe.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@ Regularized optimization layer which relies on the Frank-Wolfe algorithm to defi
77
```
88
99
!!! warning "Warning"
10-
Since this is a conditional dependency, you need to have loaded the package DifferentiableFrankWolfe.jl before using `RegularizedFrankWolfe`.
10+
Since this is a conditional dependency, you need to have loaded the following packages before using `RegularizedFrankWolfe`:
11+
- `DifferentiableFrankWolfe.jl`
12+
- `FrankWolfe.jl`
13+
- `ImplicitDifferentiation.jl`
1114
1215
# Fields
1316
1417
- `linear_maximizer`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C`
1518
- `Ω`: regularization function `Ω(y)`
1619
- `Ω_grad`: gradient function of the regularization function `∇Ω(y)`
1720
- `frank_wolfe_kwargs`: named tuple of keyword arguments passed to the Frank-Wolfe algorithm
21+
- `implicit_kwargs`: named tuple of keyword arguments passed to the implicit differentiation algorithm (in particular, the needed linear solver)
1822
1923
# Frank-Wolfe parameters
2024
@@ -29,24 +33,19 @@ Some values you can tune:
2933
3034
See the documentation of FrankWolfe.jl for details.
3135
"""
32-
struct RegularizedFrankWolfe{M,RF,RG,FWK} <: AbstractRegularized
36+
struct RegularizedFrankWolfe{M,RF,RG,FWK,IK} <: AbstractRegularized
3337
linear_maximizer::M
3438
Ω::RF
3539
Ω_grad::RG
3640
frank_wolfe_kwargs::FWK
37-
end
38-
39-
"""
40-
RegularizedFrankWolfe(linear_maximizer; Ω, Ω_grad, frank_wolfe_kwargs=(;))
41-
"""
42-
function RegularizedFrankWolfe(linear_maximizer; Ω, Ω_grad, frank_wolfe_kwargs=NamedTuple())
43-
return RegularizedFrankWolfe(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs)
41+
implicit_kwargs::IK
4442
end
4543

4644
function Base.show(io::IO, regularized::RegularizedFrankWolfe)
47-
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized
45+
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs, implicit_kwargs) = regularized
4846
return print(
49-
io, "RegularizedFrankWolfe($linear_maximizer, , $Ω_grad, $frank_wolfe_kwargs)"
47+
io,
48+
"RegularizedFrankWolfe($linear_maximizer, , $Ω_grad, $frank_wolfe_kwargs, $implicit_kwargs)",
5049
)
5150
end
5251

test/InferOptTestUtils/src/pipeline.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ function test_pipeline!(
2121
perf_storage = init_perf()
2222

2323
for ep in 1:epochs
24-
@info ep
2524
update_perf!(
2625
pl,
2726
perf_storage;

0 commit comments

Comments
 (0)