|
1 | 1 | module DynamicPPLMCMCChainsExt |
2 | 2 |
|
3 | | -using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC |
| 3 | +using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random |
4 | 4 | using MCMCChains: MCMCChains |
5 | 5 |
|
6 | 6 | function getindex_varname( |
@@ -118,6 +118,47 @@ function AbstractMCMC.to_samples( |
118 | 118 | end |
119 | 119 | end |
120 | 120 |
|
| 121 | +""" |
| 122 | + reevaluate_with( |
| 123 | + rng::AbstractRNG, |
| 124 | + model::Model, |
| 125 | + chain::MCMCChains.Chains; |
| 126 | + fallback=nothing, |
| 127 | + ) |
| 128 | +
|
| 129 | +Re-evaluate `model` for each sample in `chain`, returning an matrix of (retval, varinfo) |
| 130 | +tuples. |
| 131 | +
|
| 132 | +This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the |
| 133 | +initialisation strategy when re-evaluating the model. For many usecases the fallback should |
| 134 | +not be provided (as we expect the chain to contain all necessary variables); but for |
| 135 | +`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating |
| 136 | +the posterior predictions). |
| 137 | +""" |
| 138 | +function reevaluate_with_chain( |
| 139 | + rng::Random.AbstractRNG, |
| 140 | + model::DynamicPPL.Model, |
| 141 | + chain::MCMCChains.Chains, |
| 142 | + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, |
| 143 | + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, |
| 144 | +) where {N} |
| 145 | + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) |
| 146 | + return map(params_with_stats) do ps |
| 147 | + varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) |
| 148 | + DynamicPPL.init!!( |
| 149 | + rng, model, varinfo, DynamicPPL.InitFromParams(ps.params, fallback) |
| 150 | + ) |
| 151 | + end |
| 152 | +end |
| 153 | +function reevaluate_with_chain( |
| 154 | + model::DynamicPPL.Model, |
| 155 | + chain::MCMCChains.Chains, |
| 156 | + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, |
| 157 | + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, |
| 158 | +) where {N} |
| 159 | + return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback) |
| 160 | +end |
| 161 | + |
121 | 162 | """ |
122 | 163 | predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) |
123 | 164 |
|
@@ -186,25 +227,18 @@ function DynamicPPL.predict( |
186 | 227 | include_all=false, |
187 | 228 | ) |
188 | 229 | parameter_only_chain = MCMCChains.get_sections(chain, :parameters) |
189 | | - |
190 | | - params_and_stats = AbstractMCMC.to_samples( |
191 | | - DynamicPPL.ParamsWithStats, parameter_only_chain |
| 230 | + accs = DynamicPPL.AccumulatorTuple( |
| 231 | + DynamicPPL.LogPriorAccumulator(), |
| 232 | + DynamicPPL.LogLikelihoodAccumulator(), |
| 233 | + DynamicPPL.ValuesAsInModelAccumulator(false), |
| 234 | + ) |
| 235 | + predictions = map( |
| 236 | + DynamicPPL.ParamsWithStats ∘ last, |
| 237 | + reevaluate_with_chain( |
| 238 | + rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior() |
| 239 | + ), |
192 | 240 | ) |
193 | | - predictions = map(params_and_stats) do ps |
194 | | - varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo( |
195 | | - DynamicPPL.AccumulatorTuple(( |
196 | | - DynamicPPL.LogPriorAccumulator(), |
197 | | - DynamicPPL.LogLikelihoodAccumulator(), |
198 | | - DynamicPPL.ValuesAsInModelAccumulator(false), |
199 | | - )), |
200 | | - ) |
201 | | - _, varinfo = DynamicPPL.init!!( |
202 | | - rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) |
203 | | - ) |
204 | | - DynamicPPL.ParamsWithStats(varinfo) |
205 | | - end |
206 | 241 | chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions) |
207 | | - |
208 | 242 | parameter_names = if include_all |
209 | 243 | MCMCChains.names(chain_result, :parameters) |
210 | 244 | else |
@@ -284,13 +318,7 @@ julia> returned(model, chain) |
284 | 318 | """ |
285 | 319 | function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) |
286 | 320 | chain = MCMCChains.get_sections(chain_full, :parameters) |
287 | | - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) |
288 | | - return map(params_with_stats) do ps |
289 | | - varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(())) |
290 | | - first( |
291 | | - DynamicPPL.init!!(model, varinfo, DynamicPPL.InitFromParams(ps.params, nothing)) |
292 | | - ) |
293 | | - end |
| 321 | + return map(first, reevaluate_with_chain(model, chain, (), nothing)) |
294 | 322 | end |
295 | 323 |
|
296 | 324 | """ |
@@ -386,14 +414,10 @@ function DynamicPPL.pointwise_logdensities( |
386 | 414 | acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() |
387 | 415 | accname = DynamicPPL.accumulator_name(acc) |
388 | 416 | parameter_only_chain = MCMCChains.get_sections(chain, :parameters) |
389 | | - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) |
390 | | - pointwise_logps = map(params_with_stats) do ps |
391 | | - accs = DynamicPPL.AccumulatorTuple((acc,)) |
392 | | - vi = DynamicPPL.Experimental.OnlyAccsVarInfo(accs) |
393 | | - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) |
394 | | - DynamicPPL.getacc(vi, Val(accname)).logps |
395 | | - end |
396 | | - |
| 417 | + pointwise_logps = |
| 418 | + map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi) |
| 419 | + DynamicPPL.getacc(vi, Val(accname)).logps |
| 420 | + end |
397 | 421 | # pointwise_logps is a matrix of OrderedDicts |
398 | 422 | all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() |
399 | 423 | for d in pointwise_logps |
@@ -480,16 +504,15 @@ julia> logjoint(demo_model([1., 2.]), chain) |
480 | 504 | ``` |
481 | 505 | """ |
482 | 506 | function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) |
483 | | - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) |
484 | | - return map(params_with_stats) do ps |
485 | | - vi = DynamicPPL.Experimental.OnlyAccsVarInfo( |
486 | | - DynamicPPL.AccumulatorTuple(( |
487 | | - DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator() |
488 | | - )), |
489 | | - ) |
490 | | - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) |
491 | | - DynamicPPL.getlogjoint(vi) |
492 | | - end |
| 507 | + return map( |
| 508 | + DynamicPPL.getlogjoint ∘ last, |
| 509 | + reevaluate_with_chain( |
| 510 | + model, |
| 511 | + chain, |
| 512 | + (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()), |
| 513 | + nothing, |
| 514 | + ), |
| 515 | + ) |
493 | 516 | end |
494 | 517 |
|
495 | 518 | """ |
@@ -521,14 +544,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain) |
521 | 544 | ``` |
522 | 545 | """ |
523 | 546 | function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) |
524 | | - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) |
525 | | - return map(params_with_stats) do ps |
526 | | - vi = DynamicPPL.Experimental.OnlyAccsVarInfo( |
527 | | - DynamicPPL.AccumulatorTuple((DynamicPPL.LogLikelihoodAccumulator())) |
528 | | - ) |
529 | | - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) |
530 | | - DynamicPPL.getloglikelihood(vi) |
531 | | - end |
| 547 | + return map( |
| 548 | + DynamicPPL.getloglikelihood ∘ last, |
| 549 | + reevaluate_with_chain( |
| 550 | + model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing |
| 551 | + ), |
| 552 | + ) |
532 | 553 | end |
533 | 554 |
|
534 | 555 | """ |
@@ -561,14 +582,10 @@ julia> logprior(demo_model([1., 2.]), chain) |
561 | 582 | ``` |
562 | 583 | """ |
563 | 584 | function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) |
564 | | - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) |
565 | | - return map(params_with_stats) do ps |
566 | | - vi = DynamicPPL.Experimental.OnlyAccsVarInfo( |
567 | | - DynamicPPL.AccumulatorTuple((DynamicPPL.LogPriorAccumulator())) |
568 | | - ) |
569 | | - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) |
570 | | - DynamicPPL.getlogprior(vi) |
571 | | - end |
| 585 | + return map( |
| 586 | + DynamicPPL.getlogprior ∘ last, |
| 587 | + reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing), |
| 588 | + ) |
572 | 589 | end |
573 | 590 |
|
574 | 591 | end |
0 commit comments