Skip to content

Commit 11f9e9e

Browse files
committed
add support for torch.dtype, and dict views
1 parent 45cdc62 commit 11f9e9e

File tree

2 files changed

+123
-2
lines changed

2 files changed

+123
-2
lines changed

codeflash/verification/comparator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,19 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
130130
return False
131131
return True
132132

133+
# Handle dict view types (dict_keys, dict_values, dict_items)
134+
# Use type name checking since these are not directly importable types
135+
type_name = type(orig).__name__
136+
if type_name == "dict_keys":
137+
# dict_keys can be compared as sets (order doesn't matter)
138+
return comparator(set(orig), set(new))
139+
if type_name == "dict_values":
140+
# dict_values need element-wise comparison (order matters)
141+
return comparator(list(orig), list(new))
142+
if type_name == "dict_items":
143+
# dict_items can be compared as sets of tuples (order doesn't matter for items)
144+
return comparator(list(orig), list(new))
145+
133146
if HAS_NUMPY:
134147
import numpy as np # type: ignore # noqa: PGH003
135148

@@ -208,6 +221,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
208221
return False
209222
return torch.allclose(orig, new, equal_nan=True)
210223

224+
if isinstance(orig, torch.dtype):
225+
return orig == new
226+
211227
if HAS_PYRSISTENT:
212228
import pyrsistent # type: ignore # noqa: PGH003
213229

tests/test_comparator.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,43 @@ class TestClass(PClass):
624624
assert not comparator(v, x)
625625

626626

627+
def test_torch_dtype():
628+
try:
629+
import torch # type: ignore
630+
except ImportError:
631+
pytest.skip()
632+
633+
# Test torch.dtype comparisons
634+
a = torch.float32
635+
b = torch.float32
636+
c = torch.float64
637+
d = torch.int32
638+
assert comparator(a, b)
639+
assert not comparator(a, c)
640+
assert not comparator(a, d)
641+
642+
# Test different dtype categories
643+
e = torch.int64
644+
f = torch.int64
645+
g = torch.int32
646+
assert comparator(e, f)
647+
assert not comparator(e, g)
648+
649+
# Test complex dtypes
650+
h = torch.complex64
651+
i = torch.complex64
652+
j = torch.complex128
653+
assert comparator(h, i)
654+
assert not comparator(h, j)
655+
656+
# Test bool dtype
657+
k = torch.bool
658+
l = torch.bool
659+
m = torch.int8
660+
assert comparator(k, l)
661+
assert not comparator(k, m)
662+
663+
627664
def test_torch():
628665
try:
629666
import torch # type: ignore
@@ -1763,6 +1800,74 @@ class ExtendedClass:
17631800

17641801
minimal = MinimalClass("test", 42)
17651802
extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0)
1766-
1803+
17671804
assert not comparator(minimal, extended)
1768-
1805+
1806+
1807+
def test_dict_views() -> None:
1808+
"""Test comparator support for dict_keys, dict_values, and dict_items."""
1809+
# Test dict_keys
1810+
d1 = {"a": 1, "b": 2, "c": 3}
1811+
d2 = {"a": 1, "b": 2, "c": 3}
1812+
d3 = {"a": 1, "b": 2, "d": 3}
1813+
d4 = {"a": 1, "b": 2}
1814+
1815+
# dict_keys - same keys
1816+
assert comparator(d1.keys(), d2.keys())
1817+
# dict_keys - different keys
1818+
assert not comparator(d1.keys(), d3.keys())
1819+
# dict_keys - different length
1820+
assert not comparator(d1.keys(), d4.keys())
1821+
1822+
# Test dict_values
1823+
v1 = {"a": 1, "b": 2, "c": 3}
1824+
v2 = {"x": 1, "y": 2, "z": 3} # same values, different keys
1825+
v3 = {"a": 1, "b": 2, "c": 4} # different value
1826+
v4 = {"a": 1, "b": 2} # different length
1827+
1828+
# dict_values - same values (order matters for values since they're iterable)
1829+
assert comparator(v1.values(), v2.values())
1830+
# dict_values - different values
1831+
assert not comparator(v1.values(), v3.values())
1832+
# dict_values - different length
1833+
assert not comparator(v1.values(), v4.values())
1834+
1835+
# Test dict_items
1836+
i1 = {"a": 1, "b": 2, "c": 3}
1837+
i2 = {"a": 1, "b": 2, "c": 3}
1838+
i3 = {"a": 1, "b": 2, "c": 4} # different value
1839+
i4 = {"a": 1, "b": 2, "d": 3} # different key
1840+
i5 = {"a": 1, "b": 2} # different length
1841+
1842+
# dict_items - same items
1843+
assert comparator(i1.items(), i2.items())
1844+
# dict_items - different value
1845+
assert not comparator(i1.items(), i3.items())
1846+
# dict_items - different key
1847+
assert not comparator(i1.items(), i4.items())
1848+
# dict_items - different length
1849+
assert not comparator(i1.items(), i5.items())
1850+
1851+
# Test empty dicts
1852+
empty1 = {}
1853+
empty2 = {}
1854+
assert comparator(empty1.keys(), empty2.keys())
1855+
assert comparator(empty1.values(), empty2.values())
1856+
assert comparator(empty1.items(), empty2.items())
1857+
1858+
# Test with nested values
1859+
nested1 = {"a": [1, 2, 3], "b": {"x": 1}}
1860+
nested2 = {"a": [1, 2, 3], "b": {"x": 1}}
1861+
nested3 = {"a": [1, 2, 4], "b": {"x": 1}}
1862+
1863+
assert comparator(nested1.values(), nested2.values())
1864+
assert not comparator(nested1.values(), nested3.values())
1865+
assert comparator(nested1.items(), nested2.items())
1866+
assert not comparator(nested1.items(), nested3.items())
1867+
1868+
# Test that dict views are not equal to lists/sets
1869+
d = {"a": 1, "b": 2}
1870+
assert not comparator(d.keys(), ["a", "b"])
1871+
assert not comparator(d.keys(), {"a", "b"})
1872+
assert not comparator(d.values(), [1, 2])
1873+
assert not comparator(d.items(), [("a", 1), ("b", 2)])

0 commit comments

Comments
 (0)