Skip to content

Commit ffb5ad7

Browse files
committed
Make Frank-Wolfe regularized predictor compatible with generic arrays
1 parent 53f01dc commit ffb5ad7

File tree

7 files changed

+75
-74
lines changed

7 files changed

+75
-74
lines changed

src/frank_wolfe/differentiable_frank_wolfe.jl

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ end
2727
## Forward pass
2828

2929
function optimal_active_set(
30-
dfw::DifferentiableFrankWolfe, θ::AbstractArray{<:Real}; fw_kwargs=(;)
30+
dfw::DifferentiableFrankWolfe,
31+
θ::AbstractArray{<:Real},
32+
x0::AbstractArray{<:Real};
33+
fw_kwargs=(;),
3134
)
3235
(; f, ∇ₓf, lmo) = dfw
3336
obj(x) = f(x, θ)
3437
grad!(g, x) = g .= ∇ₓf(x, θ)
35-
x0 = compute_extreme_point(lmo, zero(θ))
3638
full_fw_kwargs = merge(DEFAULT_FRANK_WOLFE_KWARGS, fw_kwargs)
3739
x, v, primal, dual_gap, traj_data, active_set = away_frank_wolfe(
3840
obj, grad!, lmo, x0; full_fw_kwargs...
@@ -41,62 +43,66 @@ function optimal_active_set(
4143
return active_set
4244
end
4345

44-
function (dfw::DifferentiableFrankWolfe)(θ::AbstractArray{<:Real}; fw_kwargs=(;))
45-
active_set::ActiveSet = optimal_active_set(dfw, θ; fw_kwargs=fw_kwargs)
46+
function (dfw::DifferentiableFrankWolfe)(
47+
θ::AbstractArray{<:Real}, x0::AbstractArray{<:Real}; fw_kwargs=(;)
48+
)
49+
active_set::ActiveSet = optimal_active_set(dfw, θ, x0; fw_kwargs=fw_kwargs)
4650
return active_set.x
4751
end
4852

49-
## Backward pass, only works with vectors
53+
## Backward pass
5054

5155
function frank_wolfe_optimality_conditions(
5256
dfw::DifferentiableFrankWolfe,
53-
θ::AbstractVector{<:Real},
57+
θ::AbstractArray{<:Real},
5458
p::AbstractVector{<:Real},
55-
V::AbstractMatrix{<:Real},
59+
A::AbstractVector{<:AbstractArray{<:Real}},
5660
)
5761
(; ∇ₓf) = dfw
58-
∇ₚg = V' * ∇ₓf(V * p, θ)
62+
x = sum(pᵢ * Aᵢ for (pᵢ, Aᵢ) in zip(p, A))
63+
b = ∇ₓf(x, θ)
64+
∇ₚg = [dot(Aᵢ, b) for Aᵢ in A]
5965
T = sparse_argmax(p - ∇ₚg)
6066
return T - p
6167
end
6268

6369
function ChainRulesCore.rrule(
64-
rc::RuleConfig, dfw::DifferentiableFrankWolfe, θ::AbstractVector{<:Real}; fw_kwargs=(;)
65-
)
70+
rc::RuleConfig,
71+
dfw::DifferentiableFrankWolfe,
72+
θ::AbstractArray{R1},
73+
x0::AbstractArray{R2};
74+
fw_kwargs=(;),
75+
) where {R1<:Real,R2<:Real}
76+
R = promote_type(R1, R2)
6677
(; linear_solver) = dfw
6778

68-
active_set::ActiveSet = optimal_active_set(dfw, θ; fw_kwargs=fw_kwargs)
69-
V = reduce(hcat, active_set.atoms)
79+
active_set::ActiveSet = optimal_active_set(dfw, θ, x0; fw_kwargs=fw_kwargs)
80+
A = active_set.atoms
7081
p = active_set.weights
71-
n, m = length(θ), length(p)
82+
x = active_set.x
7283

73-
conditions_θ(θ_bis) = frank_wolfe_optimality_conditions(dfw, θ_bis, p, V)
74-
conditions_p(p_bis) = -frank_wolfe_optimality_conditions(dfw, θ, p_bis, V)
84+
conditions_θ(θ_bis) = frank_wolfe_optimality_conditions(dfw, θ_bis, p, A)
85+
conditions_p(p_bis) = -frank_wolfe_optimality_conditions(dfw, θ, p_bis, A)
7586

7687
pullback_Aᵀ = last rrule_via_ad(rc, conditions_p, p)[2]
7788
pullback_Bᵀ = last rrule_via_ad(rc, conditions_θ, θ)[2]
7889

79-
mul_Aᵀ!(res, v) = res .= pullback_Aᵀ(v)
80-
mul_Bᵀ!(res, v) = res .= pullback_Bᵀ(v)
90+
mul_Aᵀ!(res, u::AbstractVector) = res .= vec(pullback_Aᵀ(reshape(u, size(p))))
91+
mul_Bᵀ!(res, v::AbstractVector) = res .= vec(pullback_Bᵀ(reshape(v, size(p))))
8192

82-
Aᵀ = LinearOperator(Float64, m, m, false, false, mul_Aᵀ!)
83-
Bᵀ = LinearOperator(Float64, n, m, false, false, mul_Bᵀ!)
93+
n, m = length(θ), length(p)
94+
Aᵀ = LinearOperator(R, m, m, false, false, mul_Aᵀ!)
95+
Bᵀ = LinearOperator(R, n, m, false, false, mul_Bᵀ!)
8496

8597
function frank_wolfe_pullback(dx)
86-
dp = V' * Vector(unthunk(dx))
98+
dx = unthunk(dx)
99+
dp = [dot(Aᵢ, dx) for Aᵢ in A]
87100
u, stats = linear_solver(Aᵀ, dp)
88-
if !stats.solved
89-
error("The linear solver failed to converge")
90-
end
91-
= Bᵀ * u
92-
return (NoTangent(), dθ)
101+
stats.solved || error("Linear solver failed to converge")
102+
dθ_vec = Bᵀ * u
103+
= reshape(dθ_vec, size(θ))
104+
return (NoTangent(), dθ, NoTangent())
93105
end
94106

95-
return active_set.x, frank_wolfe_pullback
96-
end
97-
98-
function ChainRulesCore.rrule(
99-
rc::RuleConfig, dfw::DifferentiableFrankWolfe, θ::AbstractArray{<:Real}; fw_kwargs=(;)
100-
)
101-
throw(ArgumentError("θ must be a vector and not a higher-dimensional array"))
107+
return x, frank_wolfe_pullback
102108
end
Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
11
"""
2-
RegularizedGeneric{RF,RG,F,G,M,S}
2+
RegularizedGeneric{M,RF,RG,F,G,S}
33
44
Generic and differentiable regularized prediction function `ŷ(θ) = argmax {θᵀy - Ω(y)}`.
55
66
Relies on the Frank-Wolfe algorithm to minimize a concave objective on a polytope.
77
88
# Fields
9+
- `maximizer::M`
910
- `Ω::RF`
1011
- `∇Ω::RG`
1112
- `f::F`
1213
- `∇ₓf::G`
13-
- `maximizer::M`
1414
- `linear_solver::S`
1515
1616
See also: [`DifferentiableFrankWolfe`](@ref).
1717
"""
18-
struct RegularizedGeneric{RF,RG,F,G,M,S}
18+
struct RegularizedGeneric{M,RF,RG,F,G,S}
19+
maximizer::M
1920
Ω::RF
2021
∇Ω::RG
2122
f::F
2223
∇ₓf::G
23-
maximizer::M
2424
linear_solver::S
2525
end
2626

2727
function Base.show(io::IO, regularized::RegularizedGeneric)
28-
(; Ω, ∇Ω, maximizer, linear_solver) = regularized
29-
return print(io, "RegularizedGeneric($Ω, $Ω, $maximizer, $linear_solver)")
28+
(; maximizer, Ω, ∇Ω, linear_solver) = regularized
29+
return print(io, "RegularizedGeneric($maximizer, , $∇Ω, $linear_solver)")
3030
end
3131

32-
function RegularizedGeneric(Ω, ∇Ω, maximizer; linear_solver=gmres)
32+
function RegularizedGeneric(maximizer, Ω, ∇Ω; linear_solver=gmres)
3333
f(y, θ) = Ω(y) - dot(θ, y)
3434
∇ₓf(y, θ) = ∇Ω(y) - θ
35-
return RegularizedGeneric(Ω, ∇Ω, f, ∇ₓf, maximizer, linear_solver)
35+
return RegularizedGeneric(maximizer, Ω, ∇Ω, f, ∇ₓf, linear_solver)
3636
end
3737

3838
@traitimpl IsRegularized{RegularizedGeneric}
@@ -43,15 +43,6 @@ end
4343

4444
## Forward pass
4545

46-
function (regularized::RegularizedGeneric)(
47-
θ::AbstractArray{<:Real}; maximizer_kwargs=(;), fw_kwargs=(;)
48-
)
49-
(; f, ∇ₓf, maximizer, linear_solver) = regularized
50-
lmo = LMOWrapper(maximizer, maximizer_kwargs)
51-
dfw = DifferentiableFrankWolfe(f, ∇ₓf, lmo, linear_solver)
52-
return dfw(θ; fw_kwargs=fw_kwargs)
53-
end
54-
5546
function optimal_active_set(
5647
regularized::RegularizedGeneric,
5748
θ::AbstractArray{<:Real};
@@ -61,30 +52,33 @@ function optimal_active_set(
6152
(; f, ∇ₓf, maximizer, linear_solver) = regularized
6253
lmo = LMOWrapper(maximizer, maximizer_kwargs)
6354
dfw = DifferentiableFrankWolfe(f, ∇ₓf, lmo, linear_solver)
64-
return optimal_active_set(dfw, θ; fw_kwargs=fw_kwargs)
55+
x0 = compute_extreme_point(lmo, θ)
56+
return optimal_active_set(dfw, θ, x0; fw_kwargs=fw_kwargs)
57+
end
58+
59+
function (regularized::RegularizedGeneric)(
60+
θ::AbstractArray{<:Real}; maximizer_kwargs=(;), fw_kwargs=(;)
61+
)
62+
active_set = optimal_active_set(
63+
regularized, θ; maximizer_kwargs=maximizer_kwargs, fw_kwargs=fw_kwargs
64+
)
65+
return active_set.x
6566
end
6667

6768
## Backward pass, only works with vectors
6869

6970
function ChainRulesCore.rrule(
7071
rc::RuleConfig,
7172
regularized::RegularizedGeneric,
72-
θ::AbstractVector{<:Real};
73+
θ::AbstractArray{<:Real};
7374
maximizer_kwargs=(;),
7475
fw_kwargs=(;),
7576
)
7677
(; f, ∇ₓf, maximizer, linear_solver) = regularized
7778
lmo = LMOWrapper(maximizer, maximizer_kwargs)
7879
dfw = DifferentiableFrankWolfe(f, ∇ₓf, lmo, linear_solver)
79-
return rrule(rc, dfw, θ; fw_kwargs=fw_kwargs)
80-
end
81-
82-
function ChainRulesCore.rrule(
83-
rc::RuleConfig,
84-
regularized::RegularizedGeneric,
85-
θ::AbstractArray{<:Real};
86-
maximizer_kwargs=(;),
87-
fw_kwargs=(;),
88-
)
89-
throw(ArgumentError("θ must be a vector and not a higher-dimensional array"))
80+
x0 = compute_extreme_point(lmo, θ)
81+
x, frank_wolfe_pullback = rrule(rc, dfw, θ, x0; fw_kwargs=fw_kwargs)
82+
regularized_generic_pullback(dx) = frank_wolfe_pullback(dx)[1:2]
83+
return x, regularized_generic_pullback
9084
end

test/argmax.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ pipelines_imitation_y = [
6262
encoder=encoder_factory(),
6363
maximizer=identity,
6464
loss=FenchelYoungLoss(
65-
RegularizedGeneric(half_square_norm, identity, true_maximizer)
65+
RegularizedGeneric(true_maximizer, half_square_norm, identity)
6666
),
6767
),
6868
# Generic regularized + other loss
6969
(
7070
encoder=encoder_factory(),
71-
maximizer=RegularizedGeneric(half_square_norm, identity, true_maximizer),
71+
maximizer=RegularizedGeneric(true_maximizer, half_square_norm, identity),
7272
loss=Flux.Losses.mse,
7373
),
7474
]

test/frank_wolfe.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Zygote
88
Random.seed!(63)
99

1010
d = 100
11+
x0 = ones(d) / d;
1112
θ = rand(d);
1213
v = rand(d);
1314
rc = Zygote.ZygoteRuleConfig()
@@ -23,20 +24,20 @@ f(x, θ) = half_square_norm(x - θ)
2324
lmo = FrankWolfe.UnitSimplexOracle(1.0)
2425

2526
dfw = DifferentiableFrankWolfe(f, ∇ₓf, lmo)
26-
_, pullback_dfw = rrule_via_ad(rc, dfw, θ; fw_kwargs=fw_kwargs);
27+
_, pullback_dfw = rrule_via_ad(rc, dfw, θ, x0; fw_kwargs=fw_kwargs);
2728

2829
@testset verbose = true "DifferentiableFrankWolfe" begin
29-
@test mean(abs, dfw(θ; fw_kwargs=fw_kwargs) - sparse_argmax(θ)) < 1e-3
30+
@test mean(abs, dfw, x0; fw_kwargs=fw_kwargs) - sparse_argmax(θ)) < 1e-3
3031
@test mean(abs, pullback_dfw(v)[2] - pullback_sparse_argmax(v)[2]) < 1e-3
3132
end
3233

3334
## RegularizedGeneric
3435

36+
maximizer(θ) = one_hot_argmax(θ)
3537
Ω(y) = half_square_norm(y)
3638
∇Ω(y) = y
37-
maximizer(θ) = one_hot_argmax(θ)
3839

39-
regularized = RegularizedGeneric(Ω, ∇Ω, maximizer)
40+
regularized = RegularizedGeneric(maximizer, Ω, ∇Ω)
4041
_, pullback_regularized = rrule_via_ad(rc, regularized, θ; fw_kwargs=fw_kwargs);
4142

4243
@testset verbose = true "RegularizedGeneric" begin

test/paths.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ pipelines_imitation_y = [
5858
encoder=encoder_factory(),
5959
maximizer=identity,
6060
loss=FenchelYoungLoss(
61-
RegularizedGeneric(half_square_norm, identity, true_maximizer)
61+
RegularizedGeneric(true_maximizer, half_square_norm, identity)
6262
),
6363
),
6464
]

