Skip to content

Commit f5f136b

Browse files
committed
changes after code review, made logic more robust, added tests_root arg to codeflash capture, added recursive tests, renamed some files / functions
1 parent e3de6ee commit f5f136b

12 files changed

+464
-202
lines changed

codeflash/optimization/optimizer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from codeflash.verification.bayesian_analysis import compare_function_runtime_distributions
7070
from codeflash.verification.concolic_testing import generate_concolic_tests
7171
from codeflash.verification.equivalence import compare_test_results
72-
from codeflash.verification.instrument_code import instrument_code
72+
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
7373
from codeflash.verification.parse_test_output import parse_test_results
7474
from codeflash.verification.test_results import TestResults, TestType
7575
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
@@ -1029,7 +1029,9 @@ def establish_original_code_baseline(
10291029
original_fto_code = function_file_path.read_text("utf-8")
10301030
# Instrument codeflash capture
10311031
try:
1032-
instrument_code(function_to_optimize, file_path_to_helper_classes)
1032+
instrument_codeflash_capture(
1033+
function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
1034+
)
10331035
behavioral_results, coverage_results = self.run_and_parse_tests(
10341036
testing_type=TestingMode.BEHAVIOR,
10351037
test_env=test_env,
@@ -1175,7 +1177,9 @@ def run_optimized_candidate(
11751177
for module_abspath in original_helper_code:
11761178
candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8")
11771179
try:
1178-
instrument_code(function_to_optimize, file_path_to_helper_classes)
1180+
instrument_codeflash_capture(
1181+
function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
1182+
)
11791183

11801184
candidate_behavior_results, _ = self.run_and_parse_tests(
11811185
testing_type=TestingMode.BEHAVIOR,

codeflash/verification/codeflash_capture.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,81 @@
66
import os
77
import sqlite3
88
import time
9+
from pathlib import Path
910

1011
import dill as pickle
1112

1213
from codeflash.verification.test_results import VerificationType
1314

1415

15-
def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
16-
"""Extract test information from the call stack."""
17-
stack = inspect.stack()
18-
19-
# Default values
16+
def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str]:
17+
"""Extract test information by walking the call stack from the current frame."""
2018
test_module_name = ""
21-
test_class_name = None
22-
test_name = None
23-
line_id = "" # Note that the way this line_id is defined is from the line_id called in instrumentation
24-
25-
# Search through stack for test information
26-
for frame in stack:
27-
if frame.function.startswith("test_"): # May need a more robust way to find the test file
28-
test_name = frame.function
29-
test_module_name = inspect.getmodule(frame[0]).__name__
30-
line_id = str(frame.lineno)
19+
test_class_name: str | None = None
20+
test_name: str | None = None
21+
line_id = ""
22+
23+
# Get current frame and skip our own function's frame
24+
frame = inspect.currentframe()
25+
if frame is not None:
26+
frame = frame.f_back
27+
28+
# Walk the stack
29+
while frame is not None:
30+
function_name = frame.f_code.co_name
31+
filename = frame.f_code.co_filename
32+
lineno = frame.f_lineno
33+
34+
# Check if function name indicates a test (e.g., starts with "test_")
35+
if function_name.startswith("test_"):
36+
test_name = function_name
37+
test_module = inspect.getmodule(frame)
38+
if hasattr(test_module, "__name__"):
39+
test_module_name = test_module.__name__
40+
line_id = str(lineno)
41+
3142
# Check if it's a method in a class
32-
if "self" in frame.frame.f_locals:
33-
test_class_name = frame.frame.f_locals["self"].__class__.__name__
43+
if (
44+
"self" in frame.f_locals
45+
and hasattr(frame.f_locals["self"], "__class__")
46+
and hasattr(frame.f_locals["self"].__class__, "__name__")
47+
):
48+
test_class_name = frame.f_locals["self"].__class__.__name__
3449
break
35-
# Check if module name starts with test
36-
module_name = frame.frame.f_globals["__name__"]
37-
if module_name and module_name.split(".")[-1].startswith("test_"):
38-
test_module_name = module_name
39-
line_id = str(frame.lineno)
40-
if frame.function != "<module>":
41-
test_name = frame.function # Technically not a test, but save the info since there is no test function
42-
# Check if it's in a class
43-
if "self" in frame.frame.f_locals:
44-
test_class_name = frame.frame.f_locals["self"].__class__.__name__
50+
51+
# Check for instantiation on the module level
52+
if (
53+
"__name__" in frame.f_globals
54+
and frame.f_globals["__name__"].split(".")[-1].startswith("test_")
55+
and Path(filename).resolve().is_relative_to(Path(tests_root))
56+
and function_name == "<module>"
57+
):
58+
test_module_name = frame.f_globals["__name__"]
59+
line_id = str(lineno)
60+
61+
# # Check if it's a method in a class
62+
if (
63+
"self" in frame.f_locals
64+
and hasattr(frame.f_locals["self"], "__class__")
65+
and hasattr(frame.f_locals["self"].__class__, "__name__")
66+
):
67+
test_class_name = frame.f_locals["self"].__class__.__name__
4568
break
4669

70+
# Go to the previous frame
71+
frame = frame.f_back
72+
4773
return test_module_name, test_class_name, test_name, line_id
4874

4975

50-
def codeflash_capture(function_name: str, tmp_dir_path: str, is_fto: bool = False):
76+
def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is_fto: bool = False):
5177
"""Defines decorator to be instrumented onto the init function in the code. Collects info of the test that called this, and captures the state of the instance."""
5278

5379
def decorator(wrapped):
5480
@functools.wraps(wrapped)
5581
def wrapper(*args, **kwargs):
5682
# Dynamic information retrieved from stack
57-
test_module_name, test_class_name, test_name, line_id = get_test_info_from_stack()
83+
test_module_name, test_class_name, test_name, line_id = get_test_info_from_stack(tests_root)
5884

5985
# Get env variables
6086
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
@@ -98,7 +124,12 @@ def wrapper(*args, **kwargs):
98124
gc.enable()
99125

100126
# Capture instance state after initialization
101-
instance_state = args[0].__dict__ # self is always the first argument
127+
if hasattr(args[0], "__dict__"):
128+
instance_state = args[
129+
0
130+
].__dict__ # self is always the first argument, this is ensured during instrumentation
131+
else:
132+
raise ValueError("Instance state could not be captured.")
102133
codeflash_cur.execute(
103134
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)"
104135
)

codeflash/verification/comparator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
5151
if type(orig) is not type(new):
5252
type_obj = type(orig)
5353
new_type_obj = type(new)
54+
# distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names
5455
if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__:
5556
return False
5657
if isinstance(orig, (list, tuple)):

codeflash/verification/equivalence.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,19 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
4646
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
4747
are_equal = False
4848
logger.debug(
49-
f"""
50-
File Name: {original_test_result.file_name}
51-
Test Type: {original_test_result.test_type}
52-
Verification Type: {original_test_result.verification_type}
53-
Invocation ID: {original_test_result.id}
54-
Original return value: {original_test_result.return_value}
55-
CDD return value: {cdd_test_result.return_value}
56-
-------------------"""
49+
"File Name: %s\n"
50+
"Test Type: %s\n"
51+
"Verification Type: %s\n"
52+
"Invocation ID: %s\n"
53+
"Original return value: %s\n"
54+
"Candidate return value: %s\n"
55+
"-------------------",
56+
original_test_result.file_name,
57+
original_test_result.test_type,
58+
original_test_result.verification_type,
59+
original_test_result.id,
60+
original_test_result.return_value,
61+
cdd_test_result.return_value,
5762
)
5863
break
5964
if original_test_result.test_type in [TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST] and (

codeflash/verification/instrument_code.py renamed to codeflash/verification/instrument_codeflash_capture.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1010

1111

12-
def instrument_code(function_to_optimize: FunctionToOptimize, file_path_to_helper_class: dict[Path, set[str]]) -> None:
12+
def instrument_codeflash_capture(
13+
function_to_optimize: FunctionToOptimize, file_path_to_helper_class: dict[Path, set[str]], tests_root: Path
14+
) -> None:
1315
"""Instrument __init__ function with codeflash_capture decorator if it's in a class."""
1416
# Find the class parent
1517
if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef":
@@ -23,42 +25,38 @@ def instrument_code(function_to_optimize: FunctionToOptimize, file_path_to_helpe
2325
):
2426
file_path_to_helper_class[function_to_optimize.file_path].remove(class_parent.name)
2527
# Instrument fto class
26-
with open(function_to_optimize.file_path) as f:
27-
original_code = f.read()
28-
28+
original_code = function_to_optimize.file_path.read_text(encoding="utf-8")
2929
# Add decorator to init
3030
modified_code = add_codeflash_capture_to_init(
3131
target_classes={class_parent.name},
3232
fto_name=function_to_optimize.function_name,
3333
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
3434
code=original_code,
35+
tests_root=tests_root,
3536
is_fto=True,
3637
)
37-
38-
with open(function_to_optimize.file_path, "w") as f:
39-
f.write(modified_code)
38+
function_to_optimize.file_path.write_text(modified_code, encoding="utf-8")
4039

4140
# Instrument helper classes
4241
for file_path, helper_classes in file_path_to_helper_class.items():
43-
with open(file_path) as f:
44-
original_code = f.read()
42+
original_code = file_path.read_text(encoding="utf-8")
4543
modified_code = add_codeflash_capture_to_init(
4644
target_classes=helper_classes,
4745
fto_name=function_to_optimize.function_name,
4846
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
4947
code=original_code,
48+
tests_root=tests_root,
5049
is_fto=False,
5150
)
52-
with open(file_path, "w") as f:
53-
f.write(modified_code)
51+
file_path.write_text(modified_code, encoding="utf-8")
5452

5553

5654
def add_codeflash_capture_to_init(
57-
target_classes: set[str], fto_name: str, tmp_dir_path: str, code: str, is_fto: bool = False
55+
target_classes: set[str], fto_name: str, tmp_dir_path: str, code: str, tests_root: Path, is_fto: bool = False
5856
) -> str:
5957
"""Add codeflash_capture decorator to __init__ function in the specified class."""
6058
tree = ast.parse(code)
61-
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, is_fto)
59+
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, tests_root, is_fto)
6260
modified_tree = transformer.visit(tree)
6361
if transformer.inserted_decorator:
6462
ast.fix_missing_locations(modified_tree)
@@ -70,12 +68,15 @@ def add_codeflash_capture_to_init(
7068
class InitDecorator(ast.NodeTransformer):
7169
"""AST transformer that adds codeflash_capture decorator to specific class's __init__."""
7270

73-
def __init__(self, target_classes: set[str], fto_name: str, tmp_dir_path: str, is_fto=False) -> None:
71+
def __init__(
72+
self, target_classes: set[str], fto_name: str, tmp_dir_path: str, tests_root: Path, is_fto=False
73+
) -> None:
7474
self.target_classes = target_classes
7575
self.fto_name = fto_name
7676
self.tmp_dir_path = tmp_dir_path
7777
self.is_fto = is_fto
7878
self.has_import = False
79+
self.tests_root = tests_root
7980
self.inserted_decorator = False
8081

8182
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
@@ -110,12 +111,19 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
110111
keywords=[
111112
ast.keyword(arg="function_name", value=ast.Constant(value=".".join([node.name, "__init__"]))),
112113
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
114+
ast.keyword(arg="tests_root", value=ast.Constant(value=str(self.tests_root))),
113115
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
114116
],
115117
)
116118

