@@ -34,7 +34,10 @@ class TestFunction:
3434def discover_unit_tests (
3535 cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None
3636) -> dict [str , list [FunctionCalledInTest ]]:
37- framework_strategies : dict [str , Callable ] = {"pytest" : discover_tests_pytest , "unittest" : discover_tests_unittest }
37+ framework_strategies : dict [str , Callable ] = {
38+ "pytest" : discover_tests_pytest ,
39+ "unittest" : discover_tests_unittest ,
40+ }
3841 strategy = framework_strategies .get (cfg .test_framework , None )
3942 if not strategy :
4043 error_message = f"Unsupported test framework: { cfg .test_framework } "
@@ -82,7 +85,9 @@ def discover_tests_pytest(
8285 )
8386
8487 elif 0 <= exitcode <= 5 :
85- logger .warning (f"Failed to collect tests. Pytest Exit code: { exitcode } ={ ExitCode (exitcode ).name } " )
88+ logger .warning (
89+ f"Failed to collect tests. Pytest Exit code: { exitcode } ={ ExitCode (exitcode ).name } "
90+ )
8691 else :
8792 logger .warning (f"Failed to collect tests. Pytest Exit code: { exitcode } " )
8893 console .rule ()
@@ -105,7 +110,10 @@ def discover_tests_pytest(
105110 test_function = test ["test_function" ],
106111 test_type = test_type ,
107112 )
108- if discover_only_these_tests and test_obj .test_file not in discover_only_these_tests :
113+ if (
114+ discover_only_these_tests
115+ and test_obj .test_file not in discover_only_these_tests
116+ ):
109117 continue
110118 file_to_test_map [test_obj .test_file ].append (test_obj )
111119 # Within these test files, find the project functions they are referring to and return their names/locations
@@ -130,7 +138,8 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
130138 _test_module_path = Path (_test_module .replace ("." , os .sep )).with_suffix (".py" )
131139 _test_module_path = tests_root / _test_module_path
132140 if not _test_module_path .exists () or (
133- discover_only_these_tests and str (_test_module_path ) not in discover_only_these_tests
141+ discover_only_these_tests
142+ and str (_test_module_path ) not in discover_only_these_tests
134143 ):
135144 return None
136145 if "__replay_test" in str (_test_module_path ):
@@ -157,7 +166,9 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
157166 if not hasattr (test , "_testMethodName" ) and hasattr (test , "_tests" ):
158167 for test_2 in test ._tests :
159168 if not hasattr (test_2 , "_testMethodName" ):
160- logger .warning (f"Didn't find tests for { test_2 } " ) # it goes deeper?
169+ logger .warning (
170+ f"Didn't find tests for { test_2 } "
171+ ) # it goes deeper?
161172 continue
162173 details = get_test_details (test_2 )
163174 if details is not None :
@@ -182,8 +193,9 @@ def process_test_files(
182193) -> dict [str , list [FunctionCalledInTest ]]:
183194 project_root_path = cfg .project_root_path
184195 test_framework = cfg .test_framework
185- function_to_test_map = defaultdict (list )
196+ function_to_test_map = defaultdict (set )
186197 jedi_project = jedi .Project (path = project_root_path )
198+ goto_cache = {}
187199
188200 for test_file , functions in file_to_test_map .items ():
189201 try :
@@ -194,8 +206,12 @@ def process_test_files(
194206 all_defs = script .get_names (all_scopes = True , definitions = True )
195207 all_names_top = script .get_names (all_scopes = True )
196208
197- top_level_functions = {name .name : name for name in all_names_top if name .type == "function" }
198- top_level_classes = {name .name : name for name in all_names_top if name .type == "class" }
209+ top_level_functions = {
210+ name .name : name for name in all_names_top if name .type == "function"
211+ }
212+ top_level_classes = {
213+ name .name : name for name in all_names_top if name .type == "class"
214+ }
199215 except Exception as e :
200216 logger .debug (f"Failed to get jedi script for { test_file } : { e } " )
201217 continue
@@ -207,11 +223,21 @@ def process_test_files(
207223 parameters = re .split (r"[\[\]]" , function .test_function )[1 ]
208224 if function_name in top_level_functions :
209225 test_functions .add (
210- TestFunction (function_name , function .test_class , parameters , function .test_type )
226+ TestFunction (
227+ function_name ,
228+ function .test_class ,
229+ parameters ,
230+ function .test_type ,
231+ )
211232 )
212233 elif function .test_function in top_level_functions :
213234 test_functions .add (
214- TestFunction (function .test_function , function .test_class , None , function .test_type )
235+ TestFunction (
236+ function .test_function ,
237+ function .test_class ,
238+ None ,
239+ function .test_type ,
240+ )
215241 )
216242 elif re .match (r"^test_\w+_\d+(?:_\w+)*" , function .test_function ):
217243 # Try to match parameterized unittest functions here, although we can't get the parameters.
@@ -229,7 +255,7 @@ def process_test_files(
229255
230256 elif test_framework == "unittest" :
231257 functions_to_search = [elem .test_function for elem in functions ]
232- test_suites = [ elem .test_class for elem in functions ]
258+ test_suites = { elem .test_class for elem in functions }
233259
234260 matching_names = test_suites & top_level_classes .keys ()
235261 for matched_name in matching_names :
@@ -240,7 +266,9 @@ def process_test_files(
240266 and f".{ matched_name } ." in def_name .full_name
241267 ):
242268 for function in functions_to_search :
243- (is_parameterized , new_function , parameters ) = discover_parameters_unittest (function )
269+ (is_parameterized , new_function , parameters ) = (
270+ discover_parameters_unittest (function )
271+ )
244272
245273 if is_parameterized and new_function == def_name .name :
246274 test_functions .add (
@@ -264,53 +292,72 @@ def process_test_files(
264292 test_functions_list = list (test_functions )
265293 test_functions_raw = [elem .function_name for elem in test_functions_list ]
266294
295+ test_functions_by_name = defaultdict (list )
296+ for i , func_name in enumerate (test_functions_raw ):
297+ test_functions_by_name [func_name ].append (i )
298+
267299 for name in all_names :
268300 if name .full_name is None :
269301 continue
270302 m = re .search (r"([^.]+)\." + f"{ name .name } $" , name .full_name )
271303 if not m :
272304 continue
305+
273306 scope = m .group (1 )
274- indices = [i for i , x in enumerate (test_functions_raw ) if x == scope ]
275- for index in indices :
276- scope_test_function = test_functions_list [index ].function_name
277- scope_test_class = test_functions_list [index ].test_class
278- scope_parameters = test_functions_list [index ].parameters
279- test_type = test_functions_list [index ].test_type
280- try :
281- definition = name .goto (follow_imports = True , follow_builtin_imports = False )
282- except Exception as e :
283- logger .debug (str (e ))
284- continue
285- if definition and definition [0 ].type == "function" :
286- definition_path = str (definition [0 ].module_path )
287- # The definition is part of this project and not defined within the original function
288- if (
289- definition_path .startswith (str (project_root_path ) + os .sep )
290- and definition [0 ].module_name != name .module_name
291- and definition [0 ].full_name is not None
292- ):
293- if scope_parameters is not None :
294- if test_framework == "pytest" :
295- scope_test_function += "[" + scope_parameters + "]"
296- if test_framework == "unittest" :
297- scope_test_function += "_" + scope_parameters
298- full_name_without_module_prefix = definition [0 ].full_name .replace (
299- definition [0 ].module_name + "." , "" , 1
300- )
301- qualified_name_with_modules_from_root = f"{ module_name_from_file_path (definition [0 ].module_path , project_root_path )} .{ full_name_without_module_prefix } "
302- function_to_test_map [qualified_name_with_modules_from_root ].append (
303- FunctionCalledInTest (
304- tests_in_file = TestsInFile (
305- test_file = test_file ,
306- test_class = scope_test_class ,
307- test_function = scope_test_function ,
308- test_type = test_type ,
309- ),
310- position = CodePosition (line_no = name .line , col_no = name .column ),
311- )
307+ if scope not in test_functions_by_name :
308+ continue
309+
310+ cache_key = (name .full_name , name .module_name )
311+ try :
312+ if cache_key in goto_cache :
313+ definition = goto_cache [cache_key ]
314+ else :
315+ definition = name .goto (
316+ follow_imports = True , follow_builtin_imports = False
317+ )
318+ goto_cache [cache_key ] = definition
319+ except Exception as e :
320+ logger .debug (str (e ))
321+ continue
322+
323+ if not definition or definition [0 ].type != "function" :
324+ continue
325+
326+ definition_path = str (definition [0 ].module_path )
327+ if (
328+ definition_path .startswith (str (project_root_path ) + os .sep )
329+ and definition [0 ].module_name != name .module_name
330+ and definition [0 ].full_name is not None
331+ ):
332+ for index in test_functions_by_name [scope ]:
333+ scope_test_function = test_functions_list [index ].function_name
334+ scope_test_class = test_functions_list [index ].test_class
335+ scope_parameters = test_functions_list [index ].parameters
336+ test_type = test_functions_list [index ].test_type
337+
338+ if scope_parameters is not None :
339+ if test_framework == "pytest" :
340+ scope_test_function += "[" + scope_parameters + "]"
341+ if test_framework == "unittest" :
342+ scope_test_function += "_" + scope_parameters
343+
344+ full_name_without_module_prefix = definition [0 ].full_name .replace (
345+ definition [0 ].module_name + "." , "" , 1
346+ )
347+ qualified_name_with_modules_from_root = f"{ module_name_from_file_path (definition [0 ].module_path , project_root_path )} .{ full_name_without_module_prefix } "
348+
349+ function_to_test_map [qualified_name_with_modules_from_root ].add (
350+ FunctionCalledInTest (
351+ tests_in_file = TestsInFile (
352+ test_file = test_file ,
353+ test_class = scope_test_class ,
354+ test_function = scope_test_function ,
355+ test_type = test_type ,
356+ ),
357+ position = CodePosition (
358+ line_no = name .line , col_no = name .column
359+ ),
312360 )
313- deduped_function_to_test_map = {}
314- for function , tests in function_to_test_map .items ():
315- deduped_function_to_test_map [function ] = list (set (tests ))
316- return deduped_function_to_test_map
361+ )
362+
363+ return {function : list (tests ) for function , tests in function_to_test_map .items ()}
0 commit comments