Skip to content

Commit 6266f64

Browse files
authored
Remove vector_getranges and friends (#1175)
1 parent 5962903 commit 6266f64

File tree

6 files changed

+20
-125
lines changed

6 files changed

+20
-125
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.4
4+
5+
Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed).
6+
37
## 0.39.3
48

59
`DynamicPPL.TestUtils.AD.run_ad` now generates much prettier output.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.3"
3+
version = "0.39.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,23 @@ function DynamicPPL.marginalize(
101101
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
102102
kwargs...,
103103
)
104-
# Determine the indices for the variables to marginalise out.
105-
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames))
106104
# Construct the marginal log-density model.
107-
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
105+
ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
106+
# Determine the indices for the variables to marginalise out.
107+
varindices = mapreduce(vcat, marginalized_varnames) do vn
108+
if DynamicPPL.getoptic(vn) === identity
109+
ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range
110+
else
111+
ldf._varname_ranges[vn].range
112+
end
113+
end
108114
mld = MarginalLogDensities.MarginalLogDensity(
109-
LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs...
115+
LogDensityFunctionWrapper(ldf, varinfo),
116+
varinfo[:],
117+
varindices,
118+
(),
119+
method;
120+
kwargs...,
110121
)
111122
return mld
112123
end

src/threadsafe.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,6 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:
155155
end
156156

157157
vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo)
158-
vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn)
159-
function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName})
160-
return vector_getranges(vi.varinfo, vns)
161-
end
162158

163159
isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)
164160
function BangBang.empty!!(vi::ThreadSafeVarInfo)

src/varinfo.jl

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -650,80 +650,6 @@ Set the index range of `vn` in the metadata of `vi` to `range`.
650650
setrange!(vi::VarInfo, vn::VarName, range) = setrange!(getmetadata(vi, vn), vn, range)
651651
setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range
652652

653-
"""
654-
getranges(vi::VarInfo, vns::Vector{<:VarName})
655-
656-
Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
657-
"""
658-
function getranges(vi::VarInfo, vns::Vector{<:VarName})
659-
return map(Base.Fix1(getrange, vi), vns)
660-
end
661-
662-
"""
663-
vector_getrange(varinfo::VarInfo, varname::VarName)
664-
665-
Return the range corresponding to `varname` in the vector representation of `varinfo`.
666-
"""
667-
vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn)
668-
function vector_getrange(vi::NTVarInfo, vn::VarName)
669-
offset = 0
670-
for md in values(vi.metadata)
671-
# First, we need to check if `vn` is in `md`.
672-
# In this case, we can just return the corresponding range + offset.
673-
haskey(md, vn) && return getrange(md, vn) .+ offset
674-
# Otherwise, we need to get the cumulative length of the ranges in `md`
675-
# and add it to the offset.
676-
offset += sum(length, md.ranges)
677-
end
678-
# If we reach this point, `vn` is not in `vi.metadata`.
679-
throw(KeyError(vn))
680-
end
681-
682-
"""
683-
vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName})
684-
685-
Return the range corresponding to `varname` in the vector representation of `varinfo`.
686-
"""
687-
function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName})
688-
return map(Base.Fix1(vector_getrange, varinfo), varname)
689-
end
690-
# Specialized version for `NTVarInfo`.
691-
function vector_getranges(varinfo::NTVarInfo, vns::Vector{<:VarName})
692-
# TODO: Does it help if we _don't_ convert to a vector here?
693-
metadatas = collect(values(varinfo.metadata))
694-
# Extract the offsets.
695-
offsets = cumsum(map(vector_length, metadatas))
696-
# Extract the ranges from each metadata.
697-
ranges = Vector{UnitRange{Int}}(undef, length(vns))
698-
# Need to keep track of which ones we've seen.
699-
not_seen = fill(true, length(vns))
700-
for (i, metadata) in enumerate(metadatas)
701-
vns_metadata = filter(Base.Fix1(haskey, metadata), vns)
702-
# If none of the variables exist in the metadata, we return an empty array.
703-
isempty(vns_metadata) && continue
704-
# Otherwise, we extract the ranges.
705-
offset = i == 1 ? 0 : offsets[i - 1]
706-
for vn in vns_metadata
707-
r_vn = getrange(metadata, vn)
708-
# Get the index, so we return in the same order as `vns`.
709-
# NOTE: There might be duplicates in `vns`, so we need to handle that.
710-
indices = findall(==(vn), vns)
711-
for idx in indices
712-
not_seen[idx] = false
713-
ranges[idx] = r_vn .+ offset
714-
end
715-
end
716-
end
717-
# Raise key error if any of the variables were not found.
718-
if any(not_seen)
719-
inds = findall(not_seen)
720-
# Just use a `convert` to get the same type as the input; don't want to confuse by overly
721-
# specilizing the types in the error message.
722-
throw(KeyError(convert(typeof(vns), vns[inds])))
723-
end
724-
return ranges
725-
end
726-
727653
"""
728654
getdist(vi::VarInfo, vn::VarName)
729655

test/varinfo.jl

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -824,48 +824,6 @@ end
824824
@test merge(vi_double, vi_single)[vn] == 1.0
825825
end
826826

827-
# NOTE: It is not yet clear if this is something we want from all varinfo types.
828-
# Hence, we only test the `VarInfo` types here.
829-
@testset "vector_getranges for `VarInfo`" begin
830-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
831-
vns = DynamicPPL.TestUtils.varnames(model)
832-
nt = DynamicPPL.TestUtils.rand_prior_true(model)
833-
varinfos = DynamicPPL.TestUtils.setup_varinfos(
834-
model, nt, vns; include_threadsafe=true
835-
)
836-
# Only keep `VarInfo` types.
837-
varinfos = filter(
838-
Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos
839-
)
840-
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
841-
x = values_as(varinfo, Vector)
842-
843-
# Let's just check all the subsets of `vns`.
844-
@testset "$(convert(Vector{Any},vns_subset))" for vns_subset in
845-
combinations(vns)
846-
ranges = DynamicPPL.vector_getranges(varinfo, vns_subset)
847-
@test length(ranges) == length(vns_subset)
848-
for (r, vn) in zip(ranges, vns_subset)
849-
@test x[r] == DynamicPPL.tovec(varinfo[vn])
850-
end
851-
end
852-
853-
# Let's try some failure cases.
854-
@test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[]
855-
# Non-existent variables.
856-
@test_throws KeyError DynamicPPL.vector_getranges(
857-
varinfo, [VarName{gensym("vn")}()]
858-
)
859-
@test_throws KeyError DynamicPPL.vector_getranges(
860-
varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()]
861-
)
862-
# Duplicate variables.
863-
ranges_duplicated = DynamicPPL.vector_getranges(varinfo, repeat(vns, 2))
864-
@test x[reduce(vcat, ranges_duplicated)] == repeat(x, 2)
865-
end
866-
end
867-
end
868-
869827
@testset "issue #842" begin
870828
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
871829
varinfo = VarInfo(model)

0 commit comments

Comments
 (0)