Skip to content

Conversation

@misrasaurabh1
Copy link
Contributor

@misrasaurabh1 misrasaurabh1 commented Dec 19, 2025

PR Type

Enhancement, Tests


Description

  • Add TensorFlow object comparisons

  • Support dict view comparisons

  • Handle numpy 0-d and datetime64/ timedelta64

  • Compare torch.dtype and slice objects


Diagram Walkthrough

flowchart LR
  C["comparator(): core comparator"] --> TF["TensorFlow types support"]
  C --> DV["dict view (keys/values/items) support"]
  C --> NP["NumPy 0-d, datetime64/timedelta64"]
  C --> TORCHDT["torch.dtype equality"]
  C --> SL["slice handling"]
  TF -- "tensor/variable/sparse/ragged/shape/dtype" --> C
Loading

File Walkthrough

Relevant files
Enhancement
comparator.py
Expand comparator support for TF, dict views, numpy, torch, slice

codeflash/verification/comparator.py

  • Detect TensorFlow availability
  • Add comparisons for tf Tensor, Variable, DType, TensorShape,
    SparseTensor, RaggedTensor
  • Add dict view handling: keys, values, items
  • Add numpy datetime64/timedelta64, 0-d array handling
  • Add torch.dtype equality
  • Include slice in simple comparable types
+72/-0   
Tests
test_comparator.py
Add comprehensive tests for new comparator types                 

tests/test_comparator.py

  • Add tests for torch.dtype equality
  • Add tests for dict_keys/values/items comparisons
  • Add comprehensive TensorFlow tests: tensors, dtypes, variables,
    shapes, sparse, ragged
  • Add tests for slice comparisons
  • Add tests for numpy datetime64/timedelta64 and 0-d arrays
+649/-2 

@github-actions
Copy link

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 3 🔵🔵🔵⚪⚪
🧪 PR contains tests
🔒 No security concerns identified
⚡ Recommended focus areas for review

Possible Issue

TensorFlow handling compares only based on the type of orig but doesn't validate that new is the same TF type before accessing attributes like .dtype, .shape, or calling .numpy(). If new is a non-TF object, this may raise AttributeError instead of returning False. Consider guarding with paired isinstance checks for new.

if HAS_TENSORFLOW:
    import tensorflow as tf  # type: ignore  # noqa: PGH003

    if isinstance(orig, tf.Tensor):
        if orig.dtype != new.dtype:
            return False
        if orig.shape != new.shape:
            return False
        # Use numpy conversion for proper NaN handling
        return comparator(orig.numpy(), new.numpy(), superset_obj)

    if isinstance(orig, tf.Variable):
        if orig.dtype != new.dtype:
            return False
        if orig.shape != new.shape:
            return False
        return comparator(orig.numpy(), new.numpy(), superset_obj)

    if isinstance(orig, tf.dtypes.DType):
        return orig == new

    if isinstance(orig, tf.TensorShape):
        return orig == new

    if isinstance(orig, tf.SparseTensor):
        if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj):
            return False
        return (
            comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and
            comparator(orig.values.numpy(), new.values.numpy(), superset_obj)
        )

    if isinstance(orig, tf.RaggedTensor):
        if orig.dtype != new.dtype:
            return False
        if orig.shape.rank != new.shape.rank:
            return False
        return comparator(orig.to_list(), new.to_list(), superset_obj)
Consistency

Dict view handling treats dict_items as list comparison (order-sensitive) despite comment stating set-like behavior. This may cause false negatives when order differs but items are equivalent. Consider comparing as sets of tuples, similar to keys.

# Handle dict view types (dict_keys, dict_values, dict_items)
# Use type name checking since these are not directly importable types
type_name = type(orig).__name__
if type_name == "dict_keys":
    # dict_keys can be compared as sets (order doesn't matter)
    return comparator(set(orig), set(new))
if type_name == "dict_values":
    # dict_values need element-wise comparison (order matters)
    return comparator(list(orig), list(new))
if type_name == "dict_items":
    # dict_items can be compared as sets of tuples (order doesn't matter for items)
    return comparator(list(orig), list(new))
Edge Case

NumPy datetime64/timedelta64 equality relies on direct ==, which can be unit-sensitive across versions. Consider normalizing units or converting to integers with a common unit when possible, or documenting behavior.

if isinstance(orig, (np.datetime64, np.timedelta64)):
    # Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
    if np.isnat(orig) and np.isnat(new):
        return True
    if np.isnat(orig) or np.isnat(new):
        return False
    return orig == new

@github-actions
Copy link

PR Code Suggestions ✨

Explore these optional code suggestions:

CategorySuggestion                                                                                                                                    Impact
Possible issue
Guard TensorFlow type mismatches

Add type checks to ensure new is the same TensorFlow type before accessing its
attributes. Without this, comparing a TF object with a non-TF object can raise
AttributeError. Return False early if types mismatch.

