Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
HAS_TORCH = find_spec("torch") is not None
HAS_JAX = find_spec("jax") is not None
HAS_XARRAY = find_spec("xarray") is not None
HAS_TENSORFLOW = find_spec("tensorflow") is not None


def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
Expand Down Expand Up @@ -57,6 +58,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
enum.Enum,
type,
range,
slice,
OrderedDict,
),
):
Expand Down Expand Up @@ -97,6 +99,45 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
return orig.identical(new)

# Handle TensorFlow objects early to avoid boolean context errors
if HAS_TENSORFLOW:
import tensorflow as tf # type: ignore # noqa: PGH003

if isinstance(orig, tf.Tensor):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
# Use numpy conversion for proper NaN handling
return comparator(orig.numpy(), new.numpy(), superset_obj)

if isinstance(orig, tf.Variable):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
return comparator(orig.numpy(), new.numpy(), superset_obj)

if isinstance(orig, tf.dtypes.DType):
return orig == new

if isinstance(orig, tf.TensorShape):
return orig == new

if isinstance(orig, tf.SparseTensor):
if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj):
return False
return comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and comparator(
orig.values.numpy(), new.values.numpy(), superset_obj
)

if isinstance(orig, tf.RaggedTensor):
if orig.dtype != new.dtype:
return False
if orig.shape.rank != new.shape.rank:
return False
return comparator(orig.to_list(), new.to_list(), superset_obj)

if HAS_SQLALCHEMY:
import sqlalchemy # type: ignore # noqa: PGH003

Expand Down Expand Up @@ -130,14 +171,41 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return False
return True

# Handle dict view types (dict_keys, dict_values, dict_items)
# Use type name checking since these are not directly importable types
type_name = type(orig).__name__
if type_name == "dict_keys":
# dict_keys can be compared as sets (order doesn't matter)
return comparator(set(orig), set(new))
if type_name == "dict_values":
# dict_values need element-wise comparison (order matters)
return comparator(list(orig), list(new))
if type_name == "dict_items":
# Convert to dict for order-insensitive comparison (handles unhashable values)
return comparator(dict(orig), dict(new), superset_obj)

if HAS_NUMPY:
import numpy as np # type: ignore # noqa: PGH003

if isinstance(orig, (np.datetime64, np.timedelta64)):
# Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
if np.isnat(orig) and np.isnat(new):
return True
if np.isnat(orig) or np.isnat(new):
return False
return orig == new

if isinstance(orig, np.ndarray):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
# Handle 0-d arrays specially to avoid "iteration over a 0-d array" error
if orig.ndim == 0:
try:
return np.allclose(orig, new, equal_nan=True)
except Exception:
return bool(orig == new)
try:
return np.allclose(orig, new, equal_nan=True)
except Exception:
Expand Down Expand Up @@ -208,6 +276,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return False
return torch.allclose(orig, new, equal_nan=True)

if isinstance(orig, torch.dtype):
return orig == new

if HAS_PYRSISTENT:
import pyrsistent # type: ignore # noqa: PGH003

Expand Down
Loading
Loading