Skip to content

Commit 762379f

Browse files
committed
Rewrite everything using probability distributions
1 parent 02d523a commit 762379f

24 files changed

+448
-333
lines changed

CITATION.bib

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
@misc{InferOpt.jl,
2-
author = {Axel Parmentier, Guillaume Dalle, Léo Baty, Louis Bouvier},
2+
author = {Guillaume Dalle, Léo Baty, Louis Bouvier and Axel Parmentier},
33
title = {InferOpt.jl},
44
url = {https://github.com/axelparmentier/InferOpt.jl},
5-
version = {v0.1.0},
5+
version = {v0.2.0},
66
year = {2022},
7-
month = {4}
7+
month = {6}
88
}

Project.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "InferOpt"
22
uuid = "4846b161-c94e-4150-8dac-c7ae193c601f"
3-
authors = ["Axel Parmentier", "Guillaume Dalle", "Léo Baty", "Louis Bouvier"]
3+
authors = ["Guillaume Dalle", "Léo Baty", "Louis Bouvier", "Axel Parmentier"]
44
version = "0.2.0"
55

66
[deps]
@@ -13,14 +13,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
16+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718

1819
[compat]
1920
ChainRulesCore = "1"
20-
FrankWolfe = "0.2"
21-
Krylov = "0.8.2"
22-
LinearOperators = "2.3.2"
21+
FrankWolfe = "0.2.1"
22+
Krylov = "0.8"
23+
LinearOperators = "2.3"
2324
SimpleTraits = "0.9"
25+
StatsBase = "0.33"
2426
julia = "1.7"
2527

2628
[extras]

docs/Manifest.toml

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ version = "0.0.1"
1616

1717
[[deps.AbstractFFTs]]
1818
deps = ["ChainRulesCore", "LinearAlgebra"]
19-
git-tree-sha1 = "6f1d9bc1c08f9f4a8fa92e3ea3cb50153a1b40d4"
19+
git-tree-sha1 = "69f7020bd72f069c219b5e8c236c1fa90d2cb409"
2020
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
21-
version = "1.1.0"
21+
version = "1.2.1"
2222

2323
[[deps.Adapt]]
2424
deps = ["LinearAlgebra"]
@@ -49,15 +49,15 @@ version = "3.5.1+1"
4949

5050
[[deps.ArrayInterface]]
5151
deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "Static"]
52-
git-tree-sha1 = "1d062b8ab719670c16024105ace35e6d32988d4f"
52+
git-tree-sha1 = "6ccb71b40b04ad69152f1f83d5925de13911417e"
5353
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
54-
version = "6.0.18"
54+
version = "6.0.19"
5555

5656
[[deps.ArrayInterfaceCore]]
5757
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
58-
git-tree-sha1 = "5e732808bcf7bbf730e810a9eaafc52705b38bb5"
58+
git-tree-sha1 = "7d255eb1d2e409335835dc8624c35d97453011eb"
5959
uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
60-
version = "0.1.13"
60+
version = "0.1.14"
6161

6262
[[deps.Artifacts]]
6363
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
@@ -96,15 +96,15 @@ version = "3.11.0"
9696

9797
[[deps.ChainRules]]
9898
deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"]
99-
git-tree-sha1 = "97fd0a3b7703948a847265156a41079730805c77"
99+
git-tree-sha1 = "b06ed86d99c982cbe9047a45a93ac62d9605a361"
100100
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
101-
version = "1.36.0"
101+
version = "1.36.2"
102102

103103
[[deps.ChainRulesCore]]
104104
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
105-
git-tree-sha1 = "9489214b993cd42d17f44c36e359bf6a7c919abf"
105+
git-tree-sha1 = "2dd813e5f2f7eec2d1268c57cf2373d3ee91fcea"
106106
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
107-
version = "1.15.0"
107+
version = "1.15.1"
108108

109109
[[deps.ChangesOfVariables]]
110110
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
@@ -172,15 +172,14 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
172172