codeflash/verification/comparator.py [103-141]

 if HAS_TENSORFLOW:
     import tensorflow as tf  # type: ignore  # noqa: PGH003
 
     if isinstance(orig, tf.Tensor):
+        if not isinstance(new, tf.Tensor):
+            return False
         if orig.dtype != new.dtype:
             return False
         if orig.shape != new.shape:
             return False
-        # Use numpy conversion for proper NaN handling
         return comparator(orig.numpy(), new.numpy(), superset_obj)
 
     if isinstance(orig, tf.Variable):
+        if not isinstance(new, tf.Variable):
+            return False
         if orig.dtype != new.dtype:
             return False
         if orig.shape != new.shape:
             return False
         return comparator(orig.numpy(), new.numpy(), superset_obj)
 
     if isinstance(orig, tf.dtypes.DType):
-        return orig == new
+        return isinstance(new, tf.dtypes.DType) and (orig == new)
 
     if isinstance(orig, tf.TensorShape):
-        return orig == new
+        return isinstance(new, tf.TensorShape) and (orig == new)
 
     if isinstance(orig, tf.SparseTensor):
+        if not isinstance(new, tf.SparseTensor):
+            return False
         if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj):
             return False
         return (
-            comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and
-            comparator(orig.values.numpy(), new.values.numpy(), superset_obj)
+            comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj)
+            and comparator(orig.values.numpy(), new.values.numpy(), superset_obj)
         )
 
     if isinstance(orig, tf.RaggedTensor):
+        if not isinstance(new, tf.RaggedTensor):
+            return False
         if orig.dtype != new.dtype:
             return False
         if orig.shape.rank != new.shape.rank:
             return False
         return comparator(orig.to_list(), new.to_list(), superset_obj)
Suggestion importance[1-10]: 8

__

Why: Correctly adds symmetric type checks for TensorFlow objects to prevent attribute access on mismatched types and aligns with the comparator's strict-type behavior; impactful for correctness and safety.

Medium
Guard numpy type mismatches

Add type checks ensuring new is a compatible numpy type before attribute access.
Accessing new.dtype/new.shape or using np.isnat(new) when new isn't a numpy object
can raise AttributeError or TypeError; return False early on mismatch.

codeflash/verification/comparator.py [189-214]

 if HAS_NUMPY:
     import numpy as np  # type: ignore  # noqa: PGH003
 
     if isinstance(orig, (np.datetime64, np.timedelta64)):
-        # Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
+        if not isinstance(new, (np.datetime64, np.timedelta64)):
+            return False
         if np.isnat(orig) and np.isnat(new):
             return True
         if np.isnat(orig) or np.isnat(new):
             return False
-        return orig == new
+        return bool(orig == new)
 
     if isinstance(orig, np.ndarray):
+        if not isinstance(new, np.ndarray):
+            return False
         if orig.dtype != new.dtype:
             return False
         if orig.shape != new.shape:
             return False
-        # Handle 0-d arrays specially to avoid "iteration over a 0-d array" error
         if orig.ndim == 0:
             try:
-                return np.allclose(orig, new, equal_nan=True)
+                return bool(np.allclose(orig, new, equal_nan=True))
             except Exception:
                 return bool(orig == new)
         try:
-            return np.allclose(orig, new, equal_nan=True)
+            return bool(np.allclose(orig, new, equal_nan=True))
         except Exception:
             # fails at "ufunc 'isfinite' not supported for the input types"
+            ...
Suggestion importance[1-10]: 8

__

Why: Adds necessary type checks before numpy-specific attribute/method access, avoiding exceptions and enforcing strict comparability; accurate and materially improves robustness.

Medium
Validate dict view type symmetry

Ensure the right-hand operand is a matching dict view before converting. As written,
comparing a dict view to an arbitrary iterable may incorrectly return True or raise
unexpected errors. Return False when new is not the same view type.

codeflash/verification/comparator.py [176-186]

 # Handle dict view types (dict_keys, dict_values, dict_items)
-# Use type name checking since these are not directly importable types
 type_name = type(orig).__name__
 if type_name == "dict_keys":
-    # dict_keys can be compared as sets (order doesn't matter)
+    if type(new).__name__ != "dict_keys":
+        return False
     return comparator(set(orig), set(new))
 if type_name == "dict_values":
-    # dict_values need element-wise comparison (order matters)
+    if type(new).__name__ != "dict_values":
+        return False
     return comparator(list(orig), list(new))
 if type_name == "dict_items":
-    # dict_items can be compared as sets of tuples (order doesn't matter for items)
+    if type(new).__name__ != "dict_items":
+        return False
     return comparator(list(orig), list(new))
Suggestion importance[1-10]: 7

__

Why: Ensures dict view comparisons don't silently coerce arbitrary iterables, preventing false positives and potential errors; a solid maintainability and correctness improvement.

Medium

@misrasaurabh1 misrasaurabh1 merged commit 3f416af into main Dec 19, 2025
22 checks passed
@misrasaurabh1 misrasaurabh1 deleted the comparator-fixes branch December 19, 2025 02:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants