|
13 | 13 |
|
14 | 14 | @test y == [1 0 1; 0 1 0; 1 1 1] |
15 | 15 |
|
16 | | - generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h) |
| 16 | + generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) |
17 | 17 |
|
18 | 18 | @test generalized_maximizer(θ; instance) == y |
19 | 19 |
|
|
29 | 29 |
|
30 | 30 | true_encoder = encoder_factory() |
31 | 31 |
|
32 | | - generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h) |
| 32 | + generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) |
33 | 33 | function cost(y; instance) |
34 | 34 | return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) |
35 | 35 | end |
|
54 | 54 |
|
55 | 55 | true_encoder = encoder_factory() |
56 | 56 |
|
57 | | - generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h) |
| 57 | + generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) |
58 | 58 | function cost(y; instance) |
59 | 59 | return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) |
60 | 60 | end |
|
78 | 78 |
|
79 | 79 | true_encoder = encoder_factory() |
80 | 80 |
|
81 | | - generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h) |
| 81 | + generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) |
82 | 82 | function cost(y; instance) |
83 | 83 | return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) |
84 | 84 | end |
|
105 | 105 |
|
106 | 106 | true_encoder = encoder_factory() |
107 | 107 |
|
108 | | - generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h) |
| 108 | + generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) |
109 | 109 | function cost(y; instance) |
110 | 110 | return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) |
111 | 111 | end |
|
131 | 131 |
|
132 | 132 | true_encoder = encoder_factory() |
133 | 133 |
|
134 | | - generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h) |
| 134 | + generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) |
135 | 135 | function cost(y; instance) |
136 | 136 | return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) |
137 | 137 | end |
|
155 | 155 |
|
156 | 156 | true_encoder = encoder_factory() |
157 | 157 |
|
158 | | - generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h) |
| 158 | + generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) |
159 | 159 | function cost(y; instance) |
160 | 160 | return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) |
161 | 161 | end |
|
180 | 180 |
|
181 | 181 | true_encoder = encoder_factory() |
182 | 182 |
|
183 | | - generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h) |
| 183 | + generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) |
184 | 184 | function cost(y; instance) |
185 | 185 | return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) |
186 | 186 | end |
|
207 | 207 |
|
208 | 208 | true_encoder = encoder_factory() |
209 | 209 |
|
210 | | - generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h) |
| 210 | + generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) |
211 | 211 | function cost(y; instance) |
212 | 212 | return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) |
213 | 213 | end |
|
225 | 225 | cost, |
226 | 226 | ) |
227 | 227 | end |
| 228 | + |
| 229 | +@testitem "Regularized with a GeneralizedMaximizer" default_imports = false begin |
| 230 | + include("InferOptTestUtils/InferOptTestUtils.jl") |
| 231 | + using InferOpt, .InferOptTestUtils, Random, RequiredInterfaces, Test |
| 232 | + const RI = RequiredInterfaces |
| 233 | + Random.seed!(63) |
| 234 | + |
| 235 | + struct MyRegularized{M<:GeneralizedMaximizer} <: AbstractRegularizedGeneralizedMaximizer |
| 236 | + maximizer::M |
| 237 | + end |
| 238 | + |
| 239 | + (regularized::MyRegularized)(θ; kwargs...) = regularized.maximizer(θ; kwargs...) |
| 240 | + function InferOpt.compute_regularization(regularized::MyRegularized, y::AbstractArray) |
| 241 | + return InferOpt.sparse_argmax_regularization(y) |
| 242 | + end |
| 243 | + InferOpt.get_maximizer(regularized::MyRegularized) = regularized.maximizer |
| 244 | + |
| 245 | + @test RI.check_interface_implemented(AbstractRegularized, MyRegularized) |
| 246 | + |
| 247 | + regularized = MyRegularized(GeneralizedMaximizer(sparse_argmax)) |
| 248 | + |
| 249 | + test_pipeline!( |
| 250 | + PipelineLossImitation; |
| 251 | + instance_dim=5, |
| 252 | + true_maximizer=one_hot_argmax, |
| 253 | + maximizer=identity_kw, |
| 254 | + loss=FenchelYoungLoss(regularized), |
| 255 | + error_function=hamming_distance, |
| 256 | + ) |
| 257 | +end |
0 commit comments