From 11f9e9e6f7bed2938edd5a40165c44869d4feef6 Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Thu, 18 Dec 2025 16:50:06 -0800 Subject: [PATCH 1/4] add support for torch.dtype, and dict views --- codeflash/verification/comparator.py | 16 ++++ tests/test_comparator.py | 109 ++++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 2 deletions(-) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index b752a0af7..7b2af7834 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -130,6 +130,19 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return True + # Handle dict view types (dict_keys, dict_values, dict_items) + # Use type name checking since these are not directly importable types + type_name = type(orig).__name__ + if type_name == "dict_keys": + # dict_keys can be compared as sets (order doesn't matter) + return comparator(set(orig), set(new)) + if type_name == "dict_values": + # dict_values need element-wise comparison (order matters) + return comparator(list(orig), list(new)) + if type_name == "dict_items": + # dict_items can be compared as sets of tuples (order doesn't matter for items) + return comparator(list(orig), list(new)) + if HAS_NUMPY: import numpy as np # type: ignore # noqa: PGH003 @@ -208,6 +221,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return torch.allclose(orig, new, equal_nan=True) + if isinstance(orig, torch.dtype): + return orig == new + if HAS_PYRSISTENT: import pyrsistent # type: ignore # noqa: PGH003 diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 06d178f95..5164456b8 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -624,6 +624,43 @@ class TestClass(PClass): assert not comparator(v, x) +def test_torch_dtype(): + try: + import torch # type: ignore + except ImportError: + pytest.skip() + + # Test torch.dtype comparisons + a = torch.float32 + b = torch.float32 + c = torch.float64 + d = torch.int32 + assert comparator(a, b) + assert not comparator(a, c) + assert not comparator(a, d) + + # Test different dtype categories + e = torch.int64 + f = torch.int64 + g = torch.int32 + assert comparator(e, f) + assert not comparator(e, g) + + # Test complex dtypes + h = torch.complex64 + i = torch.complex64 + j = torch.complex128 + assert comparator(h, i) + assert not comparator(h, j) + + # Test bool dtype + k = torch.bool + l = torch.bool + m = torch.int8 + assert comparator(k, l) + assert not comparator(k, m) + + def test_torch(): try: import torch # type: ignore @@ -1763,6 +1800,74 @@ class ExtendedClass: minimal = MinimalClass("test", 42) extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0) - + assert not comparator(minimal, extended) - \ No newline at end of file + + +def test_dict_views() -> None: + """Test comparator support for dict_keys, dict_values, and dict_items.""" + # Test dict_keys + d1 = {"a": 1, "b": 2, "c": 3} + d2 = {"a": 1, "b": 2, "c": 3} + d3 = {"a": 1, "b": 2, "d": 3} + d4 = {"a": 1, "b": 2} + + # dict_keys - same keys + assert comparator(d1.keys(), d2.keys()) + # dict_keys - different keys + assert not comparator(d1.keys(), d3.keys()) + # dict_keys - different length + assert not comparator(d1.keys(), d4.keys()) + + # Test dict_values + v1 = {"a": 1, "b": 2, "c": 3} + v2 = {"x": 1, "y": 2, "z": 3} # same values, different keys + v3 = {"a": 1, "b": 2, "c": 4} # different value + v4 = {"a": 1, "b": 2} # different length + + # dict_values - same values (order matters for values since they're iterable) + assert comparator(v1.values(), v2.values()) + # dict_values - different values + assert not comparator(v1.values(), v3.values()) + # dict_values - different length + assert not comparator(v1.values(), v4.values()) + + # Test dict_items + i1 = {"a": 1, "b": 2, "c": 3} + i2 = {"a": 1, "b": 2, "c": 3} + i3 = {"a": 1, "b": 2, "c": 4} # different value + i4 = {"a": 1, "b": 2, "d": 3} # different key + i5 = {"a": 1, "b": 2} # different length + + # dict_items - same items + assert comparator(i1.items(), i2.items()) + # dict_items - different value + assert not comparator(i1.items(), i3.items()) + # dict_items - different key + assert not comparator(i1.items(), i4.items()) + # dict_items - different length + assert not comparator(i1.items(), i5.items()) + + # Test empty dicts + empty1 = {} + empty2 = {} + assert comparator(empty1.keys(), empty2.keys()) + assert comparator(empty1.values(), empty2.values()) + assert comparator(empty1.items(), empty2.items()) + + # Test with nested values + nested1 = {"a": [1, 2, 3], "b": {"x": 1}} + nested2 = {"a": [1, 2, 3], "b": {"x": 1}} + nested3 = {"a": [1, 2, 4], "b": {"x": 1}} + + assert comparator(nested1.values(), nested2.values()) + assert not comparator(nested1.values(), nested3.values()) + assert comparator(nested1.items(), nested2.items()) + assert not comparator(nested1.items(), nested3.items()) + + # Test that dict views are not equal to lists/sets + d = {"a": 1, "b": 2} + assert not comparator(d.keys(), ["a", "b"]) + assert not comparator(d.keys(), {"a", "b"}) + assert not comparator(d.values(), [1, 2]) + assert not comparator(d.items(), [("a", 1), ("b", 2)]) \ No newline at end of file From 524750873227072b3eb908798690f7ba78f55670 Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Thu, 18 Dec 2025 17:15:20 -0800 Subject: [PATCH 2/4] add support for tensorflow --- codeflash/verification/comparator.py | 41 ++++ tests/test_comparator.py | 336 ++++++++++++++++++++++++++- 2 files changed, 376 insertions(+), 1 deletion(-) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 7b2af7834..1b934e777 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -23,6 +23,7 @@ HAS_TORCH = find_spec("torch") is not None HAS_JAX = find_spec("jax") is not None HAS_XARRAY = find_spec("xarray") is not None +HAS_TENSORFLOW = find_spec("tensorflow") is not None def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 @@ -97,6 +98,46 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if isinstance(orig, (xarray.Dataset, xarray.DataArray)): return orig.identical(new) + # Handle TensorFlow objects early to avoid boolean context errors + if HAS_TENSORFLOW: + import tensorflow as tf # type: ignore # noqa: PGH003 + + if isinstance(orig, tf.Tensor): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + # Use numpy conversion for proper NaN handling + return comparator(orig.numpy(), new.numpy(), superset_obj) + + if isinstance(orig, tf.Variable): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + return comparator(orig.numpy(), new.numpy(), superset_obj) + + if isinstance(orig, tf.dtypes.DType): + return orig == new + + if isinstance(orig, tf.TensorShape): + return orig == new + + if isinstance(orig, tf.SparseTensor): + if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj): + return False + return ( + comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and + comparator(orig.values.numpy(), new.values.numpy(), superset_obj) + ) + + if isinstance(orig, tf.RaggedTensor): + if orig.dtype != new.dtype: + return False + if orig.shape.rank != new.shape.rank: + return False + return comparator(orig.to_list(), new.to_list(), superset_obj) + if HAS_SQLALCHEMY: import sqlalchemy # type: ignore # noqa: PGH003 diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 5164456b8..e233b61d5 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -1870,4 +1870,338 @@ def test_dict_views() -> None: assert not comparator(d.keys(), ["a", "b"]) assert not comparator(d.keys(), {"a", "b"}) assert not comparator(d.values(), [1, 2]) - assert not comparator(d.items(), [("a", 1), ("b", 2)]) \ No newline at end of file + assert not comparator(d.items(), [("a", 1), ("b", 2)]) + + +def test_tensorflow_tensor() -> None: + """Test comparator support for TensorFlow Tensor objects.""" + try: + import tensorflow as tf + except ImportError: + pytest.skip("tensorflow required for this test") + + # Test basic 1D tensors + a = tf.constant([1, 2, 3]) + b = tf.constant([1, 2, 3]) + c = tf.constant([1, 2, 4]) + + assert comparator(a, b) + assert not comparator(a, c) + + # Test 2D tensors + d = tf.constant([[1, 2, 3], [4, 5, 6]]) + e = tf.constant([[1, 2, 3], [4, 5, 6]]) + f = tf.constant([[1, 2, 3], [4, 5, 7]]) + + assert comparator(d, e) + assert not comparator(d, f) + + # Test tensors with different shapes + g = tf.constant([1, 2, 3]) + h = tf.constant([[1, 2, 3]]) + + assert not comparator(g, h) + + # Test tensors with different dtypes + i = tf.constant([1, 2, 3], dtype=tf.float32) + j = tf.constant([1, 2, 3], dtype=tf.float32) + k = tf.constant([1, 2, 3], dtype=tf.int32) + + assert comparator(i, j) + assert not comparator(i, k) + + # Test 3D tensors + l = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + m = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + n = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 9]]]) + + assert comparator(l, m) + assert not comparator(l, n) + + # Test empty tensors + o = tf.constant([]) + p = tf.constant([]) + q = tf.constant([1.0]) + + assert comparator(o, p) + assert not comparator(o, q) + + # Test tensors with NaN values + r = tf.constant([1.0, float('nan'), 3.0]) + s = tf.constant([1.0, float('nan'), 3.0]) + t = tf.constant([1.0, 2.0, 3.0]) + + assert comparator(r, s) # NaN == NaN should be True + assert not comparator(r, t) + + # Test tensors with infinity values + u = tf.constant([1.0, float('inf'), 3.0]) + v = tf.constant([1.0, float('inf'), 3.0]) + w = tf.constant([1.0, float('-inf'), 3.0]) + + assert comparator(u, v) + assert not comparator(u, w) + + # Test complex tensors + x = tf.constant([1+2j, 3+4j]) + y = tf.constant([1+2j, 3+4j]) + z = tf.constant([1+2j, 3+5j]) + + assert comparator(x, y) + assert not comparator(x, z) + + # Test boolean tensors + aa = tf.constant([True, False, True]) + bb = tf.constant([True, False, True]) + cc = tf.constant([True, True, True]) + + assert comparator(aa, bb) + assert not comparator(aa, cc) + + # Test string tensors + dd = tf.constant(["hello", "world"]) + ee = tf.constant(["hello", "world"]) + ff = tf.constant(["hello", "there"]) + + assert comparator(dd, ee) + assert not comparator(dd, ff) + + +def test_tensorflow_dtype() -> None: + """Test comparator support for TensorFlow DType objects.""" + try: + import tensorflow as tf + except ImportError: + pytest.skip("tensorflow required for this test") + + # Test float dtypes + a = tf.float32 + b = tf.float32 + c = tf.float64 + + assert comparator(a, b) + assert not comparator(a, c) + + # Test integer dtypes + d = tf.int32 + e = tf.int32 + f = tf.int64 + + assert comparator(d, e) + assert not comparator(d, f) + + # Test unsigned integer dtypes + g = tf.uint8 + h = tf.uint8 + i = tf.uint16 + + assert comparator(g, h) + assert not comparator(g, i) + + # Test complex dtypes + j = tf.complex64 + k = tf.complex64 + l = tf.complex128 + + assert comparator(j, k) + assert not comparator(j, l) + + # Test bool dtype + m = tf.bool + n = tf.bool + o = tf.int8 + + assert comparator(m, n) + assert not comparator(m, o) + + # Test string dtype + p = tf.string + q = tf.string + r = tf.int32 + + assert comparator(p, q) + assert not comparator(p, r) + + +def test_tensorflow_variable() -> None: + """Test comparator support for TensorFlow Variable objects.""" + try: + import tensorflow as tf + except ImportError: + pytest.skip("tensorflow required for this test") + + # Test basic variables + a = tf.Variable([1, 2, 3], dtype=tf.float32) + b = tf.Variable([1, 2, 3], dtype=tf.float32) + c = tf.Variable([1, 2, 4], dtype=tf.float32) + + assert comparator(a, b) + assert not comparator(a, c) + + # Test variables with different dtypes + d = tf.Variable([1, 2, 3], dtype=tf.float32) + e = tf.Variable([1, 2, 3], dtype=tf.float64) + + assert not comparator(d, e) + + # Test 2D variables + f = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32) + g = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32) + h = tf.Variable([[1, 2], [3, 5]], dtype=tf.float32) + + assert comparator(f, g) + assert not comparator(f, h) + + # Test variables with different shapes + i = tf.Variable([1, 2, 3], dtype=tf.float32) + j = tf.Variable([[1, 2, 3]], dtype=tf.float32) + + assert not comparator(i, j) + + +def test_tensorflow_tensor_shape() -> None: + """Test comparator support for TensorFlow TensorShape objects.""" + try: + import tensorflow as tf + except ImportError: + pytest.skip("tensorflow required for this test") + + # Test equal shapes + a = tf.TensorShape([2, 3, 4]) + b = tf.TensorShape([2, 3, 4]) + c = tf.TensorShape([2, 3, 5]) + + assert comparator(a, b) + assert not comparator(a, c) + + # Test different ranks + d = tf.TensorShape([2, 3]) + e = tf.TensorShape([2, 3, 4]) + + assert not comparator(d, e) + + # Test scalar shapes + f = tf.TensorShape([]) + g = tf.TensorShape([]) + h = tf.TensorShape([1]) + + assert comparator(f, g) + assert not comparator(f, h) + + # Test shapes with None dimensions (unknown dimensions) + i = tf.TensorShape([None, 3, 4]) + j = tf.TensorShape([None, 3, 4]) + k = tf.TensorShape([2, 3, 4]) + + assert comparator(i, j) + assert not comparator(i, k) + + # Test fully unknown shapes + l = tf.TensorShape(None) + m = tf.TensorShape(None) + n = tf.TensorShape([1, 2]) + + assert comparator(l, m) + assert not comparator(l, n) + + +def test_tensorflow_sparse_tensor() -> None: + """Test comparator support for TensorFlow SparseTensor objects.""" + try: + import tensorflow as tf + except ImportError: + pytest.skip("tensorflow required for this test") + + # Test equal sparse tensors + a = tf.SparseTensor( + indices=[[0, 0], [1, 2]], + values=[1.0, 2.0], + dense_shape=[3, 4] + ) + b = tf.SparseTensor( + indices=[[0, 0], [1, 2]], + values=[1.0, 2.0], + dense_shape=[3, 4] + ) + c = tf.SparseTensor( + indices=[[0, 0], [1, 2]], + values=[1.0, 3.0], # Different value + dense_shape=[3, 4] + ) + + assert comparator(a, b) + assert not comparator(a, c) + + # Test sparse tensors with different indices + d = tf.SparseTensor( + indices=[[0, 0], [1, 3]], # Different index + values=[1.0, 2.0], + dense_shape=[3, 4] + ) + + assert not comparator(a, d) + + # Test sparse tensors with different shapes + e = tf.SparseTensor( + indices=[[0, 0], [1, 2]], + values=[1.0, 2.0], + dense_shape=[4, 5] # Different shape + ) + + assert not comparator(a, e) + + # Test empty sparse tensors + f = tf.SparseTensor( + indices=tf.zeros([0, 2], dtype=tf.int64), + values=[], + dense_shape=[3, 4] + ) + g = tf.SparseTensor( + indices=tf.zeros([0, 2], dtype=tf.int64), + values=[], + dense_shape=[3, 4] + ) + + assert comparator(f, g) + + +def test_tensorflow_ragged_tensor() -> None: + """Test comparator support for TensorFlow RaggedTensor objects.""" + try: + import tensorflow as tf + except ImportError: + pytest.skip("tensorflow required for this test") + + # Test equal ragged tensors + a = tf.ragged.constant([[1, 2], [3, 4, 5], [6]]) + b = tf.ragged.constant([[1, 2], [3, 4, 5], [6]]) + c = tf.ragged.constant([[1, 2], [3, 4, 6], [6]]) # Different value + + assert comparator(a, b) + assert not comparator(a, c) + + # Test ragged tensors with different row lengths + d = tf.ragged.constant([[1, 2, 3], [4, 5], [6]]) # Different structure + + assert not comparator(a, d) + + # Test ragged tensors with different dtypes + e = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]]) + f = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]]) + + assert comparator(e, f) + assert not comparator(a, e) # int vs float + + # Test nested ragged tensors + g = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]]) + h = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]]) + i = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 7]]]) + + assert comparator(g, h) + assert not comparator(g, i) + + # Test empty ragged tensors + j = tf.ragged.constant([[], [], []]) + k = tf.ragged.constant([[], [], []]) + + assert comparator(j, k) \ No newline at end of file From f863c7860d7a6a905bfbad5609ff43e861a9890a Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Thu, 18 Dec 2025 17:24:36 -0800 Subject: [PATCH 3/4] add support for slice, and more numpy types and errors --- codeflash/verification/comparator.py | 15 ++ tests/test_comparator.py | 210 ++++++++++++++++++++++++++- 2 files changed, 224 insertions(+), 1 deletion(-) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 1b934e777..5cf9a49c1 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -58,6 +58,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 enum.Enum, type, range, + slice, OrderedDict, ), ): @@ -187,11 +188,25 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if HAS_NUMPY: import numpy as np # type: ignore # noqa: PGH003 + if isinstance(orig, (np.datetime64, np.timedelta64)): + # Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime + if np.isnat(orig) and np.isnat(new): + return True + if np.isnat(orig) or np.isnat(new): + return False + return orig == new + if isinstance(orig, np.ndarray): if orig.dtype != new.dtype: return False if orig.shape != new.shape: return False + # Handle 0-d arrays specially to avoid "iteration over a 0-d array" error + if orig.ndim == 0: + try: + return np.allclose(orig, new, equal_nan=True) + except Exception: + return bool(orig == new) try: return np.allclose(orig, new, equal_nan=True) except Exception: diff --git a/tests/test_comparator.py b/tests/test_comparator.py index e233b61d5..19344db76 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -2204,4 +2204,212 @@ def test_tensorflow_ragged_tensor() -> None: j = tf.ragged.constant([[], [], []]) k = tf.ragged.constant([[], [], []]) - assert comparator(j, k) \ No newline at end of file + assert comparator(j, k) + + +def test_slice() -> None: + """Test comparator support for slice objects.""" + # Test equal slices + a = slice(1, 10, 2) + b = slice(1, 10, 2) + assert comparator(a, b) + + # Test slices with different start + c = slice(2, 10, 2) + assert not comparator(a, c) + + # Test slices with different stop + d = slice(1, 11, 2) + assert not comparator(a, d) + + # Test slices with different step + e = slice(1, 10, 3) + assert not comparator(a, e) + + # Test slices with None values + f = slice(None, 10, 2) + g = slice(None, 10, 2) + h = slice(1, 10, 2) + assert comparator(f, g) + assert not comparator(f, h) + + # Test slices with all None (equivalent to [:]) + i = slice(None, None, None) + j = slice(None, None, None) + k = slice(None, None, 1) + assert comparator(i, j) + assert not comparator(i, k) + + # Test slices with only stop + l = slice(5) + m = slice(5) + n = slice(6) + assert comparator(l, m) + assert not comparator(l, n) + + # Test slices with negative values + o = slice(-5, -1, 1) + p = slice(-5, -1, 1) + q = slice(-5, -2, 1) + assert comparator(o, p) + assert not comparator(o, q) + + # Test slice is not equal to other types + r = slice(1, 10) + s = (1, 10) + assert not comparator(r, s) + + +def test_numpy_datetime64() -> None: + """Test comparator support for numpy datetime64 and timedelta64 types.""" + try: + import numpy as np + except ImportError: + pytest.skip("numpy required for this test") + + # Test datetime64 equality + a = np.datetime64('2021-01-01') + b = np.datetime64('2021-01-01') + c = np.datetime64('2021-01-02') + + assert comparator(a, b) + assert not comparator(a, c) + + # Test datetime64 with different units + d = np.datetime64('2021-01-01', 'D') + e = np.datetime64('2021-01-01', 'D') + f = np.datetime64('2021-01-01', 's') # Different unit (seconds) + + assert comparator(d, e) + # Note: datetime64 with different units but same moment may or may not be equal + # depending on numpy version behavior + + # Test datetime64 with time + g = np.datetime64('2021-01-01T12:00:00') + h = np.datetime64('2021-01-01T12:00:00') + i = np.datetime64('2021-01-01T12:00:01') + + assert comparator(g, h) + assert not comparator(g, i) + + # Test timedelta64 equality + j = np.timedelta64(1, 'D') + k = np.timedelta64(1, 'D') + l = np.timedelta64(2, 'D') + + assert comparator(j, k) + assert not comparator(j, l) + + # Test timedelta64 with different units + m = np.timedelta64(1, 'h') + n = np.timedelta64(1, 'h') + o = np.timedelta64(60, 'm') # Same duration, different unit + + assert comparator(m, n) + # 1 hour == 60 minutes, but they have different units + # numpy may treat them as equal or not depending on comparison + + # Test NaT (Not a Time) - numpy's equivalent of NaN for datetime + p = np.datetime64('NaT') + q = np.datetime64('NaT') + r = np.datetime64('2021-01-01') + + assert comparator(p, q) # NaT == NaT should be True + assert not comparator(p, r) + + # Test timedelta64 NaT + s = np.timedelta64('NaT') + t = np.timedelta64('NaT') + u = np.timedelta64(1, 'D') + + assert comparator(s, t) # NaT == NaT should be True + assert not comparator(s, u) + + # Test datetime64 is not equal to other types + v = np.datetime64('2021-01-01') + w = '2021-01-01' + assert not comparator(v, w) + + # Test arrays of datetime64 + x = np.array(['2021-01-01', '2021-01-02'], dtype='datetime64') + y = np.array(['2021-01-01', '2021-01-02'], dtype='datetime64') + z = np.array(['2021-01-01', '2021-01-03'], dtype='datetime64') + + assert comparator(x, y) + assert not comparator(x, z) + + +def test_numpy_0d_array() -> None: + """Test comparator handles 0-d numpy arrays without 'iteration over 0-d array' error.""" + try: + import numpy as np + except ImportError: + pytest.skip("numpy required for this test") + + # Test 0-d integer array + a = np.array(5) + b = np.array(5) + c = np.array(6) + + assert comparator(a, b) + assert not comparator(a, c) + + # Test 0-d float array + d = np.array(3.14) + e = np.array(3.14) + f = np.array(2.71) + + assert comparator(d, e) + assert not comparator(d, f) + + # Test 0-d complex array + g = np.array(1+2j) + h = np.array(1+2j) + i = np.array(1+3j) + + assert comparator(g, h) + assert not comparator(g, i) + + # Test 0-d string array + j = np.array('hello') + k = np.array('hello') + l = np.array('world') + + assert comparator(j, k) + assert not comparator(j, l) + + # Test 0-d boolean array + m = np.array(True) + n = np.array(True) + o = np.array(False) + + assert comparator(m, n) + assert not comparator(m, o) + + # Test 0-d array with NaN + p = np.array(np.nan) + q = np.array(np.nan) + r = np.array(1.0) + + assert comparator(p, q) # NaN == NaN should be True + assert not comparator(p, r) + + # Test 0-d datetime64 array + s = np.array(np.datetime64('2021-01-01')) + t = np.array(np.datetime64('2021-01-01')) + u = np.array(np.datetime64('2021-01-02')) + + assert comparator(s, t) + assert not comparator(s, u) + + # Test 0-d array vs scalar + v = np.array(5) + w = 5 + # 0-d array and scalar are different types + assert not comparator(v, w) + + # Test 0-d array vs 1-d array with one element + x = np.array(5) + y = np.array([5]) + # Different shapes + assert not comparator(x, y) \ No newline at end of file From 53027625467c55f75ab1ba5e19a97ccad2e6a6fd Mon Sep 17 00:00:00 2001 From: misrasaurabh1 Date: Thu, 18 Dec 2025 17:40:57 -0800 Subject: [PATCH 4/4] add a fix --- codeflash/verification/comparator.py | 9 ++++----- tests/test_comparator.py | 3 +++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 5cf9a49c1..7737900df 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -127,9 +127,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if isinstance(orig, tf.SparseTensor): if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj): return False - return ( - comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and - comparator(orig.values.numpy(), new.values.numpy(), superset_obj) + return comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and comparator( + orig.values.numpy(), new.values.numpy(), superset_obj ) if isinstance(orig, tf.RaggedTensor): @@ -182,8 +181,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 # dict_values need element-wise comparison (order matters) return comparator(list(orig), list(new)) if type_name == "dict_items": - # dict_items can be compared as sets of tuples (order doesn't matter for items) - return comparator(list(orig), list(new)) + # Convert to dict for order-insensitive comparison (handles unhashable values) + return comparator(dict(orig), dict(new), superset_obj) if HAS_NUMPY: import numpy as np # type: ignore # noqa: PGH003 diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 19344db76..3a7304508 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -1838,6 +1838,7 @@ def test_dict_views() -> None: i3 = {"a": 1, "b": 2, "c": 4} # different value i4 = {"a": 1, "b": 2, "d": 3} # different key i5 = {"a": 1, "b": 2} # different length + i6 = {"b": 2, "c": 3, "a": 1} # different order # dict_items - same items assert comparator(i1.items(), i2.items()) @@ -1848,6 +1849,8 @@ def test_dict_views() -> None: # dict_items - different length assert not comparator(i1.items(), i5.items()) + assert comparator(i1.items(), i6.items()) + # Test empty dicts empty1 = {} empty2 = {}