Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@ You should not need to use these directly, please use `AbstractPPL.condition` an

Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.

The unexported functions `supports_varname_indexing(chain)`, `getindex_varname(chain)`, and `varnames(chain)` have been removed.

The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).

The family of functions `returned(model, chain)`, along with the same signatures of `pointwise_logdensities`, `logjoint`, `loglikelihood`, and `logprior`, have been changed such that if the chain does not contain all variables in the model, an error is thrown.
Previously the behaviour would have been to sample missing variables.

## 0.38.9

Remove warning when using Enzyme as the AD backend.
Expand Down
184 changes: 80 additions & 104 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,19 @@
module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random
using MCMCChains: MCMCChains

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names

function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("This `Chains` object does not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
function getindex_varname(
c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx
)
_check_varname_indexing(c)
return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx]
end
function DynamicPPL.varnames(c::MCMCChains.Chains)
_check_varname_indexing(c)
function get_varnames(c::MCMCChains.Chains)
haskey(c.info, :varname_to_symbol) ||
error("This `Chains` object does not support indexing using `VarName`s.")
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(
c::MCMCChains.Chains{Tval}, sample_idx, chain_idx
) where {Tval}
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Tval}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end

"""
AbstractMCMC.from_samples(
::Type{MCMCChains.Chains},
Expand Down Expand Up @@ -118,8 +96,8 @@ function AbstractMCMC.to_samples(
# Get parameters
params_matrix = map(idxs) do (sample_idx, chain_idx)
d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(chain)
d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx)
for vn in get_varnames(chain)
d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx)
end
d
end
Expand Down Expand Up @@ -177,6 +155,46 @@ function AbstractMCMC.bundle_samples(
return sort_chain ? sort(chain) : chain
end

"""
reevaluate_with_chain(
Comment on lines +158 to +159
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fundamentally, all the functions in this extension really just use this under the hood.

FWIW the FlexiChains extension has a very similar structure and I believe these can be unified pretty much immediately after this PR. To be specific, DynamicPPL.reevaluate_with_chain should be implemented by each chain type in the most performant manner (FlexiChains doesn't use InitFromParams), but once that's done, the definitions of returned, logjoint, ..., pointwise_logdensities, ... can be shared.

predict can't yet be shared unfortunately, because the include_all keyword argument forces custom MCMCChains / FlexiChains code. That would require an extension of the AbstractChains API to support a subset-like operation.

rng::AbstractRNG,
model::Model,
chain::MCMCChains.Chains
accs::NTuple{N,AbstractAccumulator};
fallback=nothing,
)

Re-evaluate `model` for each sample in `chain` using the accumulators provided in `at`,
returning an matrix of `(retval, updated_at)` tuples.

This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the
initialisation strategy when re-evaluating the model. For many usecases the fallback should
not be provided (as we expect the chain to contain all necessary variables); but for
`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating
the posterior predictions).
"""
function reevaluate_with_chain(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
) where {N}
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
vi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs))
return map(params_with_stats) do ps
DynamicPPL.init!!(rng, model, vi, DynamicPPL.InitFromParams(ps.params, fallback))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fallback being nothing is a change in default behavior, right? Might worth noting in changelog if true

end
end
function reevaluate_with_chain(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
accs::NTuple{N,DynamicPPL.AbstractAccumulator},
fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing,
) where {N}
return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback)
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -245,30 +263,18 @@ function DynamicPPL.predict(
include_all=false,
)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)

# Set up a VarInfo with the right accumulators
varinfo = DynamicPPL.setaccs!!(
DynamicPPL.VarInfo(),
(
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.ValuesAsInModelAccumulator(false),
),
accs = (
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.ValuesAsInModelAccumulator(false),
)
_, varinfo = DynamicPPL.init!!(model, varinfo)
varinfo = DynamicPPL.typed_varinfo(varinfo)

params_and_stats = AbstractMCMC.to_samples(
DynamicPPL.ParamsWithStats, parameter_only_chain
predictions = map(
DynamicPPL.ParamsWithStats ∘ last,
reevaluate_with_chain(
rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior()
),
)
predictions = map(params_and_stats) do ps
_, varinfo = DynamicPPL.init!!(
rng, model, varinfo, DynamicPPL.InitFromParams(ps.params)
)
DynamicPPL.ParamsWithStats(varinfo)
end
chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions)

parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)
else
Expand Down Expand Up @@ -348,18 +354,7 @@ julia> returned(model, chain)
"""
function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
return map(params_with_stats) do ps
first(
DynamicPPL.init!!(
model,
varinfo,
DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()),
),
)
end
return map(first, reevaluate_with_chain(model, chain, (), nothing))
end

"""
Expand Down Expand Up @@ -452,24 +447,13 @@ function DynamicPPL.pointwise_logdensities(
::Type{Tout}=MCMCChains.Chains,
::Val{whichlogprob}=Val(:both),
) where {whichlogprob,Tout}
vi = DynamicPPL.VarInfo(model)
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}()
accname = DynamicPPL.accumulator_name(acc)
vi = DynamicPPL.setaccs!!(vi, (acc,))
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
pointwise_logps = map(iters) do (sample_idx, chain_idx)
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Re-evaluate the model
_, vi = DynamicPPL.init!!(
model,
vi,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
DynamicPPL.getacc(vi, Val(accname)).logps
end

pointwise_logps =
map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi)
DynamicPPL.getacc(vi, Val(accname)).logps
end
# pointwise_logps is a matrix of OrderedDicts
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
for d in pointwise_logps
Expand Down Expand Up @@ -556,15 +540,15 @@ julia> logjoint(demo_model([1., 2.]), chain)
```
"""
function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logjoint(model, argvals_dict)
end
return map(
DynamicPPL.getlogjoint ∘ last,
reevaluate_with_chain(
model,
chain,
(DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()),
nothing,
),
)
end

"""
Expand Down Expand Up @@ -596,15 +580,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain)
```
"""
function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.loglikelihood(model, argvals_dict)
end
return map(
DynamicPPL.getloglikelihood ∘ last,
reevaluate_with_chain(
model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing
),
)
end

"""
Expand Down Expand Up @@ -637,15 +618,10 @@ julia> logprior(demo_model([1., 2.]), chain)
```
"""
function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
vn_parent => DynamicPPL.values_from_chain(
var_info, vn_parent, chain, chain_idx, iteration_idx
) for vn_parent in keys(var_info)
)
DynamicPPL.logprior(model, argvals_dict)
end
return map(
DynamicPPL.getlogprior ∘ last,
reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing),
)
end

end
26 changes: 0 additions & 26 deletions src/chains.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,3 @@
"""
supports_varname_indexing(chain::AbstractChains)

