diff --git a/docs/src/internals/varnamedtuple.md b/docs/src/internals/varnamedtuple.md index 47ff9c65e..63f4bb5b9 100644 --- a/docs/src/internals/varnamedtuple.md +++ b/docs/src/internals/varnamedtuple.md @@ -144,6 +144,29 @@ You can also set the elements with `vnt = setindex!!(vnt, @varname(a[1]), 3.0)`, At this point you can not set any new values in that array that would be outside of its range, with something like `vnt = setindex!!(vnt, @varname(a[5]), 5.0)`. The philosophy here is that once a `Base.Array` has been attached to a `VarName`, that takes precedence, and a `PartialArray` is only used as a fallback when we are told to store a value for `@varname(a[i])` without having any previous knowledge about what `@varname(a)` is. +## Non-Array blocks with `IndexLens`es + +The above is all that is needed for setting regular scalar values. +However, in DynamicPPL we also have a particular need for something slightly odd: +We sometimes need to do calls like `setindex!!(vnt, @varname(a[1:5]), val)` on a `val` that is _not_ an `AbstractArray`, or even iterable at all. +Normally this would error: As a scalar value with size `()`, `val` is the wrong size to be set with `@varname(a[1:5])`, which clearly wants something with size `(5,)`. +However, we want to allow this even if `val` is not an iterable, if it is some object for which `size` is well-defined, and `size(val) == (5,)`. +In DynamicPPL this comes up when storing e.g. the priors of a model, where a random variable like `@varname(a[1:5])` may be associated with a prior that is a 5-dimensional distribution. + +Internally, a `PartialArray` is just a regular `Array` with a mask saying which elements have been set. +Hence we can't store `val` directly in the same `PartialArray`: +We need it to take up a sub-block of the array, in our example case a sub-block of length 5. +To this end, internally, `PartialArray` uses a wrapper type called `ArrayLikeWrapper`, that stores `val` together with the indices that are being used to set it. +The `PartialArray` has all its corresponding elements, in our example elements 1, 2, 3, 4, and, 5, point to the same wrapper object. + +While such blocks can be stored using a wrapper like this, some care must be taken in indexing into these blocks. +For instance, after setting a block with `setindex!!(vnt, @varname(a[1:5]), val)`, we can't `getindex(vnt, @varname(a[1]))`, since we can't return "the first element of five in `val`", because `val` may not be indexable in any way. +Similarly, if next we set `setindex!!(vnt, @varname(a[1]), some_other_value)`, that should invalidate/delete the elements `@varname(a[2:5])`, since the block only makes sense as a whole. +Because of these reasons, setting and getting blocks of well-defined size like this is allowed with `VarNamedTuple`s, but _only by always using the full range_. +For instance, if `setindex!!(vnt, @varname(a[1:5]), val)` has been set, then the only valid `getindex` key to access `val` is `@varname(a[1:5])`; +Not `@varname(a[1:10])`, nor `@varname(a[3])`, nor for anything else that overlaps with `@varname(a[1:5])`. +`haskey` likewise only returns true for `@varname(a[1:5])`, and `keys(vnt)` only has that as an element. + ## Limitations This design has a several of benefits, for performance and generality, but it also has limitations: diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index ffb5baf25..e28560872 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -1,6 +1,6 @@ module DynamicPPLMarginalLogDensitiesExt -using DynamicPPL: DynamicPPL, LogDensityProblems, VarName +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by @@ -105,11 +105,9 @@ function DynamicPPL.marginalize( ldf = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) # Determine the indices for the variables to marginalise out. varindices = mapreduce(vcat, marginalized_varnames) do vn - if DynamicPPL.getoptic(vn) === identity - ldf._iden_varname_ranges[DynamicPPL.getsym(vn)].range - else - ldf._varname_ranges[vn].range - end + # The type assertion helps in cases where the model is type unstable and thus + # `varname_ranges` may have an abstract element type. + (ldf._varname_ranges[vn]::RangeAndLinked).range end mld = MarginalLogDensities.MarginalLogDensity( LogDensityFunctionWrapper(ldf, varinfo), diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 90394a24c..dc811df85 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -206,13 +206,17 @@ an unlinked value. $(TYPEDFIELDS) """ -struct RangeAndLinked +struct RangeAndLinked{T<:Tuple} # indices that the variable corresponds to in the vectorised parameter range::UnitRange{Int} # whether it's linked is_linked::Bool + # original size of the variable before vectorisation + original_size::T end +Base.size(ral::RangeAndLinked) = ral.original_size + """ VectorWithRanges{Tlink}( varname_ranges::VarNamedTuple, @@ -247,7 +251,12 @@ struct VectorWithRanges{Tlink,VNT<:VarNamedTuple,T<:AbstractVector{<:Real}} end function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) - return vr.varname_ranges[vn] + # The type assertion does nothing if VectorWithRanges has concrete element types, as is + # the case for all type stable models. However, if the model is not type stable, + # vr.varname_ranges[vn] may infer to have type `Any`. In this case it is helpful to + # assert that it is a RangeAndLinked, because even though it remains non-concrete, + # it'll allow the compiler to infer the types of `range` and `is_linked`. + return vr.varname_ranges[vn]::RangeAndLinked end function init( ::Random.AbstractRNG, diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 47b49a277..89e2b5989 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -330,7 +330,10 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) for (vn, idx) in md.idcs is_linked = md.is_transformed[idx] range = md.ranges[idx] .+ (start_offset - 1) - all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) + orig_size = varnamesize(vn) + all_ranges = BangBang.setindex!!( + all_ranges, RangeAndLinked(range, is_linked, orig_size), vn + ) offset += length(range) end return all_ranges, offset @@ -341,7 +344,10 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) for (vn, idx) in vnv.varname_to_index is_linked = vnv.is_unconstrained[idx] range = vnv.ranges[idx] .+ (start_offset - 1) - all_ranges = BangBang.setindex!!(all_ranges, RangeAndLinked(range, is_linked), vn) + orig_size = varnamesize(vn) + all_ranges = BangBang.setindex!!( + all_ranges, RangeAndLinked(range, is_linked, orig_size), vn + ) offset += length(range) end return all_ranges, offset diff --git a/src/varname.jl b/src/varname.jl index 3eb1f2460..7ffe9cc08 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -41,3 +41,28 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +# TODO(mhauru) This should probably be Base.size(::VarName) in AbstractPPL. +""" + varnamesize(vn::VarName) + +Return the size of the object referenced by this VarName. + +```jldoctest +julia> varnamesize(@varname(a)) +() + +julia> varnamesize(@varname(b[1:3, 2])) +(3,) + +julia> varnamesize(@varname(c.d[4].e[3, 2:5, 2, 1:4, 1])) +(4, 4) +""" +function varnamesize(vn::VarName) + l = AbstractPPL._last(vn.optic) + if l isa Accessors.IndexLens + return reduce((x, y) -> tuple(x..., y...), map(size, l.indices)) + else + return () + end +end diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index db2462e53..1340846a9 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -27,9 +27,30 @@ function _setindex!!(arr::AbstractArray, value, optic::IndexLens) end # Some utilities for checking what sort of indices we are dealing with. -_has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) -function _is_multiindex(::T) where {T<:Tuple} - return any(x <: UnitRange || x <: Colon for x in T.parameters) +# The non-generated function implementations of these would be +# _has_colon(::T) where {T<:Tuple} = any(x <: Colon for x in T.parameters) +# function _is_multiindex(::T) where {T<:Tuple} +# return any(x <: UnitRange || x <: Colon for x in T.parameters) +# end +# However, constant propagation sometimes fails if the index tuple is too big (e.g. length +# 4), so we play it safe and use generated functions. Constant propagating these is +# important, because many functions choose different paths based on their values, which +# would lead to type instability if they were only evaluated at runtime. +@generated function _has_colon(::T) where {T<:Tuple} + for x in T.parameters + if x <: Colon + return :(true) + end + end + return :(false) +end +@generated function _is_multiindex(::T) where {T<:Tuple} + for x in T.parameters + if x <: UnitRange || x <: Colon + return :(true) + end + end + return :(false) end """ @@ -55,6 +76,45 @@ const PARTIAL_ARRAY_DIM_GROWTH_FACTOR = 4 """A convenience for defining method argument type bounds.""" const INDEX_TYPES = Union{Integer,UnitRange,Colon} +""" + ArrayLikeBlock{T,I} + +A wrapper for non-array blocks stored in `PartialArray`s. + +When setting a value in a `PartialArray` over a range of indices, if the value being set +is not itself an `AbstractArray`, but has a well-defined size, we wrap it in an +`ArrayLikeBlock`, which records both the value and the indices it was set with. + +When getting values from a `PartialArray`, if any of the requested indices correspond to +an `ArrayLikeBlock`, we check that the requested indices match the ones used to set the +value. If they do, we return the underlying block, otherwise we throw an error. +""" +struct ArrayLikeBlock{T,I} + block::T + inds::I + + function ArrayLikeBlock(block::T, inds::I) where {T,I} + if !_is_multiindex(inds) + throw(ArgumentError("ArrayLikeBlock must be constructed with a multi-index")) + end + return new{T,I}(block, inds) + end +end + +function Base.show(io::IO, alb::ArrayLikeBlock) + # Note the distinction: The raw strings that form part of the structure of the print + # out are `print`ed, whereas the keys and values are `show`n. The latter ensures + # that strings are quoted, Symbols are prefixed with :, etc. + print(io, "ArrayLikeBlock(") + show(io, alb.block) + print(io, ", ") + show(io, alb.inds) + print(io, ")") + return nothing +end + +_blocktype(::Type{ArrayLikeBlock{T}}) where {T} = T + """ PartialArray{ElType,numdims} @@ -89,6 +149,14 @@ Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known elem `ElType` and number of dimensions `numdims`. Indices into a `PartialArray` must have exactly `numdims` elements. +One can set values in a `PartialArray` either element-by-element, or with ranges like +`arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set +must either be an `AbstractArray` or otherwise something for which `size(value)` is defined, +and the size mathces the range. If the value is an `AbstractArray`, the elements are copied +individually, but if it is not, the value is stored as a block, that takes up the whole +range, e.g. `[1:3,2]`, but is only a single object. Getting such a block-value must be done +with the exact same range of indices, otherwise an error is thrown. + If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check if, after the new value has been set, the element type can be made more concrete. If so, a new `PartialArray` with a more concrete element type is returned. Thus the element type @@ -105,6 +173,9 @@ means that the largest index set so far determines the memory usage of the `Part a few scattered values are set, a structure like `SparseArray` may be more appropriate. """ struct PartialArray{ElType,num_dims} + # TODO(mhauru) Consider trying FixedSizeArrays instead, see how it would change + # performance. We reallocate new Arrays every time when resizing anyway, except for + # Vectors, which can be extended in place. data::Array{ElType,num_dims} mask::Array{Bool,num_dims} @@ -281,7 +352,13 @@ function _concretise_eltype!!(pa::PartialArray) if isconcretetype(eltype(pa)) return pa end - new_et = promote_type((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing + # and Missing, rather than falling back on Any. However, it's not exported. + new_et = typejoin((typeof(pa.data[i]) for i in eachindex(pa.mask) if pa.mask[i])...) # TODO(mhauru) Should we check as below, or rather isconcretetype(new_et)? # In other words, does it help to be more concrete, even if we aren't fully concrete? if new_et === eltype(pa) @@ -377,15 +454,129 @@ end function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) _check_index_validity(pa, inds) - if !_haskey(pa, inds) + if !(checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))) throw(BoundsError(pa, inds)) end - return getindex(pa.data, inds...) + val = getindex(pa.data, inds...) + + # If not for ArrayLikeBlocks, at this point we could just return val directly. However, + # we need to check if val contains any ArrayLikeBlocks, and if so, make sure that that + # we are retrieving exactly that block and nothing else. + + # The error we'll throw if the retrieval is invalid. + err = ArgumentError(""" + A non-Array value set with a range of indices must be retrieved with the same + range of indices. + """) + if val isa ArrayLikeBlock + # Tried to get a single value, but it's an ArrayLikeBlock. + throw(err) + elseif val isa Array && (eltype(val) <: ArrayLikeBlock || ArrayLikeBlock <: eltype(val)) + # Tried to get a range of values, and at least some of them may be ArrayLikeBlocks. + # The below isempty check is deliberately kept separate from the outer elseif, + # because the outer one can be resolved at compile time. + if isempty(val) + # We need to return an empty array, but for type stability, we want to unwrap + # any ArrayLikeBlock types in the element type. + return if eltype(val) <: ArrayLikeBlock + Array{_blocktype(eltype(val)),ndims(val)}() + else + val + end + end + first_elem = first(val) + if !(first_elem isa ArrayLikeBlock) + throw(err) + end + if inds != first_elem.inds + # The requested indices do not match the ones used to set the value. + throw(err) + end + # If _setindex!! works correctly, we should only be able to reach this point if all + # the elements in `val` are identical to first_elem. Thus we just return that one. + return first(val).block + else + return val + end end function _haskey(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) where {N} _check_index_validity(pa, inds) - return checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) + hasall = + checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...))) + + # If not for ArrayLikeBlocks, we could just return hasall directly. However, we need to + # check that if any ArrayLikeBlocks are included, they are fully included. + et = eltype(pa) + if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) + # pa can't possibly hold any ArrayLikeBlocks, so nothing to do. + return hasall + end + + if !hasall + return false + end + # From this point on we can assume that all the requested elements are set, and the only + # thing to check is that we are not partially indexing into any ArrayLikeBlocks. + # We've already checked checkbounds at the top of the function, and returned if it + # wasn't true, so @inbounds is safe. + subdata = @inbounds getindex(pa.data, inds...) + if !_is_multiindex(inds) + return !(subdata isa ArrayLikeBlock) + end + return !any(elem -> elem isa ArrayLikeBlock && elem.inds != inds, subdata) +end + +function BangBang.delete!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + _check_index_validity(pa, inds) + if _is_multiindex(inds) + pa.mask[inds...] .= false + else + pa.mask[inds...] = false + end + return pa +end + +_ensure_range(r::UnitRange) = r +_ensure_range(i::Integer) = i:i + +""" + _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + +Remove any ArrayLikeBlocks that overlap with the given indices from the PartialArray. + +Note that this removes the whole block, even the parts that are within `inds`, to avoid +partially indexing into ArrayLikeBlocks. +""" +function _remove_partial_blocks!!(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + et = eltype(pa) + if !(et <: ArrayLikeBlock || ArrayLikeBlock <: et) + # pa can't possibly hold any ArrayLikeBlocks, so nothing to do. + return pa + end + + for i in CartesianIndices(map(_ensure_range, inds)) + if pa.mask[i] + val = @inbounds pa.data[i] + if val isa ArrayLikeBlock + pa = delete!!(pa, val.inds...) + end + end + end + return pa +end + +""" + _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) + +Check if the given value needs to be wrapped in an `ArrayLikeBlock` when being set at inds. + +The value only depends on the types of the arguments, and should be constant propagated. +""" +function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) + return _is_multiindex(inds) && + !isa(value, AbstractArray) && + hasmethod(size, Tuple{typeof(value)}) end function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) @@ -395,7 +586,29 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) else _resize_partialarray!!(pa, inds) end - new_data = setindex!!(pa.data, value, inds...) + pa = _remove_partial_blocks!!(pa, inds...) + + new_data = pa.data + if _needs_arraylikeblock(value, inds...) + inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) + if size(value) != inds_size + throw( + DimensionMismatch( + "Assigned value has size $(size(value)), which does not match the " * + "size implied by the indices $(map(x -> _length_needed(x), inds)).", + ), + ) + end + # At this point we know we have a value that is not an AbstractArray, but it has + # some notion of size, and that size matches the indices that are being set. In this + # case we wrap the value in an ArrayLikeBlock, and set all the individual indices + # to point to that. + alb = ArrayLikeBlock(value, inds) + new_data = setindex!!(new_data, fill(alb, inds_size...), inds...) + else + new_data = setindex!!(new_data, value, inds...) + end + if _is_multiindex(inds) pa.mask[inds...] .= true else @@ -452,7 +665,14 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) result else # Neither is strictly bigger than the other. - et = promote_type(eltype(pa1), eltype(pa2)) + # We could use promote_type here, instead of typejoin. However, that would e.g. + # cause Ints to be converted to Float64s, since + # promote_type(Int, Float64) == Float64, which can cause problems. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. + # Base.promote_typejoin would be like typejoin, but creates Unions out of + # Nothing and Missing, rather than falling back on Any. However, it's not + # exported. + et = typejoin(eltype(pa1), eltype(pa2)) new_data = Array{et,num_dims}(undef, merge_size) new_mask = fill(false, merge_size) result = PartialArray(new_data, new_mask) @@ -478,6 +698,7 @@ function Base.keys(pa::PartialArray) inds = findall(pa.mask) lenses = map(x -> IndexLens(Tuple(x)), inds) ks = Any[] + alb_inds_seen = Set{Tuple}() for lens in lenses val = getindex(pa.data, lens.indices...) if val isa VarNamedTuple @@ -486,6 +707,11 @@ function Base.keys(pa::PartialArray) sublens = _varname_to_lens(vn) push!(ks, _compose_no_identity(sublens, lens)) end + elseif val isa ArrayLikeBlock + if !(val.inds in alb_inds_seen) + push!(ks, IndexLens(Tuple(val.inds))) + push!(alb_inds_seen, val.inds) + end else push!(ks, lens) end @@ -639,25 +865,18 @@ function apply!!(func, vnt::VarNamedTuple, name::VarName) return _setindex!!(vnt, new_subdata, name) end -# TODO(mhauru) Should this return tuples, like it does now? That makes sense for -# VarNamedTuple itself, but if there is a nested PartialArray the tuple might get very big. -# Also, this is not very type stable, it fails even in basic cases. A generated function -# would help, but I failed to make one. Might be something to do with a recursive -# generated function. function Base.keys(vnt::VarNamedTuple) - result = () + result = VarName[] for sym in keys(vnt.data) subdata = vnt.data[sym] if subdata isa VarNamedTuple subkeys = keys(subdata) - result = ( - result..., (AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys)... - ) + append!(result, [AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys]) elseif subdata isa PartialArray subkeys = keys(subdata) - result = (result..., (VarName{sym}(lens) for lens in subkeys)...) + append!(result, [VarName{sym}(lens) for lens in subkeys]) else - result = (result..., VarName{sym}()) + push!(result, VarName{sym}()) end end return result @@ -716,9 +935,15 @@ end function make_leaf(value, optic::IndexLens) inds = optic.indices num_inds = length(inds) - # Check if any of the indices are ranges or colons. If yes, value needs to be an - # AbstractArray. Otherwise it needs to be an individual value. - et = _is_multiindex(inds) ? eltype(value) : typeof(value) + # The element type of the PartialArray depends on whether we are setting a single value + # or a range of values. + et = if !_is_multiindex(inds) + typeof(value) + elseif _needs_arraylikeblock(value, inds...) + ArrayLikeBlock{typeof(value),typeof(inds)} + else + eltype(value) + end pa = PartialArray{et,num_inds}() return _setindex!!(pa, value, optic) end diff --git a/test/runtests.jl b/test/runtests.jl index 9649aebbb..e0b42904c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ include("test_util.jl") include("utils.jl") include("accumulators.jl") include("compiler.jl") + include("varnamedtuple.jl") include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 3beadebf8..8be72a184 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -3,7 +3,7 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple -using DynamicPPL.VarNamedTuples: PartialArray +using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock using AbstractPPL: VarName, prefix using BangBang: setindex!! @@ -19,12 +19,18 @@ function test_invariants(vnt::VarNamedTuple) for k in keys(vnt) @test haskey(vnt, k) v = getindex(vnt, k) + # ArrayLikeBlocks are an implementation detail, and should not be exposed through + # getindex. + @test !(v isa ArrayLikeBlock) vnt2 = setindex!!(copy(vnt), v, k) @test vnt == vnt2 @test isequal(vnt, vnt2) @test hash(vnt) == hash(vnt2) end # Check that the printed representation can be parsed back to an equal VarNamedTuple. + # The below eval test is a bit fragile: If any elements in vnt don't respect the same + # reconstructability-from-repr property, this will fail. Likewise if any element uses + # in its repr print out types that are not in scope in this module, it will fail. vnt3 = eval(Meta.parse(repr(vnt))) @test vnt == vnt3 @test isequal(vnt, vnt3) @@ -34,6 +40,12 @@ function test_invariants(vnt::VarNamedTuple) @test merge(VarNamedTuple(), vnt) == vnt end +""" A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" +struct SizedThing{T<:Tuple} + size::T +end +Base.size(st::SizedThing) = st.size + @testset "VarNamedTuple" begin @testset "Construction" begin vnt1 = VarNamedTuple() @@ -342,36 +354,36 @@ end @testset "keys" begin vnt = VarNamedTuple() - @test @inferred(keys(vnt)) == () + @test @inferred(keys(vnt)) == VarName[] vnt = setindex!!(vnt, 1.0, @varname(a)) # TODO(mhauru) that the below passes @inferred, but any of the later ones don't. # We should improve type stability of keys(). - @test @inferred(keys(vnt)) == (@varname(a),) + @test @inferred(keys(vnt)) == [@varname(a)] vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) - @test keys(vnt) == (@varname(a), @varname(b)) + @test keys(vnt) == [@varname(a), @varname(b)] vnt = setindex!!(vnt, 15, @varname(b[2])) - @test keys(vnt) == (@varname(a), @varname(b)) + @test keys(vnt) == [@varname(a), @varname(b)] vnt = setindex!!(vnt, [10], @varname(c.x.y)) - @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y)) + @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y)] vnt = setindex!!(vnt, -1.0, @varname(d[4])) - @test keys(vnt) == (@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])) + @test keys(vnt) == [@varname(a), @varname(b), @varname(c.x.y), @varname(d[4])] vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @varname(d[4]), @varname(e.f[3, 3].g.h[2, 4, 1].i), - ) + ] vnt = setindex!!(vnt, fill(1.0, 4), @varname(j[1:4])) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @@ -381,10 +393,10 @@ end @varname(j[2]), @varname(j[3]), @varname(j[4]), - ) + ] vnt = setindex!!(vnt, 1.0, @varname(j[6])) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @@ -395,10 +407,10 @@ end @varname(j[3]), @varname(j[4]), @varname(j[6]), - ) + ] vnt = setindex!!(vnt, 1.0, @varname(n[2].a)) - @test keys(vnt) == ( + @test keys(vnt) == [ @varname(a), @varname(b), @varname(c.x.y), @@ -410,7 +422,23 @@ end @varname(j[4]), @varname(j[6]), @varname(n[2].a), - ) + ] + + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(o[2:4, 5:5, 11:14])) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + @varname(o[2:4, 5:5, 11:14]), + ] end @testset "printing" begin @@ -458,6 +486,107 @@ end VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0),),)),))""" end + + @testset "block variables" begin + # Tests for setting and getting block variables, i.e. variables that have a non-zero + # size in a PartialArray, but are not Arrays themselves. + expected_err = ArgumentError(""" + A non-Array value set with a range of indices must be retrieved with the same + range of indices. + """) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, SizedThing((3,)), @varname(x[2:4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(x[2:4])) + @test @inferred(getindex(vnt, @varname(x[2:4]))) == SizedThing((3,)) + @test !haskey(vnt, @varname(x[2:3])) + @test_throws expected_err getindex(vnt, @varname(x[2:3])) + @test !haskey(vnt, @varname(x[3])) + @test_throws expected_err getindex(vnt, @varname(x[3])) + @test !haskey(vnt, @varname(x[1])) + @test !haskey(vnt, @varname(x[5])) + vnt = setindex!!(vnt, 1.0, @varname(x[1])) + vnt = setindex!!(vnt, 1.0, @varname(x[5])) + test_invariants(vnt) + @test haskey(vnt, @varname(x[1])) + @test haskey(vnt, @varname(x[5])) + @test_throws expected_err getindex(vnt, @varname(x[1:4])) + @test_throws expected_err getindex(vnt, @varname(x[2:5])) + + # Setting any of these indices should remove the block variable x[2:4]. + @testset "index = $index" for index in (2, 3, 4, 2:3, 3:5) + # Test setting different types of values. + vals = if index isa Int + (2.0,) + else + (fill(2.0, length(index)), SizedThing((length(index),))) + end + @testset "val = $val" for val in vals + vn = @varname(x[index]) + vnt2 = copy(vnt) + vnt2 = setindex!!(vnt2, val, vn) + test_invariants(vnt) + @test !haskey(vnt2, @varname(x[2:4])) + @test_throws BoundsError getindex(vnt2, @varname(x[2:4])) + other_index = index in (2, 2:3) ? 4 : 2 + @test !haskey(vnt2, @varname(x[other_index])) + @test_throws BoundsError getindex(vnt2, @varname(x[other_index])) + @test haskey(vnt2, vn) + @test getindex(vnt2, vn) == val + @test haskey(vnt2, @varname(x[1])) + @test_throws BoundsError getindex(vnt2, @varname(x[1:4])) + end + end + + # Extra checks, mostly for type stability and to confirm that multidimensional + # blocks work too. + val = SizedThing((2, 2)) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[1:2, 1:2]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[1:2, 1:2])) + @test @inferred(getindex(vnt, @varname(y.z[1:2, 1:2]))) == val + @test !haskey(vnt, @varname(y.z[1, 1])) + @test_throws expected_err getindex(vnt, @varname(y.z[1, 1])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2:3, 2:3]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2:3, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val + @test !haskey(vnt, @varname(y.z[1:2, 1:2])) + @test_throws BoundsError getindex(vnt, @varname(y.z[1:2, 1:2])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[4:5, 2:3]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2:3, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[2:3, 2:3]))) == val + @test haskey(vnt, @varname(y.z[4:5, 2:3])) + @test @inferred(getindex(vnt, @varname(y.z[4:5, 2:3]))) == val + + # A lot like above, but with extra indices that are not ranges. + val = SizedThing((2, 2)) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2, 1:2, 3, 1:2, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4]))) == val + @test !haskey(vnt, @varname(y.z[2, 1, 3, 1, 4])) + @test_throws expected_err getindex(vnt, @varname(y.z[2, 1, 3, 1, 4])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[2, 2:3, 3, 2:3, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val + @test !haskey(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + @test_throws BoundsError getindex(vnt, @varname(y.z[2, 1:2, 3, 1:2, 4])) + + vnt = @inferred(setindex!!(vnt, val, @varname(y.z[3, 2:3, 3, 2:3, 4]))) + test_invariants(vnt) + @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val + @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) + @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val + end end end