Skip to content

Commit f863c78

Browse files
committed
add support for slice, and more numpy types and errors
1 parent 5247508 commit f863c78

File tree

2 files changed

+224
-1
lines changed

2 files changed

+224
-1
lines changed

codeflash/verification/comparator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
5858
enum.Enum,
5959
type,
6060
range,
61+
slice,
6162
OrderedDict,
6263
),
6364
):
@@ -187,11 +188,25 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
187188
if HAS_NUMPY:
188189
import numpy as np # type: ignore # noqa: PGH003
189190

191+
if isinstance(orig, (np.datetime64, np.timedelta64)):
192+
# Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
193+
if np.isnat(orig) and np.isnat(new):
194+
return True
195+
if np.isnat(orig) or np.isnat(new):
196+
return False
197+
return orig == new
198+
190199
if isinstance(orig, np.ndarray):
191200
if orig.dtype != new.dtype:
192201
return False
193202
if orig.shape != new.shape:
194203
return False
204+
# Handle 0-d arrays specially to avoid "iteration over a 0-d array" error
205+
if orig.ndim == 0:
206+
try:
207+
return np.allclose(orig, new, equal_nan=True)
208+
except Exception:
209+
return bool(orig == new)
195210
try:
196211
return np.allclose(orig, new, equal_nan=True)
197212
except Exception:

tests/test_comparator.py

Lines changed: 209 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2204,4 +2204,212 @@ def test_tensorflow_ragged_tensor() -> None:
22042204
j = tf.ragged.constant([[], [], []])
22052205
k = tf.ragged.constant([[], [], []])
22062206

