Skip to content

Commit 4eb33e9

Browse files
committed
Fix issues with RangeAndLinked and VNT
1 parent ce9da19 commit 4eb33e9

File tree

5 files changed

+64
-15
lines changed

5 files changed

+64
-15
lines changed

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module DynamicPPLMarginalLogDensitiesExt
22

3-
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
3+
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked
44
using MarginalLogDensities: MarginalLogDensities
55

66
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
@@ -105,11 +105,9 @@ function DynamicPPL.marginalize(
105105
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
106106
# Determine the indices for the variables to marginalise out.
107107
varindices = mapreduce(vcat, marginalized_varnames) do vn
108-
if DynamicPPL.getoptic(vn) === identity
109-
ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range
110-
else
111-
ldf._varname_ranges[vn].range
112-
end
108+
# The type assertion helps in cases where the model is type unstable and thus
109+
# `varname_ranges` may have an abstract element type.
110+
(ldf._varname_ranges[vn]::RangeAndLinked).range
113111
end
114112
mld = MarginalLogDensities.MarginalLogDensity(
115113
LogDensityFunctionWrapper(ldf, varinfo),

src/contexts/init.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,16 @@ an unlinked value.
206206
207207
$(TYPEDFIELDS)
208208
"""
209-
struct RangeAndLinked
209+
struct RangeAndLinked{T<:Tuple}
210210
# indices that the variable corresponds to in the vectorised parameter
211211
range::UnitRange{Int}
212212
# whether it's linked
213213
is_linked::Bool
214+
# original size of the variable before vectorisation
215+
original_size::T
214216
end
215217

216-
Base.size(ral::RangeAndLinked) = size(ral.range)
218+
Base.size(ral::RangeAndLinked) = ral.original_size
217219

218220
"""
219221
VectorWithRanges{Tlink}(
@@ -249,7 +251,12 @@ struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}}
249251
end
250252

251253
function _get_range_and_linked(vr::VectorWithRanges, vn::VarName)
252-
return vr.varname_ranges[vn]
254+
# The type assertion does nothing if VectorWithRanges has concrete element types, as is
255+
# the case for all type stable models. However, if the model is not type stable,
256+
# vr.varname_ranges[vn] may infer to have type `Any`. In this case it is helpful to
257+
# assert that it is a RangeAndLinked, because even though it remains non-concrete,
258+
# it'll allow the compiler to infer the types of `range` and `is_linked`.
259+
return vr.varname_ranges[vn]::RangeAndLinked
253260
end
254261
function init(
255262
::Random.AbstractRNG,

src/logdensityfunction.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,10 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int)
330330
for (vn, idx) in md.idcs
331331
is_linked = md.is_transformed[idx]
332332
range = md.ranges[idx] .+ (start_offset - 1)
333-
all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn)
333+
orig_size = varnamesize(vn)
334+
all_ranges = BangBang.setindex!!(
335+
all_ranges, RangeAndLinked(range, is_linked, orig_size), vn
336+
)
334337
offset += length(range)
335338
end
336339
return all_ranges, offset
@@ -341,7 +344,10 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int)
341344
for (vn, idx) in vnv.varname_to_index
342345
is_linked = vnv.is_unconstrained[idx]
343346
range = vnv.ranges[idx] .+ (start_offset - 1)
344-
all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn)
347+
orig_size = varnamesize(vn)
348+
all_ranges = BangBang.setindex!!(
349+
all_ranges, RangeAndLinked(range, is_linked, orig_size), vn
350+
)
345351
offset += length(range)
346352
end
347353
return all_ranges, offset

src/varname.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,28 @@ Possibly existing indices of `varname` are neglected.
4141
) where {s,missings,_F,_a,_T}
4242
return s in missings
4343
end
44+
45+
# TODO(mhauru) This should probably be Base.size(::VarName) in AbstractPPL.
46+
"""
47+
varnamesize(vn::VarName)
48+
49+
Return the size of the object referenced by this VarName.
50+
51+
```jldoctest
52+
julia> varnamesize(@varname(a))
53+
()
54+
55+
julia> varnamesize(@varname(b[1:3, 2]))
56+
(3,)
57+
58+
julia> varnamesize(@varname(c.d[4].e[3, 2:5, 2, 1:4, 1]))
59+
(4, 4)
60+
"""
61+
function varnamesize(vn::VarName)
62+
l = AbstractPPL._last(vn.optic)
63+
if l isa Accessors.IndexLens
64+
return reduce((x, y) -> tuple(x..., y...), map(size, l.indices))
65+
else
66+
return ()
67+
end
68+
end

src/varnamedtuple.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,13 @@ function _concretise_eltype!!(pa::PartialArray)
352352
if isconcretetype(eltype(pa))
353353
return pa
354354
end
355-
new_et = promote_type((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...)
355+
# We could use promote_type here, instead of typejoin. However, that would e.g.
356+
# cause Ints to be converted to Float64s, since
357+
# promote_type(Int, Float64) == Float64, which can cause problems. See
358+
# https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188.
359+
# Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing
360+
# and Missing, rather than falling back on Any. However, it's not exported.
361+
new_et = typejoin((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...)
356362
# TODO(mhauru) Should we check as below, or rather isconcretetype(new_et)?
357363
# In other words, does it help to be more concrete, even if we aren't fully concrete?
358364
if new_et === eltype(pa)
@@ -588,8 +594,8 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES})
588594
if size(value) != inds_size
589595
throw(
590596
DimensionMismatch(
591-
"Assigned value has size $(size(value)), which does not match the size " *
592-
"implied by the indices $(map(x -> _length_needed(x), inds)).",
597+
"Assigned value has size $(size(value)), which does not match the " *
598+
"size implied by the indices $(map(x -> _length_needed(x), inds)).",
593599
),
594600
)
595601
end
@@ -659,7 +665,14 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray)
659665
result
660666
else
661667
# Neither is strictly bigger than the other.
662-
et = promote_type(eltype(pa1), eltype(pa2))
668+
# We could use promote_type here, instead of typejoin. However, that would e.g.
669+
# cause Ints to be converted to Float64s, since
670+
# promote_type(Int, Float64) == Float64, which can cause problems. See
671+
# https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188.
672+
# Base.promote_typejoin would be like typejoin, but creates Unions out of
673+
# Nothing and Missing, rather than falling back on Any. However, it's not
674+
# exported.
675+
et = typejoin(eltype(pa1), eltype(pa2))
663676
new_data = Array{et,num_dims}(undef, merge_size)
664677
new_mask = fill(false, merge_size)
665678
result = PartialArray(new_data, new_mask)

0 commit comments

Comments
 (0)