Skip to content

Commit baa5176

Browse files
committed
pre-compile regex
1 parent 8fd505d commit baa5176

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ class TestFunction:
3131
test_type: TestType
3232

3333

34+
ERROR_PATTERN = re.compile(r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)")
35+
PYTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"[\[\]]")
36+
UNITTEST_PARAMETERIZED_TEST_NAME_REGEX = re.compile(r"^test_\w+_\d+(?:_\w+)*")
37+
UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX = re.compile(r"_\d+(?:_\w+)*$")
38+
FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$")
39+
40+
3441
def discover_unit_tests(
3542
cfg: TestConfig, discover_only_these_tests: list[Path] | None = None
3643
) -> dict[str, list[FunctionCalledInTest]]:
@@ -76,8 +83,7 @@ def discover_tests_pytest(
7683
if exitcode != 0:
7784
if exitcode == 2 and "ERROR collecting" in result.stdout:
7885
# Pattern matches "===== ERRORS =====" (any number of =) and captures everything after
79-
error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
80-
match = re.search(error_pattern, result.stdout)
86+
match = ERROR_PATTERN.search(result.stdout)
8187
error_section = match.group(1) if match else result.stdout
8288

8389
logger.warning(
@@ -219,8 +225,12 @@ def process_test_files(
219225
if test_framework == "pytest":
220226
for function in functions:
221227
if "[" in function.test_function:
222-
function_name = re.split(r"[\[\]]", function.test_function)[0]
223-
parameters = re.split(r"[\[\]]", function.test_function)[1]
228+
function_name = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
229+
function.test_function
230+
)[0]
231+
parameters = PYTEST_PARAMETERIZED_TEST_NAME_REGEX.split(
232+
function.test_function
233+
)[1]
224234
if function_name in top_level_functions:
225235
test_functions.add(
226236
TestFunction(
@@ -239,10 +249,14 @@ def process_test_files(
239249
function.test_type,
240250
)
241251
)
242-
elif re.match(r"^test_\w+_\d+(?:_\w+)*", function.test_function):
252+
elif UNITTEST_PARAMETERIZED_TEST_NAME_REGEX.match(
253+
function.test_function
254+
):
243255
# Try to match parameterized unittest functions here, although we can't get the parameters.
244256
# Extract base name by removing the numbered suffix and any additional descriptions
245-
base_name = re.sub(r"_\d+(?:_\w+)*$", "", function.test_function)
257+
base_name = UNITTEST_STRIP_NUMBERED_SUFFIX_REGEX.sub(
258+
"", function.test_function
259+
)
246260
if base_name in top_level_functions:
247261
test_functions.add(
248262
TestFunction(
@@ -299,7 +313,7 @@ def process_test_files(
299313
for name in all_names:
300314
if name.full_name is None:
301315
continue
302-
m = re.search(r"([^.]+)\." + f"{name.name}$", name.full_name)
316+
m = FUNCTION_NAME_REGEX.search(name.full_name)
303317
if not m:
304318
continue
305319

0 commit comments

Comments
 (0)