2727# # Forward pass
2828
2929function optimal_active_set (
30- dfw:: DifferentiableFrankWolfe , θ:: AbstractArray{<:Real} ; fw_kwargs= (;)
30+ dfw:: DifferentiableFrankWolfe ,
31+ θ:: AbstractArray{<:Real} ,
32+ x0:: AbstractArray{<:Real} ;
33+ fw_kwargs= (;),
3134)
3235 (; f, ∇ₓf, lmo) = dfw
3336 obj (x) = f (x, θ)
3437 grad! (g, x) = g .= ∇ₓf (x, θ)
35- x0 = compute_extreme_point (lmo, zero (θ))
3638 full_fw_kwargs = merge (DEFAULT_FRANK_WOLFE_KWARGS, fw_kwargs)
3739 x, v, primal, dual_gap, traj_data, active_set = away_frank_wolfe (
3840 obj, grad!, lmo, x0; full_fw_kwargs...
@@ -41,62 +43,66 @@ function optimal_active_set(
4143 return active_set
4244end
4345
44- function (dfw:: DifferentiableFrankWolfe )(θ:: AbstractArray{<:Real} ; fw_kwargs= (;))
45- active_set:: ActiveSet = optimal_active_set (dfw, θ; fw_kwargs= fw_kwargs)
46+ function (dfw:: DifferentiableFrankWolfe )(
47+ θ:: AbstractArray{<:Real} , x0:: AbstractArray{<:Real} ; fw_kwargs= (;)
48+ )
49+ active_set:: ActiveSet = optimal_active_set (dfw, θ, x0; fw_kwargs= fw_kwargs)
4650 return active_set. x
4751end
4852
49- # # Backward pass, only works with vectors
53+ # # Backward pass
5054
5155function frank_wolfe_optimality_conditions (
5256 dfw:: DifferentiableFrankWolfe ,
53- θ:: AbstractVector {<:Real} ,
57+ θ:: AbstractArray {<:Real} ,
5458 p:: AbstractVector{<:Real} ,
55- V :: AbstractMatrix {<:Real} ,
59+ A :: AbstractVector {<:AbstractArray{<: Real} } ,
5660)
5761 (; ∇ₓf) = dfw
58- ∇ₚg = V' * ∇ₓf (V * p, θ)
62+ x = sum (pᵢ * Aᵢ for (pᵢ, Aᵢ) in zip (p, A))
63+ b = ∇ₓf (x, θ)
64+ ∇ₚg = [dot (Aᵢ, b) for Aᵢ in A]
5965 T = sparse_argmax (p - ∇ₚg)
6066 return T - p
6167end
6268
6369function ChainRulesCore. rrule (
64- rc:: RuleConfig , dfw:: DifferentiableFrankWolfe , θ:: AbstractVector{<:Real} ; fw_kwargs= (;)
65- )
70+ rc:: RuleConfig ,
71+ dfw:: DifferentiableFrankWolfe ,
72+ θ:: AbstractArray{R1} ,
73+ x0:: AbstractArray{R2} ;
74+ fw_kwargs= (;),
75+ ) where {R1<: Real ,R2<: Real }
76+ R = promote_type (R1, R2)
6677 (; linear_solver) = dfw
6778
68- active_set:: ActiveSet = optimal_active_set (dfw, θ; fw_kwargs= fw_kwargs)
69- V = reduce (hcat, active_set. atoms)
79+ active_set:: ActiveSet = optimal_active_set (dfw, θ, x0 ; fw_kwargs= fw_kwargs)
80+ A = active_set. atoms
7081 p = active_set. weights
71- n, m = length (θ), length (p)
82+ x = active_set . x
7283
73- conditions_θ (θ_bis) = frank_wolfe_optimality_conditions (dfw, θ_bis, p, V )
74- conditions_p (p_bis) = - frank_wolfe_optimality_conditions (dfw, θ, p_bis, V )
84+ conditions_θ (θ_bis) = frank_wolfe_optimality_conditions (dfw, θ_bis, p, A )
85+ conditions_p (p_bis) = - frank_wolfe_optimality_conditions (dfw, θ, p_bis, A )
7586
7687 pullback_Aᵀ = last ∘ rrule_via_ad (rc, conditions_p, p)[2 ]
7788 pullback_Bᵀ = last ∘ rrule_via_ad (rc, conditions_θ, θ)[2 ]
7889
79- mul_Aᵀ! (res, v ) = res .= pullback_Aᵀ (v )
80- mul_Bᵀ! (res, v) = res .= pullback_Bᵀ (v )
90+ mul_Aᵀ! (res, u :: AbstractVector ) = res .= vec ( pullback_Aᵀ (reshape (u, size (p))) )
91+ mul_Bᵀ! (res, v:: AbstractVector ) = res .= vec ( pullback_Bᵀ (reshape (v, size (p))) )
8192
82- Aᵀ = LinearOperator (Float64, m, m, false , false , mul_Aᵀ!)
83- Bᵀ = LinearOperator (Float64, n, m, false , false , mul_Bᵀ!)
93+ n, m = length (θ), length (p)
94+ Aᵀ = LinearOperator (R, m, m, false , false , mul_Aᵀ!)
95+ Bᵀ = LinearOperator (R, n, m, false , false , mul_Bᵀ!)
8496
8597 function frank_wolfe_pullback (dx)
86- dp = V' * Vector (unthunk (dx))
98+ dx = unthunk (dx)
99+ dp = [dot (Aᵢ, dx) for Aᵢ in A]
87100 u, stats = linear_solver (Aᵀ, dp)
88- if ! stats. solved
89- error (" The linear solver failed to converge" )
90- end
91- dθ = Bᵀ * u
92- return (NoTangent (), dθ)
101+ stats. solved || error (" Linear solver failed to converge" )
102+ dθ_vec = Bᵀ * u
103+ dθ = reshape (dθ_vec, size (θ))
104+ return (NoTangent (), dθ, NoTangent ())
93105 end
94106
95- return active_set. x, frank_wolfe_pullback
96- end
97-
98- function ChainRulesCore. rrule (
99- rc:: RuleConfig , dfw:: DifferentiableFrankWolfe , θ:: AbstractArray{<:Real} ; fw_kwargs= (;)
100- )
101- throw (ArgumentError (" θ must be a vector and not a higher-dimensional array" ))
107+ return x, frank_wolfe_pullback
102108end
0 commit comments