Skip to content

Commit 5247508

Browse files
committed
add support for tensorflow
1 parent 11f9e9e commit 5247508

File tree

2 files changed

+376
-1
lines changed

2 files changed

+376
-1
lines changed

codeflash/verification/comparator.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
HAS_TORCH = find_spec("torch") is not None
2424
HAS_JAX = find_spec("jax") is not None
2525
HAS_XARRAY = find_spec("xarray") is not None
26+
HAS_TENSORFLOW = find_spec("tensorflow") is not None
2627

2728

2829
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -97,6 +98,46 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
9798
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
9899
return orig.identical(new)
99100

101+
# Handle TensorFlow objects early to avoid boolean context errors
102+
if HAS_TENSORFLOW:
103+
import tensorflow as tf # type: ignore # noqa: PGH003
104+
105+
if isinstance(orig, tf.Tensor):
106+
if orig.dtype != new.dtype:
107+
return False
108+
if orig.shape != new.shape:
109+
return False
110+
# Use numpy conversion for proper NaN handling
111+
return comparator(orig.numpy(), new.numpy(), superset_obj)
112+
113+
if isinstance(orig, tf.Variable):
114+
if orig.dtype != new.dtype:
115+
return False
116+
if orig.shape != new.shape:
117+
return False
118+
return comparator(orig.numpy(), new.numpy(), superset_obj)
119+
120+
if isinstance(orig, tf.dtypes.DType):
121+
return orig == new
122+
123+
if isinstance(orig, tf.TensorShape):
124+
return orig == new
125+
126+
if isinstance(orig, tf.SparseTensor):
127+
if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj):
128+
return False
129+
return (
130+
comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and
131+
comparator(orig.values.numpy(), new.values.numpy(), superset_obj)
132+
)
133+
134+
if isinstance(orig, tf.RaggedTensor):
135+
if orig.dtype != new.dtype:
136+
return False
137+
if orig.shape.rank != new.shape.rank:
138+
return False
139+
return comparator(orig.to_list(), new.to_list(), superset_obj)
140+
100141
if HAS_SQLALCHEMY:
101142
import sqlalchemy # type: ignore # noqa: PGH003
102143

tests/test_comparator.py

