Skip to content

Commit 2b0c8f4

Browse files
committed
Add RegularizedGeneric to prediction tests
1 parent 006cb47 commit 2b0c8f4

File tree

7 files changed

+102
-42
lines changed

7 files changed

+102
-42
lines changed

src/InferOpt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ export Interpolation
4242

4343
export DifferentiableFrankWolfe
4444

45-
export shannon_entropy, half_square_norm
45+
export half_square_norm
46+
export shannon_entropy, negative_shannon_entropy
4647
export one_hot_argmax, ranking
4748
export IsRegularized
4849
export soft_argmax, sparse_argmax

src/frank_wolfe/differentiable_frank_wolfe.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ function frank_wolfe_optimality_conditions(
6161
end
6262

6363
function ChainRulesCore.rrule(
64-
rc::RuleConfig, dfw::DifferentiableFrankWolfe, θ::AbstractVector; fw_kwargs=(;)
64+
rc::RuleConfig, dfw::DifferentiableFrankWolfe, θ::AbstractVector{<:Real}; fw_kwargs=(;)
6565
)
6666
(; linear_solver) = dfw
6767

@@ -94,3 +94,9 @@ function ChainRulesCore.rrule(
9494

9595
return active_set.x, frank_wolfe_pullback
9696
end
97+
98+
function ChainRulesCore.rrule(
99+
rc::RuleConfig, dfw::DifferentiableFrankWolfe, θ::AbstractArray{<:Real}; fw_kwargs=(;)
100+
)
101+
throw(ArgumentError("θ must be a vector and not a higher-dimensional array"))
102+
end

src/regularized/regularized_generic.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ struct RegularizedGeneric{ΩF,ΩG,F,G,M,S}
2424
linear_solver::S
2525
end
2626

27+
function Base.show(io::IO, regularized::RegularizedGeneric)
28+
(; Ω, ∇Ω, maximizer, linear_solver) = regularized
29+
return print(io, "RegularizedGeneric(, $∇Ω, $maximizer, $linear_solver)")
30+
end
31+
2732
function RegularizedGeneric(Ω, ∇Ω, maximizer; linear_solver=gmres)
2833
f(y, θ) = Ω(y) - dot(θ, y)
2934
∇ₓf(y, θ) = ∇Ω(y) - θ
@@ -39,7 +44,7 @@ end
3944
## Forward pass
4045

4146
function (regularized::RegularizedGeneric)(
42-
θ::AbstractArray; maximizer_kwargs=(;), fw_kwargs=(;)
47+
θ::AbstractArray{<:Real}; maximizer_kwargs=(;), fw_kwargs=(;)
4348
)
4449
(; f, ∇ₓf, maximizer, linear_solver) = regularized
4550
lmo = LMOWrapper(maximizer, maximizer_kwargs)
@@ -61,3 +66,13 @@ function ChainRulesCore.rrule(
6166
dfw = DifferentiableFrankWolfe(f, ∇ₓf, lmo, linear_solver)
6267
return rrule(rc, dfw, θ; fw_kwargs=fw_kwargs)
6368
end
69+
70+
function ChainRulesCore.rrule(
71+
rc::RuleConfig,
72+
regularized::RegularizedGeneric,
73+
θ::AbstractArray{<:Real};
74+
maximizer_kwargs=(;),
75+
fw_kwargs=(;),
76+
)
77+
throw(ArgumentError("θ must be a vector and not a higher-dimensional array"))
78+
end

src/regularized/soft_argmax.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ Soft argmax activation function `s(z) = (e^zᵢ / ∑ e^zⱼ)ᵢ`.
66
Corresponds to regularized prediction on the probability simplex with entropic penalty.
77
"""
88
function soft_argmax(z::AbstractVector{<:Real}; kwargs...)
9-
s = exp.(z)
10-
s ./= sum(s)
9+
s = exp.(z) / sum(exp, z)
1110
return s
1211
end
1312

test/argmax.jl

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,54 +17,72 @@ error_function(ŷ, y) = hamming_distance(ŷ, y)
1717

1818
## Pipelines
1919

20-
pipelines_imitation_θ = [(
21-
encoder=encoder_factory(), maximizer=identity, loss=SPOPlusLoss(true_maximizer)
22-
)]
20+
pipelines_imitation_θ = [
21+
# SPO+
22+
(encoder=encoder_factory(), maximizer=identity, loss=SPOPlusLoss(true_maximizer)),
23+
]
2324

2425
pipelines_imitation_y = [
25-
# Fenchel-Young loss (test forward pass)
26+
# Structured SVM
2627
(
2728
encoder=encoder_factory(),
2829
maximizer=identity,
29-
loss=FenchelYoungLoss(PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=3)),
30+
loss=StructuredSVMLoss(ZeroOneBaseLoss()),
31+
),
32+
# Perturbed + FYL
33+
(
34+
encoder=encoder_factory(),
35+
maximizer=identity,
36+
loss=FenchelYoungLoss(PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=5)),
3037
),
3138
(
3239
encoder=encoder_factory(),
3340
maximizer=identity,
3441
loss=FenchelYoungLoss(PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=5)),
3542
),
36-
# Other differentiable loss (test backward pass)
43+
# Perturbed + other loss
3744
(
3845
encoder=encoder_factory(),
39-
maximizer=PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=3),
46+
maximizer=PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=10),
4047
loss=Flux.Losses.mse,
4148
),
4249
(
4350
encoder=encoder_factory(),
44-
maximizer=PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=5),
51+
maximizer=PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=10),
4552
loss=Flux.Losses.mse,
4653
),
47-
# Structured SVM
54+
# Explicit regularized + FYL
55+
(encoder=encoder_factory(), maximizer=identity, loss=FenchelYoungLoss(sparse_argmax)),
56+
(encoder=encoder_factory(), maximizer=identity, loss=FenchelYoungLoss(soft_argmax)),
57+
# Explicit regularized + other loss
58+
(encoder=encoder_factory(), maximizer=sparse_argmax, loss=Flux.Losses.mse),
59+
(encoder=encoder_factory(), maximizer=soft_argmax, loss=Flux.Losses.mse),
60+
# Generic regularized + FYL
4861
(
4962
encoder=encoder_factory(),
5063
maximizer=identity,
51-
loss=StructuredSVMLoss(ZeroOneBaseLoss()),
64+
loss=FenchelYoungLoss(
65+
RegularizedGeneric(half_square_norm, identity, true_maximizer)
66+
),
67+
),
68+
# Generic regularized + other loss
69+
(
70+
encoder=encoder_factory(),
71+
maximizer=RegularizedGeneric(half_square_norm, identity, true_maximizer),
72+
loss=Flux.Losses.mse,
5273
),
53-
# Regularized prediction: explicit
54-
(encoder=encoder_factory(), maximizer=identity, loss=FenchelYoungLoss(sparse_argmax)),
55-
(encoder=encoder_factory(), maximizer=identity, loss=FenchelYoungLoss(soft_argmax)),
5674
]
5775

5876
pipelines_experience = [
5977
(
6078
encoder=encoder_factory(),
6179
maximizer=identity,
62-
loss=cost PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=3),
80+
loss=cost PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=10),
6381
),
6482
(
6583
encoder=encoder_factory(),
6684
maximizer=identity,
67-
loss=cost PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=5),
85+
loss=cost PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=10),
6886
),
6987
]
7088

test/paths.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ end
2525

2626
## Pipelines
2727

28-
pipelines_imitation_θ = [(
29-
encoder=encoder_factory(), maximizer=identity, loss=SPOPlusLoss(true_maximizer)
30-
)]
28+
pipelines_imitation_θ = [
29+
# SPO+
30+
(encoder=encoder_factory(), maximizer=identity, loss=SPOPlusLoss(true_maximizer)),
31+
]
3132

3233
pipelines_imitation_y = [
33-
# Fenchel-Young loss (test forward pass)
34+
# Perturbed + FYL
3435
(
3536
encoder=encoder_factory(),
3637
maximizer=identity,
@@ -39,21 +40,27 @@ pipelines_imitation_y = [
3940
(
4041
encoder=encoder_factory(),
4142
maximizer=identity,
42-
loss=FenchelYoungLoss(
43-
PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=10)
44-
),
43+
loss=FenchelYoungLoss(PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=5)),
4544
),
46-
# Other differentiable loss (test backward pass)
45+
# Perturbed + other loss
4746
(
4847
encoder=encoder_factory(),
49-
maximizer=PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=5),
48+
maximizer=PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=10),
5049
loss=Flux.Losses.mse,
5150
),
5251
(
5352
encoder=encoder_factory(),
5453
maximizer=PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=10),
5554
loss=Flux.Losses.mse,
5655
),
56+
# Generic regularized + FYL
57+
(
58+
encoder=encoder_factory(),
59+
maximizer=identity,
60+
loss=FenchelYoungLoss(
61+
RegularizedGeneric(half_square_norm, identity, true_maximizer)
62+
),
63+
),
5764
]
5865

5966
pipelines_experience = [

test/ranking.jl

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,52 @@ error_function(ŷ, y) = hamming_distance(ŷ, y)
1717

1818
## Pipelines
1919

20-
pipelines_imitation_θ = [(
21-
encoder=encoder_factory(), maximizer=identity, loss=SPOPlusLoss(true_maximizer)
22-
)]
20+
pipelines_imitation_θ = [
21+
# SPO+
22+
(encoder=encoder_factory(), maximizer=identity, loss=SPOPlusLoss(true_maximizer)),
23+
]
2324

2425
pipelines_imitation_y = [
25-
# Fenchel-Young loss (test forward pass)
26+
# Interpolation
27+
(
28+
encoder=encoder_factory(),
29+
maximizer=Interpolation(true_maximizer; λ=5.0),
30+
loss=Flux.Losses.mse,
31+
),
32+
# Perturbed + FYL
2633
(
2734
encoder=encoder_factory(),
2835
maximizer=identity,
29-
loss=FenchelYoungLoss(PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=3)),
36+
loss=FenchelYoungLoss(PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=5)),
3037
),
3138
(
3239
encoder=encoder_factory(),
3340
maximizer=identity,
3441
loss=FenchelYoungLoss(PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=5)),
3542
),
36-
# Other differentiable loss (test backward pass)
43+
# Perturbed + other loss
3744
(
3845
encoder=encoder_factory(),
39-
maximizer=PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=3),
46+
maximizer=PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=10),
4047
loss=Flux.Losses.mse,
4148
),
4249
(
4350
encoder=encoder_factory(),
44-
maximizer=PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=5),
51+
maximizer=PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=10),
4552
loss=Flux.Losses.mse,
4653
),
47-
# Interpolation
54+
# Generic regularized + FYL
4855
(
4956
encoder=encoder_factory(),
50-
maximizer=Interpolation(true_maximizer; λ=5.0),
57+
maximizer=identity,
58+
loss=FenchelYoungLoss(
59+
RegularizedGeneric(half_square_norm, identity, true_maximizer)
60+
),
61+
),
62+
# Generic regularized + other loss
63+
(
64+
encoder=encoder_factory(),
65+
maximizer=RegularizedGeneric(half_square_norm, identity, true_maximizer),
5166
loss=Flux.Losses.mse,
5267
),
5368
]
@@ -56,14 +71,13 @@ pipelines_experience = [
5671
(
5772
encoder=encoder_factory(),
5873
maximizer=identity,
59-
loss=cost PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=3),
74+
loss=cost PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=10),
6075
),
6176
(
6277
encoder=encoder_factory(),
6378
maximizer=identity,
64-
loss=cost PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=5),
79+
loss=cost PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=10),
6580
),
66-
(encoder=encoder_factory(), maximizer=Interpolation(true_maximizer; λ=5.0), loss=cost),
6781
]
6882

6983
## Dataset generation

0 commit comments

Comments
 (0)