@@ -99,16 +99,20 @@ function Model(
9999 return Model (f, args, NamedTuple (kwargs), context, threadsafe)
100100end
101101
102+ function _requires_threadsafe (
103+ :: Model{F,A,D,M,Ta,Td,Ctx,Threaded}
104+ ) where {F,A,D,M,Ta,Td,Ctx,Threaded}
105+ return Threaded
106+ end
107+
102108"""
103109 contextualize(model::Model, context::AbstractContext)
104110
105111Return a new `Model` with the same evaluation function and other arguments, but
106112with its underlying context set to `context`.
107113"""
108- function contextualize (
109- model:: Model{F,A,D,M,Ta,Td,Ctx,Threaded} , context:: AbstractContext
110- ) where {F,A,D,M,Ta,Td,Ctx,Threaded}
111- return Model (model. f, model. args, model. defaults, context, Threaded)
114+ function contextualize (model:: Model , context:: AbstractContext )
115+ return Model (model. f, model. args, model. defaults, context, _requires_threadsafe (model))
112116end
113117
114118"""
@@ -136,10 +140,8 @@ outside of the parallel region is safe without needing to set `threadsafe=true`.
136140
137141It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`.
138142"""
139- function setthreadsafe (
140- model:: Model{F,A,D,M,Ta,Td,Ctx,Threaded} , threadsafe:: Bool
141- ) where {F,A,D,M,Ta,Td,Ctx,Threaded}
142- return if Threaded == threadsafe
143+ function setthreadsafe (model:: Model{F,A,D,M} , threadsafe:: Bool ) where {F,A,D,M}
144+ return if _requires_threadsafe (model) == threadsafe
143145 model
144146 else
145147 Model {M,threadsafe} (model. f, model. args, model. defaults, model. context)
@@ -969,27 +971,24 @@ end
969971
970972Evaluate the `model` with the given `varinfo`.
971973
972- If multiple threads are available, the varinfo provided will be wrapped in a
973- `ThreadSafeVarInfo` before evaluation.
974+ If the model has been marked as requiring threadsafe evaluation, are available, the varinfo
975+ provided will be wrapped in a `ThreadSafeVarInfo` before evaluation.
974976
975977Returns a tuple of the model's return value, plus the updated `varinfo`
976978(unwrapped if necessary).
977979"""
978- function AbstractPPL. evaluate!! (
979- model:: Model{F,A,D,M,Ta,Td,Ctx,false} , varinfo:: AbstractVarInfo
980- ) where {F,A,D,M,Ta,Td,Ctx}
981- return _evaluate!! (model, resetaccs!! (varinfo))
982- end
983- function AbstractPPL. evaluate!! (
984- model:: Model{F,A,D,M,Ta,Td,Ctx,true} , varinfo:: AbstractVarInfo
985- ) where {F,A,D,M,Ta,Td,Ctx}
986- wrapper = ThreadSafeVarInfo (resetaccs!! (varinfo))
987- result, wrapper_new = _evaluate!! (model, wrapper)
988- # TODO (penelopeysm): If seems that if you pass a TSVI to this method, it
989- # will return the underlying VI, which is a bit counterintuitive (because
990- # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it
991- # again).
992- return result, setaccs!! (wrapper_new. varinfo, getaccs (wrapper_new))
980+ function AbstractPPL. evaluate!! (model:: Model , varinfo:: AbstractVarInfo )
981+ return if _requires_threadsafe (model)
982+ wrapper = ThreadSafeVarInfo (resetaccs!! (varinfo))
983+ result, wrapper_new = _evaluate!! (model, wrapper)
984+ # TODO (penelopeysm): If seems that if you pass a TSVI to this method, it
985+ # will return the underlying VI, which is a bit counterintuitive (because
986+ # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it
987+ # again).
988+ return result, setaccs!! (wrapper_new. varinfo, getaccs (wrapper_new))
989+ else
990+ _evaluate!! (model, resetaccs!! (varinfo))
991+ end
993992end
994993
995994"""
0 commit comments