-
Notifications
You must be signed in to change notification settings - Fork 37
Implement predict, returned, logjoint, ... with OnlyAccsVarInfo
#1130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,41 +1,19 @@ | ||
| module DynamicPPLMCMCChainsExt | ||
|
|
||
| using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC | ||
| using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random | ||
| using MCMCChains: MCMCChains | ||
|
|
||
| _has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names | ||
|
|
||
| function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) | ||
| return _has_varname_to_symbol(chain.info) | ||
| end | ||
|
|
||
| function _check_varname_indexing(c::MCMCChains.Chains) | ||
| return DynamicPPL.supports_varname_indexing(c) || | ||
| error("This `Chains` object does not support indexing using `VarName`s.") | ||
| end | ||
|
|
||
| function DynamicPPL.getindex_varname( | ||
| function getindex_varname( | ||
| c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx | ||
| ) | ||
| _check_varname_indexing(c) | ||
| return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx] | ||
| end | ||
| function DynamicPPL.varnames(c::MCMCChains.Chains) | ||
| _check_varname_indexing(c) | ||
| function get_varnames(c::MCMCChains.Chains) | ||
| haskey(c.info, :varname_to_symbol) || | ||
| error("This `Chains` object does not support indexing using `VarName`s.") | ||
| return keys(c.info.varname_to_symbol) | ||
| end | ||
|
|
||
| function chain_sample_to_varname_dict( | ||
| c::MCMCChains.Chains{Tval}, sample_idx, chain_idx | ||
| ) where {Tval} | ||
| _check_varname_indexing(c) | ||
| d = Dict{DynamicPPL.VarName,Tval}() | ||
| for vn in DynamicPPL.varnames(c) | ||
| d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) | ||
| end | ||
| return d | ||
| end | ||
|
|
||
| """ | ||
| AbstractMCMC.from_samples( | ||
| ::Type{MCMCChains.Chains}, | ||
|
|
@@ -118,8 +96,8 @@ function AbstractMCMC.to_samples( | |
| # Get parameters | ||
| params_matrix = map(idxs) do (sample_idx, chain_idx) | ||
| d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() | ||
| for vn in DynamicPPL.varnames(chain) | ||
| d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx) | ||
| for vn in get_varnames(chain) | ||
| d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx) | ||
| end | ||
| d | ||
| end | ||
|
|
@@ -177,6 +155,46 @@ function AbstractMCMC.bundle_samples( | |
| return sort_chain ? sort(chain) : chain | ||
| end | ||
|
|
||
| """ | ||
| reevaluate_with_chain( | ||
| rng::AbstractRNG, | ||
| model::Model, | ||
| chain::MCMCChains.Chains | ||
| accs::NTuple{N,AbstractAccumulator}; | ||
| fallback=nothing, | ||
| ) | ||
|
|
||
| Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`, | ||
| returning an matrix of `(retval, updated_at)` tuples. | ||
|
|
||
| This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the | ||
| initialisation strategy when re-evaluating the model. For many usecases the fallback should | ||
| not be provided (as we expect the chain to contain all necessary variables); but for | ||
| `predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating | ||
| the posterior predictions). | ||
| """ | ||
| function reevaluate_with_chain( | ||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| chain::MCMCChains.Chains, | ||
| accs::NTuple{N,DynamicPPL.AbstractAccumulator}, | ||
| fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, | ||
| ) where {N} | ||
| params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) | ||
| vi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) | ||
| return map(params_with_stats) do ps | ||
| DynamicPPL.init!!(rng, model, vi, DynamicPPL.InitFromParams(ps.params, fallback)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| end | ||
| end | ||
| function reevaluate_with_chain( | ||
| model::DynamicPPL.Model, | ||
| chain::MCMCChains.Chains, | ||
| accs::NTuple{N,DynamicPPL.AbstractAccumulator}, | ||
| fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, | ||
| ) where {N} | ||
| return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback) | ||
| end | ||
|
|
||
| """ | ||
| predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) | ||
|
|
||
|
|
@@ -245,30 +263,18 @@ function DynamicPPL.predict( | |
| include_all=false, | ||
| ) | ||
| parameter_only_chain = MCMCChains.get_sections(chain, :parameters) | ||
|
|
||
| # Set up a VarInfo with the right accumulators | ||
| varinfo = DynamicPPL.setaccs!!( | ||
| DynamicPPL.VarInfo(), | ||
| ( | ||
| DynamicPPL.LogPriorAccumulator(), | ||
| DynamicPPL.LogLikelihoodAccumulator(), | ||
| DynamicPPL.ValuesAsInModelAccumulator(false), | ||
| ), | ||
| accs = ( | ||
| DynamicPPL.LogPriorAccumulator(), | ||
| DynamicPPL.LogLikelihoodAccumulator(), | ||
| DynamicPPL.ValuesAsInModelAccumulator(false), | ||
| ) | ||
| _, varinfo = DynamicPPL.init!!(model, varinfo) | ||
| varinfo = DynamicPPL.typed_varinfo(varinfo) | ||
|
|
||
| params_and_stats = AbstractMCMC.to_samples( | ||
| DynamicPPL.ParamsWithStats, parameter_only_chain | ||
| predictions = map( | ||
| DynamicPPL.ParamsWithStats ∘ last, | ||
| reevaluate_with_chain( | ||
| rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior() | ||
| ), | ||
| ) | ||
| predictions = map(params_and_stats) do ps | ||
| _, varinfo = DynamicPPL.init!!( | ||
| rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) | ||
| ) | ||
| DynamicPPL.ParamsWithStats(varinfo) | ||
| end | ||
| chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions) | ||
|
|
||
| parameter_names = if include_all | ||
| MCMCChains.names(chain_result, :parameters) | ||
| else | ||
|
|
@@ -348,18 +354,7 @@ julia> returned(model, chain) | |
| """ | ||
| function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) | ||
| chain = MCMCChains.get_sections(chain_full, :parameters) | ||
| varinfo = DynamicPPL.VarInfo(model) | ||
| iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
| params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) | ||
| return map(params_with_stats) do ps | ||
| first( | ||
| DynamicPPL.init!!( | ||
| model, | ||
| varinfo, | ||
| DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()), | ||
| ), | ||
| ) | ||
| end | ||
| return map(first, reevaluate_with_chain(model, chain, (), nothing)) | ||
| end | ||
|
|
||
| """ | ||
|
|
@@ -452,24 +447,13 @@ function DynamicPPL.pointwise_logdensities( | |
| ::Type{Tout}=MCMCChains.Chains, | ||
| ::Val{whichlogprob}=Val(:both), | ||
| ) where {whichlogprob,Tout} | ||
| vi = DynamicPPL.VarInfo(model) | ||
| acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() | ||
| accname = DynamicPPL.accumulator_name(acc) | ||
| vi = DynamicPPL.setaccs!!(vi, (acc,)) | ||
| parameter_only_chain = MCMCChains.get_sections(chain, :parameters) | ||
| iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
| pointwise_logps = map(iters) do (sample_idx, chain_idx) | ||
| # Extract values from the chain | ||
| values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) | ||
| # Re-evaluate the model | ||
| _, vi = DynamicPPL.init!!( | ||
| model, | ||
| vi, | ||
| DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), | ||
| ) | ||
| DynamicPPL.getacc(vi, Val(accname)).logps | ||
| end | ||
|
|
||
| pointwise_logps = | ||
| map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi) | ||
| DynamicPPL.getacc(vi, Val(accname)).logps | ||
| end | ||
| # pointwise_logps is a matrix of OrderedDicts | ||
| all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() | ||
| for d in pointwise_logps | ||
|
|
@@ -556,15 +540,15 @@ julia> logjoint(demo_model([1., 2.]), chain) | |
| ``` | ||
| """ | ||
| function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) | ||
| var_info = DynamicPPL.VarInfo(model) # extract variables info from the model | ||
| map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) | ||
| argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( | ||
| vn_parent => DynamicPPL.values_from_chain( | ||
| var_info, vn_parent, chain, chain_idx, iteration_idx | ||
| ) for vn_parent in keys(var_info) | ||
| ) | ||
| DynamicPPL.logjoint(model, argvals_dict) | ||
| end | ||
| return map( | ||
| DynamicPPL.getlogjoint ∘ last, | ||
| reevaluate_with_chain( | ||
| model, | ||
| chain, | ||
| (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()), | ||
| nothing, | ||
| ), | ||
| ) | ||
| end | ||
|
|
||
| """ | ||
|
|
@@ -596,15 +580,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain) | |
| ``` | ||
| """ | ||
| function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) | ||
| var_info = DynamicPPL.VarInfo(model) # extract variables info from the model | ||
| map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) | ||
| argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( | ||
| vn_parent => DynamicPPL.values_from_chain( | ||
| var_info, vn_parent, chain, chain_idx, iteration_idx | ||
| ) for vn_parent in keys(var_info) | ||
| ) | ||
| DynamicPPL.loglikelihood(model, argvals_dict) | ||
| end | ||
| return map( | ||
| DynamicPPL.getloglikelihood ∘ last, | ||
| reevaluate_with_chain( | ||
| model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing | ||
| ), | ||
| ) | ||
| end | ||
|
|
||
| """ | ||
|
|
@@ -637,15 +618,10 @@ julia> logprior(demo_model([1., 2.]), chain) | |
| ``` | ||
| """ | ||
| function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) | ||
| var_info = DynamicPPL.VarInfo(model) # extract variables info from the model | ||
| map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) | ||
| argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( | ||
| vn_parent => DynamicPPL.values_from_chain( | ||
| var_info, vn_parent, chain, chain_idx, iteration_idx | ||
| ) for vn_parent in keys(var_info) | ||
| ) | ||
| DynamicPPL.logprior(model, argvals_dict) | ||
| end | ||
| return map( | ||
| DynamicPPL.getlogprior ∘ last, | ||
| reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing), | ||
| ) | ||
| end | ||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,29 +1,3 @@ | ||
| """ | ||
| supports_varname_indexing(chain::AbstractChains) | ||
|
|
||
| Return `true` if `chain` supports indexing using `VarName` in place of the | ||
| variable name index. | ||
| """ | ||
| supports_varname_indexing(::AbstractChains) = false | ||
|
|
||
| """ | ||
| getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx) | ||
|
|
||
| Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`. | ||
|
|
||
| Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). | ||
| """ | ||
| function getindex_varname end | ||
|
|
||
| """ | ||
| varnames(chains::AbstractChains) | ||
|
|
||
| Return an iterator over the varnames present in `chains`. | ||
|
|
||
| Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). | ||
| """ | ||
| function varnames end | ||
|
|
||
|
Comment on lines
-1
to
-26
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no objection on getting rid of these, not sure if we should add an entry in changelog though. the likelihood that any of these are used somewhere is low
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, you're completely correct, and I just totally neglected the changelog in this PR. My bad. |
||
| """ | ||
| ParamsWithStats | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fundamentally, all the functions in this extension really just use this under the hood.
FWIW the FlexiChains extension has a very similar structure and I believe these can be unified pretty much immediately after this PR. To be specific,
DynamicPPL.reevaluate_with_chainshould be implemented by each chain type in the most performant manner (FlexiChains doesn't useInitFromParams), but once that's done, the definitions ofreturned,logjoint, ...,pointwise_logdensities, ... can be shared.predictcan't yet be shared unfortunately, because theinclude_allkeyword argument forces custom MCMCChains / FlexiChains code. That would require an extension of the AbstractChains API to support asubset-like operation.