Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DynamicPPLMarginalLogDensitiesExt

using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked
using MarginalLogDensities: MarginalLogDensities

# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
Expand Down Expand Up @@ -105,11 +105,9 @@ function DynamicPPL.marginalize(
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
# Determine the indices for the variables to marginalise out.
varindices = mapreduce(vcat, marginalized_varnames) do vn
if DynamicPPL.getoptic(vn) === identity
ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range
else
ldf._varname_ranges[vn].range
end
# The type assertion helps in cases where the model is type unstable and thus
# `varname_ranges` may have an abstract element type.
(ldf._varname_ranges[vn]::RangeAndLinked).range
end
mld = MarginalLogDensities.MarginalLogDensity(
LogDensityFunctionWrapper(ldf, varinfo),
Expand Down
13 changes: 11 additions & 2 deletions src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,17 @@ an unlinked value.

$(TYPEDFIELDS)
"""
struct RangeAndLinked
struct RangeAndLinked{T<:Tuple}
# indices that the variable corresponds to in the vectorised parameter
range::UnitRange{Int}
# whether it's linked
is_linked::Bool
# original size of the variable before vectorisation
original_size::T
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unfortunate to have to create this field, because now you may end up with abstractly typed RangeAndLinked, which obscures the fact that the two fields you really care about, namely range and is_linked, are still concrete. The reason this is needed is that VNT needs to know how much "space" in a PartialArray an instance of RangeAndLinked takes. An alternative to this could be something like giving setindex!!(::VarNamedTuple, ...) an extra kwarg of something like ignore_size_checks.

end

Base.size(ral::RangeAndLinked) = ral.original_size

"""
VectorWithRanges{Tlink}(
varname_ranges::VarNamedTuple,
Expand Down Expand Up @@ -247,7 +251,12 @@ struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}}
end

function _get_range_and_linked(vr::VectorWithRanges, vn::VarName)
return vr.varname_ranges[vn]
# The type assertion does nothing if VectorWithRanges has concrete element types, as is
# the case for all type stable models. However, if the model is not type stable,
# vr.varname_ranges[vn] may infer to have type `Any`. In this case it is helpful to
# assert that it is a RangeAndLinked, because even though it remains non-concrete,
# it'll allow the compiler to infer the types of `range` and `is_linked`.
return vr.varname_ranges[vn]::RangeAndLinked
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this assertion, this model:

@model function demo_one_variable_multiple_constraints(
    ::Type{TV}=Vector{Float64}
) where {TV}
    x = TV(undef, 5)
    x[1] ~ Normal()
    x[2] ~ InverseGamma(2, 3)
    x[3] ~ truncated(Normal(), -5, 20)
    x[4:5] ~ Dirichlet([1.0, 2.0])
    return (x=x,)
end

fails the test that checks that the return type of LogDensityProblems.logdensity(ldf, x) can be inferred. It infers as Any. This happens because the VNT gets a PartialArray where three elements are RangeAndLinked{Tuple{}} and one is a ArrayLikeBlock{RangeAndLinked{Tuple{Int}}, Tuple{Int}}. That makes the element type Any, and _get_range_and_linked's return type infers as Any, and all is lost.

Now, I'm fine saying that a model that mixes x[i] and x[j:l] for the same x is pretty type unstable (note that the priors have to mix univariate and multivariate). However, I don't think that should be licence for even the logdensity return type to infer as Any, and it didn't before this PR. Hence this extra assertion, to have the type instability not be so egregious.

Note that this problem is independent of the above comment of introducing a type parameter to RangeAndLinked: Even without the type parameter, the wrapping in ArrayLikeBlock would force the PartialArray to have element type Any.

end
function init(
::Random.AbstractRNG,
Expand Down
10 changes: 8 additions & 2 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int)
for (vn, idx) in md.idcs
is_linked = md.is_transformed[idx]
range = md.ranges[idx] .+ (start_offset - 1)
all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn)
orig_size = varnamesize(vn)
all_ranges = BangBang.setindex!!(
all_ranges, RangeAndLinked(range, is_linked, orig_size), vn
)
offset += length(range)
end
return all_ranges, offset
Expand All @@ -341,7 +344,10 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int)
for (vn, idx) in vnv.varname_to_index
is_linked = vnv.is_unconstrained[idx]
range = vnv.ranges[idx] .+ (start_offset - 1)
all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn)
orig_size = varnamesize(vn)
all_ranges = BangBang.setindex!!(
all_ranges, RangeAndLinked(range, is_linked, orig_size), vn
)
offset += length(range)
end
return all_ranges, offset
Expand Down
25 changes: 25 additions & 0 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,28 @@ Possibly existing indices of `varname` are neglected.
) where {s,missings,_F,_a,_T}
return s in missings
end

# TODO(mhauru) This should probably be Base.size(::VarName) in AbstractPPL.
"""
varnamesize(vn::VarName)

Return the size of the object referenced by this VarName.

```jldoctest
julia> varnamesize(@varname(a))
()

julia> varnamesize(@varname(b[1:3, 2]))
(3,)

julia> varnamesize(@varname(c.d[4].e[3, 2:5, 2, 1:4, 1]))
(4, 4)
"""
function varnamesize(vn::VarName)
l = AbstractPPL._last(vn.optic)
if l isa Accessors.IndexLens
return reduce((x, y) -> tuple(x..., y...), map(size, l.indices))
else
return ()
end
end
Loading