@@ -10,71 +10,80 @@ module InferOpt
1010using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, Tangent, ZeroTangent
1111using ChainRulesCore: rrule, rrule_via_ad, unthunk
1212using DensityInterface: logdensityof
13+ using DifferentiableExpectations:
14+ DifferentiableExpectations, Reinforce, empirical_predistribution, empirical_distribution
15+ using Distributions:
16+ Distributions,
17+ ContinuousUnivariateDistribution,
18+ LogNormal,
19+ Normal,
20+ product_distribution,
21+ logpdf
22+ using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
1323using LinearAlgebra: dot
14- using Random: AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed!
24+ using Random: Random, AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed!
1525using Statistics: mean
1626using StatsBase: StatsBase, sample
1727using StatsFuns: logaddexp, softmax
18- using ThreadsX: ThreadsX
1928using RequiredInterfaces
2029
2130include (" interface.jl" )
2231
32+ include (" utils/utils.jl" )
2333include (" utils/some_functions.jl" )
24- include (" utils/probability_distribution.jl" )
2534include (" utils/pushforward.jl" )
26- include (" utils/generalized_maximizer .jl" )
35+ include (" utils/linear_maximizer .jl" )
2736include (" utils/isotonic_regression/isotonic_l2.jl" )
2837include (" utils/isotonic_regression/isotonic_kl.jl" )
2938include (" utils/isotonic_regression/projection.jl" )
3039
31- include (" simple/interpolation.jl" )
32- include (" simple/identity.jl" )
40+ # Layers
41+ include (" layers/simple/interpolation.jl" )
42+ include (" layers/simple/identity.jl" )
3343
34- include (" regularized/abstract_regularized.jl" )
35- include (" regularized/soft_argmax.jl" )
36- include (" regularized/sparse_argmax.jl" )
37- include (" regularized/soft_rank.jl" )
38- include (" regularized/regularized_frank_wolfe.jl" )
44+ include (" layers/perturbed/utils.jl" )
45+ include (" layers/perturbed/perturbation.jl" )
46+ include (" layers/perturbed/perturbed.jl" )
3947
40- include (" perturbed/abstract_perturbed.jl" )
41- include (" perturbed/additive.jl" )
42- include (" perturbed/multiplicative.jl" )
43- include (" perturbed/perturbed_oracle.jl" )
44-
45- include (" imitation/spoplus_loss.jl" )
46- include (" imitation/ssvm_loss.jl" )
47- include (" imitation/fenchel_young_loss.jl" )
48- include (" imitation/imitation_loss.jl" )
49- include (" imitation/zero_one_loss.jl" )
48+ include (" layers/regularized/abstract_regularized.jl" )
49+ include (" layers/regularized/soft_argmax.jl" )
50+ include (" layers/regularized/sparse_argmax.jl" )
51+ include (" layers/regularized/soft_rank.jl" )
52+ include (" layers/regularized/regularized_frank_wolfe.jl" )
5053
5154if ! isdefined (Base, :get_extension )
5255 include (" ../ext/InferOptFrankWolfeExt.jl" )
5356end
5457
58+ # Losses
59+ include (" losses/fenchel_young_loss.jl" )
60+ include (" losses/spoplus_loss.jl" )
61+ include (" losses/ssvm_loss.jl" )
62+ include (" losses/zero_one_loss.jl" )
63+ include (" losses/imitation_loss.jl" )
64+
65+ export compute_probability_distribution
66+
5567export half_square_norm
5668export shannon_entropy, negative_shannon_entropy
5769export one_hot_argmax, ranking
58- export GeneralizedMaximizer , objective_value
70+ export LinearMaximizer, apply_g , objective_value
5971
60- export FixedAtomsProbabilityDistribution
61- export compute_expectation
62- export compute_probability_distribution
6372export Pushforward
6473
6574export IdentityRelaxation
6675export Interpolation
6776
68- export AbstractRegularized, AbstractRegularizedGeneralizedMaximizer
77+ export AbstractRegularized
6978export SoftArgmax, soft_argmax
7079export SparseArgmax, sparse_argmax
7180export SoftRank, soft_rank, soft_rank_l2, soft_rank_kl
7281export SoftSort, soft_sort, soft_sort_l2, soft_sort_kl
7382export RegularizedFrankWolfe
7483
84+ export PerturbedOracle
7585export PerturbedAdditive
7686export PerturbedMultiplicative
77- export PerturbedOracle
7887
7988export FenchelYoungLoss
8089export StructuredSVMLoss
0 commit comments