diff --git a/HISTORY.md b/HISTORY.md index ff28349d8..5dcb008d1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -9,12 +9,49 @@ This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. -For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. +For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/logdensityfunction.jl` file, which contains extensive comments. As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it. In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`. If you were previously relying on this behaviour, you will need to store a VarInfo separately. +#### Threadsafe evaluation + +DynamicPPL models have traditionally supported running some probabilistic statements (e.g. tilde-statements, or `@addlogprob!`) in parallel. +Prior to DynamicPPL 0.39, thread safety for such models used to be enabled by default if Julia was launched with more than one thread. + +In DynamicPPL 0.39, **thread-safe evaluation is now disabled by default**. +If you need it (see below for more discussion of when you _do_ need it), you **must** now manually mark it as so, using: + +```julia +@model f() = ... +model = f() +model = setthreadsafe(model, true) +``` + +The problem with the previous on-by-default is that it can sacrifice a huge amount of performance when thread safety is not needed. +This is especially true when running Julia in a notebook, where multiple threads are often enabled by default. +Furthermore, it is not actually the correct approach: just because Julia has multiple threads does not mean that a particular model actually requires threadsafe evaluation. + +**A model requires threadsafe evaluation if, and only if, the VarInfo object used inside the model is manipulated in parallel.** +This can occur if any of the following are inside `Threads.@threads` or other concurrency functions / macros: + + - tilde-statements + - calls to `@addlogprob!` + - any direct manipulation of the special `__varinfo__` variable + +If you have none of these inside threaded blocks, then you do not need to mark your model as threadsafe. +**Notably, the following do not require threadsafe evaluation:** + + - Using threading for any computation that does not involve VarInfo. For example, you can calculate a log-probability in parallel, and then add it using `@addlogprob!` outside of the threaded block. This does not require threadsafe evaluation. + - Sampling with `AbstractMCMC.MCMCThreads()`. + +For more information about threadsafe evaluation, please see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/). + +When threadsafe evaluation is enabled for a model, an internal flag is set on the model. +The value of this flag can be queried using `DynamicPPL.requires_threadsafe(model)`, which returns a boolean. +This function is newly exported in this version of DynamicPPL. + #### Parent and leaf contexts The `DynamicPPL.NodeTrait` function has been removed. diff --git a/docs/src/api.md b/docs/src/api.md index adb476db5..193a6ce4c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -42,6 +42,14 @@ The context of a model can be set using [`contextualize`](@ref): contextualize ``` +Some models require threadsafe evaluation (see [the Turing docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more information on when this is necessary). +If this is the case, one must enable threadsafe evaluation for a model: + +```@docs +setthreadsafe +requires_threadsafe +``` + ## Evaluation With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref). diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a885f6a96..fda428eaa 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -90,6 +90,8 @@ export AbstractVarInfo, Model, getmissings, getargnames, + setthreadsafe, + requires_threadsafe, extract_priors, values_as_in_model, # evaluation diff --git a/src/compiler.jl b/src/compiler.jl index 3324780ca..1b4260121 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -301,7 +301,7 @@ function model(mod, linenumbernode, expr, warn) modeldef = build_model_definition(expr) # Generate main body - modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn) + modeldef[:body] = generate_mainbody(mod, modeldef[:body], warn, true) return build_output(modeldef, linenumbernode) end @@ -346,10 +346,11 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +generate_mainbody(mod, expr, warn, warn_threads) = + generate_mainbody!(mod, Symbol[], expr, warn, warn_threads) -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found, x, warn, warn_threads) = x +function generate_mainbody!(mod, found, sym::Symbol, warn, warn_threads) if warn && sym in INTERNALNAMES && sym ∉ found @warn "you are using the internal variable `$sym`" push!(found, sym) @@ -357,17 +358,39 @@ function generate_mainbody!(mod, found, sym::Symbol, warn) return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found, expr::Expr, warn, warn_threads) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] + # Flag to determine whether we've issued a warning for threadsafe macros Note that this + # detection is not fully correct. We can only detect the presence of a macro that has + # the symbol `Threads.@threads`, however, we can't detect if that *is actually* + # Threads.@threads from Base.Threads. + # Do we don't want escaped expressions because we unfortunately # escape the entire body afterwards. - Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) + Meta.isexpr(expr, :escape) && + return generate_mainbody(mod, found, expr.args[1], warn, warn_threads) # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + if ( + expr.args[1] == Symbol("@threads") || + expr.args[1] == Expr(:., :Threads, QuoteNode(Symbol("@threads"))) && + warn_threads + ) + warn_threads = false + @warn ( + "It looks like you are using `Threads.@threads` in your model definition." * + "\n\nNote that since version 0.39 of DynamicPPL, threadsafe evaluation of models is disabled by default." * + " If you need it, you will need to explicitly enable it by creating the model, and then running `model = setthreadsafe(model, true)`." * + "\n\nThreadsafe model evaluation is only needed when parallelising tilde-statements (not arbitrary Julia code), and avoiding it can often lead to significant performance improvements." * + "\n\nPlease see https://turinglang.org/docs/usage/threadsafe-evaluation/ for more details of when threadsafe evaluation is actually required." + ) + end + return generate_mainbody!( + mod, found, macroexpand(mod, expr; recursive=true), warn, warn_threads + ) end # Modify dotted tilde operators. @@ -375,7 +398,7 @@ function generate_mainbody!(mod, found, expr::Expr, warn) if args_dottilde !== nothing L, R = args_dottilde return generate_mainbody!( - mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn + mod, found, Base.remove_linenums!(generate_dot_tilde(L, R)), warn, warn_threads ) end @@ -385,8 +408,8 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_tilde return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end @@ -397,13 +420,16 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_assign return Base.remove_linenums!( generate_assign( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn, warn_threads), + generate_mainbody!(mod, found, R, warn, warn_threads), ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody!(mod, found, x, warn, warn_threads), expr.args)..., + ) end function generate_assign(left, right) @@ -699,7 +725,7 @@ function build_output(modeldef, linenumbernode) # to the call site modeldef[:body] = MacroTools.@q begin $(linenumbernode) - return $(DynamicPPL.Model)($name, $args_nt; $(kwargs_inclusion...)) + return $(DynamicPPL.Model){false}($name, $args_nt; $(kwargs_inclusion...)) end return MacroTools.@q begin diff --git a/src/debug_utils.jl b/src/debug_utils.jl index e8b50a0b7..8810b9819 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -424,8 +424,10 @@ function check_model_and_trace( # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) - # Force single-threaded execution. - _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) + # TODO(penelopeysm): Implement merge, etc. for DebugAccumulator, and then perform a + # check on the merged accumulator, rather than checking it in the accumulate_assume + # calls. That way we can also correctly support multi-threaded evaluation. + _, varinfo = DynamicPPL.evaluate!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) diff --git a/src/model.jl b/src/model.jl index 7d5bbf2fb..e82fdc60c 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} @@ -17,6 +17,10 @@ An argument with a type of `Missing` will be in `missings` by default. However, non-traditional use-cases `missings` can be defined differently. All variables in `missings` are treated as random variables rather than observations. +The `Threaded` type parameter indicates whether the model requires threadsafe evaluation +(i.e., whether the model contains statements which modify the internal VarInfo that are +executed in parallel). By default, this is set to `false`. + The default arguments are used internally when constructing instances of the same model with different arguments. @@ -33,26 +37,27 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Threaded +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx @doc """ - Model{missings}(f, args::NamedTuple, defaults::NamedTuple) + Model{Threaded,missings}(f, args::NamedTuple, defaults::NamedTuple) Create a model with evaluation function `f` and missing arguments overwritten by `missings`. """ - function Model{missings}( + function Model{Threaded,missings}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,Threaded} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Threaded}( f, args, defaults, context ) end @@ -66,23 +71,39 @@ Create a model with evaluation function `f` and missing arguments deduced from ` Default arguments `defaults` are used internally when constructing instances of the same model with different arguments. """ -@generated function Model( +@generated function Model{Threaded}( f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), -) where {F,argnames,Targs,kwargnames,Tkwargs} +) where {Threaded,F,argnames,Targs,kwargnames,Tkwargs} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing ) missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{Threaded,$(missing_args..., missing_kwargs...)}( + f, args, defaults, context + )) +end + +function Model{Threaded}( + f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs... +) where {Threaded} + return Model{Threaded}(f, args, NamedTuple(kwargs), context) end -function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) +""" + requires_threadsafe(model::Model) + +Return whether `model` has been marked as needing threadsafe evaluation (using +`setthreadsafe`). +""" +function requires_threadsafe( + ::Model{F,A,D,M,Ta,Td,Ctx,Threaded} +) where {F,A,D,M,Ta,Td,Ctx,Threaded} + return Threaded end """ @@ -92,7 +113,7 @@ Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) + return Model{requires_threadsafe(model)}(model.f, model.args, model.defaults, context) end """ @@ -105,6 +126,33 @@ function setleafcontext(model::Model, context::AbstractContext) return contextualize(model, setleafcontext(model.context, context)) end +""" + setthreadsafe(model::Model, threadsafe::Bool) + +Returns a new `Model` with its threadsafe flag set to `threadsafe`. + +Threadsafe evaluation ensures correctness when executing model statements that mutate the +internal `VarInfo` object in parallel. For example, this is needed if tilde-statements are +nested inside `Threads.@threads` or similar constructs. + +It is not needed for generic multithreaded operations that don't involve VarInfo. For +example, calculating a log-likelihood term in parallel and then calling `@addlogprob!` +outside of the parallel region is safe without needing to set `threadsafe=true`. + +It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`. + +Setting `threadsafe` to `true` increases the overhead in evaluating the model. Please see +[the Turing.jl docs](https://turinglang.org/docs/usage/threadsafe-evaluation/) for more +details. +""" +function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M} + return if requires_threadsafe(model) == threadsafe + model + else + Model{threadsafe,M}(model.f, model.args, model.defaults, model.context) + end +end + """ model | (x = 1.0, ...) @@ -863,16 +911,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf return first(init!!(rng, model, varinfo)) end -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 -end - """ init!!( [rng::Random.AbstractRNG,] @@ -889,10 +927,7 @@ If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -@inline function init!!( - # Note that this `@inline` is mandatory for performance, especially for - # LogDensityFunction. If it's not inlined, it leads to extra allocations (even for - # trivial models) and much slower runtime. +function init!!( rng::Random.AbstractRNG, model::Model, vi::AbstractVarInfo, @@ -900,36 +935,11 @@ Returns a tuple of the model's return value, plus the updated `varinfo` object. ) ctx = InitContext(rng, strategy) model = DynamicPPL.setleafcontext(model, ctx) - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - return if Threads.nthreads() > 1 - # TODO(penelopeysm): The logic for setting eltype of accs is very similar to that - # used in `unflatten`. The reason why we need it here is because the VarInfo `vi` - # won't have been filled with parameters prior to `init!!` being called. - # - # Note that this eltype promotion is only needed for threadsafe evaluation. In an - # ideal world, this code should be handled inside `evaluate_threadsafe!!` or a - # similar method. In other words, it should not be here, and it should not be inside - # `unflatten` either. The problem is performance. Shifting this code around can have - # massive, inexplicable, impacts on performance. This should be investigated - # properly. - param_eltype = DynamicPPL.get_param_eltype(strategy) - accs = map(vi.accs) do acc - DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) - end - vi = DynamicPPL.setaccs!!(vi, accs) - tsvi = ThreadSafeVarInfo(resetaccs!!(vi)) - retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) - return retval, setaccs!!(tsvi_new.varinfo, DynamicPPL.getaccs(tsvi_new)) - else - return DynamicPPL._evaluate!!(model, resetaccs!!(vi)) - end + return DynamicPPL.evaluate!!(model, vi) end -@inline function init!!( +function init!!( model::Model, vi::AbstractVarInfo, strategy::AbstractInitStrategy=InitFromPrior() ) - # This `@inline` is also mandatory for performance return init!!(Random.default_rng(), model, vi, strategy) end @@ -938,55 +948,42 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a -`ThreadSafeVarInfo` before evaluation. +If the model has been marked as requiring threadsafe evaluation, are available, the varinfo +provided will be wrapped in a `ThreadSafeVarInfo` before evaluation. Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) - evaluate_threadsafe!!(model, varinfo) + return if requires_threadsafe(model) + # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is + # a gradient type of some AD backend. + # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! + # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but + # the accumulators in the VarInfo are plain floats, we error since we can't change the + # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here + # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just + # plain ugly and hacky. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation + # fails). Take care when editing. + param_eltype = DynamicPPL.get_param_eltype(varinfo, model.context) + accs = map(DynamicPPL.getaccs(varinfo)) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + varinfo = DynamicPPL.setaccs!!(varinfo, accs) + wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) + result, wrapper_new = _evaluate!!(model, wrapper) + # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it + # will return the underlying VI, which is a bit counterintuitive (because + # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it + # again). + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) else - evaluate_threadunsafe!!(model, varinfo) + _evaluate!!(model, resetaccs!!(varinfo)) end end -""" - evaluate_threadunsafe!!(model, varinfo) - -Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. - -If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadsafe!!`](@ref) -""" -function evaluate_threadunsafe!!(model, varinfo) - return _evaluate!!(model, resetaccs!!(varinfo)) -end - -""" - evaluate_threadsafe!!(model, varinfo, context) - -Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. - -With the wrapper, Julia's multithreading can be used for observe statements in the `model` -but parallel sampling will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadunsafe!!`](@ref) -""" -function evaluate_threadsafe!!(model, varinfo) - wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper) - # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it - # will return the underlying VI, which is a bit counterintuitive (because - # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it - # again). - return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) -end - """ _evaluate!!(model::Model, varinfo) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 434480be6..9d3fb1925 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -278,15 +278,7 @@ end function unflatten(svi::SimpleVarInfo, x::AbstractVector) vals = unflatten(svi.values, x) - # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is - # required but undesireable. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi) - ) - return SimpleVarInfo(vals, accs, svi.transformation) + return SimpleVarInfo(vals, svi.accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 89877f385..0e906b6ca 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -13,12 +13,7 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo) # fields. This is not good practice --- see # https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full # explanation --- but it has worked okay so far. - # The use of nthreads()*2 here ensures that threadid() doesn't exceed - # the length of the logps array. Ideally, we would use maxthreadid(), - # but Mooncake can't differentiate through that. Empirically, nthreads()*2 - # seems to provide an upper bound to maxthreadid(), so we use that here. - # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 - accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)] + accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.maxthreadid()] return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..14e08515c 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -367,21 +367,7 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is - # a gradient type of some AD backend. - # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! - # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but - # the accumulators in the VarInfo are plain floats, we error since we can't change the - # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here - # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just - # plain ugly and hacky. - # The below line is finicky for type stability. For instance, assigning the eltype to - # convert to into an intermediate variable makes this unstable (constant propagation) - # fails. Take care when editing. - accs = map( - acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) - ) - return VarInfo(md, accs) + return VarInfo(md, vi.accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in diff --git a/test/compiler.jl b/test/compiler.jl index b1309254e..9056f666a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -606,12 +606,7 @@ module Issue537 end @model demo() = return __varinfo__ retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi - else - @test retval == svi - end + @test retval == svi # We should not be altering return-values other than at top-level. @model function demo() @@ -793,4 +788,39 @@ module Issue537 end res = model() @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) end + + @testset "Threads.@threads detection" begin + # Check that the compiler detects when `Threads.@threads` is used inside a model + + e1 = quote + @model function f1() + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e1) + + e2 = quote + @model function f2() + for j in 1:10 + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e2) + + e3 = quote + @model function f3() + begin + Threads.@threads for i in 1:10 + x[i] ~ Normal() + end + end + end + end + @test_logs (:warn, r"threadsafe evaluation") eval(e3) + end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index f43ed45a4..1d609a013 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -51,21 +51,19 @@ using Mooncake: Mooncake end @testset "Threaded observe" begin - if Threads.nthreads() > 1 - @model function threaded(y) - x ~ Normal() - Threads.@threads for i in eachindex(y) - y[i] ~ Normal(x) - end + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) end - N = 100 - model = threaded(zeros(N)) - ldf = DynamicPPL.LogDensityFunction(model) - - xs = [1.0] - @test LogDensityProblems.logdensity(ldf, xs) ≈ - logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end + N = 100 + model = setthreadsafe(threaded(zeros(N)), true) + ldf = DynamicPPL.LogDensityFunction(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) end end @@ -125,34 +123,32 @@ end end @testset "LogDensityFunction: performance" begin - if Threads.nthreads() == 1 - # Evaluating these three models should not lead to any allocations (but only when - # not using TSVI). - @model function f() - x ~ Normal() - return 1.0 ~ Normal(x) - end - @model function submodel_inner() - m ~ Normal(0, 1) - s ~ Exponential() - return (m=m, s=s) - end - # Note that for the allocation tests to work on this one, `inner` has - # to be passed as an argument to `submodel_outer`, instead of just - # being called inside the model function itself - @model function submodel_outer(inner) - params ~ to_submodel(inner) - y ~ Normal(params.m, params.s) - return 1.0 ~ Normal(y) - end - @testset for model in - (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) - vi = VarInfo(model) - ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) - x = vi[:] - bench = median(@be LogDensityProblems.logdensity(ldf, x)) - @test iszero(bench.allocs) - end + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in + (f(), submodel_inner() | (; s=0.0), submodel_outer(submodel_inner())) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity($ldf, $x)) + @test iszero(bench.allocs) end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 522730566..879e936d6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -5,13 +5,23 @@ @test threadsafe_vi.varinfo === vi @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} - @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2 + @test length(threadsafe_vi.accs_by_thread) == Threads.maxthreadid() expected_accs = DynamicPPL.AccumulatorTuple( (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... ) @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end + @testset "setthreadsafe" begin + @model f() = x ~ Normal() + model = f() + @test !DynamicPPL.requires_threadsafe(model) + model = setthreadsafe(model, true) + @test DynamicPPL.requires_threadsafe(model) + model = setthreadsafe(model, false) + @test !DynamicPPL.requires_threadsafe(model) + end + # TODO: Add more tests of the public API @testset "API" begin vi = VarInfo(gdemo_default) @@ -41,8 +51,6 @@ end @testset "model" begin - println("Peforming threading tests with $(Threads.nthreads()) threads") - x = rand(10_000) @model function wthreads(x) @@ -52,63 +60,24 @@ x[i] ~ Normal(x[i - 1], 1) end end - model = wthreads(x) - - vi = VarInfo() - model(vi) - lp_w_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("With `@threads`:") - println(" default:") - @time model(vi) - - # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - # check that it's wrapped during the model evaluation - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - # ensure that it's unwrapped after evaluation finishes - @test vi isa VarInfo + model = setthreadsafe(wthreads(x), true) - println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(model, vi) - - @model function wothreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) + function correct_lp(x) + lp = logpdf(Normal(0, 1), x[1]) for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) + lp += logpdf(Normal(x[i - 1], 1), x[i]) end + return lp end - model = wothreads(x) vi = VarInfo() - model(vi) - lp_wo_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end + _, vi = DynamicPPL.evaluate!!(model, vi) - println("Without `@threads`:") - println(" default:") - @time model(vi) - - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!(model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - @test vi_ isa VarInfo + # check that logp is correct + @test getlogjoint(vi) ≈ correct_lp(x) + # check that varinfo was wrapped during the model evaluation + @test vi_ isa DynamicPPL.ThreadSafeVarInfo + # ensure that it's unwrapped after evaluation finishes @test vi isa VarInfo - - println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end