@@ -29,60 +29,6 @@ using LogDensityProblems: LogDensityProblems
2929import DifferentiationInterface as DI
3030using Random: Random
3131
32- """
33- DynamicPPL.fast_evaluate!!(
34- [rng::Random.AbstractRNG,]
35- model::Model,
36- strategy::AbstractInitStrategy,
37- accs::AccumulatorTuple,
38- )
39-
40- Evaluate a model using parameters obtained via `strategy`, and only computing the results in
41- the provided accumulators.
42-
43- It is assumed that the accumulators passed in have been initialised to appropriate values,
44- as this function will not reset them. The default constructors for each accumulator will do
45- this for you correctly.
46-
47- Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs`
48- argument may be mutated (depending on how the accumulators are implemented); hence the `!!`
49- in the function name.
50- """
51- @inline function fast_evaluate!! (
52- # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads
53- # to extra allocations (even for trivial models) and much slower runtime.
54- rng:: Random.AbstractRNG ,
55- model:: Model ,
56- strategy:: AbstractInitStrategy ,
57- accs:: AccumulatorTuple ,
58- )
59- ctx = InitContext (rng, strategy)
60- model = DynamicPPL. setleafcontext (model, ctx)
61- # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
62- # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
63- # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
64- # here.
65- # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
66- # it _should_ do, but this is wrong regardless.
67- # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
68- vi = if Threads. nthreads () > 1
69- param_eltype = DynamicPPL. get_param_eltype (strategy)
70- accs = map (accs) do acc
71- DynamicPPL. convert_eltype (float_type_with_fallback (param_eltype), acc)
72- end
73- ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
74- else
75- OnlyAccsVarInfo (accs)
76- end
77- return DynamicPPL. _evaluate!! (model, vi)
78- end
79- @inline function fast_evaluate!! (
80- model:: Model , strategy:: AbstractInitStrategy , accs:: AccumulatorTuple
81- )
82- # This `@inline` is also mandatory for performance
83- return fast_evaluate!! (Random. default_rng (), model, strategy, accs)
84- end
85-
8632"""
8733 DynamicPPL.LogDensityFunction(
8834 model::Model,
@@ -154,9 +100,9 @@ metadata can often be quite wasteful. In particular, it is very common that the
154100we care about from model evaluation are those which are stored in accumulators, such as log
155101probability densities, or `ValuesAsInModel`.
156102
157- To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains
158- accumulators . It implements enough of the `AbstractVarInfo` interface to not error during
159- model evaluation.
103+ To avoid this issue, instead of evaluating a model with a full ` VarInfo`, we use just an
104+ `OnlyAccsVarInfo` . It implements enough of the `AbstractVarInfo` interface to not error
105+ during model evaluation.
160106
161107Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with
162108it, it is mandatory that parameters are provided from outside the VarInfo, namely via
@@ -199,6 +145,7 @@ struct LogDensityFunction{
199145 F<: Function ,
200146 N<: NamedTuple ,
201147 ADP<: Union{Nothing,DI.GradientPrep} ,
148+ AT<: AccumulatorTuple ,
202149}
203150 model:: M
204151 adtype:: AD
@@ -207,13 +154,15 @@ struct LogDensityFunction{
207154 _varname_ranges:: Dict{VarName,RangeAndLinked}
208155 _adprep:: ADP
209156 _dim:: Int
157+ _accs:: AccumulatorTuple
210158
211159 function LogDensityFunction (
212160 model:: Model ,
213161 getlogdensity:: Function = getlogjoint_internal,
214162 varinfo:: AbstractVarInfo = VarInfo (model);
215163 adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
216164 )
165+ accs = fast_ldf_accs (getlogdensity)
217166 # Figure out which variable corresponds to which index, and
218167 # which variables are linked.
219168 all_iden_ranges, all_ranges = get_ranges_and_linked (varinfo)
@@ -226,7 +175,7 @@ struct LogDensityFunction{
226175 # Make backend-specific tweaks to the adtype
227176 adtype = DynamicPPL. tweak_adtype (adtype, model, varinfo)
228177 DI. prepare_gradient (
229- LogDensityAt (model, getlogdensity, all_iden_ranges, all_ranges),
178+ LogDensityAt (model, getlogdensity, all_iden_ranges, all_ranges, accs ),
230179 adtype,
231180 x,
232181 )
@@ -237,8 +186,9 @@ struct LogDensityFunction{
237186 typeof (getlogdensity),
238187 typeof (all_iden_ranges),
239188 typeof (prep),
189+ typeof (accs),
240190 }(
241- model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim
191+ model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim, accs
242192 )
243193 end
244194end
@@ -268,21 +218,25 @@ struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
268218 getlogdensity:: F
269219 iden_varname_ranges:: N
270220 varname_ranges:: Dict{VarName,RangeAndLinked}
221+ accs:: AccumulatorTuple
271222end
272223function (f:: LogDensityAt )(params:: AbstractVector{<:Real} )
273224 strategy = InitFromParams (
274225 VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
275226 )
276- accs = fast_ldf_accs (f. getlogdensity)
277- _, vi = DynamicPPL. fast_evaluate!! (f. model, strategy, accs)
227+ _, vi = DynamicPPL. init!! (f. model, OnlyAccsVarInfo (f. accs), strategy)
278228 return f. getlogdensity (vi)
279229end
280230
281231function LogDensityProblems. logdensity (
282232 ldf:: LogDensityFunction , params:: AbstractVector{<:Real}
283233)
284234 return LogDensityAt (
285- ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
235+ ldf. model,
236+ ldf. _getlogdensity,
237+ ldf. _iden_varname_ranges,
238+ ldf. _varname_ranges,
239+ ldf. _accs,
286240 )(
287241 params
288242 )
@@ -293,7 +247,11 @@ function LogDensityProblems.logdensity_and_gradient(
293247)
294248 return DI. value_and_gradient (
295249 LogDensityAt (
296- ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
250+ ldf. model,
251+ ldf. _getlogdensity,
252+ ldf. _iden_varname_ranges,
253+ ldf. _varname_ranges,
254+ ldf. _accs,
297255 ),
298256 ldf. _adprep,
299257 ldf. adtype,
0 commit comments