Skip to content

Commit 799378b

Browse files
committed
Fix constructors to allow non unicode
1 parent 16b1e5d commit 799378b

File tree

6 files changed

+50
-22
lines changed

6 files changed

+50
-22
lines changed

src/frank_wolfe/differentiable_frank_wolfe.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Parameterized version of the Frank-Wolfe algorithm `θ -> argmin_{x ∈ C} f(x,
88
# Fields
99
- `f::F`: function `f(x, θ)` to minimize wrt `x`
1010
- `f_grad1::G`: gradient `∇ₓf(x, θ)` of `f` wrt `x`
11-
- `lmo::M`: linear minimization oracle `θ -> argmin_{x ∈ C} θᵀx` which implicitly defines the polytope `C`
11+
- `lmo::M`: linear minimization oracle `θ -> argmin_{x ∈ C} θᵀx`, implicitly defines the polytope `C`
1212
- `linear_solver::S`: solver for linear systems of equations, used during implicit differentiation
1313
1414
# Applicable methods
@@ -24,7 +24,7 @@ struct DifferentiableFrankWolfe{F,G,M<:LinearMinimizationOracle,S}
2424
linear_solver::S
2525
end
2626

27-
function DifferentiableFrankWolfe(f, f_grad1, lmo; linear_solver=gmres)
27+
function DifferentiableFrankWolfe(f, f_grad1, lmo, linear_solver=gmres)
2828
return DifferentiableFrankWolfe(f, f_grad1, lmo, linear_solver)
2929
end
3030

src/perturbed/abstract_perturbed.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ These subtypes share the following fields:
1717
1818
- `maximizer`: black box optimizer
1919
- `ε`: magnitude of the perturbation
20+
- `nb_samples::Int`: number of random samples for Monte-Carlo computations
2021
- `rng::AbstractRNG`: random number generator
2122
- `seed::Union{Nothing,Int}`: random seed
22-
- `nb_samples::Int`: number of random samples for Monte-Carlo computations
2323
"""
2424
abstract type AbstractPerturbed end
2525

src/perturbed/additive.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
PerturbedAdditive{F}
33
4-
Differentiable normal perturbation of a black-box optimizer: the input undergoes `θ -> θ + εZ` where `Z ∼ N(0, I)`.
4+
Differentiable normal perturbation of a black-box optimizer of type `F`: the input undergoes `θ -> θ + εZ` where `Z ∼ N(0, I)`.
55
66
See also: [`AbstractPerturbed`](@ref).
77
@@ -10,25 +10,30 @@ Reference: <https://arxiv.org/abs/2002.08676>
1010
struct PerturbedAdditive{F,R<:AbstractRNG,S<:Union{Nothing,Int}} <: AbstractPerturbed
1111
maximizer::F
1212
ε::Float64
13+
nb_samples::Int
1314
rng::R
1415
seed::S
15-
nb_samples::Int
1616
end
1717

1818
function Base.show(io::IO, perturbed::PerturbedAdditive)
1919
(; maximizer, ε, rng, seed, nb_samples) = perturbed
2020
return print(
21-
io, "PerturbedAdditive($maximizer, , $(typeof(rng)), $seed, $nb_samples)"
21+
io, "PerturbedAdditive($maximizer, , $nb_samples, $(typeof(rng)), $seed)"
2222
)
2323
end
2424

25+
"""
26+
PerturbedAdditive(maximizer[; ε=1.0, nb_samples=1])
27+
28+
Shorter constructor with defaults.
29+
"""
2530
function PerturbedAdditive(
26-
maximizer; ε=1.0, epsilon=nothing, rng=MersenneTwister(0), seed=nothing, nb_samples=2
31+
maximizer; ε=1.0, epsilon=nothing, nb_samples=1, rng=MersenneTwister(0), seed=nothing
2732
)
2833
if isnothing(epsilon)
29-
return PerturbedAdditive(maximizer, float(ε), rng, seed, nb_samples)
34+
return PerturbedAdditive(maximizer, float(ε), nb_samples, rng, seed)
3035
else
31-
return PerturbedAdditive(maximizer, float(epsilon), rng, seed, nb_samples)
36+
return PerturbedAdditive(maximizer, float(epsilon), nb_samples, rng, seed)
3237
end
3338
end
3439

src/perturbed/multiplicative.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
PerturbedMultiplicative{F}
33
4-
Differentiable log-normal perturbation of a black-box optimizer: the input undergoes `θ -> θ ⊙ exp[εZ - ε²/2]` where `Z ∼ N(0, I)`.
4+
Differentiable log-normal perturbation of a black-box optimizer of type `F`: the input undergoes `θ -> θ ⊙ exp[εZ - ε²/2]` where `Z ∼ N(0, I)`.
55
66
See also: [`AbstractPerturbed`](@ref).
77
@@ -10,25 +10,30 @@ Reference: preprint coming soon.
1010
struct PerturbedMultiplicative{F,R<:AbstractRNG,S<:Union{Nothing,Int}} <: AbstractPerturbed
1111
maximizer::F
1212
ε::Float64
13+
nb_samples::Int
1314
rng::R
1415
seed::S
15-
nb_samples::Int
1616
end
1717

1818
function Base.show(io::IO, perturbed::PerturbedMultiplicative)
1919
(; maximizer, ε, rng, seed, nb_samples) = perturbed
2020
return print(
21-
io, "PerturbedMultiplicative($maximizer, , $(typeof(rng)), $seed, $nb_samples)"
21+
io, "PerturbedMultiplicative($maximizer, , $nb_samples, $(typeof(rng)), $seed)"
2222
)
2323
end
2424

25+
"""
26+
PerturbedMultiplicative(maximizer[; ε=1.0, nb_samples=1])
27+
28+
Shorter constructor with defaults.
29+
"""
2530
function PerturbedMultiplicative(
26-
maximizer; ε=1.0, epsilon=nothing, rng=MersenneTwister(0), seed=nothing, nb_samples=2
31+
maximizer; ε=1.0, epsilon=nothing, nb_samples=1, rng=MersenneTwister(0), seed=nothing
2732
)
2833
if isnothing(epsilon)
29-
return PerturbedMultiplicative(maximizer, float(ε), rng, seed, nb_samples)
34+
return PerturbedMultiplicative(maximizer, float(ε), nb_samples, rng, seed)
3035
else
31-
return PerturbedMultiplicative(maximizer, float(epsilon), rng, seed, nb_samples)
36+
return PerturbedMultiplicative(maximizer, float(epsilon), nb_samples, rng, seed)
3237
end
3338
end
3439

src/regularized/regularized_generic.jl

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Differentiable regularized prediction function `ŷ(θ) = argmax_{y ∈ C} {θ
66
Relies on the Frank-Wolfe algorithm to minimize a concave objective on a polytope.
77
88
# Fields
9-
- `maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx` which implicitly defines the polytope `C`
9+
- `maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C`
1010
- `Ω::RF`: regularization function `Ω(y)`
1111
- `Ω_grad::RG`: gradient of the regularization function `∇Ω(y)`
1212
- `f::F`: objective function `f(x, θ) = Ω(y) - θᵀy` minimized by Frank-Wolfe (computed automatically)
@@ -34,17 +34,32 @@ function Base.show(io::IO, regularized::RegularizedGeneric)
3434
return print(io, "RegularizedGeneric($maximizer, , $Ω_grad, $linear_solver)")
3535
end
3636

37-
"""
38-
RegularizedGeneric(maximizer, Ω, Ω_grad[; linear_solver=gmres])
39-
40-
Short form constructor with a default linear solver.
41-
"""
42-
function RegularizedGeneric(maximizer, Ω, Ω_grad; linear_solver=gmres)
37+
function RegularizedGeneric(maximizer, Ω, Ω_grad, linear_solver=gmres)
4338
f(y, θ) = Ω(y) - dot(θ, y)
4439
f_grad1(y, θ) = Ω_grad(y) - θ
4540
return RegularizedGeneric(maximizer, Ω, Ω_grad, f, f_grad1, linear_solver)
4641
end
4742

43+
"""
44+
RegularizedGeneric(maximizer[; Ω, Ω_grad, linear_solver=gmres])
45+
46+
Shorter constructor with defaults.
47+
"""
48+
function RegularizedGeneric(
49+
maximizer;
50+
Ω=zero_regularization,
51+
Ω_grad=zero_gradient,
52+
omega=nothing,
53+
omega_grad=nothing,
54+
linear_solver=gmres,
55+
)
56+
if isnothing(omega) || isnothing(omega_grad)
57+
return RegularizedGeneric(maximizer, Ω, Ω_grad, linear_solver)
58+
else
59+
return RegularizedGeneric(maximizer, omega, omega_grad, linear_solver)
60+
end
61+
end
62+
4863
@traitimpl IsRegularized{RegularizedGeneric}
4964

5065
function compute_regularization(regularized::RegularizedGeneric, y::AbstractArray{<:Real})

src/regularized/regularized_utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,6 @@ Compute the vector `r` such that `rᵢ` is the rank of `θᵢ` in `θ`.
6464
function ranking::AbstractVector{<:Real}; rev::Bool=false, kwargs...)
6565
return invperm(sortperm(θ; rev=rev))
6666
end
67+
68+
zero_regularization(y) = zero(eltype(y))
69+
zero_gradient(y) = zero(y)

0 commit comments

Comments
 (0)