Skip to content

Commit 8edaf48

Browse files
committed
Try to fix RegularizedFrankWolfe issues with ImplicitDifferentiation
1 parent f292ea0 commit 8edaf48

File tree

4 files changed

+34
-15
lines changed

4 files changed

+34
-15
lines changed

Project.toml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ DifferentiableExpectations = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36"
1010
DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
13+
FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
14+
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
1315
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1416
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1517
RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
@@ -19,19 +21,23 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1921

2022
[weakdeps]
2123
DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d"
24+
FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
25+
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
2226

2327
[extensions]
24-
InferOptFrankWolfeExt = "DifferentiableFrankWolfe"
28+
InferOptFrankWolfeExt = ["DifferentiableFrankWolfe", "FrankWolfe", "ImplicitDifferentiation"]
2529

2630
[compat]
2731
ChainRulesCore = "1"
2832
DensityInterface = "0.4.0"
2933
DifferentiableExpectations = "0.2"
30-
DifferentiableFrankWolfe = "0.2"
34+
DifferentiableFrankWolfe = "0.3"
3135
Distributions = "0.25"
32-
DocStringExtensions = "0.9.3"
33-
LinearAlgebra = "<0.0.1,1"
34-
Random = "<0.0.1,1"
36+
DocStringExtensions = "0.9"
37+
FrankWolfe = "0.3"
38+
ImplicitDifferentiation = "0.6"
39+
LinearAlgebra = "1"
40+
Random = "1"
3541
RequiredInterfaces = "0.1.3"
3642
Statistics = "1"
3743
StatsBase = "0.33, 0.34"
@@ -49,6 +55,7 @@ FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
4955
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
5056
GridGraphs = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb"
5157
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
58+
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
5259
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
5360
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
5461
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
@@ -65,4 +72,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
6572
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6673

6774
[targets]
68-
test = ["Aqua", "DifferentiableFrankWolfe", "Distributions", "Documenter", "FiniteDifferences", "Flux", "FrankWolfe", "Graphs", "GridGraphs", "HiGHS", "JET", "JuliaFormatter", "JuMP", "LinearAlgebra", "Literate", "Pkg", "ProgressMeter", "Random", "Revise", "Statistics", "Test", "TestItemRunner", "UnicodePlots", "Zygote"]
75+
test = ["Aqua", "DifferentiableFrankWolfe", "Distributions", "Documenter", "FiniteDifferences", "Flux", "FrankWolfe", "Graphs", "GridGraphs", "HiGHS", "ImplicitDifferentiation", "JET", "JuliaFormatter", "JuMP", "LinearAlgebra", "Literate", "Pkg", "ProgressMeter", "Random", "Revise", "Statistics", "Test", "TestItemRunner", "UnicodePlots", "Zygote"]

ext/InferOptFrankWolfeExt.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ module InferOptFrankWolfeExt
33
using DifferentiableExpectations:
44
DifferentiableExpectations, FixedAtomsProbabilityDistribution
55
using DifferentiableFrankWolfe: DifferentiableFrankWolfe, DiffFW
6-
using DifferentiableFrankWolfe: LinearMinimizationOracle # from FrankWolfe
7-
using DifferentiableFrankWolfe: IterativeLinearSolver # from ImplicitDifferentiation
6+
using FrankWolfe: LinearMinimizationOracle
7+
using ImplicitDifferentiation: KrylovLinearSolver
88
using InferOpt: InferOpt, RegularizedFrankWolfe
99
using LinearAlgebra: dot
1010

@@ -41,14 +41,18 @@ Keyword arguments are passed to the underlying linear maximizer.
4141
function InferOpt.compute_probability_distribution(
4242
regularized::RegularizedFrankWolfe, θ::AbstractArray; kwargs...
4343
)
44+
shape = size(θ)
4445
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized
4546
f(y, θ) = Ω(y) - dot(θ, y)
4647
f_grad1(y, θ) = Ω_grad(y) - θ
47-
lmo = LinearMaximizationOracleWithKwargs(linear_maximizer, kwargs)
48-
implicit_kwargs = (; linear_solver=IterativeLinearSolver(; accept_inconsistent=true))
48+
maximizer(θ; shape, kwargs...) = vec(linear_maximizer(reshape(θ, shape); kwargs...))
49+
lmo = LinearMaximizationOracleWithKwargs(maximizer, (; shape, kwargs...))
50+
implicit_kwargs = (; linear_solver=KrylovLinearSolver())
4951
dfw = DiffFW(f, f_grad1, lmo; implicit_kwargs)
50-
weights, atoms = dfw.implicit(θ; frank_wolfe_kwargs=frank_wolfe_kwargs)
51-
probadist = FixedAtomsProbabilityDistribution(atoms, weights)
52+
weights, atoms = dfw.implicit(vec(θ); frank_wolfe_kwargs=frank_wolfe_kwargs)
53+
probadist = FixedAtomsProbabilityDistribution(
54+
map(atom -> reshape(atom, shape), atoms), weights
55+
)
5256
return probadist
5357
end
5458

test/InferOptTestUtils/src/pipeline.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ function test_pipeline!(
2121
perf_storage = init_perf()
2222

2323
for ep in 1:epochs
24+
@info ep
2425
update_perf!(
2526
pl,
2627
perf_storage;

test/paths.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ end
9090

9191
@testitem "Paths - imit - MSE RegularizedFrankWolfe" default_imports = false begin
9292
include("InferOptTestUtils/src/InferOptTestUtils.jl")
93-
using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random
93+
using DifferentiableFrankWolfe,
94+
FrankWolfe, ImplicitDifferentiation, InferOpt, .InferOptTestUtils, Random
9495
Random.seed!(63)
9596

9697
test_pipeline!(
@@ -167,7 +168,8 @@ end
167168

168169
@testitem "Paths - imit - FYL RegularizedFrankWolfe" default_imports = false begin
169170
include("InferOptTestUtils/src/InferOptTestUtils.jl")
170-
using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random
171+
using DifferentiableFrankWolfe,
172+
FrankWolfe, ImplicitDifferentiation, InferOpt, .InferOptTestUtils, Random
171173
Random.seed!(63)
172174

173175
test_pipeline!(
@@ -235,7 +237,12 @@ end
235237
@testitem "Paths - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin
236238
include("InferOptTestUtils/src/InferOptTestUtils.jl")
237239
using DifferentiableFrankWolfe,
238-
FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random
240+
FrankWolfe,
241+
ImplicitDifferentiation,
242+
InferOpt,
243+
.InferOptTestUtils,
244+
LinearAlgebra,
245+
Random
239246
Random.seed!(63)
240247

241248
true_encoder = encoder_factory()

0 commit comments

Comments
 (0)