Skip to content

Commit b77b0af

Browse files
committed
Fix keys and some tests for PartialArray
1 parent e198fbb commit b77b0af

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

src/varnamedtuple.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ function Base.keys(pa::PartialArray)
639639
sublens = _varname_to_lens(vn)
640640
push!(ks, _compose_no_identity(sublens, lens))
641641
end
642+
elseif val isa ArrayLikeBlock
643+
push!(ks, IndexLens(Tuple(val.inds)))
642644
else
643645
push!(ks, lens)
644646
end

test/varnamedtuple.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@ module VarNamedTupleTests
22

33
using Combinatorics: Combinatorics
44
using Test: @inferred, @test, @test_throws, @testset
5-
using Distributions: Dirichlet
65
using DynamicPPL: DynamicPPL, @varname, VarNamedTuple
7-
using DynamicPPL.VarNamedTuples: PartialArray
6+
using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock
87
using AbstractPPL: VarName, prefix
98
using BangBang: setindex!!
109

@@ -20,12 +19,18 @@ function test_invariants(vnt::VarNamedTuple)
2019
for k in keys(vnt)
2120
@test haskey(vnt, k)
2221
v = getindex(vnt, k)
22+
# ArrayLikeBlocks are an implementation detail, and should not be exposed through
23+
# getindex.
24+
@test !(v isa ArrayLikeBlock)
2325
vnt2 = setindex!!(copy(vnt), v, k)
2426
@test vnt == vnt2
2527
@test isequal(vnt, vnt2)
2628
@test hash(vnt) == hash(vnt2)
2729
end
2830
# Check that the printed representation can be parsed back to an equal VarNamedTuple.
31+
# The below eval test is a bit fragile: If any elements in vnt don't respect the same
32+
# reconstructability-from-repr property, this will fail. Likewise if any element uses
33+
# in its repr print out types that are not in scope in this module, it will fail.
2934
vnt3 = eval(Meta.parse(repr(vnt)))
3035
@test vnt == vnt3
3136
@test isequal(vnt, vnt3)
@@ -461,17 +466,23 @@ end
461466
end
462467

463468
@testset "block variables" begin
469+
""" A type that has a size but is not an Array."""
470+
struct SizedThing
471+
size::Tuple
472+
end
473+
Base.size(st::SizedThing) = st.size
474+
464475
# Tests for setting and getting block variables, i.e. variables that have a non-zero
465476
# size in a PartialArray, but are not Arrays themselves.
466477
expected_err = ArgumentError("""
467478
A non-Array value set with a range of indices must be retrieved with the same
468479
range of indices.
469480
""")
470481
vnt = VarNamedTuple()
471-
vnt = @inferred(setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4])))
482+
vnt = @inferred(setindex!!(vnt, SizedThing((3,)), @varname(x[2:4])))
472483
test_invariants(vnt)
473484
@test haskey(vnt, @varname(x[2:4]))
474-
@test @inferred(getindex(vnt, @varname(x[2:4]))) == Dirichlet(3, 1.0)
485+
@test @inferred(getindex(vnt, @varname(x[2:4]))) == SizedThing((3,))
475486
@test !haskey(vnt, @varname(x[2:3]))
476487
@test_throws expected_err getindex(vnt, @varname(x[2:3]))
477488
@test !haskey(vnt, @varname(x[3]))
@@ -492,7 +503,7 @@ end
492503
vals = if index isa Int
493504
(2.0,)
494505
else
495-
(fill(2.0, length(index)), Dirichlet(length(index), 2.0))
506+
(fill(2.0, length(index)), SizedThing((length(index),)))
496507
end
497508
@testset "val = $val" for val in vals
498509
vn = @varname(x[index])
@@ -513,9 +524,7 @@ end
513524

514525
# Extra checks, mostly for type stability and to confirm that multidimensional
515526
# blocks work too.
516-
struct TwoByTwoBlock end
517-
Base.size(::TwoByTwoBlock) = (2, 2)
518-
val = TwoByTwoBlock()
527+
val = SizedThing((2, 2))
519528
vnt = VarNamedTuple()
520529
vnt = @inferred(setindex!!(vnt, val, @varname(y.z[1:2, 1:2])))
521530
test_invariants(vnt)

0 commit comments

Comments
 (0)