Skip to content

Commit ad4c323

Browse files
authored
Merge pull request #1459 from codeflash-ai/comparator-ast-recursion-depth
Update test_comparator.py to check for ast & Ellipsis objects
2 parents 1fd7e6a + 5af82f0 commit ad4c323

File tree

2 files changed

+68
-41
lines changed

2 files changed

+68
-41
lines changed

codeflash/verification/comparator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import datetime
23
import decimal
34
import enum
@@ -61,6 +62,7 @@ def comparator(orig: Any, new: Any) -> bool:
6162
bool,
6263
complex,
6364
type(None),
65+
type(Ellipsis),
6466
decimal.Decimal,
6567
set,
6668
bytes,
@@ -77,8 +79,6 @@ def comparator(orig: Any, new: Any) -> bool:
7779
return True
7880
return math.isclose(orig, new)
7981
if isinstance(orig, BaseException):
80-
# if str(orig) != str(new):
81-
# return False
8282
# compare the attributes of the two exception objects to determine if they are equivalent.
8383
orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")}
8484
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
@@ -182,7 +182,6 @@ def comparator(orig: Any, new: Any) -> bool:
182182
return orig == new
183183
except Exception:
184184
pass
185-
186185
# For class objects
187186
if hasattr(orig, "__dict__") and hasattr(new, "__dict__"):
188187
orig_keys = orig.__dict__
@@ -196,13 +195,16 @@ def comparator(orig: Any, new: Any) -> bool:
196195
orig_keys = {k: v for k, v in orig_keys.items() if not k.startswith("__")}
197196
new_keys = {k: v for k, v in new_keys.items() if not k.startswith("__")}
198197

198+
if isinstance(orig, ast.AST):
199+
orig_keys = {k: v for k, v in orig.__dict__.items() if k != "parent"}
200+
new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"}
201+
199202
return comparator(orig_keys, new_keys)
200203

201204
if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
202205
return new == orig
203206
if str(type(orig)) == "<class 'object'>":
204207
return True
205-
206208
# TODO : Add other types here
207209
logger.warning(f"Unknown comparator input type: {type(orig)}")
208210
return False

tests/test_comparator.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ast
2+
import copy
13
import dataclasses
24
import datetime
35
import decimal
@@ -6,13 +8,15 @@
68

79
import pydantic
810
import pytest
11+
from pathlib import Path
12+
913
from codeflash.either import Failure, Success
1014
from codeflash.verification.comparator import comparator
1115
from codeflash.verification.equivalence import compare_test_results
1216
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
1317

1418

15-
def test_basic_python_objects():
19+
def test_basic_python_objects() -> None:
1620
a = 5
1721
b = 5
1822
c = 6
@@ -120,40 +124,40 @@ def test_basic_python_objects():
120124
assert not comparator(a, c)
121125

122126

123-
def test_standard_python_library_objects():
124-
a = datetime.datetime(2020, 2, 2, 2, 2, 2)
125-
b = datetime.datetime(2020, 2, 2, 2, 2, 2)
126-
c = datetime.datetime(2020, 2, 2, 2, 2, 3)
127+
def test_standard_python_library_objects() -> None:
128+
a = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
129+
b = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
130+
c = datetime.datetime(2020, 2, 2, 2, 2, 3) # type: ignore
127131
assert comparator(a, b)
128132
assert not comparator(a, c)
129133

130-
a = datetime.date(2020, 2, 2)
131-
b = datetime.date(2020, 2, 2)
132-
c = datetime.date(2020, 2, 3)
134+
a = datetime.date(2020, 2, 2) # type: ignore
135+
b = datetime.date(2020, 2, 2) # type: ignore
136+
c = datetime.date(2020, 2, 3) # type: ignore
133137
assert comparator(a, b)
134138
assert not comparator(a, c)
135139

136-
a = datetime.timedelta(days=1)
137-
b = datetime.timedelta(days=1)
138-
c = datetime.timedelta(days=2)
140+
a = datetime.timedelta(days=1) # type: ignore
141+
b = datetime.timedelta(days=1) # type: ignore
142+
c = datetime.timedelta(days=2) # type: ignore
139143
assert comparator(a, b)
140144
assert not comparator(a, c)
141145

142-
a = datetime.time(2, 2, 2)
143-
b = datetime.time(2, 2, 2)
144-
c = datetime.time(2, 2, 3)
146+
a = datetime.time(2, 2, 2) # type: ignore
147+
b = datetime.time(2, 2, 2) # type: ignore
148+
c = datetime.time(2, 2, 3) # type: ignore
145149
assert comparator(a, b)
146150
assert not comparator(a, c)
147151

148-
a = datetime.timezone.utc
149-
b = datetime.timezone.utc
150-
c = datetime.timezone(datetime.timedelta(hours=1))
152+
a = datetime.timezone.utc # type: ignore
153+
b = datetime.timezone.utc # type: ignore
154+
c = datetime.timezone(datetime.timedelta(hours=1)) # type: ignore
151155
assert comparator(a, b)
152156
assert not comparator(a, c)
153157