173173
[[deps.ConstructionBase]]
174174
deps = ["LinearAlgebra"]
175-
git-tree-sha1 = "f74e9d5388b8620b4cee35d4c5a618dd4dc547f4"
175+
git-tree-sha1 = "59d00b3139a9de4eb961057eabb65ac6522be954"
176176
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
177-
version = "1.3.0"
177+
version = "1.4.0"
178178

179179
[[deps.Contour]]
180-
deps = ["StaticArrays"]
181-
git-tree-sha1 = "9f02045d934dc030edad45944ea80dbd1f0ebea7"
180+
git-tree-sha1 = "a599cfb8b1909b0f97c5e1b923ab92e1c0406076"
182181
uuid = "d38c429a-6771-53c6-b99e-75d170b6e991"
183-
version = "0.5.7"
182+
version = "0.6.1"
184183

185184
[[deps.Crayons]]
186185
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
@@ -294,9 +293,9 @@ version = "0.10.30"
294293

295294
[[deps.FrankWolfe]]
296295
deps = ["Arpack", "GenericSchur", "Hungarian", "LinearAlgebra", "MathOptInterface", "Printf", "ProgressMeter", "Random", "Setfield", "SparseArrays", "TimerOutputs"]
297-
git-tree-sha1 = "75f6e18896767729d011bd473958043dc9948641"
296+
git-tree-sha1 = "87622f91e2256920debd418da506a9490f5c3efc"
298297
uuid = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
299-
version = "0.2.0"
298+
version = "0.2.1"
300299

301300
[[deps.FreeType]]
302301
deps = ["CEnum", "FreeType2_jll"]
@@ -391,7 +390,7 @@ uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
391390
version = "0.1.1"
392391

393392
[[deps.InferOpt]]
394-
deps = ["ChainRulesCore", "FrankWolfe", "Krylov", "LinearAlgebra", "LinearOperators", "Random", "SimpleTraits", "SparseArrays", "Statistics", "Test"]
393+
deps = ["ChainRulesCore", "FrankWolfe", "Krylov", "LinearAlgebra", "LinearOperators", "Random", "SimpleTraits", "SparseArrays", "Statistics", "StatsBase", "Test"]
395394
path = ".."
396395
uuid = "4846b161-c94e-4150-8dac-c7ae193c601f"
397396
version = "0.2.0"
@@ -743,15 +742,15 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
743742

744743
[[deps.SpecialFunctions]]
745744
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
746-
git-tree-sha1 = "a9e798cae4867e3a41cae2dd9eb60c047f1212db"
745+
git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d"
747746
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
748-
version = "2.1.6"
747+
version = "2.1.7"
749748

750749
[[deps.Static]]
751750
deps = ["IfElse"]
752-
git-tree-sha1 = "11f1b69a28b6e4ca1cc18342bfab7adb7ff3a090"
751+
git-tree-sha1 = "46638763d3a25ad7818a15d441e0c3446a10742d"
753752
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
754-
version = "0.7.3"
753+
version = "0.7.5"
755754

756755
[[deps.StaticArrays]]
757756
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
@@ -760,9 +759,9 @@ uuid = "90137ffa-7385-5640-81b9-e52037218182"
760759
version = "1.5.0"
761760

762761
[[deps.StaticArraysCore]]
763-
git-tree-sha1 = "6edcea211d224fa551ec8a85debdc6d732f155dc"
762+
git-tree-sha1 = "66fe9eb253f910fe8cf161953880cfdaef01cdf0"
764763
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
765-
version = "1.0.0"
764+
version = "1.0.1"
766765

767766
[[deps.Statistics]]
768767
deps = ["LinearAlgebra", "SparseArrays"]
@@ -776,15 +775,15 @@ version = "1.4.0"
776775

777776
[[deps.StatsBase]]
778777
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
779-
git-tree-sha1 = "642f08bf9ff9e39ccc7b710b2eb9a24971b52b1a"
778+
git-tree-sha1 = "48598584bacbebf7d30e20880438ed1d24b7c7d6"
780779
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
781-
version = "0.33.17"
780+
version = "0.33.18"
782781

