From e0cdf631864d95eee8709dd5dd033535627cd96a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 12 Dec 2025 15:26:34 +0000 Subject: [PATCH] Remove `vector_getranges` and friends --- HISTORY.md | 4 ++ Project.toml | 2 +- ext/DynamicPPLMarginalLogDensitiesExt.jl | 19 ++++-- src/threadsafe.jl | 4 -- src/varinfo.jl | 74 ------------------------ test/varinfo.jl | 42 -------------- 6 files changed, 20 insertions(+), 125 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index d3650d30c..5d7892d8f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.39.4 + +Removed the internal functions `DynamicPPL.getranges`, `DynamicPPL.vector_getrange`, and `DynamicPPL.vector_getranges` (the new LogDensityFunction construction does exactly the same thing, so this specialised function was not needed). + ## 0.39.3 `DynamicPPL.TestUtils.AD.run_ad` now generates much prettier output. diff --git a/Project.toml b/Project.toml index f2ffa2c63..240e154bb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.39.3" +version = "0.39.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 8b3040757..ffb5baf25 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -101,12 +101,23 @@ function DynamicPPL.marginalize( method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), kwargs..., ) - # Determine the indices for the variables to marginalise out. - varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames)) # Construct the marginal log-density model. - f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) + ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) + # Determine the indices for the variables to marginalise out. + varindices = mapreduce(vcat, marginalized_varnames) do vn + if DynamicPPL.getoptic(vn) === identity + ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range + else + ldf._varname_ranges[vn].range + end + end mld = MarginalLogDensities.MarginalLogDensity( - LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs... + LogDensityFunctionWrapper(ldf, varinfo), + varinfo[:], + varindices, + (), + method; + kwargs..., ) return mld end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 0e906b6ca..6d3acce6c 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -155,10 +155,6 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<: end vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo) -vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn) -function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) - return vector_getranges(vi.varinfo, vns) -end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) function BangBang.empty!!(vi::ThreadSafeVarInfo) diff --git a/src/varinfo.jl b/src/varinfo.jl index 14e08515c..d1ea7dae3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -650,80 +650,6 @@ Set the index range of `vn` in the metadata of `vi` to `range`. setrange!(vi::VarInfo, vn::VarName, range) = setrange!(getmetadata(vi, vn), vn, range) setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range -""" - getranges(vi::VarInfo, vns::Vector{<:VarName}) - -Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. -""" -function getranges(vi::VarInfo, vns::Vector{<:VarName}) - return map(Base.Fix1(getrange, vi), vns) -end - -""" - vector_getrange(varinfo::VarInfo, varname::VarName) - -Return the range corresponding to `varname` in the vector representation of `varinfo`. -""" -vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn) -function vector_getrange(vi::NTVarInfo, vn::VarName) - offset = 0 - for md in values(vi.metadata) - # First, we need to check if `vn` is in `md`. - # In this case, we can just return the corresponding range + offset. - haskey(md, vn) && return getrange(md, vn) .+ offset - # Otherwise, we need to get the cumulative length of the ranges in `md` - # and add it to the offset. - offset += sum(length, md.ranges) - end - # If we reach this point, `vn` is not in `vi.metadata`. - throw(KeyError(vn)) -end - -""" - vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName}) - -Return the range corresponding to `varname` in the vector representation of `varinfo`. -""" -function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName}) - return map(Base.Fix1(vector_getrange, varinfo), varname) -end -# Specialized version for `NTVarInfo`. -function vector_getranges(varinfo::NTVarInfo, vns::Vector{<:VarName}) - # TODO: Does it help if we _don't_ convert to a vector here? - metadatas = collect(values(varinfo.metadata)) - # Extract the offsets. - offsets = cumsum(map(vector_length, metadatas)) - # Extract the ranges from each metadata. - ranges = Vector{UnitRange{Int}}(undef, length(vns)) - # Need to keep track of which ones we've seen. - not_seen = fill(true, length(vns)) - for (i, metadata) in enumerate(metadatas) - vns_metadata = filter(Base.Fix1(haskey, metadata), vns) - # If none of the variables exist in the metadata, we return an empty array. - isempty(vns_metadata) && continue - # Otherwise, we extract the ranges. - offset = i == 1 ? 0 : offsets[i - 1] - for vn in vns_metadata - r_vn = getrange(metadata, vn) - # Get the index, so we return in the same order as `vns`. - # NOTE: There might be duplicates in `vns`, so we need to handle that. - indices = findall(==(vn), vns) - for idx in indices - not_seen[idx] = false - ranges[idx] = r_vn .+ offset - end - end - end - # Raise key error if any of the variables were not found. - if any(not_seen) - inds = findall(not_seen) - # Just use a `convert` to get the same type as the input; don't want to confuse by overly - # specilizing the types in the error message. - throw(KeyError(convert(typeof(vns), vns[inds]))) - end - return ranges -end - """ getdist(vi::VarInfo, vn::VarName) diff --git a/test/varinfo.jl b/test/varinfo.jl index a1a1b370f..0d0ddc15d 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -824,48 +824,6 @@ end @test merge(vi_double, vi_single)[vn] == 1.0 end - # NOTE: It is not yet clear if this is something we want from all varinfo types. - # Hence, we only test the `VarInfo` types here. - @testset "vector_getranges for `VarInfo`" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - nt = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, nt, vns; include_threadsafe=true - ) - # Only keep `VarInfo` types. - varinfos = filter( - Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos - ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - x = values_as(varinfo, Vector) - - # Let's just check all the subsets of `vns`. - @testset "$(convert(Vector{Any},vns_subset))" for vns_subset in - combinations(vns) - ranges = DynamicPPL.vector_getranges(varinfo, vns_subset) - @test length(ranges) == length(vns_subset) - for (r, vn) in zip(ranges, vns_subset) - @test x[r] == DynamicPPL.tovec(varinfo[vn]) - end - end - - # Let's try some failure cases. - @test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[] - # Non-existent variables. - @test_throws KeyError DynamicPPL.vector_getranges( - varinfo, [VarName{gensym("vn")}()] - ) - @test_throws KeyError DynamicPPL.vector_getranges( - varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()] - ) - # Duplicate variables. - ranges_duplicated = DynamicPPL.vector_getranges(varinfo, repeat(vns, 2)) - @test x[reduce(vcat, ranges_duplicated)] == repeat(x, 2) - end - end - end - @testset "issue #842" begin model = DynamicPPL.TestUtils.DEMO_MODELS[1] varinfo = VarInfo(model)