Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.39.4

Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed).

## 0.39.3

`DynamicPPL.TestUtils.AD.run_ad` now generates much prettier output.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.39.3"
version = "0.39.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
19 changes: 15 additions & 4 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,23 @@ function DynamicPPL.marginalize(
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
kwargs...,
)
# Determine the indices for the variables to marginalise out.
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames))
# Construct the marginal log-density model.
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
# Determine the indices for the variables to marginalise out.
varindices = mapreduce(vcat, marginalized_varnames) do vn
if DynamicPPL.getoptic(vn) === identity
ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range
else
ldf._varname_ranges[vn].range
end
end
mld = MarginalLogDensities.MarginalLogDensity(
LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs...
LogDensityFunctionWrapper(ldf, varinfo),
varinfo[:],
varindices,
(),
method;
kwargs...,
)
return mld
end
Expand Down
4 changes: 0 additions & 4 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,6 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:
end

vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo)
vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn)
function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName})
return vector_getranges(vi.varinfo, vns)
end

isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)
function BangBang.empty!!(vi::ThreadSafeVarInfo)
Expand Down
74 changes: 0 additions & 74 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,80 +650,6 @@ Set the index range of `vn` in the metadata of `vi` to `range`.
setrange!(vi::VarInfo, vn::VarName, range) = setrange!(getmetadata(vi, vn), vn, range)
setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range

"""
getranges(vi::VarInfo, vns::Vector{<:VarName})

Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
"""
function getranges(vi::VarInfo, vns::Vector{<:VarName})
return map(Base.Fix1(getrange, vi), vns)
end

"""
vector_getrange(varinfo::VarInfo, varname::VarName)

Return the range corresponding to `varname` in the vector representation of `varinfo`.
"""
vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn)
function vector_getrange(vi::NTVarInfo, vn::VarName)
offset = 0
for md in values(vi.metadata)
# First, we need to check if `vn` is in `md`.
# In this case, we can just return the corresponding range + offset.
haskey(md, vn) && return getrange(md, vn) .+ offset
# Otherwise, we need to get the cumulative length of the ranges in `md`
# and add it to the offset.
offset += sum(length, md.ranges)
end
# If we reach this point, `vn` is not in `vi.metadata`.
throw(KeyError(vn))
end

"""
vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName})

Return the range corresponding to `varname` in the vector representation of `varinfo`.
"""
function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName})
return map(Base.Fix1(vector_getrange, varinfo), varname)
end
# Specialized version for `NTVarInfo`.
function vector_getranges(varinfo::NTVarInfo, vns::Vector{<:VarName})
# TODO: Does it help if we _don't_ convert to a vector here?
metadatas = collect(values(varinfo.metadata))
# Extract the offsets.
offsets = cumsum(map(vector_length, metadatas))
# Extract the ranges from each metadata.
ranges = Vector{UnitRange{Int}}(undef, length(vns))
# Need to keep track of which ones we've seen.
not_seen = fill(true, length(vns))
for (i, metadata) in enumerate(metadatas)
vns_metadata = filter(Base.Fix1(haskey, metadata), vns)
# If none of the variables exist in the metadata, we return an empty array.
isempty(vns_metadata) && continue
# Otherwise, we extract the ranges.
offset = i == 1 ? 0 : offsets[i - 1]
for vn in vns_metadata
r_vn = getrange(metadata, vn)
# Get the index, so we return in the same order as `vns`.
# NOTE: There might be duplicates in `vns`, so we need to handle that.
indices = findall(==(vn), vns)
for idx in indices
not_seen[idx] = false
ranges[idx] = r_vn .+ offset
end
end
end
# Raise key error if any of the variables were not found.
if any(not_seen)
inds = findall(not_seen)
# Just use a `convert` to get the same type as the input; don't want to confuse by overly
# specilizing the types in the error message.
throw(KeyError(convert(typeof(vns), vns[inds])))
end
return ranges
end

"""
getdist(vi::VarInfo, vn::VarName)

Expand Down
42 changes: 0 additions & 42 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -824,48 +824,6 @@ end
@test merge(vi_double, vi_single)[vn] == 1.0
end

# NOTE: It is not yet clear if this is something we want from all varinfo types.
# Hence, we only test the `VarInfo` types here.
@testset "vector_getranges for `VarInfo`" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
vns = DynamicPPL.TestUtils.varnames(model)
nt = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(
model, nt, vns; include_threadsafe=true
)
# Only keep `VarInfo` types.
varinfos = filter(
Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos
)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
x = values_as(varinfo, Vector)

# Let's just check all the subsets of `vns`.
@testset "$(convert(Vector{Any},vns_subset))" for vns_subset in
combinations(vns)
ranges = DynamicPPL.vector_getranges(varinfo, vns_subset)
@test length(ranges) == length(vns_subset)
for (r, vn) in zip(ranges, vns_subset)
@test x[r] == DynamicPPL.tovec(varinfo[vn])
end
end

# Let's try some failure cases.
@test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[]
# Non-existent variables.
@test_throws KeyError DynamicPPL.vector_getranges(
varinfo, [VarName{gensym("vn")}()]
)
@test_throws KeyError DynamicPPL.vector_getranges(
varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()]
)
# Duplicate variables.
ranges_duplicated = DynamicPPL.vector_getranges(varinfo, repeat(vns, 2))
@test x[reduce(vcat, ranges_duplicated)] == repeat(x, 2)
end
end
end

@testset "issue #842" begin
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
varinfo = VarInfo(model)
Expand Down