Skip to content

Commit 41a99ef

Browse files
committed
Make threadsafe evaluation opt-in
1 parent 8547e25 commit 41a99ef

File tree

8 files changed

+195
-125
lines changed

8 files changed

+195
-125
lines changed

HISTORY.md

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,41 @@
99
This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
1010
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
1111

12-
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
12+
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/logdensityfunction.jl` file, which contains extensive comments.
1313

1414
As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
1515
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
1616
If you were previously relying on this behaviour, you will need to store a VarInfo separately.
1717

18+
#### Threadsafe evaluation
19+
20+
DynamicPPL models are by default no longer thread-safe.
21+
If you have threading in a model, you **must** now manually mark it as so, using:
22+
23+
```julia
24+
@model f() = ...
25+
model = f()
26+
model = setthreadsafe(model, true)
27+
```
28+
29+
It used to be that DynamicPPL would 'automatically' enable thread-safe evaluation if Julia was launched with more than one thread (i.e., by checking `Threads.nthreads() > 1`).
30+
31+
The problem with this approach is that it sacrifices a huge amount of performance.
32+
Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation.
33+
34+
**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.**
35+
This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros:
36+
37+
- tilde-statements
38+
- calls to `@addlogprob!`
39+
- any direct manipulation of the special `__varinfo__` variable
40+
41+
If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe.
42+
**Notably, the following do not require threadsafe evaluation:**
43+
44+
- Using threading for anything that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation.
45+
- Sampling with `AbstractMCMC.MCMCThreads()`.
46+
1847
#### Parent and leaf contexts
1948

2049
The `DynamicPPL.NodeTrait` function has been removed.

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ The context of a model can be set using [`contextualize`](@ref):
4242
contextualize
4343
```
4444

45+
Some models require threadsafe evaluation (see https://turinglang.org/docs/THIS_DOESNT_EXIST_YET for more information on when this is necessary).
46+
If this is the case, one must enable threadsafe evaluation for a model:
47+
48+
```@docs
49+
setthreadsafe
50+
```
51+
4552
## Evaluation
4653

4754
With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ export AbstractVarInfo,
9090
Model,
9191
getmissings,
9292
getargnames,
93+
setthreadsafe,
9394
extract_priors,
9495
values_as_in_model,
9596
# evaluation

src/compiler.jl

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn)
301301
modeldef = build_model_definition(expr)
302302

303303
# Generate main body
304-
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn)
304+
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, false)
305305

