Skip to content

Commit c2615cd

Browse files
committed
revert to 0.2 for now
1 parent c5c1bfd commit c2615cd

File tree

4 files changed

+14
-27
lines changed

4 files changed

+14
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe"
2727
ChainRulesCore = "1"
2828
DensityInterface = "0.4.0"
2929
DifferentiableExpectations = "0.2"
30-
DifferentiableFrankWolfe = "0.3"
30+
DifferentiableFrankWolfe = "0.2"
3131
Distributions = "0.25"
3232
DocStringExtensions = "0.9.3"
3333
LinearAlgebra = "<0.0.1,1"

src/utils/linear_maximizer.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ end
4545
# default is oracles of the form argmax_y θᵀy
4646
objective_value(::Any, θ, y; kwargs...) = dot(θ, y)
4747
apply_g(::Any, y; kwargs...) = y
48-
# apply_h(::Any, y; kwargs...) = zero(eltype(y)) is not needed
4948

5049
"""
5150
$TYPEDSIGNATURES
@@ -65,12 +64,3 @@ Applies the function `g` of the LinearMaximizer `f` to `y`.
6564
function apply_g(f::LinearMaximizer, y; kwargs...)
6665
return f.g(y; kwargs...)
6766
end
68-
69-
# """
70-
# $TYPEDSIGNATURES
71-
72-
# Applies the function `h` of the LinearMaximizer `f` to `y`.
73-
# """
74-
# function apply_h(f::LinearMaximizer, y; kwargs...)
75-
# return f.h(y; kwargs...)
76-
# end

src/utils/pushforward.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
"""
2-
Pushforward <: AbstractLayer
2+
$TYPEDEF
33
44
Differentiable pushforward of a probabilistic optimization layer with an arbitrary function post-processing function.
55
66
`Pushforward` can be used for direct regret minimization (aka learning by experience) when the post-processing returns a cost.
77
88
# Fields
9-
- `optimization_layer::AbstractOptimizationLayer`: probabilistic optimization layer
10-
- `post_processing`: callable
11-
12-
See also: `FixedAtomsProbabilityDistribution`.
9+
$TYPEDFIELDS
1310
"""
1411
struct Pushforward{O<:AbstractOptimizationLayer,P} <: AbstractLayer
12+
"probabilistic optimization layer"
1513
optimization_layer::O
14+
"callable"
1615
post_processing::P
1716
end
1817

@@ -22,13 +21,11 @@ function Base.show(io::IO, pushforward::Pushforward)
2221
end
2322

2423
"""
25-
(pushforward::Pushforward)(θ; kwargs...)
24+
$TYPEDSIGNATURES
2625
2726
Output the expectation of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`.
2827
2928
This function is differentiable, even if `pushforward.post_processing` isn't.
30-
31-
See also: `compute_expectation`.
3229
"""
3330
function (pushforward::Pushforward)(θ::AbstractArray; kwargs...)
3431
(; optimization_layer, post_processing) = pushforward

src/utils/some_functions.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
"""
2-
positive_part(x)
2+
$TYPEDSIGNATURES
33
4-
Compute `max(x,0)`.
4+
Compute `max(x, 0)`.
55
"""
66
positive_part(x) = x >= zero(x) ? x : zero(x)
77

88
"""
9-
isproba(x)
9+
$TYPEDSIGNATURES
1010
1111
Check whether `x ∈ [0,1]`.
1212
"""
1313
isproba(x::Real) = zero(x) <= x <= one(x)
1414

1515
"""
16-
isprobadist(p)
16+
$TYPEDSIGNATURES
1717
1818
Check whether the elements of `p` are nonnegative and sum to 1.
1919
"""
2020
isprobadist(p::AbstractVector{R}) where {R<:Real} = all(isproba, p) && sum(p) one(R)
2121

2222
"""
23-
half_square_norm(x)
23+
$TYPEDSIGNATURES
2424
2525
Compute the squared Euclidean norm of `x` and divide it by 2.
2626
"""
@@ -29,7 +29,7 @@ function half_square_norm(x::AbstractArray)
2929
end
3030

3131
"""
32-
shannon_entropy(p)
32+
$TYPEDSIGNATURES
3333
3434
Compute the Shannon entropy of a probability distribution: `H(p) = -∑ pᵢlog(pᵢ)`.
3535
"""
@@ -46,7 +46,7 @@ end
4646
negative_shannon_entropy(p::AbstractVector) = -shannon_entropy(p)
4747

4848
"""
49-
one_hot_argmax(z)
49+
$TYPEDSIGNATURES
5050
5151
One-hot encoding of the argmax function.
5252
"""
@@ -57,7 +57,7 @@ function one_hot_argmax(z::AbstractVector{R}; kwargs...) where {R<:Real}
5757
end
5858

5959
"""
60-
ranking(θ[; rev])
60+
$TYPEDSIGNATURES
6161
6262
Compute the vector `r` such that `rᵢ` is the rank of `θᵢ` in `θ`.
6363
"""

0 commit comments

Comments
 (0)