Skip to content

Commit 8902e6a

Browse files
committed
Make it more elegant
1 parent 5a976fd commit 8902e6a

File tree

1 file changed

+76
-59
lines changed

1 file changed

+76
-59
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 76 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module DynamicPPLMCMCChainsExt
22

3-
using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
3+
using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random
44
using MCMCChains: MCMCChains
55

66
function getindex_varname(
@@ -118,6 +118,47 @@ function AbstractMCMC.to_samples(
118118
end
119119
end
120120

121+
"""
122+
reevaluate_with(
123+
rng::AbstractRNG,
124+
model::Model,
125+
chain::MCMCChains.Chains;
126+
fallback=nothing,
127+
)
128+
129+
Re-evaluate `model` for each sample in `chain`, returning an matrix of (retval, varinfo)
130+
tuples.
131+
132+
This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the
133+
initialisation strategy when re-evaluating the model. For many usecases the fallback should
134+
not be provided (as we expect the chain to contain all necessary variables); but for
135+
`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating
136+
the posterior predictions).
137+
"""
138+
function reevaluate_with_chain(
139+
rng::Random.AbstractRNG,
140+
model::DynamicPPL.Model,
141+
chain::MCMCChains.Chains,
142+
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
143+
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
144+
) where {N}
145+
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
146+
return map(params_with_stats) do ps
147+
varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs))
148+
DynamicPPL.init!!(
149+
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params, fallback)
150+
)
151+
end
152+
end
153+
function reevaluate_with_chain(
154+
model::DynamicPPL.Model,
155+
chain::MCMCChains.Chains,
156+
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
157+
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
158+
) where {N}
159+
return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback)
160+
end
161+
121162
"""
122163
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
123164
@@ -186,25 +227,18 @@ function DynamicPPL.predict(
186227
include_all=false,
187228
)
188229
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
189-
190-
params_and_stats = AbstractMCMC.to_samples(
191-
DynamicPPL.ParamsWithStats, parameter_only_chain
230+
accs = DynamicPPL.AccumulatorTuple(
231+
DynamicPPL.LogPriorAccumulator(),
232+
DynamicPPL.LogLikelihoodAccumulator(),
233+
DynamicPPL.ValuesAsInModelAccumulator(false),
234+
)
235+
predictions = map(
236+
DynamicPPL.ParamsWithStats last,
237+
reevaluate_with_chain(
238+
rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior()
239+
),
192240
)
193-
predictions = map(params_and_stats) do ps
194-
varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(
195-
DynamicPPL.AccumulatorTuple((
196-
DynamicPPL.LogPriorAccumulator(),
197-
DynamicPPL.LogLikelihoodAccumulator(),
198-
DynamicPPL.ValuesAsInModelAccumulator(false),
199-
)),
200-
)
201-
_, varinfo = DynamicPPL.init!!(
202-
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params)
203-
)
204-
DynamicPPL.ParamsWithStats(varinfo)
205-
end
206241
chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions)
207-
208242
parameter_names = if include_all
209243
MCMCChains.names(chain_result, :parameters)
210244
else
@@ -284,13 +318,7 @@ julia> returned(model, chain)
284318
"""
285319
function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains)
286320
chain = MCMCChains.get_sections(chain_full, :parameters)
287-
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
288-
return map(params_with_stats) do ps
289-
varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(()))
290-
first(
291-
DynamicPPL.init!!(model, varinfo, DynamicPPL.InitFromParams(ps.params, nothing))
292-
)
293-
end
321+
return map(first, reevaluate_with_chain(model, chain, (), nothing))
294322
end
295323

296324
"""
@@ -386,14 +414,10 @@ function DynamicPPL.pointwise_logdensities(
386414
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}()
387415
accname = DynamicPPL.accumulator_name(acc)
388416
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
389-
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
390-
pointwise_logps = map(params_with_stats) do ps
391-
accs = DynamicPPL.AccumulatorTuple((acc,))
392-
vi = DynamicPPL.Experimental.OnlyAccsVarInfo(accs)
393-
_, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing))
394-
DynamicPPL.getacc(vi, Val(accname)).logps
395-
end
396-
417+
pointwise_logps =
418+
map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi)
419+
DynamicPPL.getacc(vi, Val(accname)).logps
420+
end
397421
# pointwise_logps is a matrix of OrderedDicts
398422
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
399423
for d in pointwise_logps
@@ -480,16 +504,15 @@ julia> logjoint(demo_model([1., 2.]), chain)
480504
```
481505
"""
482506
function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains)
483-
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
484-
return map(params_with_stats) do ps
485-
vi = DynamicPPL.Experimental.OnlyAccsVarInfo(
486-
DynamicPPL.AccumulatorTuple((
487-
DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()
488-
)),
489-
)
490-
_, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing))
491-
DynamicPPL.getlogjoint(vi)
492-
end
507+
return map(
508+
DynamicPPL.getlogjoint last,
509+
reevaluate_with_chain(
510+
model,
511+
chain,
512+
(DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()),
513+
nothing,
514+
),
515+
)
493516
end
494517

495518
"""
@@ -521,14 +544,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain)
521544
```
522545
"""
523546
function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
524-
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
525-
return map(params_with_stats) do ps
526-
vi = DynamicPPL.Experimental.OnlyAccsVarInfo(
527-
DynamicPPL.AccumulatorTuple((DynamicPPL.LogLikelihoodAccumulator()))
528-
)
529-
_, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing))
530-
DynamicPPL.getloglikelihood(vi)
531-
end
547+
return map(
548+
DynamicPPL.getloglikelihood last,
549+
reevaluate_with_chain(
550+
model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing
551+
),
552+
)
532553
end
533554

534555
"""
@@ -561,14 +582,10 @@ julia> logprior(demo_model([1., 2.]), chain)
561582
```
562583
"""
563584
function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
564-
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
565-
return map(params_with_stats) do ps
566-
vi = DynamicPPL.Experimental.OnlyAccsVarInfo(
567-
DynamicPPL.AccumulatorTuple((DynamicPPL.LogPriorAccumulator()))
568-
)
569-
_, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing))
570-
DynamicPPL.getlogprior(vi)
571-
end
585+
return map(
586+
DynamicPPL.getlogprior last,
587+
reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing),
588+
)
572589
end
573590

574591
end

0 commit comments

Comments
 (0)