Skip to content

Commit de4575c

Browse files
committed
cleanup
1 parent 5807766 commit de4575c

File tree

7 files changed

+80
-14
lines changed

7 files changed

+80
-14
lines changed

src/InferOpt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ include("losses/ssvm_loss.jl")
6262
include("losses/zero_one_loss.jl")
6363
include("losses/imitation_loss.jl")
6464

65+
export compute_probability_distribution
66+
6567
export half_square_norm
6668
export shannon_entropy, negative_shannon_entropy
6769
export one_hot_argmax, ranking

src/layers/perturbed/perturbation.jl

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ $TYPEDEF
44
Abstract type for a perturbation.
55
It's a function that takes a parameter `θ` and returns a perturbed parameter by a distribution `perturbation_dist`.
66
7-
All subtypes should have a `perturbation_dist`
7+
!!! warning
8+
All subtypes should implement a `perturbation_dist` field, which is a `ContinuousUnivariateDistribution`.
89
910
# Existing implementations
1011
- [`AdditivePerturbation`](@ref)
@@ -44,19 +45,37 @@ function (pdc::AdditivePerturbation)(θ::AbstractArray)
4445
return product_distribution.+ ε * perturbation_dist)
4546
end
4647

48+
"""
49+
$TYPEDEF
50+
51+
Method with parameters to compute the gradient of the logdensity of η = θ + εZ w.r.t. θ., with Z ∼ N(0, 1).
52+
53+
# Fields
54+
$TYPEDFIELDS
55+
"""
56+
struct NormalAdditiveGradLogdensity
57+
"perturbation size"
58+
ε::Float64
59+
end
60+
61+
function NormalAdditiveGradLogdensity(pdc::AdditivePerturbation)
62+
return NormalAdditiveGradLogdensity(pdc.ε)
63+
end
64+
4765
"""
4866
$TYPEDSIGNATURES
4967
5068
Compute the gradient of the logdensity of η = θ + εZ w.r.t. θ., with Z ∼ N(0, 1).
5169
"""
52-
function normal_additive_grad_logdensity(ε, η, θ)
70+
function (f::NormalAdditiveGradLogdensity)(η::AbstractArray, θ::AbstractArray)
71+
(; ε) = f
5372
return ((η .- θ) ./ ε^2,)
5473
end
5574