2207-
assert comparator(j, k)
2207+
assert comparator(j, k)
2208+
2209+
2210+
def test_slice() -> None:
2211+
"""Test comparator support for slice objects."""
2212+
# Test equal slices
2213+
a = slice(1, 10, 2)
2214+
b = slice(1, 10, 2)
2215+
assert comparator(a, b)
2216+
2217+
# Test slices with different start
2218+
c = slice(2, 10, 2)
2219+
assert not comparator(a, c)
2220+
2221+
# Test slices with different stop
2222+
d = slice(1, 11, 2)
2223+
assert not comparator(a, d)
2224+
2225+
# Test slices with different step
2226+
e = slice(1, 10, 3)
2227+
assert not comparator(a, e)
2228+
2229+
# Test slices with None values
2230+
f = slice(None, 10, 2)
2231+
g = slice(None, 10, 2)
2232+
h = slice(1, 10, 2)
2233+
assert comparator(f, g)
2234+
assert not comparator(f, h)
2235+
2236+
# Test slices with all None (equivalent to [:])
2237+
i = slice(None, None, None)
2238+
j = slice(None, None, None)
2239+
k = slice(None, None, 1)
2240+
assert comparator(i, j)
2241+
assert not comparator(i, k)
2242+
2243+
# Test slices with only stop
2244+
l = slice(5)
2245+
m = slice(5)
2246+
n = slice(6)
2247+
assert comparator(l, m)
2248+
assert not comparator(l, n)
2249+
2250+
# Test slices with negative values
2251+
o = slice(-5, -1, 1)
2252+
p = slice(-5, -1, 1)
2253+
q = slice(-5, -2, 1)
2254+
assert comparator(o, p)
2255+
assert not comparator(o, q)
2256+
2257+
# Test slice is not equal to other types
2258+
r = slice(1, 10)
2259+
s = (1, 10)
2260+
assert not comparator(r, s)
2261+
2262+
2263+
def test_numpy_datetime64() -> None:
2264+
"""Test comparator support for numpy datetime64 and timedelta64 types."""
2265+
try:
2266+
import numpy as np
2267+
except ImportError:
2268+
pytest.skip("numpy required for this test")
2269+
2270+
# Test datetime64 equality
2271+
a = np.datetime64('2021-01-01')
2272+
b = np.datetime64('2021-01-01')
2273+
c = np.datetime64('2021-01-02')
2274+
2275+
assert comparator(a, b)
2276+
assert not comparator(a, c)
2277+
2278+
# Test datetime64 with different units
2279+
d = np.datetime64('2021-01-01', 'D')
2280+
e = np.datetime64('2021-01-01', 'D')
2281+
f = np.datetime64('2021-01-01', 's') # Different unit (seconds)
2282+
2283+
assert comparator(d, e)
2284+
# Note: datetime64 with different units but same moment may or may not be equal
2285+
# depending on numpy version behavior
2286+
2287+
# Test datetime64 with time
2288+
g = np.datetime64('2021-01-01T12:00:00')
2289+
h = np.datetime64('2021-01-01T12:00:00')
2290+
i = np.datetime64('2021-01-01T12:00:01')
2291+
2292+
assert comparator(g, h)
2293+
assert not comparator(g, i)
2294+
2295+
# Test timedelta64 equality
2296+
j = np.timedelta64(1, 'D')
2297+
k = np.timedelta64(1, 'D')
2298+
l = np.timedelta64(2, 'D')
2299+
2300+
assert comparator(j, k)
2301+
assert not comparator(j, l)
2302+
2303+
# Test timedelta64 with different units
2304+
m = np.timedelta64(1, 'h')
2305+
n = np.timedelta64(1, 'h')
2306+
o = np.timedelta64(60, 'm') # Same duration, different unit
2307+
2308+
assert comparator(m, n)
2309+
# 1 hour == 60 minutes, but they have different units
2310+
# numpy may treat them as equal or not depending on comparison
2311+
2312+
# Test NaT (Not a Time) - numpy's equivalent of NaN for datetime
2313+
p = np.datetime64('NaT')
2314+
q = np.datetime64('NaT')
2315+
r = np.datetime64('2021-01-01')
2316+
2317+
assert comparator(p, q) # NaT == NaT should be True
2318+
assert not comparator(p, r)
2319+
2320+
# Test timedelta64 NaT
2321+
s = np.timedelta64('NaT')
2322+
t = np.timedelta64('NaT')
2323+
u = np.timedelta64(1, 'D')
2324+
2325+
assert comparator(s, t) # NaT == NaT should be True
2326+
assert not comparator(s, u)
2327+
2328+
# Test datetime64 is not equal to other types
2329+
v = np.datetime64('2021-01-01')
2330+
w = '2021-01-01'
2331+
assert not comparator(v, w)
2332+
2333+
# Test arrays of datetime64
2334+
x = np.array(['2021-01-01', '2021-01-02'], dtype='datetime64')
2335+
y = np.array(['2021-01-01', '2021-01-02'], dtype='datetime64')
2336+
z = np.array(['2021-01-01', '2021-01-03'], dtype='datetime64')
2337+
2338+
assert comparator(x, y)
2339+
assert not comparator(x, z)
2340+
2341+
2342+
def test_numpy_0d_array() -> None:
2343+
"""Test comparator handles 0-d numpy arrays without 'iteration over 0-d array' error."""
2344+
try:
2345+
import numpy as np
2346+
except ImportError:
2347+
pytest.skip("numpy required for this test")
2348+
2349+
# Test 0-d integer array
2350+
a = np.array(5)
2351+
b = np.array(5)
2352+
c = np.array(6)
2353+
2354+
assert comparator(a, b)
2355+
assert not comparator(a, c)
2356+
2357+
# Test 0-d float array
2358+
d = np.array(3.14)
2359+
e = np.array(3.14)
2360+
f = np.array(2.71)
2361+
2362+
assert comparator(d, e)
2363+
assert not comparator(d, f)
2364+
2365+
# Test 0-d complex array
2366+
g = np.array(1+2j)
2367+
h = np.array(1+2j)
2368+
i = np.array(1+3j)
2369+
2370+
assert comparator(g, h)
2371+
assert not comparator(g, i)
2372+
2373+
# Test 0-d string array
2374+
j = np.array('hello')
2375+
k = np.array('hello')
2376+
l = np.array('world')
2377+
2378+
assert comparator(j, k)
2379+
assert not comparator(j, l)
2380+
2381+
# Test 0-d boolean array
2382+
m = np.array(True)
2383+
n = np.array(True)
2384+
o = np.array(False)
2385+
2386+
assert comparator(m, n)
2387+
assert not comparator(m, o)
2388+
2389+
# Test 0-d array with NaN
2390+
p = np.array(np.nan)
2391+
q = np.array(np.nan)
2392+
r = np.array(1.0)
2393+
2394+
assert comparator(p, q) # NaN == NaN should be True
2395+
assert not comparator(p, r)
2396+
2397+
# Test 0-d datetime64 array
2398+
s = np.array(np.datetime64('2021-01-01'))
2399+
t = np.array(np.datetime64('2021-01-01'))
2400+
u = np.array(np.datetime64('2021-01-02'))
2401+
2402+
assert comparator(s, t)
2403+
assert not comparator(s, u)
2404+
2405+
# Test 0-d array vs scalar
2406+
v = np.array(5)
2407+
w = 5
2408+
# 0-d array and scalar are different types
2409+
assert not comparator(v, w)
2410+
2411+
# Test 0-d array vs 1-d array with one element
2412+
x = np.array(5)
2413+
y = np.array([5])
2414+
# Different shapes
2415+
assert not comparator(x, y)

0 commit comments

Comments
 (0)