|
23 | 23 | HAS_TORCH = find_spec("torch") is not None |
24 | 24 | HAS_JAX = find_spec("jax") is not None |
25 | 25 | HAS_XARRAY = find_spec("xarray") is not None |
| 26 | +HAS_TENSORFLOW = find_spec("tensorflow") is not None |
26 | 27 |
|
27 | 28 |
|
28 | 29 | def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 |
@@ -57,6 +58,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 |
57 | 58 | enum.Enum, |
58 | 59 | type, |
59 | 60 | range, |
| 61 | + slice, |
60 | 62 | OrderedDict, |
61 | 63 | ), |
62 | 64 | ): |
@@ -97,6 +99,45 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 |
97 | 99 | if isinstance(orig, (xarray.Dataset, xarray.DataArray)): |
98 | 100 | return orig.identical(new) |
99 | 101 |
|
| 102 | + # Handle TensorFlow objects early to avoid boolean context errors |
| 103 | + if HAS_TENSORFLOW: |
| 104 | + import tensorflow as tf # type: ignore # noqa: PGH003 |
| 105 | + |
| 106 | + if isinstance(orig, tf.Tensor): |
| 107 | + if orig.dtype != new.dtype: |
| 108 | + return False |
| 109 | + if orig.shape != new.shape: |
| 110 | + return False |
| 111 | + # Use numpy conversion for proper NaN handling |
| 112 | + return comparator(orig.numpy(), new.numpy(), superset_obj) |
| 113 | + |
| 114 | + if isinstance(orig, tf.Variable): |
| 115 | + if orig.dtype != new.dtype: |
| 116 | + return False |
| 117 | + if orig.shape != new.shape: |
| 118 | + return False |
| 119 | + return comparator(orig.numpy(), new.numpy(), superset_obj) |
| 120 | + |
| 121 | + if isinstance(orig, tf.dtypes.DType): |
| 122 | + return orig == new |
| 123 | + |
| 124 | + if isinstance(orig, tf.TensorShape): |
| 125 | + return orig == new |
| 126 | + |
| 127 | + if isinstance(orig, tf.SparseTensor): |
| 128 | + if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj): |
| 129 | + return False |
| 130 | + return comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and comparator( |
| 131 | + orig.values.numpy(), new.values.numpy(), superset_obj |
| 132 | + ) |
| 133 | + |
| 134 | + if isinstance(orig, tf.RaggedTensor): |
| 135 | + if orig.dtype != new.dtype: |
| 136 | + return False |
| 137 | + if orig.shape.rank != new.shape.rank: |
| 138 | + return False |
| 139 | + return comparator(orig.to_list(), new.to_list(), superset_obj) |
| 140 | + |
100 | 141 | if HAS_SQLALCHEMY: |
101 | 142 | import sqlalchemy # type: ignore # noqa: PGH003 |
102 | 143 |
|
@@ -130,14 +171,41 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 |
130 | 171 | return False |
131 | 172 | return True |
132 | 173 |
|
| 174 | + # Handle dict view types (dict_keys, dict_values, dict_items) |
| 175 | + # Use type name checking since these are not directly importable types |
| 176 | + type_name = type(orig).__name__ |
| 177 | + if type_name == "dict_keys": |
| 178 | + # dict_keys can be compared as sets (order doesn't matter) |
| 179 | + return comparator(set(orig), set(new)) |
| 180 | + if type_name == "dict_values": |
| 181 | + # dict_values need element-wise comparison (order matters) |
| 182 | + return comparator(list(orig), list(new)) |
| 183 | + if type_name == "dict_items": |
| 184 | + # Convert to dict for order-insensitive comparison (handles unhashable values) |
| 185 | + return comparator(dict(orig), dict(new), superset_obj) |
| 186 | + |
133 | 187 | if HAS_NUMPY: |
134 | 188 | import numpy as np # type: ignore # noqa: PGH003 |
135 | 189 |
|
| 190 | + if isinstance(orig, (np.datetime64, np.timedelta64)): |
| 191 | + # Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime |
| 192 | + if np.isnat(orig) and np.isnat(new): |
| 193 | + return True |
| 194 | + if np.isnat(orig) or np.isnat(new): |
| 195 | + return False |
| 196 | + return orig == new |
| 197 | + |
136 | 198 | if isinstance(orig, np.ndarray): |
137 | 199 | if orig.dtype != new.dtype: |
138 | 200 | return False |
139 | 201 | if orig.shape != new.shape: |
140 | 202 | return False |
| 203 | + # Handle 0-d arrays specially to avoid "iteration over a 0-d array" error |
| 204 | + if orig.ndim == 0: |
| 205 | + try: |
| 206 | + return np.allclose(orig, new, equal_nan=True) |
| 207 | + except Exception: |
| 208 | + return bool(orig == new) |
141 | 209 | try: |
142 | 210 | return np.allclose(orig, new, equal_nan=True) |
143 | 211 | except Exception: |
@@ -208,6 +276,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 |
208 | 276 | return False |
209 | 277 | return torch.allclose(orig, new, equal_nan=True) |
210 | 278 |
|
| 279 | + if isinstance(orig, torch.dtype): |
| 280 | + return orig == new |
| 281 | + |
211 | 282 | if HAS_PYRSISTENT: |
212 | 283 | import pyrsistent # type: ignore # noqa: PGH003 |
213 | 284 |
|
|
0 commit comments