Skip to content

Commit e3de6ee

Browse files
committed
Merge branch 'refs/heads/main' into init_caching
# Conflicts: # cli/codeflash/verification/comparator.py
2 parents 3d33f8b + ad4c323 commit e3de6ee

File tree

2 files changed

+68
-39
lines changed

2 files changed

+68
-39
lines changed

codeflash/verification/comparator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import datetime
23
import decimal
34
import enum
@@ -47,7 +48,6 @@
4748
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
4849
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
4950
try:
50-
# if not type_comparator(orig, new):
5151
if type(orig) is not type(new):
5252
type_obj = type(orig)
5353
new_type_obj = type(new)
@@ -66,6 +66,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
6666
bool,
6767
complex,
6868
type(None),
69+
type(Ellipsis),
6970
decimal.Decimal,
7071
set,
7172
bytes,
@@ -206,13 +207,16 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
206207
if superset_obj:
207208
# allow new object to be a superset of the original object
208209
return all(k in new_keys and comparator(v, new_keys[k], superset_obj) for k, v in orig_keys.items())
210+
211+
if isinstance(orig, ast.AST):
212+
orig_keys = {k: v for k, v in orig.__dict__.items() if k != "parent"}
213+
new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"}
209214
return comparator(orig_keys, new_keys, superset_obj)
210215

211216
if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
212217
return new == orig
213218
if str(type(orig)) == "<class 'object'>":
214219
return True
215-
216220
# TODO : Add other types here
217221
logger.warning(f"Unknown comparator input type: {type(orig)}")
218222
return False

tests/test_comparator.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ast
2+
import copy
13
import dataclasses
24
import datetime
35
import decimal
@@ -8,13 +10,15 @@
810

911
import pydantic
1012
import pytest
13+
from pathlib import Path
14+
1115
from codeflash.either import Failure, Success
1216
from codeflash.verification.comparator import comparator
1317
from codeflash.verification.equivalence import compare_test_results
1418
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
1519

1620

17-
def test_basic_python_objects():
21+
def test_basic_python_objects() -> None:
1822
a = 5
1923
b = 5
2024
c = 6
@@ -122,40 +126,40 @@ def test_basic_python_objects():
122126
assert not comparator(a, c)
123127

124128

125-
def test_standard_python_library_objects():
126-
a = datetime.datetime(2020, 2, 2, 2, 2, 2)
127-
b = datetime.datetime(2020, 2, 2, 2, 2, 2)
128-
c = datetime.datetime(2020, 2, 2, 2, 2, 3)
129+
def test_standard_python_library_objects() -> None:
130+
a = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
131+
b = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
132+
c = datetime.datetime(2020, 2, 2, 2, 2, 3) # type: ignore
129133
assert comparator(a, b)
130134
assert not comparator(a, c)
131135

132-
a = datetime.date(2020, 2, 2)
133-
b = datetime.date(2020, 2, 2)
134-
c = datetime.date(2020, 2, 3)
136+
a = datetime.date(2020, 2, 2) # type: ignore
137+
b = datetime.date(2020, 2, 2) # type: ignore
138+
c = datetime.date(2020, 2, 3) # type: ignore
135139
assert comparator(a, b)
136140
assert not comparator(a, c)
137141

138-
a = datetime.timedelta(days=1)
139-
b = datetime.timedelta(days=1)
140-
c = datetime.timedelta(days=2)
142+
a = datetime.timedelta(days=1) # type: ignore
143+
b = datetime.timedelta(days=1) # type: ignore
144+
c = datetime.timedelta(days=2) # type: ignore
141145
assert comparator(a, b)
142146
assert not comparator(a, c)
143147

144-
a = datetime.time(2, 2, 2)
145-
b = datetime.time(2, 2, 2)
146-
c = datetime.time(2, 2, 3)
148+
a = datetime.time(2, 2, 2) # type: ignore
149+
b = datetime.time(2, 2, 2) # type: ignore
150+
c = datetime.time(2, 2, 3) # type: ignore
147151
assert comparator(a, b)
148152
assert not comparator(a, c)
149153

150-
a = datetime.timezone.utc
151-
b = datetime.timezone.utc
152-
c = datetime.timezone(datetime.timedelta(hours=1))
154+
a = datetime.timezone.utc # type: ignore
155+
b = datetime.timezone.utc # type: ignore
156+
c = datetime.timezone(datetime.timedelta(hours=1)) # type: ignore
153157
assert comparator(a, b)
154158
assert not comparator(a, c)
155159

156-
a = decimal.Decimal(3.14)
157-
b = decimal.Decimal(3.14)
158-
c = decimal.Decimal(3.15)
160+
a = decimal.Decimal(3.14) # type: ignore
161+
b = decimal.Decimal(3.14) # type: ignore
162+
c = decimal.Decimal(3.15) # type: ignore
159163
assert comparator(a, b)
160164
assert not comparator(a, c)
161165

