@@ -7,14 +7,18 @@ Regularized optimization layer which relies on the Frank-Wolfe algorithm to defi
77```
88
99!!! warning "Warning"
10- Since this is a conditional dependency, you need to have loaded the package DifferentiableFrankWolfe.jl before using `RegularizedFrankWolfe`.
10+ Since this is a conditional dependency, you need to have loaded the following packages before using `RegularizedFrankWolfe`:
11+ - `DifferentiableFrankWolfe.jl`
12+ - `FrankWolfe.jl`
13+ - `ImplicitDifferentiation.jl`
1114
1215# Fields
1316
1417- `linear_maximizer`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C`
1518- `Ω`: regularization function `Ω(y)`
1619- `Ω_grad`: gradient function of the regularization function `∇Ω(y)`
1720- `frank_wolfe_kwargs`: named tuple of keyword arguments passed to the Frank-Wolfe algorithm
21+ - `implicit_kwargs`: named tuple of keyword arguments passed to the implicit differentiation algorithm (in particular, the needed linear solver)
1822
1923# Frank-Wolfe parameters
2024
@@ -29,24 +33,19 @@ Some values you can tune:
2933
3034See the documentation of FrankWolfe.jl for details.
3135"""
32- struct RegularizedFrankWolfe{M,RF,RG,FWK} <: AbstractRegularized
36+ struct RegularizedFrankWolfe{M,RF,RG,FWK,IK } <: AbstractRegularized
3337 linear_maximizer:: M
3438 Ω:: RF
3539 Ω_grad:: RG
3640 frank_wolfe_kwargs:: FWK
37- end
38-
39- """
40- RegularizedFrankWolfe(linear_maximizer; Ω, Ω_grad, frank_wolfe_kwargs=(;))
41- """
42- function RegularizedFrankWolfe (linear_maximizer; Ω, Ω_grad, frank_wolfe_kwargs= NamedTuple ())
43- return RegularizedFrankWolfe (linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs)
41+ implicit_kwargs:: IK
4442end
4543
4644function Base. show (io:: IO , regularized:: RegularizedFrankWolfe )
47- (; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized
45+ (; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs, implicit_kwargs ) = regularized
4846 return print (
49- io, " RegularizedFrankWolfe($linear_maximizer , $Ω , $Ω_grad , $frank_wolfe_kwargs )"
47+ io,
48+ " RegularizedFrankWolfe($linear_maximizer , $Ω , $Ω_grad , $frank_wolfe_kwargs , $implicit_kwargs )" ,
5049 )
5150end
5251
0 commit comments