154-
a = decimal.Decimal(3.14)
155-
b = decimal.Decimal(3.14)
156-
c = decimal.Decimal(3.15)
158+
a = decimal.Decimal(3.14) # type: ignore
159+
b = decimal.Decimal(3.14) # type: ignore
160+
c = decimal.Decimal(3.15) # type: ignore
157161
assert comparator(a, b)
158162
assert not comparator(a, c)
159163

@@ -167,15 +171,15 @@ class Color2(Enum):
167171
GREEN = auto()
168172
BLUE = auto()
169173

170-
a = Color.RED
171-
b = Color.RED
172-
c = Color.GREEN
174+
a = Color.RED # type: ignore
175+
b = Color.RED # type: ignore
176+
c = Color.GREEN # type: ignore
173177
assert comparator(a, b)
174178
assert not comparator(a, c)
175179

176-
a = Color2.RED
177-
b = Color2.RED
178-
c = Color2.GREEN
180+
a = Color2.RED # type: ignore
181+
b = Color2.RED # type: ignore
182+
c = Color2.GREEN # type: ignore
179183
assert comparator(a, b)
180184
assert not comparator(a, c)
181185

@@ -184,9 +188,9 @@ class Color4(IntFlag):
184188
GREEN = auto()
185189
BLUE = auto()
186190

187-
a = Color4.RED
188-
b = Color4.RED
189-
c = Color4.GREEN
191+
a = Color4.RED # type: ignore
192+
b = Color4.RED # type: ignore
193+
c = Color4.GREEN # type: ignore
190194
assert comparator(a, b)
191195
assert not comparator(a, c)
192196

@@ -296,7 +300,7 @@ def test_numpy():
296300

297301
def test_scipy():
298302
try:
299-
import scipy as sp
303+
import scipy as sp # type: ignore
300304
except ImportError:
301305
pytest.skip()
302306
a = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
@@ -466,7 +470,7 @@ def test_pandas():
466470

467471
def test_pyrsistent():
468472
try:
469-
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector
473+
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore
470474
except ImportError:
471475
pytest.skip()
472476

@@ -678,7 +682,7 @@ def test_compare_results_fn():
678682
function_getting_tested="function_getting_tested",
679683
iteration_id="0",
680684
),
681-
file_name="file_name",
685+
file_name=Path("file_name"),
682686
did_pass=True,
683687
runtime=5,
684688
test_framework="unittest",
@@ -699,7 +703,7 @@ def test_compare_results_fn():
699703
function_getting_tested="function_getting_tested",
700704
iteration_id="0",
701705
),
702-
file_name="file_name",
706+
file_name=Path("file_name"),
703707
did_pass=True,
704708
runtime=10,
705709
test_framework="unittest",
@@ -722,7 +726,7 @@ def test_compare_results_fn():
722726
function_getting_tested="function_getting_tested",
723727
iteration_id="0",
724728
),
725-
file_name="file_name",
729+
file_name=Path("file_name"),
726730
did_pass=True,
727731
runtime=10,
728732
test_framework="unittest",
@@ -745,7 +749,7 @@ def test_compare_results_fn():
745749
function_getting_tested="function_getting_tested",
746750
iteration_id="0",
747751
),
748-
file_name="file_name",
752+
file_name=Path("file_name"),
749753
did_pass=True,
750754
runtime=10,
751755
test_framework="unittest",
@@ -764,7 +768,7 @@ def test_compare_results_fn():
764768
function_getting_tested="function_getting_tested",
765769
iteration_id="2",
766770
),
767-
file_name="file_name",
771+
file_name=Path("file_name"),
768772
did_pass=True,
769773
runtime=10,
770774
test_framework="unittest",
@@ -787,7 +791,7 @@ def test_compare_results_fn():
787791
function_getting_tested="function_getting_tested",
788792
iteration_id="0",
789793
),
790-
file_name="file_name",
794+
file_name=Path("file_name"),
791795
did_pass=False,
792796
runtime=5,
793797
test_framework="unittest",
@@ -941,3 +945,24 @@ def raise_specific_exception():
941945

942946
zero_division_exc3 = ZeroDivisionError("Different message")
943947
assert comparator(zero_division_exc1, zero_division_exc3)
948+
949+
assert comparator(..., ...)
950+
assert comparator(Ellipsis, Ellipsis)
951+
952+
assert not comparator(..., None)
953+
954+
assert not comparator(Ellipsis, None)
955+
956+
code7 = "a = 1 + 2"
957+
module7 = ast.parse(code7)
958+
for node in ast.walk(module7):
959+
for child in ast.iter_child_nodes(node):
960+
child.parent = node # type: ignore
961+
module8 = copy.deepcopy(module7)
962+
assert comparator(module7, module8)
963+
964+
code2 = "a = 1 + 3"
965+
966+
module2 = ast.parse(code2)
967+
968+
assert not comparator(module7, module2)

0 commit comments

Comments
 (0)