Skip to content

Commit 16b1e5d

Browse files
committed
Improve documentation
1 parent 762379f commit 16b1e5d

File tree

12 files changed

+224
-77
lines changed

12 files changed

+224
-77
lines changed

docs/src/algorithms.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
# API Reference
22

3-
## Index
4-
5-
```@index
6-
Modules = [InferOpt]
7-
```
8-
93
## Probability distributions
104

115
```@autodocs
@@ -48,6 +42,9 @@ Pages = ["ssvm/isbaseloss.jl", "ssvm/ssvm_loss.jl", "ssvm/zeroone_baseloss.jl"]
4842
!!! note "Reference"
4943
[Efficient and Modular Implicit Differentiation](http://arxiv.org/abs/2105.15183)
5044

45+
!!! note "Reference"
46+
[FrankWolfe.jl: a high-performance and flexible toolbox for Frank-Wolfe algorithms and Conditional Gradients](https://arxiv.org/abs/2104.06675)
47+
5148
```@autodocs
5249
Modules = [InferOpt]
5350
Pages = ["frank_wolfe/frank_wolfe_utils.jl", "frank_wolfe/differentiable_frank_wolfe.jl"]
@@ -60,7 +57,7 @@ Pages = ["frank_wolfe/frank_wolfe_utils.jl", "frank_wolfe/differentiable_frank_w
6057

6158
```@autodocs
6259
Modules = [InferOpt]
63-
Pages = ["regularized/frank_wolfe.jl", "regularized/isregularized.jl", "regularized/soft_argmax.jl", "regularized/sparse_argmax.jl", "regularized/regularized_generic.jl", "regularized/regularized_utils.jl"]
60+
Pages = ["regularized/isregularized.jl", "regularized/regularized_generic.jl", "regularized/regularized_utils.jl", "regularized/soft_argmax.jl", "regularized/sparse_argmax.jl"]
6461
```
6562

6663
## Perturbed optimizers
@@ -70,7 +67,7 @@ Pages = ["regularized/frank_wolfe.jl", "regularized/isregularized.jl", "regulari
7067

7168
```@autodocs
7269
Modules = [InferOpt]
73-
Pages = ["perturbed/abstract_perturbed.jl", "perturbed/additive.jl", "perturbed/composition.jl", "perturbed/multiplicative.jl"]
70+
Pages = ["perturbed/abstract_perturbed.jl", "perturbed/additive.jl", "perturbed/multiplicative.jl"]
7471
```
7572

7673
## Fenchel-Young losses
@@ -80,5 +77,11 @@ Pages = ["perturbed/abstract_perturbed.jl", "perturbed/additive.jl", "perturbed/
8077

8178
```@autodocs
8279
Modules = [InferOpt]
83-
Pages = ["fenchel_young/fenchel_young.jl"]
80+
Pages = ["fenchel_young/fenchel_young.jl", "fenchel_young/perturbed.jl"]
8481
```
82+
83+
## Index
84+
85+
```@index
86+
Modules = [InferOpt]
87+
```

src/InferOpt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ export compute_probability_distribution
4949
export Interpolation
5050

5151
export DifferentiableFrankWolfe
52+
export LMOWrapper
5253

5354
export half_square_norm
5455
export shannon_entropy, negative_shannon_entropy

src/frank_wolfe/differentiable_frank_wolfe.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
"""
44
DifferentiableFrankWolfe{F,G,M,S}
55
6-
Parameterized version of the Frank-Wolfe algorithm `θ -> argmin_{x ∈ C} f(x, θ)`.
7-
8-
Compatible with implicit differentiation.
6+
Parameterized version of the Frank-Wolfe algorithm `θ -> argmin_{x ∈ C} f(x, θ)`, which can be differentiated implicitly wrt `θ`.
97
108
# Fields
119
- `f::F`: function `f(x, θ)` to minimize wrt `x`
1210
- `f_grad1::G`: gradient `∇ₓf(x, θ)` of `f` wrt `x`
13-
- `lmo::M`: linear minimization oracle `θ -> argmin_{x ∈ C} θᵀx`
14-
- `linear_solver::S`: solver for linear systems of equations
11+
- `lmo::M`: linear minimization oracle `θ -> argmin_{x ∈ C} θᵀx` which implicitly defines the polytope `C`
12+
- `linear_solver::S`: solver for linear systems of equations, used during implicit differentiation
13+
14+
# Applicable methods
15+
16+
- [`compute_probability_distribution(dfw::DifferentiableFrankWolfe, θ, x0)`](@ref)
17+
- `(dfw::DifferentiableFrankWolfe)(θ, x0)`
18+
1519
"""
1620
struct DifferentiableFrankWolfe{F,G,M<:LinearMinimizationOracle,S}
1721
f::F
@@ -26,11 +30,19 @@ end
2630

2731
## Forward pass
2832

33+
"""
34+
compute_probability_distribution(dfw::DifferentiableFrankWolfe, θ, x0[; fw_kwargs=(;)])
35+
36+
Compute the optimal active set by applying the away-step Frank-Wolfe algorithm with initial point `x0`, then turn it into a probability distribution.
37+
38+
The named tuple `fw_kwargs` is passed as keyword arguments to `FrankWolfe.away_frank_wolfe`.
39+
"""
2940
function compute_probability_distribution(
3041
dfw::DifferentiableFrankWolfe,
3142
θ::AbstractArray{<:Real},
3243
x0::AbstractArray{<:Real};
3344
fw_kwargs=(;),
45+
kwargs...,
3446
)
3547
(; f, f_grad1, lmo) = dfw
3648
obj(x) = f(x, θ)
@@ -44,8 +56,13 @@ function compute_probability_distribution(
4456
return probadist
4557
end
4658

59+
"""
60+
(dfw::DifferentiableFrankWolfe)(θ, x0[; fw_kwargs=(;)])
61+
62+
Apply `compute_probability_distribution(dfw, θ, x0)` and return the expectation.
63+
"""
4764
function (dfw::DifferentiableFrankWolfe)(
48-
θ::AbstractArray{<:Real}, x0::AbstractArray{<:Real}; fw_kwargs=(;)
65+
θ::AbstractArray{<:Real}, x0::AbstractArray{<:Real}; fw_kwargs=(;), kwargs...
4966
)
5067
probadist = compute_probability_distribution(dfw, θ, x0; fw_kwargs=fw_kwargs)
5168
return compute_expectation(probadist)
@@ -74,6 +91,7 @@ function ChainRulesCore.rrule(
7491
θ::AbstractArray{R1},
7592
x0::AbstractArray{R2};
7693
fw_kwargs=(;),
94+
kwargs...,
7795
) where {R1<:Real,R2<:Real}
7896
R = promote_type(R1, R2)
7997
(; linear_solver) = dfw
@@ -88,8 +106,8 @@ function ChainRulesCore.rrule(
88106
pullback_Aᵀ = last rrule_via_ad(rc, conditions_p, p)[2]
89107
pullback_Bᵀ = last rrule_via_ad(rc, conditions_θ, θ)[2]
90108

91-
mul_Aᵀ!(res, u::AbstractVector) = res .= vec(pullback_Aᵀ(reshape(u, size(p))))
92-
mul_Bᵀ!(res, v::AbstractVector) = res .= vec(pullback_Bᵀ(reshape(v, size(p))))
109+
mul_Aᵀ!(res, u::AbstractVector) = res .= vec(pullback_Aᵀ(u))
110+
mul_Bᵀ!(res, v::AbstractVector) = res .= vec(pullback_Bᵀ(v))
93111

94112
n, m = length(θ), length(p)
95113
Aᵀ = LinearOperator(R, m, m, false, false, mul_Aᵀ!)

src/frank_wolfe/frank_wolfe_utils.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
Default configuration for the Frank-Wolfe wrapper.
55
66
# Parameters
7-
- `away_steps`
8-
- `epsilon`
9-
- `lazy`
10-
- `line_search`
11-
- `max_iteration`
12-
- `timeout`
13-
- `verbose`
7+
- `away_steps=true`: activate away steps to avoid zig-zagging
8+
- `epsilon=1e-2`: precision
9+
- `lazy=true`: caching strategy
10+
- `line_search=FrankWolfe.Adaptive()`: step size selection
11+
- `max_iteration=10`: number of iterations
12+
- `timeout=1.0`: maximum time in seconds
13+
- `verbose=false`: console output
1414
"""
1515
const DEFAULT_FRANK_WOLFE_KWARGS = (
1616
away_steps=true,
@@ -24,13 +24,25 @@ const DEFAULT_FRANK_WOLFE_KWARGS = (
2424

2525
## Wrapper for linear maximizers to use them within Frank-Wolfe
2626

27+
"""
28+
LMOWrapper{F,K}
29+
30+
Wraps a linear maximizer as a `FrankWolfe.LinearMinimizationOracle`.
31+
32+
# Fields
33+
- `maximizer::F`: black box linear maximizer
34+
- `maximizer_kwargs::K`: keyword arguments passed to the maximizer whenever it is called
35+
"""
2736
struct LMOWrapper{F,K} <: LinearMinimizationOracle
2837
maximizer::F
2938
maximizer_kwargs::K
3039
end
3140

3241
LMOWrapper(maximizer) = LMOWrapper(maximizer, (;))
3342

43+
"""
44+
FrankWolfe.compute_extreme_point(lmo_wrapper::LMOWrapper, direction)
45+
"""
3446
function FrankWolfe.compute_extreme_point(lmo_wrapper::LMOWrapper, direction; kwargs...)
3547
(; maximizer, maximizer_kwargs) = lmo_wrapper
3648
v = maximizer(-direction; maximizer_kwargs...)

src/perturbed/abstract_perturbed.jl

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,40 @@
11
"""
22
AbstractPerturbed
33
4-
Differentiable perturbation of a black-box optimizer.
4+
Differentiable perturbation of a black box optimizer.
5+
6+
# Applicable functions
7+
8+
- [`compute_probability_distribution(perturbed::AbstractPerturbed, θ)`](@ref)
9+
- `(perturbed::AbstractPerturbed)(θ)`
510
611
# Available subtypes
7-
- [`PerturbedAdditive{F}`](@ref)
8-
- [`PerturbedMultiplicative{F}`](@ref)
912
10-
# Required fields
13+
- [`PerturbedAdditive`](@ref)
14+
- [`PerturbedMultiplicative`](@ref)
15+
16+
These subtypes share the following fields:
17+
18+
- `maximizer`: black box optimizer
19+
- `ε`: magnitude of the perturbation
1120
- `rng::AbstractRNG`: random number generator
1221
- `seed::Union{Nothing,Int}`: random seed
1322
- `nb_samples::Int`: number of random samples for Monte-Carlo computations
1423
"""
15-
abstract type AbstractPerturbed{F} end
24+
abstract type AbstractPerturbed end
1625

26+
"""
27+
sample_perturbations(perturbed::AbstractPerturbed, θ)
28+
29+
Draw random perturbations `Z` which will be applied to the objective direction `θ`.
30+
"""
1731
function sample_perturbations(perturbed::AbstractPerturbed, θ::AbstractArray{<:Real})
1832
(; rng, seed, nb_samples) = perturbed
1933
seed!(rng, seed)
2034
Z_samples = [randn(rng, size(θ)) for _ in 1:nb_samples]
2135
return Z_samples
2236
end
2337

24-
"""
25-
perturb_and_optimize(perturbed, θ, Z; kwargs...)
26-
"""
27-
function perturb_and_optimize(
28-
perturbed::AbstractPerturbed,
29-
θ::AbstractArray{<:Real},
30-
Z::AbstractArray{<:Real};
31-
kwargs...,
32-
)
33-
return error("Not implemented")
34-
end
35-
3638
function compute_probability_distribution(
3739
perturbed::AbstractPerturbed,
3840
θ::AbstractArray{<:Real},
@@ -45,23 +47,24 @@ function compute_probability_distribution(
4547
return probadist
4648
end
4749

50+
"""
51+
compute_probability_distribution(perturbed::AbstractPerturbed, θ)
52+
53+
Turn random perturbations of `θ` into a distribution on polytope vertices.
54+
"""
4855
function compute_probability_distribution(
4956
perturbed::AbstractPerturbed, θ::AbstractArray{<:Real}; kwargs...
5057
)
5158
Z_samples = sample_perturbations(perturbed, θ)
5259
return compute_probability_distribution(perturbed, θ, Z_samples; kwargs...)
5360
end
5461

62+
"""
63+
(perturbed::AbstractPerturbed)(θ)
64+
65+
Apply `compute_probability_distribution(perturbed, θ)` and return the expectation.
66+
"""
5567
function (perturbed::AbstractPerturbed)(θ::AbstractArray{<:Real}; kwargs...)
5668
probadist = compute_probability_distribution(perturbed, θ; kwargs...)
5769
return compute_expectation(probadist)
5870
end
59-
60-
function ChainRulesCore.rrule(
61-
::typeof(compute_probability_distribution),
62-
perturbed::AbstractPerturbed,
63-
θ::AbstractArray{<:Real};
64-
kwargs...,
65-
)
66-
return error("Not implemented")
67-
end

src/perturbed/additive.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
44
Differentiable normal perturbation of a black-box optimizer: the input undergoes `θ -> θ + εZ` where `Z ∼ N(0, I)`.
55
6-
See also: [`AbstractPerturbed{F}`](@ref).
6+
See also: [`AbstractPerturbed`](@ref).
77
88
Reference: <https://arxiv.org/abs/2002.08676>
99
"""
10-
struct PerturbedAdditive{F,R<:AbstractRNG,S<:Union{Nothing,Int}} <: AbstractPerturbed{F}
10+
struct PerturbedAdditive{F,R<:AbstractRNG,S<:Union{Nothing,Int}} <: AbstractPerturbed
1111
maximizer::F
1212
ε::Float64
1313
rng::R

src/perturbed/multiplicative.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
44
Differentiable log-normal perturbation of a black-box optimizer: the input undergoes `θ -> θ ⊙ exp[εZ - ε²/2]` where `Z ∼ N(0, I)`.
55
6-
See also: [`AbstractPerturbed{F}`](@ref).
6+
See also: [`AbstractPerturbed`](@ref).
7+
8+
Reference: preprint coming soon.
79
"""
8-
struct PerturbedMultiplicative{F,R<:AbstractRNG,S<:Union{Nothing,Int}} <:
9-
AbstractPerturbed{F}
10+
struct PerturbedMultiplicative{F,R<:AbstractRNG,S<:Union{Nothing,Int}} <: AbstractPerturbed
1011
maximizer::F
1112
ε::Float64
1213
rng::R

src/regularized/isregularized.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@ For `predictor::P` to comply with this interface, the following methods must exi
1111
- [`one_hot_argmax`](@ref)
1212
- [`soft_argmax`](@ref)
1313
- [`sparse_argmax`](@ref)
14+
- [`RegularizedGeneric`](@ref)
1415
"""
1516
@traitdef IsRegularized{P}
1617

17-
@traitfn function compute_regularization(predictor::P, y) where {P; IsRegularized{P}} end
18+
"""
19+
compute_regularization(predictor::P, y)
20+
21+
Compute the convex regularization function `Ω(y)`.
22+
"""
23+
function compute_regularization end

src/regularized/regularized_generic.jl

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
"""
22
RegularizedGeneric{M,RF,RG,F,G,S}
33
4-
Generic and differentiable regularized prediction function `ŷ(θ) = argmax {θᵀy - Ω(y)}`.
4+
Differentiable regularized prediction function `ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)}`.
55
66
Relies on the Frank-Wolfe algorithm to minimize a concave objective on a polytope.
77
88
# Fields
9-
- `maximizer::M`
10-
- `Ω::RF`
11-
- `Ω_grad::RG`
12-
- `f::F`
13-
- `f_grad1::G`
14-
- `linear_solver::S`
9+
- `maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx` which implicitly defines the polytope `C`
10+
- `Ω::RF`: regularization function `Ω(y)`
11+
- `Ω_grad::RG`: gradient of the regularization function `∇Ω(y)`
12+
- `f::F`: objective function `f(x, θ) = Ω(y) - θᵀy` minimized by Frank-Wolfe (computed automatically)
13+
- `f_grad1::G`: gradient of the objective function `∇ₓf(x, θ) = ∇Ω(y) - θ` with respect to `x` (computed automatically)
14+
- `linear_solver::S`: solver for linear systems of equations, used during implicit differentiation
15+
16+
# Applicable methods
17+
18+
- [`compute_probability_distribution(regularized::RegularizedGeneric, θ)`](@ref)
19+
- `(regularized::RegularizedGeneric)(θ)`
1520
1621
See also: [`DifferentiableFrankWolfe`](@ref).
1722
"""
@@ -29,6 +34,11 @@ function Base.show(io::IO, regularized::RegularizedGeneric)
2934
return print(io, "RegularizedGeneric($maximizer, , $Ω_grad, $linear_solver)")
3035
end
3136

37+
"""
38+
RegularizedGeneric(maximizer, Ω, Ω_grad[; linear_solver=gmres])
39+
40+
Short form constructor with a default linear solver.
41+
"""
3242
function RegularizedGeneric(maximizer, Ω, Ω_grad; linear_solver=gmres)
3343
f(y, θ) = Ω(y) - dot(θ, y)
3444
f_grad1(y, θ) = Ω_grad(y) - θ
@@ -43,6 +53,14 @@ end
4353

4454
## Forward pass
4555

56+
"""
57+
compute_probability_distribution(regularized::RegularizedGeneric, θ[; maximizer_kwargs=(;), fw_kwargs=(;)])
58+
59+
Construct a [`DifferentiableFrankWolfe`](@ref) struct and call `compute_probability_distribution` on it.
60+
61+
The named tuple `maximizer_kwargs` is passed as keyword arguments to the underlying maximizer, which is wrapped inside a [`LMOWrapper`](@ref).
62+
The named tuple `fw_kwargs` is passed as keyword arguments to `FrankWolfe.away_frank_wolfe`.
63+
"""
4664
function compute_probability_distribution(
4765
regularized::RegularizedGeneric,
4866
θ::AbstractArray{<:Real};
@@ -58,8 +76,13 @@ function compute_probability_distribution(
5876
return probadist
5977
end
6078

79+
"""
80+
(regularized::RegularizedGeneric)(θ[; maximizer_kwargs=(;), fw_kwargs=(;)])
81+
82+
Apply `compute_probability_distribution(regularized, θ)` and return the expectation.
83+
"""
6184
function (regularized::RegularizedGeneric)(
62-
θ::AbstractArray{<:Real}; maximizer_kwargs=(;), fw_kwargs=(;)
85+
θ::AbstractArray{<:Real}; maximizer_kwargs=(;), fw_kwargs=(;), kwargs...
6386
)
6487
probadist = compute_probability_distribution(
6588
regularized, θ; maximizer_kwargs=maximizer_kwargs, fw_kwargs=fw_kwargs

0 commit comments

Comments
 (0)