Skip to content

Commit c27f5e0

Browse files
penelopeysmmhauru
andauthored
Make threadsafe evaluation opt-in (#1151)
* Make threadsafe evaluation opt-in * Reduce number of type parameters in methods * Make `warned_warn_about_threads_threads_threads_threads` shorter * Improve `setthreadsafe` docstring * warn on bare `@threads` as well * fix merge * Fix performance issues * Use maxthreadid() in TSVI * Move convert_eltype code to threadsafe eval function * Point to new Turing docs page * Add a test for setthreadsafe * Tidy up check_model * Apply suggestions from code review Fix outdated docstrings Co-authored-by: Markus Hauru <markus@mhauru.org> * Improve warning message * Export `requires_threadsafe` * Add an actual docstring for `requires_threadsafe` --------- Co-authored-by: Markus Hauru <markus@mhauru.org>
1 parent a6d56a2 commit c27f5e0

File tree

12 files changed

+281
-241
lines changed

12 files changed

+281
-241
lines changed

HISTORY.md

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,49 @@
99
This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
1010
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
1111

12-
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
12+
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/logdensityfunction.jl` file, which contains extensive comments.
1313

1414
As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
1515
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
1616
If you were previously relying on this behaviour, you will need to store a VarInfo separately.
1717

18+
#### Threadsafe evaluation
19+
20+
DynamicPPL models have traditionally supported running some probabilistic statements (e.g. tilde-statements, or `@addlogprob!`) in parallel.
21+
Prior to DynamicPPL 0.39, thread safety for such models used to be enabled by default if Julia was launched with more than one thread.
22+
23+
In DynamicPPL 0.39, **thread-safe evaluation is now disabled by default**.
24+
If you need it (see below for more discussion of when you _do_ need it), you **must** now manually mark it as so, using:
25+
26+
```julia
27+
@model f() = ...
28+
model = f()
29+
model = setthreadsafe(model, true)
30+
```
31+
32+
The problem with the previous on-by-default is that it can sacrifice a huge amount of performance when thread safety is not needed.
33+
This is especially true when running Julia in a notebook, where multiple threads are often enabled by default.
34+
Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation.
35+
36+
**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.**
37+
This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros:
38+
39+
- tilde-statements
40+
- calls to `@addlogprob!`
41+
- any direct manipulation of the special `__varinfo__` variable
42+
43+
If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe.
44+
**Notably, the following do not require threadsafe evaluation:**
45+
46+
- Using threading for any computation that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation.
47+
- Sampling with `AbstractMCMC.MCMCThreads()`.
48+
49+
For more information about threadsafe evaluation, please see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/).
50+
51+
When threadsafe evaluation is enabled for a model, an internal flag is set on the model.
52+
The value of this flag can be queried using `DynamicPPL.requires_threadsafe(model)`, which returns a boolean.
53+
This function is newly exported in this version of DynamicPPL.
54+
1855
#### Parent and leaf contexts
1956

2057
The `DynamicPPL.NodeTrait` function has been removed.

docs/src/api.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ The context of a model can be set using [`contextualize`](@ref):
4242
contextualize
4343
```
4444

45+
Some models require threadsafe evaluation (see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more information on when this is necessary).
46+
If this is the case, one must enable threadsafe evaluation for a model:
47+
48+
```@docs
49+
setthreadsafe
50+
requires_threadsafe
51+
```
52+
4553
## Evaluation
4654

4755
With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ export AbstractVarInfo,
9090
Model,
9191
getmissings,
9292
getargnames,
93+
setthreadsafe,
94+
requires_threadsafe,
9395
extract_priors,
9496
values_as_in_model,
9597
# evaluation

src/compiler.jl

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn)
301301
modeldef = build_model_definition(expr)
302302

303303
# Generate main body
304-
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn)
304+
modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, true)
305305

