Skip to content

Commit 33437d3

Browse files
authored
use pytest as the execution engine for all tests (#951)
* first pass restore restore this too Revert "first pass" This reverts commit b507770. * continue * Update uv.lock * refresh lockfile * bugfix * temp * fix these * pytest changes * formatting * set up test env properly here too * ruff * make ruff happy * Update e2e-bubblesort-unittest.yaml * with pytest * bugfix * oops
1 parent c9e1483 commit 33437d3

File tree

10 files changed

+833
-774
lines changed

10 files changed

+833
-774
lines changed

.github/workflows/e2e-bubblesort-unittest.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ jobs:
6161
- name: Install dependencies (CLI)
6262
run: |
6363
uv sync
64+
uv add timeout_decorator
6465
6566
- name: Run Codeflash to optimize code
6667
id: optimize_code

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import ast
4-
import platform
54
from dataclasses import dataclass
65
from pathlib import Path
76
from typing import TYPE_CHECKING
@@ -329,17 +328,6 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
329328
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
330329
if node.name.startswith("test_"):
331330
did_update = False
332-
if self.test_framework == "unittest" and platform.system() != "Windows":
333-
# Only add timeout decorator on non-Windows platforms
334-
# Windows doesn't support SIGALRM signal required by timeout_decorator
335-
336-
node.decorator_list.append(
337-
ast.Call(
338-
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
339-
args=[ast.Constant(value=15)],
340-
keywords=[],
341-
)
342-
)
343331
i = len(node.body) - 1
344332
while i >= 0:
345333
line_node = node.body[i]
@@ -505,25 +493,6 @@ def __init__(
505493
self.class_name = function.top_level_parent_name
506494

507495
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
508-
# Add timeout decorator for unittest test classes if needed
509-
if self.test_framework == "unittest":
510-
timeout_decorator = ast.Call(
511-
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
512-
args=[ast.Constant(value=15)],
513-
keywords=[],
514-
)
515-
for item in node.body:
516-
if (
517-
isinstance(item, ast.FunctionDef)
518-
and item.name.startswith("test_")
519-
and not any(
520-
isinstance(d, ast.Call)
521-
and isinstance(d.func, ast.Name)
522-
and d.func.id == "timeout_decorator.timeout"
523-
for d in item.decorator_list
524-
)
525-
):
526-
item.decorator_list.append(timeout_decorator)
527496
return self.generic_visit(node)
528497

529498
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
@@ -542,25 +511,6 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
542511
def _process_test_function(
543512
self, node: ast.AsyncFunctionDef | ast.FunctionDef
544513
) -> ast.AsyncFunctionDef | ast.FunctionDef:
545-
# Optimize the search for decorator presence
546-
if self.test_framework == "unittest":
547-
found_timeout = False
548-
for d in node.decorator_list:
549-
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
550-
if isinstance(d, ast.Call):
551-
f = d.func
552-
# Avoid attribute lookup if f is not ast.Name
553-
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
554-
found_timeout = True
555-
break
556-
if not found_timeout:
557-
timeout_decorator = ast.Call(
558-
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
559-
args=[ast.Constant(value=15)],
560-
keywords=[],
561-
)
562-
node.decorator_list.append(timeout_decorator)
563-
564514
# Initialize counter for this test function
565515
if node.name not in self.async_call_counter:
566516
self.async_call_counter[node.name] = 0
@@ -715,8 +665,6 @@ def inject_async_profiling_into_existing_test(
715665

716666
# Add necessary imports
717667
new_imports = [ast.Import(names=[ast.alias(name="os")])]
718-
if test_framework == "unittest":
719-
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
720668

721669
tree.body = [*new_imports, *tree.body]
722670
return True, sort_imports(ast.unparse(tree), float_to_top=True)
@@ -762,8 +710,6 @@ def inject_profiling_into_existing_test(
762710
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
763711
]
764712
)
765-
if test_framework == "unittest" and platform.system() != "Windows":
766-
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
767713
additional_functions = [create_wrapper_function(mode)]
768714

769715
tree.body = [*new_imports, *additional_functions, *tree.body]

codeflash/optimization/function_optimizer.py

Lines changed: 49 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import queue
77
import random
88
import subprocess
9-
import time
109
import uuid
1110
from collections import defaultdict
1211
from pathlib import Path
@@ -1641,57 +1640,34 @@ def establish_original_code_baseline(
16411640
f"Test coverage is {coverage_results.coverage}%, which is below the required threshold of {COVERAGE_THRESHOLD}%."
16421641
)
16431642

1644-
if test_framework == "pytest":
1645-
with progress_bar("Running line profiler to identify performance bottlenecks..."):
1646-
line_profile_results = self.line_profiler_step(
1647-
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
1648-
)
1649-
console.rule()
1650-
with progress_bar("Running performance benchmarks..."):
1651-
if self.function_to_optimize.is_async:
1652-
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
1643+
with progress_bar("Running line profiler to identify performance bottlenecks..."):
1644+
line_profile_results = self.line_profiler_step(
1645+
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
1646+
)
1647+
console.rule()
1648+
with progress_bar("Running performance benchmarks..."):
1649+
if self.function_to_optimize.is_async:
1650+
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
16531651

1654-
add_async_decorator_to_function(
1655-
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
1656-
)
1652+
add_async_decorator_to_function(
1653+
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
1654+
)
16571655

1658-
try:
1659-
benchmarking_results, _ = self.run_and_parse_tests(
1660-
testing_type=TestingMode.PERFORMANCE,
1661-
test_env=test_env,
1662-
test_files=self.test_files,
1663-
optimization_iteration=0,
1664-
testing_time=total_looping_time,
1665-
enable_coverage=False,
1666-
code_context=code_context,
1667-
)
1668-
finally:
1669-
if self.function_to_optimize.is_async:
1670-
self.write_code_and_helpers(
1671-
self.function_to_optimize_source_code,
1672-
original_helper_code,
1673-
self.function_to_optimize.file_path,
1674-
)
1675-
else:
1676-
benchmarking_results = TestResults()
1677-
start_time: float = time.time()
1678-
for i in range(100):
1679-
if i >= 5 and time.time() - start_time >= total_looping_time * 1.5:
1680-
# * 1.5 to give unittest a bit more time to run
1681-
break
1682-
test_env["CODEFLASH_LOOP_INDEX"] = str(i + 1)
1683-
with progress_bar("Running performance benchmarks..."):
1684-
unittest_loop_results, _ = self.run_and_parse_tests(
1685-
testing_type=TestingMode.PERFORMANCE,
1686-
test_env=test_env,
1687-
test_files=self.test_files,
1688-
optimization_iteration=0,
1689-
testing_time=total_looping_time,
1690-
enable_coverage=False,
1691-
code_context=code_context,
1692-
unittest_loop_index=i + 1,
1656+
try:
1657+
benchmarking_results, _ = self.run_and_parse_tests(
1658+
testing_type=TestingMode.PERFORMANCE,
1659+
test_env=test_env,
1660+
test_files=self.test_files,
1661+
optimization_iteration=0,
1662+
testing_time=total_looping_time,
1663+
enable_coverage=False,
1664+
code_context=code_context,
1665+
)
1666+
finally:
1667+
if self.function_to_optimize.is_async:
1668+
self.write_code_and_helpers(
1669+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
16931670
)
1694-
benchmarking_results.merge(unittest_loop_results)
16951671

16961672
console.print(
16971673
TestResults.report_to_tree(
@@ -1760,8 +1736,6 @@ def run_optimized_candidate(
17601736
original_helper_code: dict[Path, str],
17611737
file_path_to_helper_classes: dict[Path, set[str]],
17621738
) -> Result[OptimizedCandidateResult, str]:
1763-
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
1764-
17651739
with progress_bar("Testing optimization candidate"):
17661740
test_env = self.get_test_env(
17671741
codeflash_loop_index=0,
@@ -1818,59 +1792,34 @@ def run_optimized_candidate(
18181792

18191793
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
18201794

1821-
if test_framework == "pytest":
1822-
# For async functions, instrument at definition site for performance benchmarking
1823-
if self.function_to_optimize.is_async:
1824-
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
1825-
1826-
add_async_decorator_to_function(
1827-
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
1828-
)
1795+
# For async functions, instrument at definition site for performance benchmarking
1796+
if self.function_to_optimize.is_async:
1797+
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
18291798

1830-
try:
1831-
candidate_benchmarking_results, _ = self.run_and_parse_tests(
1832-
testing_type=TestingMode.PERFORMANCE,
1833-
test_env=test_env,
1834-
test_files=self.test_files,
1835-
optimization_iteration=optimization_candidate_index,
1836-
testing_time=total_looping_time,
1837-
enable_coverage=False,
1838-
)
1839-
finally:
1840-
# Restore original source if we instrumented it
1841-
if self.function_to_optimize.is_async:
1842-
self.write_code_and_helpers(
1843-
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
1844-
)
1845-
loop_count = (
1846-
max(all_loop_indices)
1847-
if (
1848-
all_loop_indices := {
1849-
result.loop_index for result in candidate_benchmarking_results.test_results
1850-
}
1851-
)
1852-
else 0
1799+
add_async_decorator_to_function(
1800+
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
18531801
)
18541802

1855-
else:
1856-
candidate_benchmarking_results = TestResults()
1857-
start_time: float = time.time()
1858-
loop_count = 0
1859-
for i in range(100):
1860-
if i >= 5 and time.time() - start_time >= TOTAL_LOOPING_TIME_EFFECTIVE * 1.5:
1861-
# * 1.5 to give unittest a bit more time to run
1862-
break
1863-
test_env["CODEFLASH_LOOP_INDEX"] = str(i + 1)
1864-
unittest_loop_results, _cov = self.run_and_parse_tests(
1865-
testing_type=TestingMode.PERFORMANCE,
1866-
test_env=test_env,
1867-
test_files=self.test_files,
1868-
optimization_iteration=optimization_candidate_index,
1869-
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
1870-
unittest_loop_index=i + 1,
1803+
try:
1804+
candidate_benchmarking_results, _ = self.run_and_parse_tests(
1805+
testing_type=TestingMode.PERFORMANCE,
1806+
test_env=test_env,
1807+
test_files=self.test_files,
1808+
optimization_iteration=optimization_candidate_index,
1809+
testing_time=total_looping_time,
1810+
enable_coverage=False,
1811+
)
1812+
finally:
1813+
# Restore original source if we instrumented it
1814+
if self.function_to_optimize.is_async:
1815+
self.write_code_and_helpers(
1816+
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
18711817
)
1872-
loop_count = i + 1
1873-
candidate_benchmarking_results.merge(unittest_loop_results)
1818+
loop_count = (
1819+
max(all_loop_indices)
1820+
if (all_loop_indices := {result.loop_index for result in candidate_benchmarking_results.test_results})
1821+
else 0
1822+
)
18741823

18751824
if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0:
18761825
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
@@ -1920,7 +1869,6 @@ def run_and_parse_tests(
19201869
pytest_min_loops: int = 5,
19211870
pytest_max_loops: int = 250,
19221871
code_context: CodeOptimizationContext | None = None,
1923-
unittest_loop_index: int | None = None,
19241872
line_profiler_output_file: Path | None = None,
19251873
) -> tuple[TestResults | dict, CoverageData | None]:
19261874
coverage_database_file = None
@@ -1933,7 +1881,6 @@ def run_and_parse_tests(
19331881
cwd=self.project_root,
19341882
test_env=test_env,
19351883
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
1936-
verbose=True,
19371884
enable_coverage=enable_coverage,
19381885
)
19391886
elif testing_type == TestingMode.LINE_PROFILE:
@@ -1947,7 +1894,6 @@ def run_and_parse_tests(
19471894
pytest_min_loops=1,
19481895
pytest_max_loops=1,
19491896
test_framework=self.test_cfg.test_framework,
1950-
line_profiler_output_file=line_profiler_output_file,
19511897
)
19521898
elif testing_type == TestingMode.PERFORMANCE:
19531899
result_file_path, run_result = run_benchmarking_tests(
@@ -1996,7 +1942,6 @@ def run_and_parse_tests(
19961942
test_config=self.test_cfg,
19971943
optimization_iteration=optimization_iteration,
19981944
run_result=run_result,
1999-
unittest_loop_index=unittest_loop_index,
20001945
function_name=self.function_to_optimize.function_name,
20011946
source_file=self.function_to_optimize.file_path,
20021947
code_context=code_context,

0 commit comments

Comments
 (0)