Skip to content

Commit 289765e

Browse files
committed
fix merge
1 parent 54f6810 commit 289765e

File tree

2 files changed

+38
-45
lines changed

2 files changed

+38
-45
lines changed

src/model.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -937,10 +937,7 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object.
937937
)
938938
ctx = InitContext(rng, strategy)
939939
model = DynamicPPL.setleafcontext(model, ctx)
940-
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
941-
# it _should_ do, but this is wrong regardless.
942-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
943-
return if Threads.nthreads() > 1
940+
return if _requires_threadsafe(model)
944941
# TODO(penelopeysm): The logic for setting eltype of accs is very similar to that
945942
# used in `unflatten`. The reason why we need it here is because the VarInfo `vi`
946943
# won't have been filled with parameters prior to `init!!` being called.

test/logdensityfunction.jl

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,19 @@ using Mooncake: Mooncake
5151
end
5252

5353
@testset "Threaded observe" begin
54-
if Threads.nthreads() > 1
55-
@model function threaded(y)
56-
x ~ Normal()
57-
Threads.@threads for i in eachindex(y)
58-
y[i] ~ Normal(x)
59-
end
54+
@model function threaded(y)
55+
x ~ Normal()
56+
Threads.@threads for i in eachindex(y)
57+
y[i] ~ Normal(x)
6058
end
61-
N = 100
62-
model = threaded(zeros(N))
63-
ldf = DynamicPPL.LogDensityFunction(model)
64-
65-
xs = [1.0]
66-
@test LogDensityProblems.logdensity(ldf, xs)
67-
logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0)
6859
end
60+
N = 100
61+
model = setthreadsafe(threaded(zeros(N)), true)
62+
ldf = DynamicPPL.LogDensityFunction(model)
63+
64+
xs = [1.0]
65+
@test LogDensityProblems.logdensity(ldf, xs)
66+
logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0)
6967
end
7068
end
7169

@@ -109,34 +107,32 @@ end
109107
end
110108

111109
@testset "LogDensityFunction: performance" begin
112-
if Threads.nthreads() == 1
113-
# Evaluating these three models should not lead to any allocations (but only when
114-
# not using TSVI).
115-
@model function f()
116-
x ~ Normal()
117-
return 1.0 ~ Normal(x)
118-
end
119-
@model function submodel_inner()
120-
m ~ Normal(0, 1)
121-
s ~ Exponential()
122-
return (m=m, s=s)
123-
end
124-
# Note that for the allocation tests to work on this one, `inner` has
125-
# to be passed as an argument to `submodel_outer`, instead of just
126-
# being called inside the model function itself
127-
@model function submodel_outer(inner)
128-
params ~ to_submodel(inner)
129-
y ~ Normal(params.m, params.s)
130-
return 1.0 ~ Normal(y)
131-
end
132-
@testset for model in
133-
(f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner()))
134-
vi = VarInfo(model)
135-
ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi)
136-
x = vi[:]
137-
bench = median(@be LogDensityProblems.logdensity(ldf, x))
138-
@test iszero(bench.allocs)
139-
end
110+
# Evaluating these three models should not lead to any allocations (but only when
111+
# not using TSVI).
112+
@model function f()
113+
x ~ Normal()
114+
return 1.0 ~ Normal(x)
115+
end
116+
@model function submodel_inner()
117+
m ~ Normal(0, 1)
118+
s ~ Exponential()
119+
return (m=m, s=s)
120+
end
121+
# Note that for the allocation tests to work on this one, `inner` has
122+
# to be passed as an argument to `submodel_outer`, instead of just
123+
# being called inside the model function itself
124+
@model function submodel_outer(inner)
125+
params ~ to_submodel(inner)
126+
y ~ Normal(params.m, params.s)
127+
return 1.0 ~ Normal(y)
128+
end
129+
@testset for model in
130+
(f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner()))
131+
vi = VarInfo(model)
132+
ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi)
133+
x = vi[:]
134+
bench = median(@be LogDensityProblems.logdensity(ldf, x))
135+
@test iszero(bench.allocs)
140136
end
141137
end
142138

0 commit comments

Comments
 (0)