Skip to content

Commit 8fd505d

Browse files
committed
Update discover_unit_tests.py
1 parent 3aae8c2 commit 8fd505d

File tree

1 file changed

+101
-54
lines changed

1 file changed

+101
-54
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 101 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ class TestFunction:
3434
def discover_unit_tests(
3535
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
3636
) -> dict[str, list[FunctionCalledInTest]]:
37-
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
37+
framework_strategies: dict[str, Callable] = {
38+
"pytest": discover_tests_pytest,
39+
"unittest": discover_tests_unittest,
40+
}
3841
strategy = framework_strategies.get(cfg.test_framework, None)
3942
if not strategy:
4043
error_message = f"Unsupported test framework: {cfg.test_framework}"
@@ -82,7 +85,9 @@ def discover_tests_pytest(
8285
)
8386

8487
elif 0 <= exitcode <= 5:
85-
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}")
88+
logger.warning(
89+
f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}"
90+
)
8691
else:
8792
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
8893
console.rule()
@@ -105,7 +110,10 @@ def discover_tests_pytest(
105110
test_function=test["test_function"],
106111
test_type=test_type,
107112
)
108-
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
113+
if (
114+
discover_only_these_tests
115+
and test_obj.test_file not in discover_only_these_tests
116+
):
109117
continue
110118
file_to_test_map[test_obj.test_file].append(test_obj)
111119
# Within these test files, find the project functions they are referring to and return their names/locations
@@ -130,7 +138,8 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
130138
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
131139
_test_module_path = tests_root / _test_module_path
132140
if not _test_module_path.exists() or (
133-
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
141+
discover_only_these_tests
142+
and str(_test_module_path) not in discover_only_these_tests
134143
):
135144
return None
136145
if "__replay_test" in str(_test_module_path):
@@ -157,7 +166,9 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
157166
if not hasattr(test, "_testMethodName") and hasattr(test, "_tests"):
158167
for test_2 in test._tests:
159168
if not hasattr(test_2, "_testMethodName"):
160-
logger.warning(f"Didn't find tests for {test_2}") # it goes deeper?
169+
logger.warning(
170+
f"Didn't find tests for {test_2}"
171+
) # it goes deeper?
161172
continue
162173
details = get_test_details(test_2)
163174
if details is not None:
@@ -182,8 +193,9 @@ def process_test_files(
182193
) -> dict[str, list[FunctionCalledInTest]]:
183194
project_root_path = cfg.project_root_path
184195
test_framework = cfg.test_framework
185-
function_to_test_map = defaultdict(list)
196+
function_to_test_map = defaultdict(set)
186197
jedi_project = jedi.Project(path=project_root_path)
198+
goto_cache = {}
187199