306306
return build_output(modeldef, linenumbernode)
307307
end
@@ -346,36 +346,59 @@ Generate the body of the main evaluation function from expression `expr` and arg
346346
If `warn` is true, a warning is displayed if internal variables are used in the model
347347
definition.
348348
"""
349-
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
349+
generate_mainbody(mod, expr, warn, warn_threads) =
350+
generate_mainbody!(mod, Symbol[], expr, warn, warn_threads)
350351

351-
generate_mainbody!(mod, found, x, warn) = x
352-
function generate_mainbody!(mod, found, sym::Symbol, warn)
352+
generate_mainbody!(mod, found, x, warn, warn_threads) = x
353+
function generate_mainbody!(mod, found, sym::Symbol, warn, warn_threads)
353354
if warn && sym in INTERNALNAMES && sym found
354355
@warn "you are using the internal variable `$sym`"
355356
push!(found, sym)
356357
end
357358

358359
return sym
359360
end
360-
function generate_mainbody!(mod, found, expr::Expr, warn)
361+
function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads)
361362
# Do not touch interpolated expressions
362363
expr.head === :$ && return expr.args[1]
363364

365+
# Flag to determine whether we've issued a warning for threadsafe macros Note that this
366+
# detection is not fully correct. We can only detect the presence of a macro that has
367+
# the symbol `Threads.@threads`, however, we can't detect if that *is actually*
368+
# Threads.@threads from Base.Threads.
369+
364370
# Do we don't want escaped expressions because we unfortunately
365371
# escape the entire body afterwards.
366-
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)
372+
Meta.isexpr(expr, :escape) &&
373+
return generate_mainbody(mod, found, expr.args[1], warn, warn_threads)
367374

368375
# If it's a macro, we expand it
369376
if Meta.isexpr(expr, :macrocall)
370-
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
377+
if (
378+
expr.args[1] == Symbol("@threads") ||
379+
expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) &&
380+
warn_threads
381+
)
382+
warn_threads = false
383+
@warn (
384+
"It looks like you are using `Threads.@threads` in your model definition." *
385+
"\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." *
386+
" If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." *
387+
"\n\nThreadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements." *
388+
"\n\nPlease see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required."
389+
)
390+
end
391+
return generate_mainbody!(
392+
mod, found, macroexpand(mod, expr; recursive=true), warn, warn_threads
393+
)
371394
end
372395

373396
# Modify dotted tilde operators.
374397
args_dottilde = getargs_dottilde(expr)
375398
if args_dottilde !== nothing
376399
L, R = args_dottilde
377400
return generate_mainbody!(
378-
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn
401+
mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn, warn_threads
379402
)
380403
end
381404

@@ -385,8 +408,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
385408
L, R = args_tilde
386409
return Base.remove_linenums!(
387410
generate_tilde(
388-
generate_mainbody!(mod, found, L, warn),
389-
generate_mainbody!(mod, found, R, warn),
411+
generate_mainbody!(mod, found, L, warn, warn_threads),
412+
generate_mainbody!(mod, found, R, warn, warn_threads),
390413
),
391414
)
392415
end
@@ -397,13 +420,16 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
397420
L, R = args_assign
398421
return Base.remove_linenums!(
399422
generate_assign(
400-
generate_mainbody!(mod, found, L, warn),
401-
generate_mainbody!(mod, found, R, warn),
423+
generate_mainbody!(mod, found, L, warn, warn_threads),
424+
generate_mainbody!(mod, found, R, warn, warn_threads),
402425
),
403426
)
404427
end
405428

406-
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
429+
return Expr(
430+
expr.head,
431+
map(x -> generate_mainbody!(mod, found, x, warn, warn_threads), expr.args)...,
432+
)
407433
end
408434

409435
function generate_assign(left, right)
@@ -699,7 +725,7 @@ function build_output(modeldef, linenumbernode)
699725
# to the call site
700726
modeldef[:body] = MacroTools.@q begin
701727
$(linenumbernode)
702-
return $(DynamicPPL.Model)($name, $args_nt; $(kwargs_inclusion...))
728+
return $(DynamicPPL.Model){false}($name, $args_nt; $(kwargs_inclusion...))
703729
end
704730

705731
return MacroTools.@q begin

src/debug_utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,10 @@ function check_model_and_trace(
424424
# Perform checks before evaluating the model.
425425
issuccess = check_model_pre_evaluation(model)
426426

427-
# Force single-threaded execution.
428-
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
427+
# TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a
428+
# check on the merged accumulator, rather than checking it in the accumulate_assume
429+
# calls. That way we can also correctly support multi-threaded evaluation.
430+
_, varinfo = DynamicPPL.evaluate!!(model, varinfo)
429431

430432
# Perform checks after evaluating the model.
431433
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))

0 commit comments

Comments
 (0)