Skip to content

Commit acff586

Browse files
committed
first pass
1 parent 0e5916c commit acff586

File tree

6 files changed

+19
-39
lines changed

6 files changed

+19
-39
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def parse_args() -> Namespace:
7575
parser.add_argument(
7676
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
7777
)
78-
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
7978
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
8079
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
8180
parser.add_argument(
@@ -172,7 +171,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
172171
"module_root",
173172
"tests_root",
174173
"benchmarks_root",
175-
"test_framework",
176174
"ignore_paths",
177175
"pytest_cmd",
178176
"formatter_cmds",

codeflash/cli_cmds/cmd_init.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,26 +1233,8 @@ def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
12331233
arr[j + 1] = temp
12341234
return arr
12351235
"""
1236-
if args.test_framework == "unittest":
1237-
bubble_sort_test_content = f"""import unittest
1238-
from {os.path.basename(args.module_root)}.bubble_sort import sorter # Keep usage of os.path.basename to avoid pathlib potential incompatibility https://github.com/codeflash-ai/codeflash/pull/1066#discussion_r1801628022
1239-
1240-
class TestBubbleSort(unittest.TestCase):
1241-
def test_sort(self):
1242-
input = [5, 4, 3, 2, 1, 0]
1243-
output = sorter(input)
1244-
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
1245-
1246-
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
1247-
output = sorter(input)
1248-
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
1249-
1250-
input = list(reversed(range(100)))
1251-
output = sorter(input)
1252-
self.assertEqual(output, list(range(100)))
1253-
""" # noqa: PTH119
1254-
elif args.test_framework == "pytest":
1255-
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
1236+
# Always use pytest for tests
1237+
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
12561238
12571239
def test_sort():
12581240
input = [5, 4, 3, 2, 1, 0]

codeflash/optimization/function_optimizer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
10531053
call_positions=[test.position for test in tests_in_file_list],
10541054
function_to_optimize=self.function_to_optimize,
10551055
tests_project_root=self.test_cfg.tests_project_rootdir,
1056-
test_framework=self.args.test_framework,
1056+
test_framework="pytest",
10571057
)
10581058
if not success:
10591059
continue
@@ -1063,7 +1063,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
10631063
call_positions=[test.position for test in tests_in_file_list],
10641064
function_to_optimize=self.function_to_optimize,
10651065
tests_project_root=self.test_cfg.tests_project_rootdir,
1066-
test_framework=self.args.test_framework,
1066+
test_framework="pytest",
10671067
)
10681068
if not success:
10691069
continue
@@ -1271,7 +1271,7 @@ def setup_and_establish_baseline(
12711271

12721272
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
12731273
if isinstance(original_code_baseline, OriginalCodeBaseline) and (
1274-
not coverage_critic(original_code_baseline.coverage_results, self.args.test_framework)
1274+
not coverage_critic(original_code_baseline.coverage_results, "pytest")
12751275
or not quantity_of_tests_critic(original_code_baseline)
12761276
):
12771277
if self.args.override_fixtures:
@@ -1593,7 +1593,7 @@ def establish_original_code_baseline(
15931593
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
15941594
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
15951595
# For the original function - run the tests and get the runtime, plus coverage
1596-
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
1596+
test_framework = "pytest" # Always use pytest for all tests
15971597
success = True
15981598

15991599
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
@@ -1618,7 +1618,7 @@ def establish_original_code_baseline(
16181618
test_files=self.test_files,
16191619
optimization_iteration=0,
16201620
testing_time=total_looping_time,
1621-
enable_coverage=test_framework == "pytest",
1621+
enable_coverage=True, # Always enable coverage with pytest
16221622
code_context=code_context,
16231623
)
16241624
finally:
@@ -1632,7 +1632,7 @@ def establish_original_code_baseline(
16321632
)
16331633
console.rule()
16341634
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
1635-
if not coverage_critic(coverage_results, self.args.test_framework):
1635+
if not coverage_critic(coverage_results, "pytest"):
16361636
did_pass_all_tests = all(result.did_pass for result in behavioral_results)
16371637
if not did_pass_all_tests:
16381638
return Failure("Tests failed to pass for the original code.")

codeflash/optimization/optimizer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def __init__(self, args: Namespace) -> None:
4444
tests_root=args.tests_root,
4545
tests_project_rootdir=args.test_project_root,
4646
project_root_path=args.project_root,
47-
test_framework=args.test_framework,
48-
pytest_cmd=args.pytest_cmd,
47+
pytest_cmd=args.pytest_cmd if hasattr(args, "pytest_cmd") else "pytest",
4948
benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None,
5049
)
5150

@@ -274,10 +273,9 @@ def run(self) -> None:
274273
file_to_funcs_to_optimize, num_optimizable_functions
275274
)
276275
optimizations_found: int = 0
277-
if self.args.test_framework == "pytest":
278-
self.test_cfg.concolic_test_root_dir = Path(
279-
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
280-
)
276+
self.test_cfg.concolic_test_root_dir = Path(
277+
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
278+
)
281279
try:
282280
ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions})
283281
if num_optimizable_functions == 0:

codeflash/verification/concolic_testing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def generate_concolic_tests(
8181
tests_root=concolic_test_suite_dir,
8282
tests_project_rootdir=test_cfg.concolic_test_root_dir,
8383
project_root_path=args.project_root,
84-
test_framework=args.test_framework,
85-
pytest_cmd=args.pytest_cmd,
84+
pytest_cmd=args.pytest_cmd if hasattr(args, "pytest_cmd") else "pytest",
8685
)
8786
function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg)
8887
logger.info(

codeflash/verification/verification_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
6969
class TestConfig:
7070
tests_root: Path
7171
project_root_path: Path
72-
test_framework: str
7372
tests_project_rootdir: Path
74-
# tests_project_rootdir corresponds to pytest rootdir,
75-
# or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path)
73+
# tests_project_rootdir corresponds to pytest rootdir
7674
concolic_test_root_dir: Optional[Path] = None
7775
pytest_cmd: str = "pytest"
7876
benchmark_tests_root: Optional[Path] = None
7977
use_cache: bool = True
78+
79+
@property
80+
def test_framework(self) -> str:
81+
"""Always returns 'pytest' as we use pytest for all tests."""
82+
return "pytest"

0 commit comments

Comments
 (0)