Skip to content

Commit d013e61

Browse files
committed
Reduce number of type parameters in methods
1 parent 41a99ef commit d013e61

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

src/model.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,20 @@ function Model(
9999
return Model(f, args, NamedTuple(kwargs), context, threadsafe)
100100
end
101101

102+
function _requires_threadsafe(
103+
::Model{F,A,D,M,Ta,Td,Ctx,Threaded}
104+
) where {F,A,D,M,Ta,Td,Ctx,Threaded}
105+
return Threaded
106+
end
107+
102108
"""
103109
contextualize(model::Model, context::AbstractContext)
104110
105111
Return a new `Model` with the same evaluation function and other arguments, but
106112
with its underlying context set to `context`.
107113
"""
108-
function contextualize(
109-
model::Model{F,A,D,M,Ta,Td,Ctx,Threaded}, context::AbstractContext
110-
) where {F,A,D,M,Ta,Td,Ctx,Threaded}
111-
return Model(model.f, model.args, model.defaults, context, Threaded)
114+
function contextualize(model::Model, context::AbstractContext)
115+
return Model(model.f, model.args, model.defaults, context, _requires_threadsafe(model))
112116
end
113117

114118
"""
@@ -136,10 +140,8 @@ outside of the parallel region is safe without needing to set `threadsafe=true`.
136140
137141
It is also not needed for multithreaded sampling with AbstractMCMC's `MCMCThreads()`.
138142
"""
139-
function setthreadsafe(
140-
model::Model{F,A,D,M,Ta,Td,Ctx,Threaded}, threadsafe::Bool
141-
) where {F,A,D,M,Ta,Td,Ctx,Threaded}
142-
return if Threaded == threadsafe
143+
function setthreadsafe(model::Model{F,A,D,M}, threadsafe::Bool) where {F,A,D,M}
144+
return if _requires_threadsafe(model) == threadsafe
143145
model
144146
else
145147
Model{M,threadsafe}(model.f, model.args, model.defaults, model.context)
@@ -969,27 +971,24 @@ end
969971
970972
Evaluate the `model` with the given `varinfo`.
971973
972-
If multiple threads are available, the varinfo provided will be wrapped in a
973-
`ThreadSafeVarInfo` before evaluation.
974+
If the model has been marked as requiring threadsafe evaluation, are available, the varinfo
975+
provided will be wrapped in a `ThreadSafeVarInfo` before evaluation.
974976
975977
Returns a tuple of the model's return value, plus the updated `varinfo`
976978
(unwrapped if necessary).
977979
"""
978-
function AbstractPPL.evaluate!!(
979-
model::Model{F,A,D,M,Ta,Td,Ctx,false}, varinfo::AbstractVarInfo
980-
) where {F,A,D,M,Ta,Td,Ctx}
981-
return _evaluate!!(model, resetaccs!!(varinfo))
982-
end
983-
function AbstractPPL.evaluate!!(
984-
model::Model{F,A,D,M,Ta,Td,Ctx,true}, varinfo::AbstractVarInfo
985-
) where {F,A,D,M,Ta,Td,Ctx}
986-
wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo))
987-
result, wrapper_new = _evaluate!!(model, wrapper)
988-
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it
989-
# will return the underlying VI, which is a bit counterintuitive (because
990-
# calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it
991-
# again).
992-
return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new))
980+
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
981+
return if _requires_threadsafe(model)
982+
wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo))
983+
result, wrapper_new = _evaluate!!(model, wrapper)
984+
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it
985+
# will return the underlying VI, which is a bit counterintuitive (because
986+
# calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it
987+
# again).
988+
return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new))
989+
else
990+
_evaluate!!(model, resetaccs!!(varinfo))
991+
end
993992
end
994993

995994
"""

0 commit comments

Comments
 (0)