Skip to content

Commit 9148c3d

Browse files
committed
fix bug in PerturbedMultiplicative
1 parent 01bbd5a commit 9148c3d

File tree

3 files changed

+14
-0
lines changed

3 files changed

+14
-0
lines changed

src/InferOpt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ using DifferentiableExpectations:
1515
using Distributions:
1616
Distributions,
1717
ContinuousUnivariateDistribution,
18+
AffineDistribution,
19+
Dirac,
1820
LogNormal,
1921
Normal,
2022
product_distribution,

src/layers/perturbed/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@ It is equal to ``logpdf(d, log(x)) - log(x)``.
2626
function Distributions.logpdf(d::ExponentialOf, x::Real)
2727
return logpdf(d.dist, log(x)) - log(x)
2828
end
29+
30+
function Base.:*(x::Real, d::ExponentialOf)
31+
return iszero(x) ? Dirac(0.0) : AffineDistribution(zero(x), x, d)
32+
end

test/perturbed.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,11 @@ end
112112
@test Ja Jz rtol = 0.01
113113
@test Jm Jz rtol = 0.01
114114
end
115+
116+
@testitem "Perturbed - Misc" begin
117+
# Make sure that having 0s in θ works
118+
perturbed = PerturbedMultiplicative(identity; ε=1.0, nb_samples=1e4, seed=0)
119+
θ = [1.0, 0.0]
120+
y = perturbed(θ)
121+
@test iszero(y[2])
122+
end

0 commit comments

Comments
 (0)