Skip to content

Commit 3c02da4

Browse files
committed
Improve equality tests
1 parent c818bf8 commit 3c02da4

File tree

2 files changed

+70
-35
lines changed

2 files changed

+70
-35
lines changed

src/varnamedtuple.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,44 @@ function Base.:(==)(pa1::PartialArray, pa2::PartialArray)
208208
# TODO(mhauru) This could be optimised by not calling checkbounds on all elements
209209
# outside the size of an array, but not sure it's worth it.
210210
merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1))
211+
result = true
211212
for i in CartesianIndices(merge_size)
212213
m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false
213214
m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false
214215
if m1 != m2
215216
return false
216217
end
217-
if m1 && (pa1.data[i] != pa2.data[i])
218+
if m1
219+
elements_equal = pa1.data[i] == pa2.data[i]
220+
if elements_equal === false
221+
return false
222+
elseif elements_equal === missing
223+
# This branch can't short-circuit and just return missing, because some
224+
# later values may be straight-up not equal.
225+
result = missing
226+
end
227+
end
228+
end
229+
return result
230+
end
231+
232+
# Exactly as == above, except the comparison of the data elements uses isequal.
233+
function Base.isequal(pa1::PartialArray, pa2::PartialArray)
234+
if ndims(pa1) != ndims(pa2)
235+
return false
236+
end
237+
size1 = _internal_size(pa1)
238+
size2 = _internal_size(pa2)
239+
# TODO(mhauru) This could be optimised by not calling checkbounds on all elements
240+
# outside the size of an array, but not sure it's worth it.
241+
merge_size = ntuple(i -> max(size1[i], size2[i]), ndims(pa1))
242+
for i in CartesianIndices(merge_size)
243+
m1 = checkbounds(Bool, pa1.mask, Tuple(i)...) ? pa1.mask[i] : false
244+
m2 = checkbounds(Bool, pa2.mask, Tuple(i)...) ? pa2.mask[i] : false
245+
if m1 != m2
246+
return false
247+
end
248+
if m1 && !isequal(pa1.data[i], pa2.data[i])
218249
return false
219250
end
220251
end
@@ -497,6 +528,7 @@ end
497528
VarNamedTuple(; kwargs...) = VarNamedTuple((; kwargs...))
498529

499530
Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data
531+
Base.isequal(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = isequal(vnt1.data, vnt2.data)
500532
Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h)
501533

502534
function Base.show(io::IO, vnt::VarNamedTuple)

test/varnamedtuple.jl

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module VarNamedTupleTests
22

3+
using Combinatorics: Combinatorics
34
using Test: @inferred, @test, @test_throws, @testset
45
using DynamicPPL: DynamicPPL, @varname, VarNamedTuple
56
using DynamicPPL.VarNamedTuples: PartialArray
@@ -20,11 +21,13 @@ function test_invariants(vnt::VarNamedTuple)
2021
v = getindex(vnt, k)
2122
vnt2 = setindex!!(copy(vnt), v, k)
2223
@test vnt == vnt2
24+
@test isequal(vnt, vnt2)
2325
@test hash(vnt) == hash(vnt2)
2426
end
2527
# Check that the printed representation can be parsed back to an equal VarNamedTuple.
2628
vnt3 = eval(Meta.parse(repr(vnt)))
2729
@test vnt == vnt3
30+
@test isequal(vnt, vnt3)
2831
@test hash(vnt) == hash(vnt3)
2932
# Check that merge with an empty VarNamedTuple is a no-op.
3033
@test merge(vnt, VarNamedTuple()) == vnt
@@ -218,40 +221,40 @@ end
218221
test_invariants(vnt)
219222
end
220223

221-
@testset "equality" begin
222-
vnt1 = VarNamedTuple()
223-
vnt2 = VarNamedTuple()
224-
@test vnt1 == vnt2
225-
226-
vnt1 = setindex!!(vnt1, 1.0, @varname(a))
227-
@test vnt1 != vnt2
228-
229-
vnt2 = setindex!!(vnt2, 1.0, @varname(a))
230-
@test vnt1 == vnt2
231-
232-
vnt1 = setindex!!(vnt1, [1, 2], @varname(b))
233-
vnt2 = setindex!!(vnt2, [1, 2], @varname(b))
234-
@test vnt1 == vnt2
235-
236-
vnt2 = setindex!!(vnt2, [1, 3], @varname(b))
237-
@test vnt1 != vnt2
238-
vnt2 = setindex!!(vnt2, [1, 2], @varname(b))
239-
240-
# Try with index lenses too
241-
vnt1 = setindex!!(vnt1, 2, @varname(c[2]))
242-
vnt2 = setindex!!(vnt2, 2, @varname(c[2]))
243-
@test vnt1 == vnt2
244-
245-
vnt2 = setindex!!(vnt2, 3, @varname(c[2]))
246-
@test vnt1 != vnt2
247-
vnt2 = setindex!!(vnt2, 2, @varname(c[2]))
248-
249-
vnt1 = setindex!!(vnt1, ["a", "b"], @varname(d.e[1:2]))
250-
vnt2 = setindex!!(vnt2, ["a", "b"], @varname(d.e[1:2]))
251-
@test vnt1 == vnt2
252-
253-
vnt2 = setindex!!(vnt2, :b, @varname(d.e[2]))
254-
@test vnt1 != vnt2
224+
@testset "equality and hash" begin
225+
# Test all combinations of having or not having the below values set, and having
226+
# them set to any of the possible_values, and check that isequal and == return the
227+
# expected value.
228+
# NOTE: Be very careful adding new values to these sets. The below test has three
229+
# nested loops over Combinatorics.combinations, the run time can explode very, very
230+
# quickly.
231+
varnames = (@varname(b[1]), @varname(b[3]), @varname(c.d[2].e))
232+
possible_values = (missing, 1, -0.0, 0.0)
233+
for vn_set in Combinatorics.combinations(varnames)
234+
valuesets1 = Combinatorics.with_replacement_combinations(
235+
possible_values, length(vn_set)
236+
)
237+
valuesets2 = Combinatorics.with_replacement_combinations(
238+
possible_values, length(vn_set)
239+
)
240+
for vset1 in valuesets1, vset2 in valuesets2
241+
vnt1 = VarNamedTuple()
242+
vnt2 = VarNamedTuple()
243+
expected_isequal = true
244+
expected_doubleequal = true
245+
for (vn, v1, v2) in zip(vn_set, vset1, vset2)
246+
vnt1 = setindex!!(vnt1, v1, vn)
247+
vnt2 = setindex!!(vnt2, v2, vn)
248+
expected_isequal = expected_isequal & isequal(v1, v2)
249+
expected_doubleequal = expected_doubleequal & (v1 == v2)
250+
end
251+
@test isequal(vnt1, vnt2) == expected_isequal
252+
@test (vnt1 == vnt2) === expected_doubleequal
253+
if expected_isequal
254+
@test hash(vnt1) == hash(vnt2)
255+
end
256+
end
257+
end
255258
end
256259

257260
@testset "merge" begin

0 commit comments

Comments
 (0)