306306
return build_output(modeldef, linenumbernode)
307307
end
@@ -346,36 +346,64 @@ Generate the body of the main evaluation function from expression `expr` and arg
346346
If `warn` is true, a warning is displayed if internal variables are used in the model
347347
definition.
348348
"""
349-
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
349+
generate_mainbody(mod, expr, warn, warned_about_threads_threads) =
350+
generate_mainbody!(mod, Symbol[], expr, warn, warned_about_threads_threads)
350351

351-
generate_mainbody!(mod, found, x, warn) = x
352-
function generate_mainbody!(mod, found, sym::Symbol, warn)
352+
generate_mainbody!(mod, found, x, warn, warned_about_threads_threads) = x
353+
function generate_mainbody!(mod, found, sym::Symbol, warn, warned_about_threads_threads)
353354
if warn && sym in INTERNALNAMES && sym found
354355
@warn "you are using the internal variable `$sym`"
355356
push!(found, sym)
356357
end
357358

358359
return sym
359360
end
360-
function generate_mainbody!(mod, found, expr::Expr, warn)
361+
function generate_mainbody!(mod, found, expr::Expr, warn, warned_about_threads_threads)
361362
# Do not touch interpolated expressions
362363
expr.head === :$ && return expr.args[1]
363364

365+
# Flag to determine whether we've issued a warning for threadsafe macros Note that this
366+
# detection is not fully correct. We can only detect the presence of a macro that has
367+
# the symbol `Threads.@threads`, however, we can't detect if that *is actually*
368+
# Threads.@threads from Base.Threads.
369+
364370
# Do we don't want escaped expressions because we unfortunately
365371
# escape the entire body afterwards.
366-
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)
372+
Meta.isexpr(expr, :escape) && return generate_mainbody(
373+
mod, found, expr.args[1], warn, warned_about_threads_threads
374+
)
367375

368376
# If it's a macro, we expand it
369377
if Meta.isexpr(expr, :macrocall)
370-
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
378+
if expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) &&
379+
!warned_about_threads_threads
380+
warned_about_threads_threads = true
381+
@warn (
382+
"It looks like you are using `Threads.@threads` in your model definition." *
383+
"\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." *
384+
" If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." *
385+
"\n\nAvoiding threadsafe evaluation can often lead to significant performance improvements. Please see https://turinglang.org/docs/THIS_PAGE_DOESNT_EXIST_YET for more details of when threadsafe evaluation is actually required."
386+
)
387+
end
388+
return generate_mainbody!(
389+
mod,
390+
found,
391+
macroexpand(mod, expr; recursive=true),
392+
warn,
393+
warned_about_threads_threads,
394+
)
371395
end
372396

373397
# Modify dotted tilde operators.
374398
args_dottilde = getargs_dottilde(expr)
375399
if args_dottilde !== nothing
376400
L, R = args_dottilde
377401
return generate_mainbody!(
378-
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
402+
mod,
403+
found,
404+
Base.remove_linenums!(generate_dot_tilde(L, R)),
405+
warn,
406+
warned_about_threads_threads,
379407
)
380408
end
381409

@@ -385,8 +413,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
385413
L, R = args_tilde
386414
return Base.remove_linenums!(
387415
generate_tilde(
388-
generate_mainbody!(mod, found, L, warn),
389-
generate_mainbody!(mod, found, R, warn),
416+
generate_mainbody!(mod, found, L, warn, warned_about_threads_threads),
417+
generate_mainbody!(mod, found, R, warn, warned_about_threads_threads),
390418
),
391419
)
392420
end
@@ -397,13 +425,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
397425
L, R = args_assign
398426
return Base.remove_linenums!(
399427
generate_assign(
400-
generate_mainbody!(mod, found, L, warn),
401-
generate_mainbody!(mod, found, R, warn),
428+
generate_mainbody!(mod, found, L, warn, warned_about_threads_threads),
429+
generate_mainbody!(mod, found, R, warn, warned_about_threads_threads),
402430
),
403431
)
404432
end
405433

406-
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
434+
return Expr(
435+
expr.head,
436+
map(
437+
x -> generate_mainbody!(mod, found, x, warn, warned_about_threads_threads),
438+
expr.args,
439+
)...,
440+
)
407441
end
408442

409443
function generate_assign(left, right)

src/debug_utils.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,11 @@ function check_model_and_trace(
424424
# Perform checks before evaluating the model.
425425
issuccess = check_model_pre_evaluation(model)
426426

427-
# Force single-threaded execution.
428-
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
427+
# TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a
428+
# check on the merged accumulator, rather than checking it in the accumulate_assume
429+
# calls. That way we can also support multi-threaded evaluation and use `evaluate!!`
430+
# here instead of `_evaluate!!`.
431+
_, varinfo = DynamicPPL._evaluate!!(model, varinfo)
429432

430433
# Perform checks after evaluating the model.
431434
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))

src/model.jl

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
2+
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded}
33
f::F
44
args::NamedTuple{argnames,Targs}
55
defaults::NamedTuple{defaultnames,Tdefaults}
@@ -17,6 +17,10 @@ An argument with a type of `Missing` will be in `missings` by default. However,
1717
non-traditional use-cases `missings` can be defined differently. All variables in `missings`
1818
are treated as random variables rather than observations.
1919
20+
The `Threaded` type parameter indicates whether the model requires threadsafe evaluation
21+
(i.e., whether the model contains statements which modify the internal VarInfo that are
22+
executed in parallel). By default, this is set to `false`.
23+
2024
The default arguments are used internally when constructing instances of the same model with
2125
different arguments.
2226
@@ -33,8 +37,9 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
3337
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
3438
```
3539
"""
36-
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
37-
AbstractProbabilisticProgram
40+
struct Model{
41+
F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded
42+
} <: AbstractProbabilisticProgram
3843
f::F
3944
args::NamedTuple{argnames,Targs}
4045
defaults::NamedTuple{defaultnames,Tdefaults}
@@ -46,13 +51,13 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
4651
Create a model with evaluation function `f` and missing arguments overwritten by
4752
`missings`.
4853
"""
49-
function Model{missings}(
54+
function Model{missings,Threaded}(
5055
f::F,
5156
args::NamedTuple{argnames,Targs},
5257
defaults::NamedTuple{defaultnames,Tdefaults},
5358
context::Ctx=DefaultContext(),
54-
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
55-
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
59+
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,Threaded}
60+
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Threaded}(
5661
f, args, defaults, context
5762
)
5863
end
@@ -71,18 +76,27 @@ model with different arguments.
7176
args::NamedTuple{argnames,Targs},
7277
defaults::NamedTuple{kwargnames,Tkwargs},
7378
context::AbstractContext=DefaultContext(),
79+
threadsafe::Bool=false,
7480
) where {F,argnames,Targs,kwargnames,Tkwargs}
7581
missing_args = Tuple(
7682
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
7783
)
7884
missing_kwargs = Tuple(
7985
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
8086
)
81-
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context))
87+
return :(Model{$(missing_args..., missing_kwargs...),threadsafe}(
88+
f, args, defaults, context
89+
))
8290
end
8391

84-
function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)
85-
return Model(f, args, NamedTuple(kwargs), context)
92+
function Model(
93+
f,
94+
args::NamedTuple,
95+
context::AbstractContext=DefaultContext(),
96+
threadsafe=false;
97+
kwargs...,
98+
)
99+
return Model(f, args, NamedTuple(kwargs), context, threadsafe)
86100
end
87101

88102
"""
@@ -91,8 +105,10 @@ end
91105
Return a new `Model` with the same evaluation function and other arguments, but
92106
with its underlying context set to `context`.
93107
"""
94-
function contextualize(model::Model, context::AbstractContext)
95-
return Model(model.f, model.args, model.defaults, context)
108+
function contextualize(
109+
model::Model{F,A,D,M,Ta,Td,Ctx,Threaded}, context::AbstractContext
110+
) where {F,A,D,M,Ta,Td,Ctx,Threaded}
111+
return Model(model.f, model.args, model.defaults, context, Threaded)
96112
end
97113

98114
"""
@@ -105,6 +121,31 @@ function setleafcontext(model::Model, context::AbstractContext)
105121
return contextualize(model, setleafcontext(model.context, context))
106122
end
107123

