Skip to content

Commit 9a245f5

Browse files
committed
try to add support for AbstractRegularized through an extended interface
1 parent a271756 commit 9a245f5

File tree

7 files changed

+30
-5
lines changed

7 files changed

+30
-5
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
99
DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
1213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1314
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1415
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
@@ -23,6 +24,7 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe"
2324
ChainRulesCore = "1"
2425
DensityInterface = "0.4.0"
2526
DifferentiableFrankWolfe = "0.1.2"
27+
RequiredInterfaces = "0.1.3"
2628
StatsBase = "0.33, 0.34"
2729
TestItemRunner = "0.2.2"
2830
ThreadsX = "0.1.11"

src/InferOpt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using Random: AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed!
1515
using Statistics: mean
1616
using StatsBase: StatsBase, sample
1717
using ThreadsX: ThreadsX
18+
using RequiredInterfaces
1819

1920
include("interface.jl")
2021

src/imitation/fenchel_young_loss.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ function fenchel_young_loss_and_grad(
4343
return l, g
4444
end
4545

46+
function fenchel_young_loss_and_grad(
47+
fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs...
48+
) where {O<:AbstractRegularized{<:GeneralizedMaximizer}}
49+
(; optimization_layer) = fyl
50+
= optimization_layer(θ; kwargs...)
51+
Ωy_true = compute_regularization(optimization_layer, y_true)
52+
Ωŷ = compute_regularization(optimization_layer, ŷ)
53+
maximizer = get_maximizer(optimization_layer)
54+
l =
55+
(Ωy_true - objective_value(maximizer, θ, y_true; kwargs...)) -
56+
(Ωŷ - objective_value(maximizer, θ, ŷ; kwargs...))
57+
g = maximizer.g(ŷ; kwargs...) - maximizer.g(y_true; kwargs...)
58+
return l, g
59+
end
60+
4661
function fenchel_young_loss_and_grad(
4762
fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs...
4863
) where {O<:AbstractPerturbed}
@@ -61,7 +76,7 @@ function fenchel_young_loss_and_grad(
6176
optimization_layer, θ; kwargs...
6277
)
6378
l = F - objective_value(optimization_layer.oracle, θ, y_true; kwargs...)
64-
g = almost_g_of_ŷ - optimization_layer.oracle.g(y_true)
79+
g = almost_g_of_ŷ - optimization_layer.oracle.g(y_true; kwargs...)
6580
return l, g
6681
end
6782

src/regularized/abstract_regularized.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,18 @@ Convex regularization perturbation of a black box optimizer
1717
- [`SparseArgmax`](@ref)
1818
- [`RegularizedFrankWolfe`](@ref)
1919
"""
20-
abstract type AbstractRegularized <: AbstractOptimizationLayer end
20+
abstract type AbstractRegularized{O} <: AbstractOptimizationLayer end
2121

2222
"""
2323
compute_regularization(regularized, y)
2424
2525
Return the convex penalty `Ω(y)` associated with an `AbstractRegularized` layer.
2626
"""
2727
function compute_regularization end
28+
29+
function get_maximizer end
30+
31+
@required AbstractRegularized begin
32+
compute_regularization(::AbstractRegularized, ::Any)
33+
get_maximizer(::AbstractRegularized)
34+
end

src/regularized/regularized_frank_wolfe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Some values you can tune:
2929
3030
See the documentation of FrankWolfe.jl for details.
3131
"""
32-
struct RegularizedFrankWolfe{M,RF,RG,FWK} <: AbstractRegularized
32+
struct RegularizedFrankWolfe{M,RF,RG,FWK} <: AbstractRegularized{M}
3333
linear_maximizer::M
3434
Ω::RF
3535
Ω_grad::RG

src/regularized/soft_argmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Soft argmax activation function `s(z) = (e^zᵢ / ∑ e^zⱼ)ᵢ`.
55
66
Corresponds to regularized prediction on the probability simplex with entropic penalty.
77
"""
8-
struct SoftArgmax <: AbstractRegularized end
8+
struct SoftArgmax <: AbstractRegularized{nothing} end
99

1010
(::SoftArgmax)(z; kwargs...) = soft_argmax(z)
1111
compute_regularization(::SoftArgmax, y) = soft_argmax_regularization(y)

src/regularized/sparse_argmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Compute the Euclidean projection of the vector `z` onto the probability simplex.
55
66
Corresponds to regularized prediction on the probability simplex with square norm penalty.
77
"""
8-
struct SparseArgmax <: AbstractRegularized end
8+
struct SparseArgmax <: AbstractRegularized{nothing} end
99

1010
(::SparseArgmax)(z; kwargs...) = sparse_argmax(z)
1111
compute_regularization(::SparseArgmax, y) = sparse_argmax_regularization(y)

0 commit comments

Comments
 (0)