Skip to content

Commit 6874197

Browse files
committed
rework interface
1 parent 9a245f5 commit 6874197

File tree

5 files changed

+39
-10
lines changed

5 files changed

+39
-10
lines changed

src/imitation/fenchel_young_loss.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545

4646
function fenchel_young_loss_and_grad(
4747
fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs...
48-
) where {O<:AbstractRegularized{<:GeneralizedMaximizer}}
48+
) where {O<:AbstractRegularizedGeneralizedMaximizer}
4949
(; optimization_layer) = fyl
5050
= optimization_layer(θ; kwargs...)
5151
Ωy_true = compute_regularization(optimization_layer, y_true)

src/regularized/abstract_regularized.jl

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
AbstractRegularized{parallel} <: AbstractOptimizationLayer
2+
AbstractRegularized <: AbstractOptimizationLayer
33
4-
Convex regularization perturbation of a black box optimizer
4+
Convex regularization perturbation of a black box linear optimizer
55
```
66
ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)}
77
```
@@ -17,7 +17,29 @@ Convex regularization perturbation of a black box optimizer
1717
- [`SparseArgmax`](@ref)
1818
- [`RegularizedFrankWolfe`](@ref)
1919
"""
20-
abstract type AbstractRegularized{O} <: AbstractOptimizationLayer end
20+
abstract type AbstractRegularized <: AbstractOptimizationLayer end
21+
22+
"""
23+
AbstractRegularizedGeneralizedMaximizer <: AbstractRegularized
24+
25+
Convex regularization perturbation of a black box **generalized** optimizer
26+
```
27+
ŷ(θ) = argmax_{y ∈ C} {θᵀg(y) + h(y) - Ω(y)}
28+
with g and h functions of y.
29+
```
30+
31+
# Interface
32+
33+
- `(regularized::AbstractRegularized)(θ; kwargs...)`: return `ŷ(θ)`
34+
- `compute_regularization(regularized, y)`: return `Ω(y)`
35+
36+
# Available implementations
37+
38+
- [`SoftArgmax`](@ref)
39+
- [`SparseArgmax`](@ref)
40+
- [`RegularizedFrankWolfe`](@ref)
41+
"""
42+
abstract type AbstractRegularizedGeneralizedMaximizer <: AbstractRegularized end
2143

2244
"""
2345
compute_regularization(regularized, y)
@@ -26,9 +48,16 @@ Return the convex penalty `Ω(y)` associated with an `AbstractRegularized` layer
2648
"""
2749
function compute_regularization end
2850

51+
@required AbstractRegularized begin
52+
#(regularized::AbstractRegularized)(θ::AbstractArray; kwargs...)
53+
compute_regularization(::AbstractRegularized, ::AbstractArray)
54+
end
55+
56+
"""
57+
TODO
58+
"""
2959
function get_maximizer end
3060

31-
@required AbstractRegularized begin
32-
compute_regularization(::AbstractRegularized, ::Any)
33-
get_maximizer(::AbstractRegularized)
61+
@required AbstractRegularizedGeneralizedMaximizer begin
62+
get_maximizer(::AbstractRegularizedGeneralizedMaximizer)
3463
end

src/regularized/regularized_frank_wolfe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Some values you can tune:
2929
3030
See the documentation of FrankWolfe.jl for details.
3131
"""
32-
struct RegularizedFrankWolfe{M,RF,RG,FWK} <: AbstractRegularized{M}
32+
struct RegularizedFrankWolfe{M,RF,RG,FWK} <: AbstractRegularized
3333
linear_maximizer::M
3434
Ω::RF
3535
Ω_grad::RG

src/regularized/soft_argmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Soft argmax activation function `s(z) = (e^zᵢ / ∑ e^zⱼ)ᵢ`.
55
66
Corresponds to regularized prediction on the probability simplex with entropic penalty.
77
"""
8-
struct SoftArgmax <: AbstractRegularized{nothing} end
8+
struct SoftArgmax <: AbstractRegularized end
99

1010
(::SoftArgmax)(z; kwargs...) = soft_argmax(z)
1111
compute_regularization(::SoftArgmax, y) = soft_argmax_regularization(y)

src/regularized/sparse_argmax.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Compute the Euclidean projection of the vector `z` onto the probability simplex.
55
66
Corresponds to regularized prediction on the probability simplex with square norm penalty.
77
"""
8-
struct SparseArgmax <: AbstractRegularized{nothing} end
8+
struct SparseArgmax <: AbstractRegularized end
99

1010
(::SparseArgmax)(z; kwargs...) = sparse_argmax(z)
1111
compute_regularization(::SparseArgmax, y) = sparse_argmax_regularization(y)

0 commit comments

Comments
 (0)