Skip to content

Commit 00e6195

Browse files
committed
Use init!! instead of fast_evaluate!!
1 parent e165249 commit 00e6195

File tree

5 files changed

+37
-80
lines changed

5 files changed

+37
-80
lines changed

docs/src/api.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,10 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte
6666
LogDensityFunction
6767
```
6868

69-
Internally, this is accomplished using:
69+
Internally, this is accomplished using [`init!!`](@ref) on:
7070

7171
```@docs
7272
OnlyAccsVarInfo
73-
fast_evaluate!!
7473
```
7574

7675
## Condition and decondition
@@ -517,7 +516,7 @@ The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
517516
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.
518517

519518
```@docs
520-
DynamicPPL.init!!
519+
init!!
521520
```
522521

523522
To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.

src/DynamicPPL.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,8 @@ export AbstractVarInfo,
9292
getargnames,
9393
extract_priors,
9494
values_as_in_model,
95-
# LogDensityFunction and fasteval
95+
# LogDensityFunction
9696
LogDensityFunction,
97-
fast_evaluate!!,
98-
OnlyAccsVarInfo,
9997
# Leaf contexts
10098
AbstractContext,
10199
contextualize,
@@ -110,6 +108,9 @@ export AbstractVarInfo,
110108
# Tilde pipeline
111109
tilde_assume!!,
112110
tilde_observe!!,
111+
# Evaluation
112+
evaluate!!,
113+
init!!,
113114
# Initialisation
114115
AbstractInitStrategy,
115116
InitFromPrior,

src/chains.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
"""
138138
ParamsWithStats(
139139
param_vector::AbstractVector,
140-
ldf::DynamicPPL.Experimental.FastLDF,
140+
ldf::DynamicPPL.LogDensityFunction,
141141
stats::NamedTuple=NamedTuple();
142142
include_colon_eq::Bool=true,
143143
include_log_probs::Bool=true,
@@ -152,11 +152,11 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
152152
1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as
153153
otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent
154154
MCMC iterations).
155-
2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`.
155+
2. The re-evaluation is faster as it uses `OnlyAccsVarInfo` rather than a full VarInfo.
156156
"""
157157
function ParamsWithStats(
158158
param_vector::AbstractVector,
159-
ldf::DynamicPPL.Experimental.FastLDF,
159+
ldf::DynamicPPL.LogDensityFunction,
160160
stats::NamedTuple=NamedTuple();
161161
include_colon_eq::Bool=true,
162162
include_log_probs::Bool=true,
@@ -174,9 +174,7 @@ function ParamsWithStats(
174174
else
175175
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
176176
end
177-
_, vi = DynamicPPL.Experimental.fast_evaluate!!(
178-
ldf.model, strategy, AccumulatorTuple(accs)
179-
)
177+
_, vi = DynamicPPL.init!!(ldf.model, AccumulatorTuple(accs), strategy)
180178
params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
181179
if include_log_probs
182180
stats = merge(

src/contexts/init.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,12 @@ function tilde_assume!!(
312312
# linked, and `inv_logjac` will also just be the negative of `fwd_logjac`.
313313
#
314314
# However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which
315-
# case this branch is never hit (since `in_varinfo` will always be false). It does
316-
# mean that the combination of InitFromParams{<:VectorWithRanges} with a full,
317-
# linked, VarInfo will be very slow. That should never really be used, though. So
318-
# (at least for now) we can leave this branch in for full generality with other
319-
# combinations of init strategies / VarInfo.
315+
# case this method is never hit (since there's a special method for it, in
316+
# `src/onlyaccs.jl`). It does mean that the combination of
317+
# InitFromParams{<:VectorWithRanges} with a full, linked, VarInfo will be very slow.
318+
# That should never really be used, though. So (at least for now) we can leave this
319+
# branch in for full generality with other combinations of init strategies /
320+
# VarInfo.
320321
#
321322
# TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue
322323
# is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`,

src/fasteval.jl

Lines changed: 21 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -29,60 +29,6 @@ using LogDensityProblems: LogDensityProblems
2929
import DifferentiationInterface as DI
3030
using Random: Random
3131

32-
"""
33-
DynamicPPL.fast_evaluate!!(
34-
[rng::Random.AbstractRNG,]
35-
model::Model,
36-
strategy::AbstractInitStrategy,
37-
accs::AccumulatorTuple,
38-
)
39-
40-
Evaluate a model using parameters obtained via `strategy`, and only computing the results in
41-
the provided accumulators.
42-
43-
It is assumed that the accumulators passed in have been initialised to appropriate values,
44-
as this function will not reset them. The default constructors for each accumulator will do
45-
this for you correctly.
46-
47-
Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs`
48-
argument may be mutated (depending on how the accumulators are implemented); hence the `!!`
49-
in the function name.
50-
"""
51-
@inline function fast_evaluate!!(
52-
# Note that this `@inline` is mandatory for performance. If it's not inlined, it leads
53-
# to extra allocations (even for trivial models) and much slower runtime.
54-
rng::Random.AbstractRNG,
55-
model::Model,
56-
strategy::AbstractInitStrategy,
57-
accs::AccumulatorTuple,
58-
)
59-
ctx = InitContext(rng, strategy)
60-
model = DynamicPPL.setleafcontext(model, ctx)
61-
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
62-
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
63-
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
64-
# here.
65-
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
66-
# it _should_ do, but this is wrong regardless.
67-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
68-
vi = if Threads.nthreads() > 1
69-
param_eltype = DynamicPPL.get_param_eltype(strategy)
70-
accs = map(accs) do acc
71-
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
72-
end
73-
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
74-
else
75-
OnlyAccsVarInfo(accs)
76-
end
77-
return DynamicPPL._evaluate!!(model, vi)
78-
end
79-
@inline function fast_evaluate!!(
80-
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
81-
)
82-
# This `@inline` is also mandatory for performance
83-
return fast_evaluate!!(Random.default_rng(), model, strategy, accs)
84-
end
85-
8632
"""
8733
DynamicPPL.LogDensityFunction(
8834
model::Model,
@@ -154,9 +100,9 @@ metadata can often be quite wasteful. In particular, it is very common that the
154100
we care about from model evaluation are those which are stored in accumulators, such as log
155101
probability densities, or `ValuesAsInModel`.
156102
157-
To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains
158-
accumulators. It implements enough of the `AbstractVarInfo` interface to not error during
159-
model evaluation.
103+
To avoid this issue, instead of evaluating a model with a full `VarInfo`, we use just an
104+
`OnlyAccsVarInfo`. It implements enough of the `AbstractVarInfo` interface to not error
105+
during model evaluation.
160106
161107
Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with
162108
it, it is mandatory that parameters are provided from outside the VarInfo, namely via
@@ -199,6 +145,7 @@ struct LogDensityFunction{
199145
F<:Function,
200146
N<:NamedTuple,
201147
ADP<:Union{Nothing,DI.GradientPrep},
148+
AT<:AccumulatorTuple,
202149
}
203150
model::M
204151
adtype::AD
@@ -207,13 +154,15 @@ struct LogDensityFunction{
207154
_varname_ranges::Dict{VarName,RangeAndLinked}
208155
_adprep::ADP
209156
_dim::Int
157+
_accs::AccumulatorTuple
210158

211159
function LogDensityFunction(
212160
model::Model,
213161
getlogdensity::Function=getlogjoint_internal,
214162
varinfo::AbstractVarInfo=VarInfo(model);
215163
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
216164
)
165+
accs = fast_ldf_accs(getlogdensity)
217166
# Figure out which variable corresponds to which index, and
218167
# which variables are linked.
219168
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
@@ -226,7 +175,7 @@ struct LogDensityFunction{
226175
# Make backend-specific tweaks to the adtype
227176
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
228177
DI.prepare_gradient(
229-
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
178+
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges, accs),
230179
adtype,
231180
x,
232181
)
@@ -237,8 +186,9 @@ struct LogDensityFunction{
237186
typeof(getlogdensity),
238187
typeof(all_iden_ranges),
239188
typeof(prep),
189+
typeof(accs),
240190
}(
241-
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim
191+
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim, accs
242192
)
243193
end
244194
end
@@ -268,21 +218,25 @@ struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
268218
getlogdensity::F
269219
iden_varname_ranges::N
270220
varname_ranges::Dict{VarName,RangeAndLinked}
221+
accs::AccumulatorTuple
271222
end
272223
function (f::LogDensityAt)(params::AbstractVector{<:Real})
273224
strategy = InitFromParams(
274225
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
275226
)
276-
accs = fast_ldf_accs(f.getlogdensity)
277-
_, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs)
227+
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(f.accs), strategy)
278228
return f.getlogdensity(vi)
279229
end
280230

281231
function LogDensityProblems.logdensity(
282232
ldf::LogDensityFunction, params::AbstractVector{<:Real}
283233
)
284234
return LogDensityAt(
285-
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
235+
ldf.model,
236+
ldf._getlogdensity,
237+
ldf._iden_varname_ranges,
238+
ldf._varname_ranges,
239+
ldf._accs,
286240
)(
287241
params
288242
)
@@ -293,7 +247,11 @@ function LogDensityProblems.logdensity_and_gradient(
293247
)
294248
return DI.value_and_gradient(
295249
LogDensityAt(
296-
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
250+
ldf.model,
251+
ldf._getlogdensity,
252+
ldf._iden_varname_ranges,
253+
ldf._varname_ranges,
254+
ldf._accs,
297255
),
298256
ldf._adprep,
299257
ldf.adtype,

0 commit comments

Comments
 (0)