Skip to content

Commit f8c45c2

Browse files
committed
Add a test for setthreadsafe
1 parent 74ea4ac commit f8c45c2

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

test/threadsafe.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212
@test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread)
1313
end
1414

15+
@testset "setthreadsafe" begin
16+
@model f() = x ~ Normal()
17+
model = f()
18+
@test !DynamicPPL._requires_threadsafe(model)
19+
model = setthreadsafe(model, true)
20+
@test DynamicPPL._requires_threadsafe(model)
21+
model = setthreadsafe(model, false)
22+
@test !DynamicPPL._requires_threadsafe(model)
23+
end
24+
1525
# TODO: Add more tests of the public API
1626
@testset "API" begin
1727
vi = VarInfo(gdemo_default)
@@ -41,8 +51,6 @@
4151
end
4252

4353
@testset "model" begin
44-
println("Peforming threading tests with $(Threads.nthreads()) threads")
45-
4654
x = rand(10_000)
4755

4856
@model function wthreads(x)

0 commit comments

Comments
 (0)