124+
"""
125+
setthreadsafe(model::Model, threadsafe::Bool)
126+
127+
Returns a new `Model` with its threadsafe flag set to `threadsafe`.
128+
129+
Threadsafe evaluation allows for parallel execution of model statements that mutate the
130+
internal `VarInfo` object. For example, this is needed if tilde-statements are nested inside
131+
`Threads.@threads` or similar constructs.
132+
133+
It is not needed for generic multithreaded operations that don't involve VarInfo. For
134+
example, calculating a log-likelihood term in parallel and then calling `@addlogprob!`
135+
outside of the parallel region is safe without needing to set `threadsafe=true`.
136+
137+
It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`.
138+
"""
139+
function setthreadsafe(
140+
model::Model{F,A,D,M,Ta,Td,Ctx,Threaded}, threadsafe::Bool
141+
) where {F,A,D,M,Ta,Td,Ctx,Threaded}
142+
return if Threaded == threadsafe
143+
model
144+
else
145+
Model{M,threadsafe}(model.f, model.args, model.defaults, model.context)
146+
end
147+
end
148+
108149
"""
109150
model | (x = 1.0, ...)
110151
@@ -863,16 +904,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf
863904
return first(init!!(rng, model, varinfo))
864905
end
865906

866-
"""
867-
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
868-
869-
Return `true` if evaluation of a model using `context` and `varinfo` should
870-
wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
871-
"""
872-
function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
873-
return Threads.nthreads() > 1
874-
end
875-
876907
"""
877908
init!!(
878909
[rng::Random.AbstractRNG,]
@@ -944,40 +975,14 @@ If multiple threads are available, the varinfo provided will be wrapped in a
944975
Returns a tuple of the model's return value, plus the updated `varinfo`
945976
(unwrapped if necessary).
946977
"""
947-
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
948-
return if use_threadsafe_eval(model.context, varinfo)
949-
evaluate_threadsafe!!(model, varinfo)
950-
else
951-
evaluate_threadunsafe!!(model, varinfo)
952-
end
953-
end
954-
955-
"""
956-
evaluate_threadunsafe!!(model, varinfo)
957-
958-
Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`.
959-
960-
If the `model` makes use of Julia's multithreading this will lead to undefined behaviour.
961-
This method is not exposed and supposed to be used only internally in DynamicPPL.
962-
963-
See also: [`evaluate_threadsafe!!`](@ref)
964-
"""
965-
function evaluate_threadunsafe!!(model, varinfo)
978+
function AbstractPPL.evaluate!!(
979+
model::Model{F,A,D,M,Ta,Td,Ctx,false}, varinfo::AbstractVarInfo
980+
) where {F,A,D,M,Ta,Td,Ctx}
966981
return _evaluate!!(model, resetaccs!!(varinfo))
967982
end
968-
969-
"""
970-
evaluate_threadsafe!!(model, varinfo, context)
971-
972-
Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`.
973-
974-
With the wrapper, Julia's multithreading can be used for observe statements in the `model`
975-
but parallel sampling will lead to undefined behaviour.
976-
This method is not exposed and supposed to be used only internally in DynamicPPL.
977-
978-
See also: [`evaluate_threadunsafe!!`](@ref)
979-
"""
980-
function evaluate_threadsafe!!(model, varinfo)
983+
function AbstractPPL.evaluate!!(
984+
model::Model{F,A,D,M,Ta,Td,Ctx,true}, varinfo::AbstractVarInfo
985+
) where {F,A,D,M,Ta,Td,Ctx}
981986
wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo))
982987
result, wrapper_new = _evaluate!!(model, wrapper)
983988
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it

0 commit comments

Comments
 (0)