Return `true` if `chain` supports indexing using `VarName` in place of the
variable name index.
"""
supports_varname_indexing(::AbstractChains) = false

"""
getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx)

Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`.

Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function getindex_varname end

"""
varnames(chains::AbstractChains)

Return an iterator over the varnames present in `chains`.

Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function varnames end

Comment on lines -1 to -26
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no objection on getting rid of these, not sure if we should add an entry in changelog though. the likelihood that any of these are used somewhere is low

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you're completely correct, and I just totally neglected the changelog in this PR. My bad.

"""
ParamsWithStats

Expand Down
16 changes: 8 additions & 8 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,21 +193,21 @@ end
# LogDensityProblems.jl interface #
###################################
"""
fast_ldf_accs(getlogdensity::Function)
ldf_accs(getlogdensity::Function)

Determine which accumulators are needed for fast evaluation with the given
`getlogdensity` function.
"""
fast_ldf_accs(::Function) = default_accumulators()
fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators()
function fast_ldf_accs(::typeof(getlogjoint))
ldf_accs(::Function) = default_accumulators()
ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators()
function ldf_accs(::typeof(getlogjoint))
return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator()))
end
function fast_ldf_accs(::typeof(getlogprior_internal))
function ldf_accs(::typeof(getlogprior_internal))
return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator()))
end
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))

struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
model::M
Expand All @@ -219,7 +219,7 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real})
strategy = InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
)
accs = fast_ldf_accs(f.getlogdensity)
accs = ldf_accs(f.getlogdensity)
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
return f.getlogdensity(vi)
end
Expand Down
15 changes: 9 additions & 6 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1181,12 +1181,15 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0))
```
"""
function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}})
vi = DynamicPPL.setaccs!!(VarInfo(), ())
# Note: we can't use `fix(model, parameters)` because
# https://github.com/TuringLang/DynamicPPL.jl/issues/1097
# Use `nothing` as the fallback to ensure that any missing parameters cause an error
ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing))
new_model = setleafcontext(model, ctx)
# We can't use new_model() because that overwrites it with an InitContext of its own.
return first(evaluate!!(new_model, vi))
return first(
init!!(
model,
DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple()),
# Use `nothing` as the fallback to ensure that any missing parameters cause an
# error
InitFromParams(parameters, nothing),
),
)
end
Loading