test/ranking.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ pipelines_imitation_y = [
5656
encoder=encoder_factory(),
5757
maximizer=identity,
5858
loss=FenchelYoungLoss(
59-
RegularizedGeneric(half_square_norm, identity, true_maximizer)
59+
RegularizedGeneric(true_maximizer, half_square_norm, identity)
6060
),
6161
),
6262
# Generic regularized + other loss
6363
(
6464
encoder=encoder_factory(),
65-
maximizer=RegularizedGeneric(half_square_norm, identity, true_maximizer),
65+
maximizer=RegularizedGeneric(true_maximizer, half_square_norm, identity),
6666
loss=Flux.Losses.mse,
6767
),
6868
]

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ includet("utils/pipeline.jl")
1818
@testset verbose = true "Jacobian approx" begin
1919
include("jacobian_approx.jl")
2020
end
21-
@testset verbose = true "Frank-Wolfe" begin
22-
include("frank_wolfe.jl")
23-
end
2421
@testset verbose = true "Argmax" begin
2522
include("argmax.jl")
2623
end
@@ -30,6 +27,9 @@ includet("utils/pipeline.jl")
3027
@testset verbose = true "Paths" begin
3128
include("paths.jl")
3229
end
30+
@testset verbose = true "Frank-Wolfe" begin
31+
include("frank_wolfe.jl")
32+
end
3333
@testset verbose = true "Tutorial" begin
3434
include("tutorial.jl")
3535
end

0 commit comments

Comments
 (0)