Skip to content

Commit 2e34d83

Browse files
authored
remove test_framework from pyproject.toml (#955)
* follow up * remove requirement * Delete uv.lock * refresh uv-lock * first pass * cleanup test_framework here * cleanup * code_review * cleanup tests * fix for E2E * fix tests dir missing * one more cleanup * cancel-in-progress * Revert "cancel-in-progress" This reverts commit f4bb907. * not needed here * lower threshold and cleanup comments * debug * temp * debug Revert "debug" This reverts commit fc36551. fix(discover): Fix pytest discovery for futurehouse structure Revert "fix(discover): Fix pytest discovery for futurehouse structure" This reverts commit 40c48882b7413f5876af0e2e08d8f17a65bab091. Reapply "debug" This reverts commit c8297e5. Revert "not needed here" This reverts commit dd2c5cd. Revert "lower threshold and cleanup comments" This reverts commit 0e2f57e. Reapply "lower threshold and cleanup comments" This reverts commit e3b24f4a2967551eca8a19f96bf6647b23acdbbc. Reapply "not needed here" This reverts commit aec32103c931ff6d57dfa0d012113c2cec5d37a7. Revert "Reapply "debug"" This reverts commit 77ab9f34f858a17fb29764c544769a0eb72ce7f0. Reapply "fix(discover): Fix pytest discovery for futurehouse structure" This reverts commit 506b94ab4fe17a7c8e0d458253812758cced3f22. feat(futurehouse): Make futurehouse structure pytest compatible * Revert "debug" This reverts commit 271c5a3. * Revert "temp" This reverts commit b363acd. * Revert "debug" This reverts commit ac29b6b. * just for now
1 parent b7b82ee commit 2e34d83

29 files changed

+277
-430
lines changed

.github/workflows/e2e-bubblesort-unittest.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ jobs:
6565
- name: Install dependencies (CLI)
6666
run: |
6767
uv sync
68-
uv add timeout_decorator
6968
7069
- name: Run Codeflash to optimize code
7170
id: optimize_code

.github/workflows/e2e-futurehouse-structure.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
COLUMNS: 110
2525
MAX_RETRIES: 3
2626
RETRY_DELAY: 5
27-
EXPECTED_IMPROVEMENT_PCT: 10
27+
EXPECTED_IMPROVEMENT_PCT: 5
2828
CODEFLASH_END_TO_END: 1
2929
steps:
3030
- name: 🛎️ Checkout

codeflash/cli_cmds/cli.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def parse_args() -> Namespace:
7575
parser.add_argument(
7676
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
7777
)
78-
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
7978
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
8079
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
8180
parser.add_argument(
@@ -172,7 +171,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
172171
"module_root",
173172
"tests_root",
174173
"benchmarks_root",
175-
"test_framework",
176174
"ignore_paths",
177175
"pytest_cmd",
178176
"formatter_cmds",

codeflash/cli_cmds/cmd_init.py

Lines changed: 2 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import ast
43
import os
54
import re
65
import subprocess
@@ -59,7 +58,6 @@ class CLISetupInfo:
5958
module_root: str
6059
tests_root: str
6160
benchmarks_root: Union[str, None]
62-
test_framework: str
6361
ignore_paths: list[str]
6462
formatter: Union[str, list[str]]
6563
git_remote: str
@@ -70,7 +68,6 @@ class CLISetupInfo:
7068
class VsCodeSetupInfo:
7169
module_root: str
7270
tests_root: str
73-
test_framework: str
7471
formatter: Union[str, list[str]]
7572

7673

@@ -256,7 +253,6 @@ def __init__(self) -> None:
256253
class CommonSections(Enum):
257254
module_root = "module_root"
258255
tests_root = "tests_root"
259-
test_framework = "test_framework"
260256
formatter_cmds = "formatter_cmds"
261257

262258
def get_toml_key(self) -> str:
@@ -292,9 +288,6 @@ def get_suggestions(section: str) -> tuple[list[str], Optional[str]]:
292288
if section == CommonSections.tests_root:
293289
default = "tests" if "tests" in valid_subdirs else None
294290
return valid_subdirs, default
295-
if section == CommonSections.test_framework:
296-
auto_detected = detect_test_framework_from_config_files(Path.cwd())
297-
return ["pytest", "unittest"], auto_detected
298291
if section == CommonSections.formatter_cmds:
299292
return ["disabled", "ruff", "black"], "disabled"
300293
msg = f"Unknown section: {section}"
@@ -455,43 +448,6 @@ def collect_setup_info() -> CLISetupInfo:
455448

456449
ph("cli-tests-root-provided")
457450

458-
test_framework_choices, detected_framework = get_suggestions(CommonSections.test_framework)
459-
autodetected_test_framework = detected_framework or detect_test_framework_from_test_files(tests_root)
460-
461-
framework_message = "⚗️ Let's configure your test framework.\n\n"
462-
if autodetected_test_framework:
463-
framework_message += f"I detected that you're using {autodetected_test_framework}. "
464-
framework_message += "Please confirm or select a different one."
465-
466-
framework_panel = Panel(Text(framework_message, style="blue"), title="⚗️ Test Framework", border_style="bright_blue")
467-
console.print(framework_panel)
468-
console.print()
469-
470-
framework_choices = []
471-
# add icons based on the detected framework
472-
for choice in test_framework_choices:
473-
if choice == "pytest":
474-
framework_choices.append(("🧪 pytest", "pytest"))
475-
elif choice == "unittest":
476-
framework_choices.append(("🐍 unittest", "unittest"))
477-
478-
framework_questions = [
479-
inquirer.List(
480-
"test_framework",
481-
message="Which test framework do you use?",
482-
choices=framework_choices,
483-
default=autodetected_test_framework or "pytest",
484-
carousel=True,
485-
)
486-
]
487-
488-
framework_answers = inquirer.prompt(framework_questions, theme=CodeflashTheme())
489-
if not framework_answers:
490-
apologize_and_exit()
491-
test_framework = framework_answers["test_framework"]
492-
493-
ph("cli-test-framework-provided", {"test_framework": test_framework})
494-
495451
benchmarks_root = None
496452

497453
# TODO: Implement other benchmark framework options
@@ -588,60 +544,13 @@ def collect_setup_info() -> CLISetupInfo:
588544
module_root=str(module_root),
589545
tests_root=str(tests_root),
590546
benchmarks_root=str(benchmarks_root) if benchmarks_root else None,
591-
test_framework=cast("str", test_framework),
592547
ignore_paths=ignore_paths,
593548
formatter=cast("str", formatter),
594549
git_remote=str(git_remote),
595550
enable_telemetry=enable_telemetry,
596551
)
597552

598553

599-
def detect_test_framework_from_config_files(curdir: Path) -> Optional[str]:
600-
test_framework = None
601-
pytest_files = ["pytest.ini", "pyproject.toml", "tox.ini", "setup.cfg"]
602-
pytest_config_patterns = {
603-
"pytest.ini": "[pytest]",
604-
"pyproject.toml": "[tool.pytest.ini_options]",
605-
"tox.ini": "[pytest]",
606-
"setup.cfg": "[tool:pytest]",
607-
}
608-
for pytest_file in pytest_files:
609-
file_path = curdir / pytest_file
610-
if file_path.exists():
611-
with file_path.open(encoding="utf8") as file:
612-
contents = file.read()
613-
if pytest_config_patterns[pytest_file] in contents:
614-
test_framework = "pytest"
615-
break
616-
test_framework = "pytest"
617-
return test_framework
618-
619-
620-
def detect_test_framework_from_test_files(tests_root: Path) -> Optional[str]:
621-
test_framework = None
622-
# Check if any python files contain a class that inherits from unittest.TestCase
623-
for filename in tests_root.iterdir():
624-
if filename.suffix == ".py":
625-
with filename.open(encoding="utf8") as file:
626-
contents = file.read()
627-
try:
628-
node = ast.parse(contents)
629-
except SyntaxError:
630-
continue
631-
if any(
632-
isinstance(item, ast.ClassDef)
633-
and any(
634-
(isinstance(base, ast.Attribute) and base.attr == "TestCase")
635-
or (isinstance(base, ast.Name) and base.id == "TestCase")
636-
for base in item.bases
637-
)
638-
for item in node.body
639-
):
640-
test_framework = "unittest"
641-
break
642-
return test_framework
643-
644-
645554
def check_for_toml_or_setup_file() -> str | None:
646555
click.echo()
647556
click.echo("Checking for pyproject.toml or setup.py…\r", nl=False)
@@ -1060,7 +969,6 @@ def configure_pyproject_toml(
1060969
else:
1061970
codeflash_section["module-root"] = setup_info.module_root
1062971
codeflash_section["tests-root"] = setup_info.tests_root
1063-
codeflash_section["test-framework"] = setup_info.test_framework
1064972
codeflash_section["ignore-paths"] = setup_info.ignore_paths
1065973
if not setup_info.enable_telemetry:
1066974
codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry
@@ -1325,26 +1233,8 @@ def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
13251233
arr[j + 1] = temp
13261234
return arr
13271235
"""
1328-
if args.test_framework == "unittest":
1329-
bubble_sort_test_content = f"""import unittest
1330-
from {os.path.basename(args.module_root)}.bubble_sort import sorter # Keep usage of os.path.basename to avoid pathlib potential incompatibility https://github.com/codeflash-ai/codeflash/pull/1066#discussion_r1801628022
1331-
1332-
class TestBubbleSort(unittest.TestCase):
1333-
def test_sort(self):
1334-
input = [5, 4, 3, 2, 1, 0]
1335-
output = sorter(input)
1336-
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
1337-
1338-
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
1339-
output = sorter(input)
1340-
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
1341-
1342-
input = list(reversed(range(100)))
1343-
output = sorter(input)
1344-
self.assertEqual(output, list(range(100)))
1345-
""" # noqa: PTH119
1346-
elif args.test_framework == "pytest":
1347-
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
1236+
# Always use pytest for tests
1237+
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
13481238
13491239
def test_sort():
13501240
input = [5, 4, 3, 2, 1, 0]

codeflash/code_utils/config_parser.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,6 @@ def parse_config_file(
149149
else:
150150
config[key] = []
151151

152-
if config.get("test-framework"):
153-
assert config["test-framework"] in {"pytest", "unittest"}, (
154-
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
155-
)
156152
# see if this is happening during GitHub actions setup
157153
if config.get("formatter-cmds") and len(config.get("formatter-cmds")) > 0 and not override_formatter_check:
158154
assert config.get("formatter-cmds")[0] != "your-formatter $file", (

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def __init__(
7070
self,
7171
function: FunctionToOptimize,
7272
module_path: str,
73-
test_framework: str,
7473
call_positions: list[CodePosition],
7574
mode: TestingMode = TestingMode.BEHAVIOR,
7675
) -> None:
@@ -79,7 +78,6 @@ def __init__(
7978
self.class_name = None
8079
self.only_function_name = function.function_name
8180
self.module_path = module_path
82-
self.test_framework = test_framework
8381
self.call_positions = call_positions
8482
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
8583
self.class_name = function.top_level_parent_name
@@ -475,7 +473,6 @@ def __init__(
475473
self,
476474
function: FunctionToOptimize,
477475
module_path: str,
478-
test_framework: str,
479476
call_positions: list[CodePosition],
480477
mode: TestingMode = TestingMode.BEHAVIOR,
481478
) -> None:
@@ -484,7 +481,6 @@ def __init__(
484481
self.class_name = None
485482
self.only_function_name = function.function_name
486483
self.module_path = module_path
487-
self.test_framework = test_framework
488484
self.call_positions = call_positions
489485
self.did_instrument = False
490486
# Track function call count per test function
@@ -639,7 +635,6 @@ def inject_async_profiling_into_existing_test(
639635
call_positions: list[CodePosition],
640636
function_to_optimize: FunctionToOptimize,
641637
tests_project_root: Path,
642-
test_framework: str,
643638
mode: TestingMode = TestingMode.BEHAVIOR,
644639
) -> tuple[bool, str | None]:
645640
"""Inject profiling for async function calls by setting environment variables before each call."""
@@ -657,7 +652,7 @@ def inject_async_profiling_into_existing_test(
657652
import_visitor.visit(tree)
658653
func = import_visitor.imported_as
659654

660-
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode)
655+
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, call_positions, mode=mode)
661656
tree = async_instrumenter.visit(tree)
662657

663658
if not async_instrumenter.did_instrument:
@@ -675,12 +670,11 @@ def inject_profiling_into_existing_test(
675670
call_positions: list[CodePosition],
676671
function_to_optimize: FunctionToOptimize,
677672
tests_project_root: Path,
678-
test_framework: str,
679673
mode: TestingMode = TestingMode.BEHAVIOR,
680674
) -> tuple[bool, str | None]:
681675
if function_to_optimize.is_async:
682676
return inject_async_profiling_into_existing_test(
683-
test_path, call_positions, function_to_optimize, tests_project_root, test_framework, mode
677+
test_path, call_positions, function_to_optimize, tests_project_root, mode
684678
)
685679

686680
with test_path.open(encoding="utf8") as f:
@@ -696,7 +690,7 @@ def inject_profiling_into_existing_test(
696690
import_visitor.visit(tree)
697691
func = import_visitor.imported_as
698692

699-
tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode).visit(tree)
693+
tree = InjectPerfOnly(func, test_module_path, call_positions, mode=mode).visit(tree)
700694
new_imports = [
701695
ast.Import(names=[ast.alias(name="time")]),
702696
ast.Import(names=[ast.alias(name="gc")]),

codeflash/lsp/beta.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def write_config(params: WriteConfigParams) -> dict[str, any]:
187187
setup_info = VsCodeSetupInfo(
188188
module_root=getattr(cfg, "module_root", ""),
189189
tests_root=getattr(cfg, "tests_root", ""),
190-
test_framework=getattr(cfg, "test_framework", "pytest"),
191190
formatter=get_formatter_cmds(getattr(cfg, "formatter_cmds", "disabled")),
192191
)
193192

@@ -203,7 +202,6 @@ def write_config(params: WriteConfigParams) -> dict[str, any]:
203202
def get_config_suggestions(_params: any) -> dict[str, any]:
204203
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
205204
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
206-
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
207205
formatter_suggestions, default_formatter = get_suggestions(CommonSections.formatter_cmds)
208206
get_valid_subdirs.cache_clear()
209207

@@ -238,7 +236,6 @@ def get_config_suggestions(_params: any) -> dict[str, any]:
238236
return {
239237
"module_root": {"choices": module_root_suggestions, "default": default_module_root},
240238
"tests_root": {"choices": tests_root_suggestions, "default": default_tests_root},
241-
"test_framework": {"choices": test_framework_suggestions, "default": default_test_framework},
242239
"formatter_cmds": {"choices": formatter_suggestions, "default": default_formatter},
243240
}
244241

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
10511051
call_positions=[test.position for test in tests_in_file_list],
10521052
function_to_optimize=self.function_to_optimize,
10531053
tests_project_root=self.test_cfg.tests_project_rootdir,
1054-
test_framework=self.args.test_framework,
10551054
)
10561055
if not success:
10571056
continue
@@ -1061,7 +1060,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
10611060
call_positions=[test.position for test in tests_in_file_list],
10621061
function_to_optimize=self.function_to_optimize,
10631062
tests_project_root=self.test_cfg.tests_project_rootdir,
1064-
test_framework=self.args.test_framework,
10651063
)
10661064
if not success:
10671065
continue
@@ -1275,7 +1273,7 @@ def setup_and_establish_baseline(
12751273

12761274
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
12771275
if isinstance(original_code_baseline, OriginalCodeBaseline) and (
1278-
not coverage_critic(original_code_baseline.coverage_results, self.args.test_framework)
1276+
not coverage_critic(original_code_baseline.coverage_results)
12791277
or not quantity_of_tests_critic(original_code_baseline)
12801278
):
12811279
if self.args.override_fixtures:
@@ -1597,7 +1595,6 @@ def establish_original_code_baseline(
15971595
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
15981596
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
15991597
# For the original function - run the tests and get the runtime, plus coverage
1600-
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
16011598
success = True
16021599

16031600
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
@@ -1622,7 +1619,7 @@ def establish_original_code_baseline(
16221619
test_files=self.test_files,
16231620
optimization_iteration=0,
16241621
testing_time=total_looping_time,
1625-
enable_coverage=test_framework == "pytest",
1622+
enable_coverage=True,
16261623
code_context=code_context,
16271624
)
16281625
finally:
@@ -1636,7 +1633,7 @@ def establish_original_code_baseline(
16361633
)
16371634
console.rule()
16381635
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
1639-
if not coverage_critic(coverage_results, self.args.test_framework):
1636+
if not coverage_critic(coverage_results):
16401637
did_pass_all_tests = all(result.did_pass for result in behavioral_results)
16411638
if not did_pass_all_tests:
16421639
return Failure("Tests failed to pass for the original code.")

codeflash/optimization/optimizer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def __init__(self, args: Namespace) -> None:
4545
tests_root=args.tests_root,
4646
tests_project_rootdir=args.test_project_root,
4747
project_root_path=args.project_root,
48-
test_framework=args.test_framework,
49-
pytest_cmd=args.pytest_cmd,
48+
pytest_cmd=args.pytest_cmd if hasattr(args, "pytest_cmd") and args.pytest_cmd else "pytest",
5049
benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None,
5150
)
5251

@@ -285,10 +284,9 @@ def run(self) -> None:
285284
file_to_funcs_to_optimize, num_optimizable_functions
286285
)
287286
optimizations_found: int = 0
288-
if self.args.test_framework == "pytest":
289-
self.test_cfg.concolic_test_root_dir = Path(
290-
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
291-
)
287+
self.test_cfg.concolic_test_root_dir = Path(
288+
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
289+
)
292290
try:
293291
ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions})
294292
if num_optimizable_functions == 0:

0 commit comments

Comments
 (0)