Skip to content

Commit 3f416af

Browse files
Merge pull request #979 from codeflash-ai/comparator-fixes
Comparator support for tensorflow, dict views, slices, 0-d numpy arrays, torch dtypes
2 parents 45cdc62 + 5302762 commit 3f416af

File tree

2 files changed

+723
-2
lines changed

2 files changed

+723
-2
lines changed

codeflash/verification/comparator.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
HAS_TORCH = find_spec("torch") is not None
2424
HAS_JAX = find_spec("jax") is not None
2525
HAS_XARRAY = find_spec("xarray") is not None
26+
HAS_TENSORFLOW = find_spec("tensorflow") is not None
2627

2728

2829
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -57,6 +58,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
5758
enum.Enum,
5859
type,
5960
range,
61+
slice,
6062
OrderedDict,
6163
),
6264
):
@@ -97,6 +99,45 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
9799
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
98100
return orig.identical(new)
99101

102+
# Handle TensorFlow objects early to avoid boolean context errors
103+
if HAS_TENSORFLOW:
104+
import tensorflow as tf # type: ignore # noqa: PGH003
105+
106+
if isinstance(orig, tf.Tensor):
107+
if orig.dtype != new.dtype:
108+
return False
109+
if orig.shape != new.shape:
110+
return False
111+
# Use numpy conversion for proper NaN handling
112+
return comparator(orig.numpy(), new.numpy(), superset_obj)
113+
114+
if isinstance(orig, tf.Variable):
115+
if orig.dtype != new.dtype:
116+
return False
117+
if orig.shape != new.shape:
118+
return False
119+
return comparator(orig.numpy(), new.numpy(), superset_obj)
120+
121+
if isinstance(orig, tf.dtypes.DType):
122+
return orig == new
123+
124+
if isinstance(orig, tf.TensorShape):
125+
return orig == new
126+
127+
if isinstance(orig, tf.SparseTensor):
128+
if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj):
129+
return False
130+
return comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and comparator(
131+
orig.values.numpy(), new.values.numpy(), superset_obj
132+
)
133+
134+
if isinstance(orig, tf.RaggedTensor):
135+
if orig.dtype != new.dtype:
136+
return False
137+
if orig.shape.rank != new.shape.rank:
138+
return False
139+
return comparator(orig.to_list(), new.to_list(), superset_obj)
140+
100141
if HAS_SQLALCHEMY:
101142
import sqlalchemy # type: ignore # noqa: PGH003
102143

@@ -130,14 +171,41 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
130171
return False
131172
return True
132173

174+
# Handle dict view types (dict_keys, dict_values, dict_items)
175+
# Use type name checking since these are not directly importable types
176+
type_name = type(orig).__name__
177+
if type_name == "dict_keys":
178+
# dict_keys can be compared as sets (order doesn't matter)
179+
return comparator(set(orig), set(new))
180+
if type_name == "dict_values":
181+
# dict_values need element-wise comparison (order matters)
182+
return comparator(list(orig), list(new))
183+
if type_name == "dict_items":
184+
# Convert to dict for order-insensitive comparison (handles unhashable values)
185+
return comparator(dict(orig), dict(new), superset_obj)
186+
133187
if HAS_NUMPY:
134188
import numpy as np # type: ignore # noqa: PGH003
135189

190+
if isinstance(orig, (np.datetime64, np.timedelta64)):
191+
# Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
192+
if np.isnat(orig) and np.isnat(new):
193+
return True
194+
if np.isnat(orig) or np.isnat(new):
195+
return False
196+
return orig == new
197+
136198
if isinstance(orig, np.ndarray):
137199
if orig.dtype != new.dtype:
138200
return False
139201
if orig.shape != new.shape:
140202
return False
203+
# Handle 0-d arrays specially to avoid "iteration over a 0-d array" error
204+
if orig.ndim == 0:
205+
try:
206+
return np.allclose(orig, new, equal_nan=True)
207+
except Exception:
208+
return bool(orig == new)
141209
try:
142210
return np.allclose(orig, new, equal_nan=True)
143211
except Exception:
@@ -208,6 +276,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
208276
return False
209277
return torch.allclose(orig, new, equal_nan=True)
210278

279+
if isinstance(orig, torch.dtype):
280+
return orig == new
281+
211282
if HAS_PYRSISTENT:
212283
import pyrsistent # type: ignore # noqa: PGH003
213284

0 commit comments

Comments
 (0)