Skip to content

Commit 4253e9b

Browse files
committed
ArrayLikeBlock WIP2
1 parent 35c3e20 commit 4253e9b

File tree

2 files changed

+155
-13
lines changed

2 files changed

+155
-13
lines changed

src/varnamedtuple.jl

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ const INDEX_TYPES = Union{Integer,UnitRange,Colon}
5858
struct ArrayLikeBlock{T,I}
5959
block::T
6060
inds::I
61+
62+
function ArrayLikeBlock(block::T, inds::I) where {T,I}
63+
if !_is_multiindex(inds)
64+
throw(ArgumentError("ArrayLikeBlock must be constructed with a multi-index"))
65+
end
66+
return new{T,I}(block, inds)
67+
end
6168
end
6269

6370
"""
@@ -385,15 +392,102 @@ end
385392

386393
function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES})
387394
_check_index_validity(pa, inds)
388-
if !_haskey(pa, inds)
395+
if !(checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))))
389396
throw(BoundsError(pa, inds))
390397
end
391-
return getindex(pa.data, inds...)
398+
val = getindex(pa.data, inds...)
399+
400+
# If not for ArrayLikeBlocks, at this point we could just return val directly. However,
401+
# we need to check if val contains any ArrayLikeBlocks, and if so, make sure that that
402+
# we are retrieving exactly that block and nothing else.
403+
404+
# The error we'll throw if the retrieval is invalid.
405+
err = ArgumentError("""
406+
A non-Array value set with a range of indices must be retrieved with the same
407+
range of indices.
408+
""")
409+
if val isa ArrayLikeBlock
410+
# Tried to get a single value, but it's an ArrayLikeBlock.
411+
throw(err)
412+
elseif val isa Array && (eltype(val) <: ArrayLikeBlock || ArrayLikeBlock <: eltype(val))
413+
# Tried to get a range of values, and at least some of them may be ArrayLikeBlocks.
414+
# The below isempty check is deliberately kept separate from the outer elseif,
415+
# because the outer one can be resolved at compile time.
416+
if isempty(val)
417+
return val
418+
end
419+
first_elem = first(val)
420+
if !(first_elem isa ArrayLikeBlock)
421+
throw(err)
422+
end
423+
if inds != first_elem.inds
424+
# The requested indices do not match the ones used to set the value.
425+
throw(err)
426+
end
427+
# If _setindex!! works correctly, we should only be able to reach this point if all
428+
# the elements in `val` are identical to first_elem. Thus we just return that one.
429+
return first(val).block
430+
else
431+
return val
432+
end
392433
end
393434

394435
function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N}
395436
_check_index_validity(pa, inds)
396-
return checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))
437+
hasall =
438+
checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))
439+
440+
# If not for ArrayLikeBlocks, we could just return hasall directly. However, we need to
441+
# check that if any ArrayLikeBlocks are included, they are fully included.
442+
et = eltype(pa)
443+
if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et)
444+
# pa can't possibly hold any ArrayLikeBlocks, so nothing to do.
445+
return hasall
446+
end
447+
448+
if !hasall
449+
return false
450+
end
451+
# From this point on we can assume that all the requested elements are set, and the only
452+
# thing to check is that we are not partially indexing into any ArrayLikeBlocks.
453+
# We've already checked checkbounds at the top of the function, and returned if it
454+
# wasn't true, so @inbounds is safe.
455+
subdata = @inbounds getindex(pa.data, inds...)
456+
if !_is_multiindex(inds)
457+
return !(subdata isa ArrayLikeBlock)
458+
end
459+
return !any(elem -> elem isa ArrayLikeBlock && elem.inds != inds, subdata)
460+
end
461+
462+
function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES})
463+
_check_index_validity(pa, inds)
464+
if _is_multiindex(inds)
465+
pa.mask[inds...] .= false
466+
else
467+
pa.mask[inds...] = false
468+
end
469+
return _concretise_eltype!!(pa)
470+
end
471+
472+
_ensure_range(r::UnitRange) = r
473+
_ensure_range(i::Integer) = i:i
474+
475+
function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES})
476+
et = eltype(pa)
477+
if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et)
478+
# pa can't possibly hold any ArrayLikeBlocks, so nothing to do.
479+
return pa
480+
end
481+
482+
for i in CartesianIndices(map(_ensure_range, inds))
483+
if pa.mask[i]
484+
val = @inbounds pa.data[i]
485+
if val isa ArrayLikeBlock
486+
pa = delete!!(pa, val.inds...)
487+
end
488+
end
489+
end
490+
return pa
397491
end
398492

399493
function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES})
@@ -403,13 +497,15 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES})
403497
else
404498
_resize_partialarray!!(pa, inds)
405499
end
500+
pa = _remove_partial_blocks!!(pa, inds...)
406501

407502
new_data = pa.data
408503
if _is_multiindex(inds) && !(isa(value, AbstractArray))
409-
if !hasmethod(size, value)
504+
if !hasmethod(size, Tuple{typeof(value)})
410505
throw(ArgumentError("Cannot assign a scalar value to a range."))
411506
end
412-
if size(value) != map(x -> _length_needed(x), inds)
507+
inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds))
508+
if size(value) != inds_size
413509
throw(
414510
DimensionMismatch(
415511
"Assigned value has size $(size(value)), which does not match the size " *
@@ -419,14 +515,10 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES})
419515
end
420516
# At this point we know we have a value that is not an AbstractArray, but it has
421517
# some notion of size, and that size matches the indices that are being set. In this
422-
# case we wrap the value in a ArrayLikeBlock, and set all the individual indices
423-
# point to that, with the right subindices.
424-
first_index = first.(inds)
425-
# Iterate over all the subindices of inds.
426-
for ind in CartesianIndices(map(x -> _length_needed(x), inds))
427-
subinds = ntuple(i -> first_index[i] + ind[i] - 1, length(inds))
428-
new_data = _setindex!!(new_data, ArrayLikeBlock(value, Tuple(ind)), subinds...)
429-
end
518+
# case we wrap the value in an ArrayLikeBlock, and set all the individual indices
519+
# point to that.
520+
alb = ArrayLikeBlock(value, inds)
521+
new_data = setindex!!(new_data, fill(alb, inds_size...), inds...)
430522
else
431523
new_data = setindex!!(new_data, value, inds...)
432524
end

test/varnamedtuple.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module VarNamedTupleTests
22

33
using Combinatorics: Combinatorics
44
using Test: @inferred, @test, @test_throws, @testset
5+
using Distributions: Dirichlet
56
using DynamicPPL: DynamicPPL, @varname, VarNamedTuple
67
using DynamicPPL.VarNamedTuples: PartialArray
78
using AbstractPPL: VarName, prefix
@@ -458,6 +459,55 @@ end
458459
VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \
459460
(2,) => 17.0),),)),))"""
460461
end
462+
463+
@testset "block variables" begin
464+
# Tests for setting and getting block variables, i.e. variables that have a non-zero
465+
# size in a PartialArray, but are not Arrays themselves.
466+
expected_err = ArgumentError("""
467+
A non-Array value set with a range of indices must be retrieved with the same
468+
range of indices.
469+
""")
470+
vnt = VarNamedTuple()
471+
vnt = setindex!!(vnt, Dirichlet(3, 1.0), @varname(x[2:4]))
472+
@test haskey(vnt, @varname(x[2:4]))
473+
@test getindex(vnt, @varname(x[2:4])) == Dirichlet(3, 1.0)
474+
@test !haskey(vnt, @varname(x[2:3]))
475+
@test_throws expected_err getindex(vnt, @varname(x[2:3]))
476+
@test !haskey(vnt, @varname(x[3]))
477+
@test_throws expected_err getindex(vnt, @varname(x[3]))
478+
@test !haskey(vnt, @varname(x[1]))
479+
@test !haskey(vnt, @varname(x[5]))
480+
vnt = setindex!!(vnt, 1.0, @varname(x[1]))
481+
vnt = setindex!!(vnt, 1.0, @varname(x[5]))
482+
@test haskey(vnt, @varname(x[1]))
483+
@test haskey(vnt, @varname(x[5]))
484+
@test_throws expected_err getindex(vnt, @varname(x[1:4]))
485+
@test_throws expected_err getindex(vnt, @varname(x[2:5]))
486+
487+
# Setting any of these indices should remove the block variable x[2:4].
488+
@testset "index = $index" for index in (2, 3, 4, 2:3, 3:5)
489+
# Test setting different types of values.
490+
vals = if index isa Int
491+
(2.0,)
492+
else
493+
(fill(2.0, length(index)), Dirichlet(length(index), 2.0))
494+
end
495+
@testset "val = $val" for val in vals
496+
vn = @varname(x[index])
497+
vnt2 = copy(vnt)
498+
vnt2 = setindex!!(vnt2, val, vn)
499+
@test !haskey(vnt2, @varname(x[2:4]))
500+
@test_throws BoundsError getindex(vnt2, @varname(x[2:4]))
501+
other_index = index in (2, 2:3) ? 4 : 2
502+
@test !haskey(vnt2, @varname(x[other_index]))
503+
@test_throws BoundsError getindex(vnt2, @varname(x[other_index]))
504+
@test haskey(vnt2, vn)
505+
@test getindex(vnt2, vn) == val
506+
@test haskey(vnt2, @varname(x[1]))
507+
@test_throws BoundsError getindex(vnt2, @varname(x[1:4]))
508+
end
509+
end
510+
end
461511
end
462512

463513
end

0 commit comments

Comments
 (0)