|
6 | 6 | import os |
7 | 7 | import sqlite3 |
8 | 8 | import time |
| 9 | +from pathlib import Path |
9 | 10 |
|
10 | 11 | import dill as pickle |
11 | 12 |
|
12 | 13 | from codeflash.verification.test_results import VerificationType |
13 | 14 |
|
14 | 15 |
|
15 | | -def get_test_info_from_stack() -> tuple[str, str | None, str, str]: |
16 | | - """Extract test information from the call stack.""" |
17 | | - stack = inspect.stack() |
18 | | - |
19 | | - # Default values |
| 16 | +def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str]: |
| 17 | + """Extract test information by walking the call stack from the current frame.""" |
20 | 18 | test_module_name = "" |
21 | | - test_class_name = None |
22 | | - test_name = None |
23 | | - line_id = "" # Note that the way this line_id is defined is from the line_id called in instrumentation |
24 | | - |
25 | | - # Search through stack for test information |
26 | | - for frame in stack: |
27 | | - if frame.function.startswith("test_"): # May need a more robust way to find the test file |
28 | | - test_name = frame.function |
29 | | - test_module_name = inspect.getmodule(frame[0]).__name__ |
30 | | - line_id = str(frame.lineno) |
| 19 | + test_class_name: str | None = None |
| 20 | + test_name: str | None = None |
| 21 | + line_id = "" |
| 22 | + |
| 23 | + # Get current frame and skip our own function's frame |
| 24 | + frame = inspect.currentframe() |
| 25 | + if frame is not None: |
| 26 | + frame = frame.f_back |
| 27 | + |
| 28 | + # Walk the stack |
| 29 | + while frame is not None: |
| 30 | + function_name = frame.f_code.co_name |
| 31 | + filename = frame.f_code.co_filename |
| 32 | + lineno = frame.f_lineno |
| 33 | + |
| 34 | + # Check if function name indicates a test (e.g., starts with "test_") |
| 35 | + if function_name.startswith("test_"): |
| 36 | + test_name = function_name |
| 37 | + test_module = inspect.getmodule(frame) |
| 38 | + if hasattr(test_module, "__name__"): |
| 39 | + test_module_name = test_module.__name__ |
| 40 | + line_id = str(lineno) |
| 41 | + |
31 | 42 | # Check if it's a method in a class |
32 | | - if "self" in frame.frame.f_locals: |
33 | | - test_class_name = frame.frame.f_locals["self"].__class__.__name__ |
| 43 | + if ( |
| 44 | + "self" in frame.f_locals |
| 45 | + and hasattr(frame.f_locals["self"], "__class__") |
| 46 | + and hasattr(frame.f_locals["self"].__class__, "__name__") |
| 47 | + ): |
| 48 | + test_class_name = frame.f_locals["self"].__class__.__name__ |
34 | 49 | break |
35 | | - # Check if module name starts with test |
36 | | - module_name = frame.frame.f_globals["__name__"] |
37 | | - if module_name and module_name.split(".")[-1].startswith("test_"): |
38 | | - test_module_name = module_name |
39 | | - line_id = str(frame.lineno) |
40 | | - if frame.function != "<module>": |
41 | | - test_name = frame.function # Technically not a test, but save the info since there is no test function |
42 | | - # Check if it's in a class |
43 | | - if "self" in frame.frame.f_locals: |
44 | | - test_class_name = frame.frame.f_locals["self"].__class__.__name__ |
| 50 | + |
| 51 | + # Check for instantiation on the module level |
| 52 | + if ( |
| 53 | + "__name__" in frame.f_globals |
| 54 | + and frame.f_globals["__name__"].split(".")[-1].startswith("test_") |
| 55 | + and Path(filename).resolve().is_relative_to(Path(tests_root)) |
| 56 | + and function_name == "<module>" |
| 57 | + ): |
| 58 | + test_module_name = frame.f_globals["__name__"] |
| 59 | + line_id = str(lineno) |
| 60 | + |
| 61 | + # # Check if it's a method in a class |
| 62 | + if ( |
| 63 | + "self" in frame.f_locals |
| 64 | + and hasattr(frame.f_locals["self"], "__class__") |
| 65 | + and hasattr(frame.f_locals["self"].__class__, "__name__") |
| 66 | + ): |
| 67 | + test_class_name = frame.f_locals["self"].__class__.__name__ |
45 | 68 | break |
46 | 69 |
|
| 70 | + # Go to the previous frame |
| 71 | + frame = frame.f_back |
| 72 | + |
47 | 73 | return test_module_name, test_class_name, test_name, line_id |
48 | 74 |
|
49 | 75 |
|
50 | | -def codeflash_capture(function_name: str, tmp_dir_path: str, is_fto: bool = False): |
| 76 | +def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is_fto: bool = False): |
51 | 77 | """Defines decorator to be instrumented onto the init function in the code. Collects info of the test that called this, and captures the state of the instance.""" |
52 | 78 |
|
53 | 79 | def decorator(wrapped): |
54 | 80 | @functools.wraps(wrapped) |
55 | 81 | def wrapper(*args, **kwargs): |
56 | 82 | # Dynamic information retrieved from stack |
57 | | - test_module_name, test_class_name, test_name, line_id = get_test_info_from_stack() |
| 83 | + test_module_name, test_class_name, test_name, line_id = get_test_info_from_stack(tests_root) |
58 | 84 |
|
59 | 85 | # Get env variables |
60 | 86 | loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) |
@@ -98,7 +124,12 @@ def wrapper(*args, **kwargs): |
98 | 124 | gc.enable() |
99 | 125 |
|
100 | 126 | # Capture instance state after initialization |
101 | | - instance_state = args[0].__dict__ # self is always the first argument |
| 127 | + if hasattr(args[0], "__dict__"): |
| 128 | + instance_state = args[ |
| 129 | + 0 |
| 130 | + ].__dict__ # self is always the first argument, this is ensured during instrumentation |
| 131 | + else: |
| 132 | + raise ValueError("Instance state could not be captured.") |
102 | 133 | codeflash_cur.execute( |
103 | 134 | "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" |
104 | 135 | ) |
|
0 commit comments