5675
"""
5776
$TYPEDEF
5877
59-
Multiplicative perturbation: θ ↦ θ ⊙ exp(εZ - ε²/2)
78+
Multiplicative perturbation: θ ↦ θ ⊙ exp(εZ - shift)
6079
6180
# Fields
6281
$TYPEDFIELDS
@@ -66,6 +85,17 @@ struct MultiplicativePerturbation{F}
6685
perturbation_dist::F
6786
"perturbation size"
6887
ε::Float64
88+
"optional shift to have 0 mean, default value is ε²/2"
89+
shift::Float64
90+
end
91+
92+
"""
93+
$TYPEDSIGNATURES
94+
95+
Constructor for [`MultiplicativePerturbation`](@ref).
96+
"""
97+
function MultiplicativePerturbation(perturbation_dist, ε, shift=ε^2 / 2)
98+
return MultiplicativePerturbation(perturbation_dist, ε, shift)
6999
end
70100

71101
"""
@@ -74,16 +104,42 @@ $TYPEDSIGNATURES
74104
Apply the multiplicative perturbation to the parameter `θ`.
75105
"""
76106
function (pdc::MultiplicativePerturbation)(θ::AbstractArray)
77-
(; perturbation_dist, ε) = pdc
78-
return product_distribution.* ExponentialOf* perturbation_dist - ε^2 / 2))
107+
(; perturbation_dist, ε, shift) = pdc
108+
return product_distribution.* ExponentialOf* perturbation_dist - shift))
79109
end
110+
111+
"""
112+
$TYPEDEF
113+
114+
Method with parameters to compute the gradient of the logdensity of η = θ ⊙ exp(εZ - shift) w.r.t. θ., with Z ∼ N(0, 1).
115+
116+
# Fields
117+
$TYPEDFIELDS
118+
"""
119+
struct NormalMultiplicativeGradLogdensity
120+
"perturbation size"
121+
ε::Float64
122+
"optional shift to have 0 mean"
123+
shift::Float64
124+
end
125+
126+
function NormalMultiplicativeGradLogdensity(pdc::MultiplicativePerturbation)
127+
return NormalMultiplicativeGradLogdensity(pdc.ε, pdc.shift)
128+
end
129+
130+
function NormalMultiplicativeGradLogdensity::Float64, shift=ε^2 / 2)
131+
return NormalMultiplicativeGradLogdensity(ε, shift)
132+
end
133+
80134
"""
81135
$TYPEDSIGNATURES
82136
83-
Compute the gradient of the logdensity of η = θ ⊙ exp(εZ - ε²/2) w.r.t. θ., with Z ∼ N(0, 1).
137+
Compute the gradient of the logdensity of η = θ ⊙ exp(εZ - shift) w.r.t. θ., with Z ∼ N(0, 1).
138+
84139
!!! warning
85140
η should be a realization of θ, i.e. should be of the same sign.
86141
"""
87-
function normal_multiplicative_grad_logdensity(ε, η, θ)
88-
return (inv.(ε^2 .* θ) .* (log.(abs.(η)) - log.(abs.(θ)) .+^2 / 2)),)
142+
function (f::NormalMultiplicativeGradLogdensity)(η::AbstractArray, θ::AbstractArray)
143+
(; ε, shift) = f
144+
return (inv.(ε^2 .* θ) .* (log.(abs.(η)) - log.(abs.(θ)) .+ shift),)
89145
end

src/layers/perturbed/perturbed.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function PerturbedAdditive(
9393
threaded=false,
9494
rng=Random.default_rng(),
9595
dist_logdensity_grad=if (perturbation_dist == Normal(0, 1))
96-
FixFirst(normal_additive_grad_logdensity, ε)
96+
NormalAdditiveGradLogdensity(ε)
9797
else
9898
nothing
9999
end,
@@ -126,7 +126,7 @@ function PerturbedMultiplicative(
126126
threaded=false,
127127
rng=Random.default_rng(),
128128
dist_logdensity_grad=if (perturbation_dist == Normal(0, 1))
129-
FixFirst(normal_multiplicative_grad_logdensity, ε)
129+
NormalMultiplicativeGradLogdensity(float(ε))
130130
else
131131
nothing
132132
end,

src/layers/perturbed/utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
$TYPEDSIGNATURES
33
44
Data structure modeling the exponential of a continuous univariate random variable.
5+
6+
`Random.rand` and `Distributions.logpdf` are defined for the [`ExponentialOf`](@ref) distribution.
57
"""
68
struct ExponentialOf{D<:ContinuousUnivariateDistribution} <:
79
ContinuousUnivariateDistribution
@@ -19,7 +21,7 @@ end
1921
$TYPEDSIGNATURES
2022
2123
Return the log-density of the [`ExponentialOf`](@ref) distribution at `x`.
22-
It is equal to ``logpdf(d, log(x)) - log(x)``
24+
It is equal to ``logpdf(d, log(x)) - log(x)``.
2325
"""
2426
function Distributions.logpdf(d::ExponentialOf, x::Real)
2527
return logpdf(d.dist, log(x)) - log(x)

src/layers/regularized/abstract_regularized.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ with g and h functions of y.
1010
# Interface
1111
- `(regularized::AbstractRegularized)(θ; kwargs...)`: return `ŷ(θ)`
1212
- `compute_regularization(regularized, y)`: return `Ω(y)
13-
- `get_maximizer(regularized)`: return the associated `GeneralizedMaximizer` optimizer
13+
- `get_maximizer(regularized)`: return the associated optimizer
1414
1515
# Available implementations
1616
- [`SoftArgmax`](@ref)

src/losses/fenchel_young_loss.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ Reference: <https://arxiv.org/abs/1901.02324>
1010
1111
# Fields
1212
- `optimization_layer::AbstractOptimizationLayer`: optimization layer that can be formulated as `ŷ(θ) = argmax {θᵀy - Ω(y)}` (either regularized or perturbed)
13+
14+
# Compatibility
15+
This loss is compatible with:
16+
- [`LinearMaximizer`](@ref)-based layers.
17+
- [`PerturbedOracle`](@ref) layers, with additive or multiplicative perturbations (generic perturbations are not supported).
18+
- any [`AbstractRegularized`](@ref) layer.
1319
"""
1420
struct FenchelYoungLoss{O<:AbstractOptimizationLayer} <: AbstractLossLayer
1521
optimization_layer::O

src/utils/linear_maximizer.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ function (f::LinearMaximizer)(θ::AbstractArray; kwargs...)
4343
end
4444

4545
# default is oracles of the form argmax_y θᵀy
46-
objective_value(::Any, θ, y; kwargs...) = dot(θ, y)
47-
apply_g(::Any, y; kwargs...) = y
46+
@inline objective_value(::Any, θ, y; kwargs...) = dot(θ, y)
47+
@inline apply_g(::Any, y; kwargs...) = y
4848

4949
"""
5050
$TYPEDSIGNATURES

0 commit comments

Comments
 (0)