File tree Expand file tree Collapse file tree 3 files changed +14
-10
lines changed
Expand file tree Collapse file tree 3 files changed +14
-10
lines changed Original file line number Diff line number Diff line change 9797Forward pass. Compute the expectation of the underlying distribution.
9898"""
9999function (perturbed:: AbstractPerturbed )(
100- θ:: AbstractArray ; autodiff_variance_reduction:: Bool = false , kwargs...
100+ θ:: AbstractArray ; autodiff_variance_reduction:: Bool = true , kwargs...
101101)
102102 probadist = compute_probability_distribution (
103103 perturbed, θ; autodiff_variance_reduction, kwargs...
@@ -118,7 +118,7 @@ function ChainRulesCore.rrule(
118118 :: typeof (compute_probability_distribution),
119119 perturbed:: AbstractPerturbed ,
120120 θ:: AbstractArray ;
121- autodiff_variance_reduction:: Bool = false ,
121+ autodiff_variance_reduction:: Bool = true ,
122122 kwargs... ,
123123)
124124 η_samples = sample_perturbations (perturbed, θ)
Original file line number Diff line number Diff line change 1515
1616 @testset " PerturbedAdditive" begin
1717 # Compute jacobian with reverse mode
18- jac1 = Zygote. jacobian (perturbed1, θ)[1 ]
19- jac1_big = Zygote. jacobian (perturbed1_big, θ)[1 ]
18+ jac1 = Zygote. jacobian (θ -> perturbed1 (θ; autodiff_variance_reduction= false ), θ)[1 ]
19+ jac1_big = Zygote. jacobian (
20+ θ -> perturbed1_big (θ; autodiff_variance_reduction= false ), θ
21+ )[1 ]
2022 # Only diagonal should be positive
2123 @test all (diag (jac1) .>= 0 )
2224 @test all (jac1 - Diagonal (jac1) .<= 0 )
2729 end
2830
2931 @testset " PerturbedMultiplicative" begin
30- jac2 = Zygote. jacobian (perturbed2, θ)[1 ]
31- jac2_big = Zygote. jacobian (perturbed2_big, θ)[1 ]
32+ jac2 = Zygote. jacobian (θ -> perturbed2 (θ; autodiff_variance_reduction= false ), θ)[1 ]
33+ jac2_big = Zygote. jacobian (
34+ θ -> perturbed2_big (θ; autodiff_variance_reduction= false ), θ
35+ )[1 ]
3236 @test all (diag (jac2) .>= 0 )
3337 @test all (jac2 - Diagonal (jac2) .<= 0 )
3438 @test sortperm (diag (jac2)) != sortperm (θ)
Original file line number Diff line number Diff line change 3131 n = 10
3232 θ = randn (10 )
3333
34- Ja = jacobian (pa , θ)[1 ]
35- Ja_reduced_variance = jacobian (x -> pa (x; autodiff_variance_reduction = true ) , θ)[1 ]
34+ Ja = jacobian (θ -> pa (θ; autodiff_variance_reduction = false ) , θ)[1 ]
35+ Ja_reduced_variance = jacobian (pa , θ)[1 ]
3636
37- Jm = jacobian (pm , θ)[1 ]
38- Jm_reduced_variance = jacobian (x -> pm (x; autodiff_variance_reduction = true ) , θ)[1 ]
37+ Jm = jacobian (x -> pm (x; autodiff_variance_reduction = false ) , θ)[1 ]
38+ Jm_reduced_variance = jacobian (pm , θ)[1 ]
3939
4040 J_true = Matrix (I, n, n) # exact jacobian is the identity matrix
4141
You can’t perform that action at this time.
0 commit comments