Skip to content

Commit f70453c

Browse files
committed
add more tests
1 parent 8f7f103 commit f70453c

File tree

4 files changed

+44
-11
lines changed

4 files changed

+44
-11
lines changed

src/InferOpt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ export Pushforward
6060
export IdentityRelaxation
6161
export Interpolation
6262

63-
export AbstractRegularized
63+
export AbstractRegularized, AbstractRegularizedGeneralizedMaximizer
6464
export SoftArgmax, soft_argmax
6565
export SparseArgmax, sparse_argmax
6666
export RegularizedFrankWolfe

src/utils/generalized_maximizer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct GeneralizedMaximizer{F,G,H}
1414
h::H
1515
end
1616

17-
GeneralizedMaximizer(f; g=identity, h=zero) = GeneralizedMaximizer(f, g, h)
17+
GeneralizedMaximizer(f; g=identity_kw, h=zero eltype_kw) = GeneralizedMaximizer(f, g, h)
1818

1919
function Base.show(io::IO, f::GeneralizedMaximizer)
2020
(; maximizer, g, h) = f

src/utils/some_functions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,6 @@ end
6767

6868
zero_regularization(y) = zero(eltype(y))
6969
zero_gradient(y) = zero(y)
70+
71+
identity_kw(x; kwargs...) = identity(x)
72+
eltype_kw(x; kwargs...) = eltype(x)

test/generalized_maximizer.jl

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
@test y == [1 0 1; 0 1 0; 1 1 1]
1515

16-
generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h)
16+
generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
1717

1818
@test generalized_maximizer(θ; instance) == y
1919

@@ -29,7 +29,7 @@ end
2929

3030
true_encoder = encoder_factory()
3131

32-
generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h)
32+
generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
3333
function cost(y; instance)
3434
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
3535
end
@@ -54,7 +54,7 @@ end
5454

5555
true_encoder = encoder_factory()
5656

57-
generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h)
57+
generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
5858
function cost(y; instance)
5959
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
6060
end
@@ -78,7 +78,7 @@ end
7878

7979
true_encoder = encoder_factory()
8080

81-
generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h)
81+
generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
8282
function cost(y; instance)
8383
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
8484
end
@@ -105,7 +105,7 @@ end
105105

106106
true_encoder = encoder_factory()
107107

108-
generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h)
108+
generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
109109
function cost(y; instance)
110110
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
111111
end
@@ -131,7 +131,7 @@ end
131131

132132
true_encoder = encoder_factory()
133133

134-
generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h)
134+
generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
135135
function cost(y; instance)
136136
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
137137
end
@@ -155,7 +155,7 @@ end
155155

156156
true_encoder = encoder_factory()
157157

158-
generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h)
158+
generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
159159
function cost(y; instance)
160160
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
161161
end
@@ -180,7 +180,7 @@ end
180180

181181
true_encoder = encoder_factory()
182182

183-
generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h)
183+
generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
184184
function cost(y; instance)
185185
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
186186
end
@@ -207,7 +207,7 @@ end
207207

208208
true_encoder = encoder_factory()
209209

210-
generalized_maximizer = GeneralizedMaximizer(max_pricing, g, h)
210+
generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h)
211211
function cost(y; instance)
212212
return -objective_value(generalized_maximizer, true_encoder(instance), y; instance)
213213
end
@@ -225,3 +225,33 @@ end
225225
cost,
226226
)
227227
end
228+
229+
@testitem "Regularized with a GeneralizedMaximizer" default_imports = false begin
230+
include("InferOptTestUtils/InferOptTestUtils.jl")
231+
using InferOpt, .InferOptTestUtils, Random, RequiredInterfaces, Test
232+
const RI = RequiredInterfaces
233+
Random.seed!(63)
234+
235+
struct MyRegularized{M<:GeneralizedMaximizer} <: AbstractRegularizedGeneralizedMaximizer
236+
maximizer::M
237+
end
238+
239+
(regularized::MyRegularized)(θ; kwargs...) = regularized.maximizer(θ; kwargs...)
240+
function InferOpt.compute_regularization(regularized::MyRegularized, y::AbstractArray)
241+
return InferOpt.sparse_argmax_regularization(y)
242+
end
243+
InferOpt.get_maximizer(regularized::MyRegularized) = regularized.maximizer
244+
245+
@test RI.check_interface_implemented(AbstractRegularized, MyRegularized)
246+
247+
regularized = MyRegularized(GeneralizedMaximizer(sparse_argmax))
248+
249+
test_pipeline!(
250+
PipelineLossImitation;
251+
instance_dim=5,
252+
true_maximizer=one_hot_argmax,
253+
maximizer=identity_kw,
254+
loss=FenchelYoungLoss(regularized),
255+
error_function=hamming_distance,
256+
)
257+
end

0 commit comments

Comments
 (0)