188200
for test_file, functions in file_to_test_map.items():
189201
try:
@@ -194,8 +206,12 @@ def process_test_files(
194206
all_defs = script.get_names(all_scopes=True, definitions=True)
195207
all_names_top = script.get_names(all_scopes=True)
196208

197-
top_level_functions = {name.name: name for name in all_names_top if name.type == "function"}
198-
top_level_classes = {name.name: name for name in all_names_top if name.type == "class"}
209+
top_level_functions = {
210+
name.name: name for name in all_names_top if name.type == "function"
211+
}
212+
top_level_classes = {
213+
name.name: name for name in all_names_top if name.type == "class"
214+
}
199215
except Exception as e:
200216
logger.debug(f"Failed to get jedi script for {test_file}: {e}")
201217
continue
@@ -207,11 +223,21 @@ def process_test_files(
207223
parameters = re.split(r"[\[\]]", function.test_function)[1]
208224
if function_name in top_level_functions:
209225
test_functions.add(
210-
TestFunction(function_name, function.test_class, parameters, function.test_type)
226+
TestFunction(
227+
function_name,
228+
function.test_class,
229+
parameters,
230+
function.test_type,
231+
)
211232
)
212233
elif function.test_function in top_level_functions:
213234
test_functions.add(
214-
TestFunction(function.test_function, function.test_class, None, function.test_type)
235+
TestFunction(
236+
function.test_function,
237+
function.test_class,
238+
None,
239+
function.test_type,
240+
)
215241
)
216242
elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function):
217243
# Try to match parameterized unittest functions here, although we can't get the parameters.
@@ -229,7 +255,7 @@ def process_test_files(
229255

230256
elif test_framework == "unittest":
231257
functions_to_search = [elem.test_function for elem in functions]
232-
test_suites = [elem.test_class for elem in functions]
258+
test_suites = {elem.test_class for elem in functions}
233259

234260
matching_names = test_suites & top_level_classes.keys()
235261
for matched_name in matching_names:
@@ -240,7 +266,9 @@ def process_test_files(
240266
and f".{matched_name}." in def_name.full_name
241267
):
242268
for function in functions_to_search:
243-
(is_parameterized, new_function, parameters) = discover_parameters_unittest(function)
269+
(is_parameterized, new_function, parameters) = (
270+
discover_parameters_unittest(function)
271+
)
244272

245273
if is_parameterized and new_function == def_name.name:
246274
test_functions.add(
@@ -264,53 +292,72 @@ def process_test_files(
264292
test_functions_list = list(test_functions)
265293
test_functions_raw = [elem.function_name for elem in test_functions_list]
266294

295+
test_functions_by_name = defaultdict(list)
296+
for i, func_name in enumerate(test_functions_raw):
297+
test_functions_by_name[func_name].append(i)
298+
267299
for name in all_names:
268300
if name.full_name is None:
269301
continue
270302
m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name)
271303
if not m:
272304
continue
305+
273306
scope = m.group(1)
274-
indices = [i for i, x in enumerate(test_functions_raw) if x == scope]
275-
for index in indices:
276-
scope_test_function = test_functions_list[index].function_name
277-
scope_test_class = test_functions_list[index].test_class
278-
scope_parameters = test_functions_list[index].parameters
279-
test_type = test_functions_list[index].test_type
280-
try:
281-
definition = name.goto(follow_imports=True, follow_builtin_imports=False)
282-
except Exception as e:
283-
logger.debug(str(e))
284-
continue
285-
if definition and definition[0].type == "function":
286-
definition_path = str(definition[0].module_path)
287-
# The definition is part of this project and not defined within the original function
288-
if (
289-
definition_path.startswith(str(project_root_path) + os.sep)
290-
and definition[0].module_name != name.module_name
291-
and definition[0].full_name is not None
292-
):
293-
if scope_parameters is not None:
294-
if test_framework == "pytest":
295-
scope_test_function += "[" + scope_parameters + "]"
296-
if test_framework == "unittest":
297-
scope_test_function += "_" + scope_parameters
298-
full_name_without_module_prefix = definition[0].full_name.replace(
299-
definition[0].module_name + ".", "", 1
300-
)
301-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
302-
function_to_test_map[qualified_name_with_modules_from_root].append(
303-
FunctionCalledInTest(
304-
tests_in_file=TestsInFile(
305-
test_file=test_file,
306-
test_class=scope_test_class,
307-
test_function=scope_test_function,
308-
test_type=test_type,
309-
),
310-
position=CodePosition(line_no=name.line, col_no=name.column),
311-
)
307+
if scope not in test_functions_by_name:
308+
continue
309+
310+
cache_key = (name.full_name, name.module_name)
311+
try:
312+
if cache_key in goto_cache:
313+
definition = goto_cache[cache_key]
314+
else:
315+
definition = name.goto(
316+
follow_imports=True, follow_builtin_imports=False
317+
)
318+
goto_cache[cache_key] = definition
319+
except Exception as e:
320+
logger.debug(str(e))
321+
continue
322+
323+
if not definition or definition[0].type != "function":
324+
continue
325+
326+
definition_path = str(definition[0].module_path)
327+
if (
328+
definition_path.startswith(str(project_root_path) + os.sep)
329+
and definition[0].module_name != name.module_name
330+
and definition[0].full_name is not None
331+
):
332+
for index in test_functions_by_name[scope]:
333+
scope_test_function = test_functions_list[index].function_name
334+
scope_test_class = test_functions_list[index].test_class
335+
scope_parameters = test_functions_list[index].parameters
336+
test_type = test_functions_list[index].test_type
337+
338+
if scope_parameters is not None:
339+
if test_framework == "pytest":
340+
scope_test_function += "[" + scope_parameters + "]"
341+
if test_framework == "unittest":
342+
scope_test_function += "_" + scope_parameters
343+
344+
full_name_without_module_prefix = definition[0].full_name.replace(
345+
definition[0].module_name + ".", "", 1
346+
)
347+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}"
348+
349+
function_to_test_map[qualified_name_with_modules_from_root].add(
350+
FunctionCalledInTest(
351+
tests_in_file=TestsInFile(
352+
test_file=test_file,
353+
test_class=scope_test_class,
354+
test_function=scope_test_function,
355+
test_type=test_type,
356+
),
357+
position=CodePosition(
358+
line_no=name.line, col_no=name.column
359+
),
312360
)
313-
deduped_function_to_test_map = {}
314-
for function, tests in function_to_test_map.items():
315-
deduped_function_to_test_map[function] = list(set(tests))
316-
return deduped_function_to_test_map
361+
)
362+
363+
return {function: list(tests) for function, tests in function_to_test_map.items()}

0 commit comments

Comments
 (0)