|
6 | 6 |
|
7 | 7 | θ = [3, 5, 4, 2] |
8 | 8 |
|
9 | | - perturbed1 = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=1_000, seed=0) |
10 | | - perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=10_000, seed=0) |
11 | | - perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=0.5, nb_samples=1_000, seed=0) |
12 | | - perturbed2_big = PerturbedMultiplicative( |
13 | | - one_hot_argmax; ε=0.5, nb_samples=10_000, seed=0 |
14 | | - ) |
| 9 | + perturbed1 = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0) |
| 10 | + perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0) |
| 11 | + |
| 12 | + perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0) |
| 13 | + perturbed2_big = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0) |
15 | 14 |
|
16 | 15 | @testset "PerturbedAdditive" begin |
17 | 16 | # Compute jacobian with reverse mode |
18 | | - jac1 = Zygote.jacobian(θ -> perturbed1(θ; autodiff_variance_reduction=false), θ)[1] |
19 | | - jac1_big = Zygote.jacobian( |
20 | | - θ -> perturbed1_big(θ; autodiff_variance_reduction=false), θ |
21 | | - )[1] |
| 17 | + jac1 = Zygote.jacobian(perturbed1, θ)[1] |
| 18 | + jac1_big = Zygote.jacobian(perturbed1_big, θ)[1] |
22 | 19 | # Only diagonal should be positive |
23 | 20 | @test all(diag(jac1) .>= 0) |
24 | 21 | @test all(jac1 - Diagonal(jac1) .<= 0) |
|
29 | 26 | end |
30 | 27 |
|
31 | 28 | @testset "PerturbedMultiplicative" begin |
32 | | - jac2 = Zygote.jacobian(θ -> perturbed2(θ; autodiff_variance_reduction=false), θ)[1] |
33 | | - jac2_big = Zygote.jacobian( |
34 | | - θ -> perturbed2_big(θ; autodiff_variance_reduction=false), θ |
35 | | - )[1] |
| 29 | + jac2 = Zygote.jacobian(perturbed2, θ)[1] |
| 30 | + jac2_big = Zygote.jacobian(perturbed2_big, θ)[1] |
36 | 31 | @test all(diag(jac2_big) .>= 0) |
37 | 32 | @test all(jac2_big - Diagonal(jac2_big) .<= 0) |
38 | | - @test sortperm(diag(jac2_big)) == sortperm(θ) |
| 33 | + @info diag(jac2_big) |
| 34 | + @test_broken sortperm(diag(jac2_big)) == sortperm(θ) |
39 | 35 | @test norm(jac2) ≈ norm(jac2_big) rtol = 5e-2 |
40 | 36 | end |
41 | 37 | end |
|
99 | 95 |
|
100 | 96 | ε = 1e-12 |
101 | 97 |
|
102 | | - function already_differentiable(θ) |
103 | | - return 2 ./ exp.(θ) .* θ .^ 2 |
104 | | - end |
| 98 | + already_differentiable(θ) = 2 ./ exp.(θ) .* θ .^ 2 .+ sum(θ) |
| 99 | + pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0) |
| 100 | + pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0) |
105 | 101 |
|
106 | | - θ = randn(5) |
107 | | - Jz = jacobian(already_differentiable, θ)[1] |
| 102 | + θ = [1.0, 2.0, 3.0, 4.0, 5.0] |
108 | 103 |
|
109 | | - pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0) |
110 | | - Ja = jacobian(pa, θ)[1] |
111 | | - @test_broken all(isapprox.(Ja, Jz, rtol=0.01)) |
| 104 | + fz = already_differentiable(θ) |
| 105 | + fa = pa(θ) |
| 106 | + fm = pm(θ) |
| 107 | + @test fz ≈ fa rtol = 0.01 |
| 108 | + @test fz ≈ fm rtol = 0.01 |
112 | 109 |
|
113 | | - pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0) |
| 110 | + Jz = jacobian(already_differentiable, θ)[1] |
| 111 | + Ja = jacobian(pa, θ)[1] |
114 | 112 | Jm = jacobian(pm, θ)[1] |
115 | | - @test_broken all(isapprox.(Jm, Jz, rtol=0.01)) |
| 113 | + @test Ja ≈ Jz rtol = 0.01 |
| 114 | + @test Jm ≈ Jz rtol = 0.01 |
116 | 115 | end |
0 commit comments