@@ -169,15 +173,15 @@ class Color2(Enum):
169173
GREEN = auto()
170174
BLUE = auto()
171175

172-
a = Color.RED
173-
b = Color.RED
174-
c = Color.GREEN
176+
a = Color.RED # type: ignore
177+
b = Color.RED # type: ignore
178+
c = Color.GREEN # type: ignore
175179
assert comparator(a, b)
176180
assert not comparator(a, c)
177181

178-
a = Color2.RED
179-
b = Color2.RED
180-
c = Color2.GREEN
182+
a = Color2.RED # type: ignore
183+
b = Color2.RED # type: ignore
184+
c = Color2.GREEN # type: ignore
181185
assert comparator(a, b)
182186
assert not comparator(a, c)
183187

@@ -186,9 +190,9 @@ class Color4(IntFlag):
186190
GREEN = auto()
187191
BLUE = auto()
188192

189-
a = Color4.RED
190-
b = Color4.RED
191-
c = Color4.GREEN
193+
a = Color4.RED # type: ignore
194+
b = Color4.RED # type: ignore
195+
c = Color4.GREEN # type: ignore
192196
assert comparator(a, b)
193197
assert not comparator(a, c)
194198

@@ -298,7 +302,7 @@ def test_numpy():
298302

299303
def test_scipy():
300304
try:
301-
import scipy as sp
305+
import scipy as sp # type: ignore
302306
except ImportError:
303307
pytest.skip()
304308
a = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
@@ -468,7 +472,7 @@ def test_pandas():
468472

469473
def test_pyrsistent():
470474
try:
471-
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector
475+
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore
472476
except ImportError:
473477
pytest.skip()
474478

@@ -766,7 +770,7 @@ def test_compare_results_fn():
766770
function_getting_tested="function_getting_tested",
767771
iteration_id="0",
768772
),
769-
file_name="file_name",
773+
file_name=Path("file_name"),
770774
did_pass=True,
771775
runtime=5,
772776
test_framework="unittest",
@@ -787,7 +791,7 @@ def test_compare_results_fn():
787791
function_getting_tested="function_getting_tested",
788792
iteration_id="0",
789793
),
790-
file_name="file_name",
794+
file_name=Path("file_name"),
791795
did_pass=True,
792796
runtime=10,
793797
test_framework="unittest",
@@ -810,7 +814,7 @@ def test_compare_results_fn():
810814
function_getting_tested="function_getting_tested",
811815
iteration_id="0",
812816
),
813-
file_name="file_name",
817+
file_name=Path("file_name"),
814818
did_pass=True,
815819
runtime=10,
816820
test_framework="unittest",
@@ -833,7 +837,7 @@ def test_compare_results_fn():
833837
function_getting_tested="function_getting_tested",
834838
iteration_id="0",
835839
),
836-
file_name="file_name",
840+
file_name=Path("file_name"),
837841
did_pass=True,
838842
runtime=10,
839843
test_framework="unittest",
@@ -852,7 +856,7 @@ def test_compare_results_fn():
852856
function_getting_tested="function_getting_tested",
853857
iteration_id="2",
854858
),
855-
file_name="file_name",
859+
file_name=Path("file_name"),
856860
did_pass=True,
857861
runtime=10,
858862
test_framework="unittest",
@@ -875,7 +879,7 @@ def test_compare_results_fn():
875879
function_getting_tested="function_getting_tested",
876880
iteration_id="0",
877881
),
878-
file_name="file_name",
882+
file_name=Path("file_name"),
879883
did_pass=False,
880884
runtime=5,
881885
test_framework="unittest",
@@ -1029,3 +1033,24 @@ def raise_specific_exception():
10291033

10301034
zero_division_exc3 = ZeroDivisionError("Different message")
10311035
assert comparator(zero_division_exc1, zero_division_exc3)
1036+
1037+
assert comparator(..., ...)
1038+
assert comparator(Ellipsis, Ellipsis)
1039+
1040+
assert not comparator(..., None)
1041+
1042+
assert not comparator(Ellipsis, None)
1043+
1044+
code7 = "a = 1 + 2"
1045+
module7 = ast.parse(code7)
1046+
for node in ast.walk(module7):
1047+
for child in ast.iter_child_nodes(node):
1048+
child.parent = node # type: ignore
1049+
module8 = copy.deepcopy(module7)
1050+
assert comparator(module7, module8)
1051+
1052+
code2 = "a = 1 + 3"
1053+
1054+
module2 = ast.parse(code2)
1055+
1056+
assert not comparator(module7, module2)

0 commit comments

Comments
 (0)