33"""
44 DifferentiableFrankWolfe{F,G,M,S}
55
6- Parameterized version of the Frank-Wolfe algorithm `θ -> argmin_{x ∈ C} f(x, θ)`.
7-
8- Compatible with implicit differentiation.
6+ Parameterized version of the Frank-Wolfe algorithm `θ -> argmin_{x ∈ C} f(x, θ)`, which can be differentiated implicitly wrt `θ`.
97
108# Fields
119- `f::F`: function `f(x, θ)` to minimize wrt `x`
1210- `f_grad1::G`: gradient `∇ₓf(x, θ)` of `f` wrt `x`
13- - `lmo::M`: linear minimization oracle `θ -> argmin_{x ∈ C} θᵀx`
14- - `linear_solver::S`: solver for linear systems of equations
11+ - `lmo::M`: linear minimization oracle `θ -> argmin_{x ∈ C} θᵀx` which implicitly defines the polytope `C`
12+ - `linear_solver::S`: solver for linear systems of equations, used during implicit differentiation
13+
14+ # Applicable methods
15+
16+ - [`compute_probability_distribution(dfw::DifferentiableFrankWolfe, θ, x0)`](@ref)
17+ - `(dfw::DifferentiableFrankWolfe)(θ, x0)`
18+
1519"""
1620struct DifferentiableFrankWolfe{F,G,M<: LinearMinimizationOracle ,S}
1721 f:: F
2630
2731# # Forward pass
2832
33+ """
34+ compute_probability_distribution(dfw::DifferentiableFrankWolfe, θ, x0[; fw_kwargs=(;)])
35+
36+ Compute the optimal active set by applying the away-step Frank-Wolfe algorithm with initial point `x0`, then turn it into a probability distribution.
37+
38+ The named tuple `fw_kwargs` is passed as keyword arguments to `FrankWolfe.away_frank_wolfe`.
39+ """
2940function compute_probability_distribution (
3041 dfw:: DifferentiableFrankWolfe ,
3142 θ:: AbstractArray{<:Real} ,
3243 x0:: AbstractArray{<:Real} ;
3344 fw_kwargs= (;),
45+ kwargs... ,
3446)
3547 (; f, f_grad1, lmo) = dfw
3648 obj (x) = f (x, θ)
@@ -44,8 +56,13 @@ function compute_probability_distribution(
4456 return probadist
4557end
4658
59+ """
60+ (dfw::DifferentiableFrankWolfe)(θ, x0[; fw_kwargs=(;)])
61+
62+ Apply `compute_probability_distribution(dfw, θ, x0)` and return the expectation.
63+ """
4764function (dfw:: DifferentiableFrankWolfe )(
48- θ:: AbstractArray{<:Real} , x0:: AbstractArray{<:Real} ; fw_kwargs= (;)
65+ θ:: AbstractArray{<:Real} , x0:: AbstractArray{<:Real} ; fw_kwargs= (;), kwargs ...
4966)
5067 probadist = compute_probability_distribution (dfw, θ, x0; fw_kwargs= fw_kwargs)
5168 return compute_expectation (probadist)
@@ -74,6 +91,7 @@ function ChainRulesCore.rrule(
7491 θ:: AbstractArray{R1} ,
7592 x0:: AbstractArray{R2} ;
7693 fw_kwargs= (;),
94+ kwargs... ,
7795) where {R1<: Real ,R2<: Real }
7896 R = promote_type (R1, R2)
7997 (; linear_solver) = dfw
@@ -88,8 +106,8 @@ function ChainRulesCore.rrule(
88106 pullback_Aᵀ = last ∘ rrule_via_ad (rc, conditions_p, p)[2 ]
89107 pullback_Bᵀ = last ∘ rrule_via_ad (rc, conditions_θ, θ)[2 ]
90108
91- mul_Aᵀ! (res, u:: AbstractVector ) = res .= vec (pullback_Aᵀ (reshape (u, size (p)) ))
92- mul_Bᵀ! (res, v:: AbstractVector ) = res .= vec (pullback_Bᵀ (reshape (v, size (p)) ))
109+ mul_Aᵀ! (res, u:: AbstractVector ) = res .= vec (pullback_Aᵀ (u ))
110+ mul_Bᵀ! (res, v:: AbstractVector ) = res .= vec (pullback_Bᵀ (v ))
93111
94112 n, m = length (θ), length (p)
95113 Aᵀ = LinearOperator (R, m, m, false , false , mul_Aᵀ!)
0 commit comments