Lines changed: 335 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1870,4 +1870,338 @@ def test_dict_views() -> None:
18701870
assert not comparator(d.keys(), ["a", "b"])
18711871
assert not comparator(d.keys(), {"a", "b"})
18721872
assert not comparator(d.values(), [1, 2])
1873-
assert not comparator(d.items(), [("a", 1), ("b", 2)])
1873+
assert not comparator(d.items(), [("a", 1), ("b", 2)])
1874+
1875+
1876+
def test_tensorflow_tensor() -> None:
1877+
"""Test comparator support for TensorFlow Tensor objects."""
1878+
try:
1879+
import tensorflow as tf
1880+
except ImportError:
1881+
pytest.skip("tensorflow required for this test")
1882+
1883+
# Test basic 1D tensors
1884+
a = tf.constant([1, 2, 3])
1885+
b = tf.constant([1, 2, 3])
1886+
c = tf.constant([1, 2, 4])
1887+
1888+
assert comparator(a, b)
1889+
assert not comparator(a, c)
1890+
1891+
# Test 2D tensors
1892+
d = tf.constant([[1, 2, 3], [4, 5, 6]])
1893+
e = tf.constant([[1, 2, 3], [4, 5, 6]])
1894+
f = tf.constant([[1, 2, 3], [4, 5, 7]])
1895+
1896+
assert comparator(d, e)
1897+
assert not comparator(d, f)
1898+
1899+
# Test tensors with different shapes
1900+
g = tf.constant([1, 2, 3])
1901+
h = tf.constant([[1, 2, 3]])
1902+
1903+
assert not comparator(g, h)
1904+
1905+
# Test tensors with different dtypes
1906+
i = tf.constant([1, 2, 3], dtype=tf.float32)
1907+
j = tf.constant([1, 2, 3], dtype=tf.float32)
1908+
k = tf.constant([1, 2, 3], dtype=tf.int32)
1909+
1910+
assert comparator(i, j)
1911+
assert not comparator(i, k)
1912+
1913+
# Test 3D tensors
1914+
l = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
1915+
m = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
1916+
n = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
1917+
1918+
assert comparator(l, m)
1919+
assert not comparator(l, n)
1920+
1921+
# Test empty tensors
1922+
o = tf.constant([])
1923+
p = tf.constant([])
1924+
q = tf.constant([1.0])
1925+
1926+
assert comparator(o, p)
1927+
assert not comparator(o, q)
1928+
1929+
# Test tensors with NaN values
1930+
r = tf.constant([1.0, float('nan'), 3.0])
1931+
s = tf.constant([1.0, float('nan'), 3.0])
1932+
t = tf.constant([1.0, 2.0, 3.0])
1933+
1934+
assert comparator(r, s) # NaN == NaN should be True
1935+
assert not comparator(r, t)
1936+
1937+
# Test tensors with infinity values
1938+
u = tf.constant([1.0, float('inf'), 3.0])
1939+
v = tf.constant([1.0, float('inf'), 3.0])
1940+
w = tf.constant([1.0, float('-inf'), 3.0])
1941+
1942+
assert comparator(u, v)
1943+
assert not comparator(u, w)
1944+
1945+
# Test complex tensors
1946+
x = tf.constant([1+2j, 3+4j])
1947+
y = tf.constant([1+2j, 3+4j])
1948+
z = tf.constant([1+2j, 3+5j])
1949+
1950+
assert comparator(x, y)
1951+
assert not comparator(x, z)
1952+
1953+
# Test boolean tensors
1954+
aa = tf.constant([True, False, True])
1955+
bb = tf.constant([True, False, True])
1956+
cc = tf.constant([True, True, True])
1957+
1958+
assert comparator(aa, bb)
1959+
assert not comparator(aa, cc)
1960+
1961+
# Test string tensors
1962+
dd = tf.constant(["hello", "world"])
1963+
ee = tf.constant(["hello", "world"])
1964+
ff = tf.constant(["hello", "there"])
1965+
1966+
assert comparator(dd, ee)
1967+
assert not comparator(dd, ff)
1968+
1969+
1970+
def test_tensorflow_dtype() -> None:
1971+
"""Test comparator support for TensorFlow DType objects."""
1972+
try:
1973+
import tensorflow as tf
1974+
except ImportError:
1975+
pytest.skip("tensorflow required for this test")
1976+
1977+
# Test float dtypes
1978+
a = tf.float32
1979+
b = tf.float32
1980+
c = tf.float64
1981+
1982+
assert comparator(a, b)
1983+
assert not comparator(a, c)
1984+
1985+
# Test integer dtypes
1986+
d = tf.int32
1987+
e = tf.int32
1988+
f = tf.int64
1989+
1990+
assert comparator(d, e)
1991+
assert not comparator(d, f)
1992+
1993+
# Test unsigned integer dtypes
1994+
g = tf.uint8
1995+
h = tf.uint8
1996+
i = tf.uint16
1997+
1998+
assert comparator(g, h)
1999+
assert not comparator(g, i)
2000+
2001+
# Test complex dtypes
2002+
j = tf.complex64
2003+
k = tf.complex64
2004+
l = tf.complex128
2005+
2006+
assert comparator(j, k)
2007+
assert not comparator(j, l)
2008+
2009+
# Test bool dtype
2010+
m = tf.bool
2011+
n = tf.bool
2012+
o = tf.int8
2013+
2014+
assert comparator(m, n)
2015+
assert not comparator(m, o)
2016+
2017+
# Test string dtype
2018+
p = tf.string
2019+
q = tf.string
2020+
r = tf.int32
2021+
2022+
assert comparator(p, q)
2023+
assert not comparator(p, r)
2024+
2025+
2026+
def test_tensorflow_variable() -> None:
2027+
"""Test comparator support for TensorFlow Variable objects."""
2028+
try:
2029+
import tensorflow as tf
2030+
except ImportError:
2031+
pytest.skip("tensorflow required for this test")
2032+
2033+
# Test basic variables
2034+
a = tf.Variable([1, 2, 3], dtype=tf.float32)
2035+
b = tf.Variable([1, 2, 3], dtype=tf.float32)
2036+
c = tf.Variable([1, 2, 4], dtype=tf.float32)
2037+
2038+
assert comparator(a, b)
2039+
assert not comparator(a, c)
2040+
2041+
# Test variables with different dtypes
2042+
d = tf.Variable([1, 2, 3], dtype=tf.float32)
2043+
e = tf.Variable([1, 2, 3], dtype=tf.float64)
2044+
2045+
assert not comparator(d, e)
2046+
2047+
# Test 2D variables
2048+
f = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32)
2049+
g = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32)
2050+
h = tf.Variable([[1, 2], [3, 5]], dtype=tf.float32)
2051+
2052+
assert comparator(f, g)
2053+
assert not comparator(f, h)
2054+
2055+
# Test variables with different shapes
2056+
i = tf.Variable([1, 2, 3], dtype=tf.float32)
2057+
j = tf.Variable([[1, 2, 3]], dtype=tf.float32)
2058+
2059+
assert not comparator(i, j)
2060+
2061+
2062+
def test_tensorflow_tensor_shape() -> None:
2063+
"""Test comparator support for TensorFlow TensorShape objects."""
2064+
try:
2065+
import tensorflow as tf
2066+
except ImportError:
2067+
pytest.skip("tensorflow required for this test")
2068+
2069+
# Test equal shapes
2070+
a = tf.TensorShape([2, 3, 4])
2071+
b = tf.TensorShape([2, 3, 4])
2072+
c = tf.TensorShape([2, 3, 5])
2073+
2074+
assert comparator(a, b)
2075+
assert not comparator(a, c)
2076+
2077+
# Test different ranks
2078+
d = tf.TensorShape([2, 3])
2079+
e = tf.TensorShape([2, 3, 4])
2080+
2081+
assert not comparator(d, e)
2082+
2083+
# Test scalar shapes
2084+
f = tf.TensorShape([])
2085+
g = tf.TensorShape([])
2086+
h = tf.TensorShape([1])
2087+
2088+
assert comparator(f, g)
2089+
assert not comparator(f, h)
2090+
2091+
# Test shapes with None dimensions (unknown dimensions)
2092+
i = tf.TensorShape([None, 3, 4])
2093+
j = tf.TensorShape([None, 3, 4])
2094+
k = tf.TensorShape([2, 3, 4])
2095+
2096+
assert comparator(i, j)
2097+
assert not comparator(i, k)
2098+
2099+
# Test fully unknown shapes
2100+
l = tf.TensorShape(None)
2101+
m = tf.TensorShape(None)
2102+
n = tf.TensorShape([1, 2])
2103+
2104+
assert comparator(l, m)
2105+
assert not comparator(l, n)
2106+
2107+
2108+
def test_tensorflow_sparse_tensor() -> None:
2109+
"""Test comparator support for TensorFlow SparseTensor objects."""
2110+
try:
2111+
import tensorflow as tf
2112+
except ImportError:
2113+
pytest.skip("tensorflow required for this test")
2114+
2115+
# Test equal sparse tensors
2116+
a = tf.SparseTensor(
2117+
indices=[[0, 0], [1, 2]],
2118+
values=[1.0, 2.0],
2119+
dense_shape=[3, 4]
2120+
)
2121+
b = tf.SparseTensor(
2122+
indices=[[0, 0], [1, 2]],
2123+
values=[1.0, 2.0],
2124+
dense_shape=[3, 4]
2125+
)
2126+
c = tf.SparseTensor(
2127+
indices=[[0, 0], [1, 2]],
2128+
values=[1.0, 3.0], # Different value
2129+
dense_shape=[3, 4]
2130+
)
2131+
2132+
assert comparator(a, b)
2133+
assert not comparator(a, c)
2134+
2135+
# Test sparse tensors with different indices
2136+
d = tf.SparseTensor(
2137+
indices=[[0, 0], [1, 3]], # Different index
2138+
values=[1.0, 2.0],
2139+
dense_shape=[3, 4]
2140+
)
2141+
2142+
assert not comparator(a, d)
2143+
2144+
# Test sparse tensors with different shapes
2145+
e = tf.SparseTensor(
2146+
indices=[[0, 0], [1, 2]],
2147+
values=[1.0, 2.0],
2148+
dense_shape=[4, 5] # Different shape
2149+
)
2150+
2151+
assert not comparator(a, e)
2152+
2153+
# Test empty sparse tensors
2154+
f = tf.SparseTensor(
2155+
indices=tf.zeros([0, 2], dtype=tf.int64),
2156+
values=[],
2157+
dense_shape=[3, 4]
2158+
)
2159+
g = tf.SparseTensor(
2160+
indices=tf.zeros([0, 2], dtype=tf.int64),
2161+
values=[],
2162+
dense_shape=[3, 4]
2163+
)
2164+
2165+
assert comparator(f, g)
2166+
2167+
2168+
def test_tensorflow_ragged_tensor() -> None:
2169+
"""Test comparator support for TensorFlow RaggedTensor objects."""
2170+
try:
2171+
import tensorflow as tf
2172+
except ImportError:
2173+
pytest.skip("tensorflow required for this test")
2174+
2175+
# Test equal ragged tensors
2176+
a = tf.ragged.constant([[1, 2], [3, 4, 5], [6]])
2177+
b = tf.ragged.constant([[1, 2], [3, 4, 5], [6]])
2178+
c = tf.ragged.constant([[1, 2], [3, 4, 6], [6]]) # Different value
2179+
2180+
assert comparator(a, b)
2181+
assert not comparator(a, c)
2182+
2183+
# Test ragged tensors with different row lengths
2184+
d = tf.ragged.constant([[1, 2, 3], [4, 5], [6]]) # Different structure
2185+
2186+
assert not comparator(a, d)
2187+
2188+
# Test ragged tensors with different dtypes
2189+
e = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]])
2190+
f = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]])
2191+
2192+
assert comparator(e, f)
2193+
assert not comparator(a, e) # int vs float
2194+
2195+
# Test nested ragged tensors
2196+
g = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
2197+
h = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
2198+
i = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 7]]])
2199+
2200+
assert comparator(g, h)
2201+
assert not comparator(g, i)
2202+
2203+
# Test empty ragged tensors
2204+
j = tf.ragged.constant([[], [], []])
2205+
k = tf.ragged.constant([[], [], []])
2206+
2207+
assert comparator(j, k)

0 commit comments

Comments
 (0)