783782
[[deps.StructArrays]]
784783
deps = ["Adapt", "DataAPI", "StaticArrays", "Tables"]
785-
git-tree-sha1 = "55ef24d228f9396fa9e2317eb60c953b8cec1ae7"
784+
git-tree-sha1 = "ec47fb6069c57f1cee2f67541bf8f23415146de7"
786785
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
787-
version = "0.6.10"
786+
version = "0.6.11"
788787

789788
[[deps.SuiteSparse]]
790789
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
@@ -841,9 +840,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
841840

842841
[[deps.UnicodePlots]]
843842
deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "FileIO", "FreeTypeAbstraction", "LazyModules", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "SparseArrays", "StaticArrays", "StatsBase", "Unitful"]
844-
git-tree-sha1 = "3cfa7a287c202c4c574ab535971c4c5b6c104371"
843+
git-tree-sha1 = "e8192bf70f28cf0e79ae9c215008e2f1464edbd6"
845844
uuid = "b8865327-cd53-5732-bb35-84acbb429228"
846-
version = "3.0.0"
845+
version = "3.0.3"
847846

848847
[[deps.Unitful]]
849848
deps = ["ConstructionBase", "Dates", "LinearAlgebra", "Random"]
@@ -856,10 +855,10 @@ deps = ["Libdl"]
856855
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
857856

858857
[[deps.Zygote]]
859-
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
860-
git-tree-sha1 = "a49267a2e5f113c7afe93843deea7461c0f6b206"
858+
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
859+
git-tree-sha1 = "3cfdb31b517eec4173584fba2b1aa65daad46e09"
861860
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
862-
version = "0.6.40"
861+
version = "0.6.41"
863862

864863
[[deps.ZygoteRules]]
865864
deps = ["MacroTools"]

