diff --git a/HISTORY.md b/HISTORY.md index 91306c219..ff28349d8 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -32,9 +32,14 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +The unexported functions `supports_varname_indexing(chain)`, `getindex_varname(chain)`, and `varnames(chain)` have been removed. + The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). +The family of functions `returned(model, chain)`, along with the same signatures of `pointwise_logdensities`, `logjoint`, `loglikelihood`, and `logprior`, have been changed such that if the chain does not contain all variables in the model, an error is thrown. +Previously the behaviour would have been to sample missing variables. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index e74f0b8a9..8ad828648 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,41 +1,19 @@ module DynamicPPLMCMCChainsExt -using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC +using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random using MCMCChains: MCMCChains -_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names - -function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) - return _has_varname_to_symbol(chain.info) -end - -function _check_varname_indexing(c::MCMCChains.Chains) - return DynamicPPL.supports_varname_indexing(c) || - error("This `Chains` object does not support indexing using `VarName`s.") -end - -function DynamicPPL.getindex_varname( +function getindex_varname( c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx ) - _check_varname_indexing(c) return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx] end -function DynamicPPL.varnames(c::MCMCChains.Chains) - _check_varname_indexing(c) +function get_varnames(c::MCMCChains.Chains) + haskey(c.info, :varname_to_symbol) || + error("This `Chains` object does not support indexing using `VarName`s.") return keys(c.info.varname_to_symbol) end -function chain_sample_to_varname_dict( - c::MCMCChains.Chains{Tval}, sample_idx, chain_idx -) where {Tval} - _check_varname_indexing(c) - d = Dict{DynamicPPL.VarName,Tval}() - for vn in DynamicPPL.varnames(c) - d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) - end - return d -end - """ AbstractMCMC.from_samples( ::Type{MCMCChains.Chains}, @@ -118,8 +96,8 @@ function AbstractMCMC.to_samples( # Get parameters params_matrix = map(idxs) do (sample_idx, chain_idx) d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() - for vn in DynamicPPL.varnames(chain) - d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx) + for vn in get_varnames(chain) + d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx) end d end @@ -177,6 +155,46 @@ function AbstractMCMC.bundle_samples( return sort_chain ? sort(chain) : chain end +""" + reevaluate_with_chain( + rng::AbstractRNG, + model::Model, + chain::MCMCChains.Chains + accs::NTuple{N,AbstractAccumulator}; + fallback=nothing, + ) + +Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`, +returning an matrix of `(retval, updated_at)` tuples. + +This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the +initialisation strategy when re-evaluating the model. For many usecases the fallback should +not be provided (as we expect the chain to contain all necessary variables); but for +`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating +the posterior predictions). +""" +function reevaluate_with_chain( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + vi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) + return map(params_with_stats) do ps + DynamicPPL.init!!(rng, model, vi, DynamicPPL.InitFromParams(ps.params, fallback)) + end +end +function reevaluate_with_chain( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback) +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -245,30 +263,18 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - - # Set up a VarInfo with the right accumulators - varinfo = DynamicPPL.setaccs!!( - DynamicPPL.VarInfo(), - ( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.ValuesAsInModelAccumulator(false), - ), + accs = ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(false), ) - _, varinfo = DynamicPPL.init!!(model, varinfo) - varinfo = DynamicPPL.typed_varinfo(varinfo) - - params_and_stats = AbstractMCMC.to_samples( - DynamicPPL.ParamsWithStats, parameter_only_chain + predictions = map( + DynamicPPL.ParamsWithStats ∘ last, + reevaluate_with_chain( + rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior() + ), ) - predictions = map(params_and_stats) do ps - _, varinfo = DynamicPPL.init!!( - rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) - ) - DynamicPPL.ParamsWithStats(varinfo) - end chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions) - parameter_names = if include_all MCMCChains.names(chain_result, :parameters) else @@ -348,18 +354,7 @@ julia> returned(model, chain) """ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) chain = MCMCChains.get_sections(chain_full, :parameters) - varinfo = DynamicPPL.VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) - return map(params_with_stats) do ps - first( - DynamicPPL.init!!( - model, - varinfo, - DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()), - ), - ) - end + return map(first, reevaluate_with_chain(model, chain, (), nothing)) end """ @@ -452,24 +447,13 @@ function DynamicPPL.pointwise_logdensities( ::Type{Tout}=MCMCChains.Chains, ::Val{whichlogprob}=Val(:both), ) where {whichlogprob,Tout} - vi = DynamicPPL.VarInfo(model) acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() accname = DynamicPPL.accumulator_name(acc) - vi = DynamicPPL.setaccs!!(vi, (acc,)) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - pointwise_logps = map(iters) do (sample_idx, chain_idx) - # Extract values from the chain - values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) - # Re-evaluate the model - _, vi = DynamicPPL.init!!( - model, - vi, - DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), - ) - DynamicPPL.getacc(vi, Val(accname)).logps - end - + pointwise_logps = + map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi) + DynamicPPL.getacc(vi, Val(accname)).logps + end # pointwise_logps is a matrix of OrderedDicts all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() for d in pointwise_logps @@ -556,15 +540,15 @@ julia> logjoint(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.logjoint(model, argvals_dict) - end + return map( + DynamicPPL.getlogjoint ∘ last, + reevaluate_with_chain( + model, + chain, + (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()), + nothing, + ), + ) end """ @@ -596,15 +580,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.loglikelihood(model, argvals_dict) - end + return map( + DynamicPPL.getloglikelihood ∘ last, + reevaluate_with_chain( + model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing + ), + ) end """ @@ -637,15 +618,10 @@ julia> logprior(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) - ) - DynamicPPL.logprior(model, argvals_dict) - end + return map( + DynamicPPL.getlogprior ∘ last, + reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing), + ) end end diff --git a/src/chains.jl b/src/chains.jl index f176b8e68..2fcd4e713 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -1,29 +1,3 @@ -""" - supports_varname_indexing(chain::AbstractChains) - -Return `true` if `chain` supports indexing using `VarName` in place of the -variable name index. -""" -supports_varname_indexing(::AbstractChains) = false - -""" - getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx) - -Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function getindex_varname end - -""" - varnames(chains::AbstractChains) - -Return an iterator over the varnames present in `chains`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function varnames end - """ ParamsWithStats diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 65eab448e..bcdd0bb25 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -193,21 +193,21 @@ end # LogDensityProblems.jl interface # ################################### """ - fast_ldf_accs(getlogdensity::Function) + ldf_accs(getlogdensity::Function) Determine which accumulators are needed for fast evaluation with the given `getlogdensity` function. """ -fast_ldf_accs(::Function) = default_accumulators() -fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() -function fast_ldf_accs(::typeof(getlogjoint)) +ldf_accs(::Function) = default_accumulators() +ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function ldf_accs(::typeof(getlogjoint)) return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end -function fast_ldf_accs(::typeof(getlogprior_internal)) +function ldf_accs(::typeof(getlogprior_internal)) return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end -fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) -fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) +ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} model::M @@ -219,7 +219,7 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) strategy = InitFromParams( VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) - accs = fast_ldf_accs(f.getlogdensity) + accs = ldf_accs(f.getlogdensity) _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) return f.getlogdensity(vi) end diff --git a/src/model.jl b/src/model.jl index 9029318b1..7d5bbf2fb 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1181,12 +1181,15 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) ``` """ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) - vi = DynamicPPL.setaccs!!(VarInfo(), ()) # Note: we can't use `fix(model, parameters)` because # https://github.com/TuringLang/DynamicPPL.jl/issues/1097 - # Use `nothing` as the fallback to ensure that any missing parameters cause an error - ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing)) - new_model = setleafcontext(model, ctx) - # We can't use new_model() because that overwrites it with an InitContext of its own. - return first(evaluate!!(new_model, vi)) + return first( + init!!( + model, + DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple()), + # Use `nothing` as the fallback to ensure that any missing parameters cause an + # error + InitFromParams(parameters, nothing), + ), + ) end