117119
for item in node.body:
118-
if isinstance(item, ast.FunctionDef) and item.name == "__init__":
120+
if (
121+
isinstance(item, ast.FunctionDef)
122+
and item.name == "__init__"
123+
and item.args.args
124+
and isinstance(item.args.args[0], ast.arg)
125+
and item.args.args[0].arg == "self"
126+
):
119127
has_init = True
120128

121129
# Add decorator at the start of the list if not already present

codeflash/verification/parse_test_output.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,9 @@ def merge_test_results(
412412
test_type=xml_result.test_type,
413413
return_value=result_bin.return_value,
414414
timed_out=xml_result.timed_out,
415-
verification_type=VerificationType(result_bin.verification_type),
415+
verification_type=VerificationType(result_bin.verification_type)
416+
if result_bin.verification_type
417+
else None,
416418
)
417419
)
418420
elif xml_results.test_results[0].id.iteration_id is not None:
@@ -439,7 +441,9 @@ def merge_test_results(
439441
timed_out=xml_result.timed_out
440442
if bin_result.runtime is None
441443
else False, # If runtime was measured in the bin file, then the testcase did not time out
442-
verification_type=VerificationType(bin_result.verification_type),
444+
verification_type=VerificationType(bin_result.verification_type)
445+
if bin_result.verification_type
446+
else None,
443447
)
444448
)
445449
else:
@@ -463,7 +467,9 @@ def merge_test_results(
463467
test_type=bin_result.test_type,
464468
return_value=bin_result.return_value,
465469
timed_out=xml_result.timed_out, # only the xml gets the timed_out flag
466-
verification_type=VerificationType(bin_result.verification_type),
470+
verification_type=VerificationType(bin_result.verification_type)
471+
if bin_result.verification_type
472+
else None,
467473
)
468474
)
469475

0 commit comments

Comments
 (0)