docs/src/algorithms.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
Modules = [InferOpt]
77
```
88

9+
## Probability distributions
10+
11+
```@autodocs
12+
Modules = [InferOpt]
13+
Pages = ["utils/probability_distribution.jl", "utils/composition.jl"]
14+
```
15+
916
## Interpolation
1017

1118
!!! note "Reference"

src/InferOpt.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@ using FrankWolfe: away_frank_wolfe, compute_extreme_point
77
using Krylov: gmres
88
using LinearAlgebra
99
using LinearOperators: LinearOperator
10-
using Random
10+
using Random: AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed!
1111
using SimpleTraits: SimpleTraits
1212
using SimpleTraits: @traitdef, @traitfn, @traitimpl
1313
using SparseArrays
1414
using Statistics
15+
using StatsBase: StatsBase, sample
1516
using Test
1617

18+
include("utils/probability_distribution.jl")
19+
include("utils/composition.jl")
20+
1721
include("interpolation/interpolation.jl")
1822

1923
include("frank_wolfe/frank_wolfe_utils.jl")
@@ -26,10 +30,10 @@ include("regularized/sparse_argmax.jl")
2630
include("regularized/regularized_generic.jl")
2731

2832
include("perturbed/abstract_perturbed.jl")
29-
include("perturbed/composition.jl")
3033
include("perturbed/additive.jl")
3134
include("perturbed/multiplicative.jl")
3235

36+
include("fenchel_young/perturbed.jl")
3337
include("fenchel_young/fenchel_young.jl")
3438

3539
include("spo/spoplus_loss.jl")
@@ -38,7 +42,9 @@ include("ssvm/isbaseloss.jl")
3842
include("ssvm/zeroone_baseloss.jl")
3943
include("ssvm/ssvm_loss.jl")
4044

41-
export get_probability_distribution
45+
export FixedAtomsProbabilityDistribution, sample, compute_expectation
46+
export ProbabilisticComposition
47+
export compute_probability_distribution
4248

4349
export Interpolation
4450

@@ -53,7 +59,6 @@ export RegularizedGeneric
5359

5460
export PerturbedAdditive
5561
export PerturbedMultiplicative
56-
export PerturbedComposition
5762

5863
export FenchelYoungLoss
5964

src/fenchel_young/fenchel_young.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ end
2020
## Forward pass
2121

2222
function (fyl::FenchelYoungLoss)(
23-
θ::AbstractArray{<:Real}, y::AbstractArray{<:Real}; kwargs...
23+
θ::AbstractArray{<:Real}, y_true::AbstractArray{<:Real}; kwargs...
2424
)
25-
_, l = prediction_and_loss(fyl, θ, y; kwargs...)
25+
l, _ = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...)
2626
return l
2727
end
2828

29-
@traitfn function prediction_and_loss(
29+
@traitfn function fenchel_young_loss_and_grad(
3030
fyl::FenchelYoungLoss{P},
3131
θ::AbstractArray{<:Real},
3232
y_true::AbstractArray{<:Real};
@@ -37,24 +37,32 @@ end
3737
Ωy_true = compute_regularization(predictor, y_true)
3838
Ωŷ = compute_regularization(predictor, ŷ)
3939
l = (Ωy_true - dot(θ, y_true)) - (Ωŷ - dot(θ, ŷ))
40-
return ŷ, l
40+
g =- y_true
41+
return l, g
4142
end
4243

43-
function prediction_and_loss(
44-
fyl::FenchelYoungLoss{P}, θ::AbstractArray{<:Real}, y::AbstractArray{<:Real}; kwargs...
44+
function fenchel_young_loss_and_grad(
45+
fyl::FenchelYoungLoss{P},
46+
θ::AbstractArray{<:Real},
47+
y_true::AbstractArray{<:Real};
48+
kwargs...,
4549
) where {P<:AbstractPerturbed}
4650
(; predictor) = fyl
47-
ŷ, F = compute_y_and_F(predictor, θ; kwargs...)
48-
l = F - dot(θ, y)
49-
return ŷ, l
51+
F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(predictor, θ; kwargs...)
52+
l = F - dot(θ, y_true)
53+
g = almost_ŷ - y_true
54+
return l, g
5055
end
5156

5257
## Backward pass
5358

5459
function ChainRulesCore.rrule(
55-
fyl::FenchelYoungLoss, θ::AbstractArray{<:Real}, y::AbstractArray{<:Real}; kwargs...
60+
fyl::FenchelYoungLoss,
61+
θ::AbstractArray{<:Real},
62+
y_true::AbstractArray{<:Real};
63+
kwargs...,
5664
)
57-
ŷ, l = prediction_and_loss(fyl, θ, y; kwargs...)
58-
fyl_pullback(dl) = NoTangent(), dl * (ŷ - y), NoTangent()
65+
l, g = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...)
66+
fyl_pullback(dl) = NoTangent(), dl * g, NoTangent()
5967
return l, fyl_pullback
6068
end

src/fenchel_young/perturbed.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
function fenchel_young_F_and_first_part_of_grad(
2+
perturbed::AbstractPerturbed, θ::AbstractArray{<:Real}; kwargs...
3+
)
4+
Z_samples = sample_perturbations(perturbed, θ)
5+
F_and_y_samples = [
6+
fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for
7+
Z in Z_samples
8+
]
9+
return mean(first, F_and_y_samples), mean(last, F_and_y_samples)
10+
end
11+
12+
function fenchel_young_F_and_first_part_of_grad(
13+
perturbed::PerturbedAdditive,
14+
θ::AbstractArray{<:Real},
15+
Z::AbstractArray{<:Real};
16+
kwargs...,
17+
)
18+
(; maximizer, ε) = perturbed
19+
θ_perturbed = θ .+ ε .* Z
20+
y = maximizer(θ_perturbed; kwargs...)
21+
F = dot(θ_perturbed, y)
22+
return F, y
23+
end
24+
25+
function fenchel_young_F_and_first_part_of_grad(
26+
perturbed::PerturbedMultiplicative,
27+
θ::AbstractArray{<:Real},
28+
Z::AbstractArray{<:Real};
29+
kwargs...,
30+
)
31+
(; maximizer, ε) = perturbed
32+
eZ = exp.(ε .* Z .- ε^2)
33+
θ_perturbed = θ .* eZ
34+
y = maximizer(θ_perturbed; kwargs...)
35+
F = dot(θ_perturbed, y)
36+
y_scaled = y .* eZ
37+
return F, y_scaled
38+
end

0 commit comments

Comments
 (0)