Skip to content

Commit aa22430

Browse files
committed
Merge branch 'main' into generalized-maximizer
2 parents 32c261d + f9d8dab commit aa22430

File tree

17 files changed

+1081
-197
lines changed

17 files changed

+1081
-197
lines changed

test/InferOptTestUtils/Manifest.toml

Lines changed: 898 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name = "InferOptTestUtils"
2+
uuid = "da316ebf-0808-4679-97b0-38aabd8f7bf9"
3+
authors = ["Guillaume Dalle", "Léo Baty", "Louis Bouvier", "Axel Parmentier"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
8+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
9+
GridGraphs = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb"
10+
InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f"
11+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
13+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14+
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

test/InferOptTestUtils/loss.jl

Lines changed: 0 additions & 57 deletions
This file was deleted.

test/InferOptTestUtils/InferOptTestUtils.jl renamed to test/InferOptTestUtils/src/InferOptTestUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ include("const.jl")
1313
include("maximizers.jl")
1414
include("dataset.jl")
1515
include("error.jl")
16-
include("perf.jl")
1716
include("loss.jl")
17+
include("perf.jl")
1818
include("pipeline.jl")
1919

2020
export DECREASE, EPOCHS, NB_FEATURES, NB_INSTANCES, NOISE_STD, VERBOSE

test/InferOptTestUtils/src/loss.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
abstract type PipelineLoss end
2+
3+
struct PipelineLossExperience <: PipelineLoss end
4+
struct PipelineLossImitation <: PipelineLoss end
5+
struct PipelineLossImitationθ <: PipelineLoss end
6+
struct PipelineLossImitationθy <: PipelineLoss end
7+
struct PipelineLossImitationLoss <: PipelineLoss end
8+
9+
get_loss(::PipelineLossExperience, loss, res, x, θ, y) = loss(res; instance=x)
10+
get_loss(::PipelineLossImitation, loss, res, x, θ, y) = loss(res, y; instance=x)
11+
get_loss(::PipelineLossImitationθ, loss, res, x, θ, y) = loss(res, θ; instance=x)
12+
get_loss(::PipelineLossImitationθy, loss, res, x, θ, y) = loss(res, θ, y; instance=x)
13+
function get_loss(::PipelineLossImitationLoss, loss, res, x, θ, y)
14+
return loss(res, (; y_true=y, θ_true=θ); instance=x)
15+
end
Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ function generate_predictions(encoder, maximizer, X)
1919
end
2020

2121
function update_perf!(
22+
pl::PipelineLoss,
2223
perf_storage::NamedTuple;
2324
data_train,
2425
data_test,
2526
true_encoder,
2627
encoder,
2728
true_maximizer,
28-
pipeline_loss,
29+
maximizer,
30+
loss,
2931
error_function,
3032
cost,
3133
)
@@ -42,8 +44,14 @@ function update_perf!(
4244
(X_train, thetas_train, Y_train) = data_train
4345
(X_test, thetas_test, Y_test) = data_test
4446

45-
train_loss = sum(pipeline_loss(x, θ, y) for (x, θ, y) in zip(data_train...))
46-
test_loss = sum(pipeline_loss(x, θ, y) for (x, θ, y) in zip(data_test...))
47+
train_loss = sum(
48+
get_loss(pl, loss, maximizer(encoder(x); instance=x), x, θ, y) for
49+
(x, θ, y) in zip(data_train...)
50+
)
51+
test_loss = sum(
52+
get_loss(pl, loss, maximizer(encoder(x); instance=x), x, θ, y) for
53+
(x, θ, y) in zip(data_test...)
54+
)
4755

4856
Y_train_pred = generate_predictions(encoder, true_maximizer, X_train)
4957
Y_test_pred = generate_predictions(encoder, true_maximizer, X_test)

0 commit comments

Comments
 (0)