@@ -17,54 +17,72 @@ error_function(ŷ, y) = hamming_distance(ŷ, y)
1717
1818# # Pipelines
1919
20- pipelines_imitation_θ = [(
21- encoder= encoder_factory (), maximizer= identity, loss= SPOPlusLoss (true_maximizer)
22- )]
20+ pipelines_imitation_θ = [
21+ # SPO+
22+ (encoder= encoder_factory (), maximizer= identity, loss= SPOPlusLoss (true_maximizer)),
23+ ]
2324
2425pipelines_imitation_y = [
25- # Fenchel-Young loss (test forward pass)
26+ # Structured SVM
2627 (
2728 encoder= encoder_factory (),
2829 maximizer= identity,
29- loss= FenchelYoungLoss (PerturbedAdditive (true_maximizer; ε= 1.0 , nb_samples= 3 )),
30+ loss= StructuredSVMLoss (ZeroOneBaseLoss ()),
31+ ),
32+ # Perturbed + FYL
33+ (
34+ encoder= encoder_factory (),
35+ maximizer= identity,
36+ loss= FenchelYoungLoss (PerturbedAdditive (true_maximizer; ε= 1.0 , nb_samples= 5 )),
3037 ),
3138 (
3239 encoder= encoder_factory (),
3340 maximizer= identity,
3441 loss= FenchelYoungLoss (PerturbedMultiplicative (true_maximizer; ε= 1.0 , nb_samples= 5 )),
3542 ),
36- # Other differentiable loss (test backward pass)
43+ # Perturbed + other loss
3744 (
3845 encoder= encoder_factory (),
39- maximizer= PerturbedAdditive (true_maximizer; ε= 1.0 , nb_samples= 3 ),
46+ maximizer= PerturbedAdditive (true_maximizer; ε= 1.0 , nb_samples= 10 ),
4047 loss= Flux. Losses. mse,
4148 ),
4249 (
4350 encoder= encoder_factory (),
44- maximizer= PerturbedMultiplicative (true_maximizer; ε= 1.0 , nb_samples= 5 ),
51+ maximizer= PerturbedMultiplicative (true_maximizer; ε= 1.0 , nb_samples= 10 ),
4552 loss= Flux. Losses. mse,
4653 ),
47- # Structured SVM
54+ # Explicit regularized + FYL
55+ (encoder= encoder_factory (), maximizer= identity, loss= FenchelYoungLoss (sparse_argmax)),
56+ (encoder= encoder_factory (), maximizer= identity, loss= FenchelYoungLoss (soft_argmax)),
57+ # Explicit regularized + other loss
58+ (encoder= encoder_factory (), maximizer= sparse_argmax, loss= Flux. Losses. mse),
59+ (encoder= encoder_factory (), maximizer= soft_argmax, loss= Flux. Losses. mse),
60+ # Generic regularized + FYL
4861 (
4962 encoder= encoder_factory (),
5063 maximizer= identity,
51- loss= StructuredSVMLoss (ZeroOneBaseLoss ()),
64+ loss= FenchelYoungLoss (
65+ RegularizedGeneric (half_square_norm, identity, true_maximizer)
66+ ),
67+ ),
68+ # Generic regularized + other loss
69+ (
70+ encoder= encoder_factory (),
71+ maximizer= RegularizedGeneric (half_square_norm, identity, true_maximizer),
72+ loss= Flux. Losses. mse,
5273 ),
53- # Regularized prediction: explicit
54- (encoder= encoder_factory (), maximizer= identity, loss= FenchelYoungLoss (sparse_argmax)),
55- (encoder= encoder_factory (), maximizer= identity, loss= FenchelYoungLoss (soft_argmax)),
5674]
5775
5876pipelines_experience = [
5977 (
6078 encoder= encoder_factory (),
6179 maximizer= identity,
62- loss= cost ∘ PerturbedAdditive (true_maximizer; ε= 1.0 , nb_samples= 3 ),
80+ loss= cost ∘ PerturbedAdditive (true_maximizer; ε= 1.0 , nb_samples= 10 ),
6381 ),
6482 (
6583 encoder= encoder_factory (),
6684 maximizer= identity,
67- loss= cost ∘ PerturbedMultiplicative (true_maximizer; ε= 1.0 , nb_samples= 5 ),
85+ loss= cost ∘ PerturbedMultiplicative (true_maximizer; ε= 1.0 , nb_samples= 10 ),
6886 ),
6987]
7088
0 commit comments