Skip to content

Commit 61ec0d0

Browse files
committed
Merge branch 'main' into ashraf/cf-894-vsc-init-no-way-to-create-tests-dir-when-no-test-dir-present
2 parents 57972ff + a879f11 commit 61ec0d0

32 files changed

+858
-818
lines changed

.github/workflows/e2e-bubblesort-unittest.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ jobs:
6565
- name: Install dependencies (CLI)
6666
run: |
6767
uv sync
68-
uv add timeout_decorator
6968
7069
- name: Run Codeflash to optimize code
7170
id: optimize_code

.github/workflows/e2e-futurehouse-structure.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
COLUMNS: 110
2525
MAX_RETRIES: 3
2626
RETRY_DELAY: 5
27-
EXPECTED_IMPROVEMENT_PCT: 10
27+
EXPECTED_IMPROVEMENT_PCT: 5
2828
CODEFLASH_END_TO_END: 1
2929
steps:
3030
- name: 🛎️ Checkout

codeflash/api/aiservice.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def optimize_python_code( # noqa: D417
153153

154154
if response.status_code == 200:
155155
optimizations_json = response.json()["optimizations"]
156-
logger.info(f"!lsp|Generated {len(optimizations_json)} candidate optimizations.")
157156
console.rule()
158157
end_time = time.perf_counter()
159158
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")

codeflash/cli_cmds/cli.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def parse_args() -> Namespace:
7575
parser.add_argument(
7676
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
7777
)
78-
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
7978
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
8079
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
8180
parser.add_argument(
@@ -172,7 +171,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
172171
"module_root",
173172
"tests_root",
174173
"benchmarks_root",
175-
"test_framework",
176174
"ignore_paths",
177175
"pytest_cmd",
178176
"formatter_cmds",

codeflash/cli_cmds/cmd_init.py

Lines changed: 2 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import ast
43
import os
54
import re
65
import subprocess
@@ -60,7 +59,6 @@ class CLISetupInfo:
6059
module_root: str
6160
tests_root: str
6261
benchmarks_root: Union[str, None]
63-
test_framework: str
6462
ignore_paths: list[str]
6563
formatter: Union[str, list[str]]
6664
git_remote: str
@@ -71,7 +69,6 @@ class CLISetupInfo:
7169
class VsCodeSetupInfo:
7270
module_root: str
7371
tests_root: str
74-
test_framework: str
7572
formatter: Union[str, list[str]]
7673

7774

@@ -257,7 +254,6 @@ def __init__(self) -> None:
257254
class CommonSections(Enum):
258255
module_root = "module_root"
259256
tests_root = "tests_root"
260-
test_framework = "test_framework"
261257
formatter_cmds = "formatter_cmds"
262258

263259
def get_toml_key(self) -> str:
@@ -293,9 +289,6 @@ def get_suggestions(section: str) -> tuple[list[str], Optional[str]]:
293289
if section == CommonSections.tests_root:
294290
default = "tests" if "tests" in valid_subdirs else None
295291
return valid_subdirs, default
296-
if section == CommonSections.test_framework:
297-
auto_detected = detect_test_framework_from_config_files(Path.cwd())
298-
return ["pytest", "unittest"], auto_detected
299292
if section == CommonSections.formatter_cmds:
300293
return ["disabled", "ruff", "black"], "disabled"
301294
msg = f"Unknown section: {section}"
@@ -480,43 +473,6 @@ def collect_setup_info() -> CLISetupInfo:
480473

481474
ph("cli-tests-root-provided")
482475

483-
test_framework_choices, detected_framework = get_suggestions(CommonSections.test_framework)
484-
autodetected_test_framework = detected_framework or detect_test_framework_from_test_files(tests_root)
485-
486-
framework_message = "⚗️ Let's configure your test framework.\n\n"
487-
if autodetected_test_framework:
488-
framework_message += f"I detected that you're using {autodetected_test_framework}. "
489-
framework_message += "Please confirm or select a different one."
490-
491-
framework_panel = Panel(Text(framework_message, style="blue"), title="⚗️ Test Framework", border_style="bright_blue")
492-
console.print(framework_panel)
493-
console.print()
494-
495-
framework_choices = []
496-
# add icons based on the detected framework
497-
for choice in test_framework_choices:
498-
if choice == "pytest":
499-
framework_choices.append(("🧪 pytest", "pytest"))
500-
elif choice == "unittest":
501-
framework_choices.append(("🐍 unittest", "unittest"))
502-
503-
framework_questions = [
504-
inquirer.List(
505-
"test_framework",
506-
message="Which test framework do you use?",
507-
choices=framework_choices,
508-
default=autodetected_test_framework or "pytest",
509-
carousel=True,
510-
)
511-
]
512-
513-
framework_answers = inquirer.prompt(framework_questions, theme=CodeflashTheme())
514-
if not framework_answers:
515-
apologize_and_exit()
516-
test_framework = framework_answers["test_framework"]
517-
518-
ph("cli-test-framework-provided", {"test_framework": test_framework})
519-
520476
benchmarks_root = None
521477

522478
# TODO: Implement other benchmark framework options
@@ -613,60 +569,13 @@ def collect_setup_info() -> CLISetupInfo:
613569
module_root=str(module_root),
614570
tests_root=str(tests_root),
615571
benchmarks_root=str(benchmarks_root) if benchmarks_root else None,
616-
test_framework=cast("str", test_framework),
617572
ignore_paths=ignore_paths,
618573
formatter=cast("str", formatter),
619574
git_remote=str(git_remote),
620575
enable_telemetry=enable_telemetry,
621576
)
622577

623578

624-
def detect_test_framework_from_config_files(curdir: Path) -> Optional[str]:
625-
test_framework = None
626-
pytest_files = ["pytest.ini", "pyproject.toml", "tox.ini", "setup.cfg"]
627-
pytest_config_patterns = {
628-
"pytest.ini": "[pytest]",
629-
"pyproject.toml": "[tool.pytest.ini_options]",
630-
"tox.ini": "[pytest]",
631-
"setup.cfg": "[tool:pytest]",
632-
}
633-
for pytest_file in pytest_files:
634-
file_path = curdir / pytest_file
635-
if file_path.exists():
636-
with file_path.open(encoding="utf8") as file:
637-
contents = file.read()
638-
if pytest_config_patterns[pytest_file] in contents:
639-
test_framework = "pytest"
640-
break
641-
test_framework = "pytest"
642-
return test_framework
643-
644-
645-
def detect_test_framework_from_test_files(tests_root: Path) -> Optional[str]:
646-
test_framework = None
647-
# Check if any python files contain a class that inherits from unittest.TestCase
648-
for filename in tests_root.iterdir():
649-
if filename.suffix == ".py":
650-
with filename.open(encoding="utf8") as file:
651-
contents = file.read()
652-
try:
653-
node = ast.parse(contents)
654-
except SyntaxError:
655-
continue
656-
if any(
657-
isinstance(item, ast.ClassDef)
658-
and any(
659-
(isinstance(base, ast.Attribute) and base.attr == "TestCase")
660-
or (isinstance(base, ast.Name) and base.id == "TestCase")
661-
for base in item.bases
662-
)
663-
for item in node.body
664-
):
665-
test_framework = "unittest"
666-
break
667-
return test_framework
668-
669-
670579
def check_for_toml_or_setup_file() -> str | None:
671580
click.echo()
672581
click.echo("Checking for pyproject.toml or setup.py…\r", nl=False)
@@ -1085,7 +994,6 @@ def configure_pyproject_toml(
1085994
else:
1086995
codeflash_section["module-root"] = setup_info.module_root
1087996
codeflash_section["tests-root"] = setup_info.tests_root
1088-
codeflash_section["test-framework"] = setup_info.test_framework
1089997
codeflash_section["ignore-paths"] = setup_info.ignore_paths
1090998
if not setup_info.enable_telemetry:
1091999
codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry
@@ -1350,26 +1258,8 @@ def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
13501258
arr[j + 1] = temp
13511259
return arr
13521260
"""
1353-
if args.test_framework == "unittest":
1354-
bubble_sort_test_content = f"""import unittest
1355-
from {os.path.basename(args.module_root)}.bubble_sort import sorter # Keep usage of os.path.basename to avoid pathlib potential incompatibility https://github.com/codeflash-ai/codeflash/pull/1066#discussion_r1801628022
1356-
1357-
class TestBubbleSort(unittest.TestCase):
1358-
def test_sort(self):
1359-
input = [5, 4, 3, 2, 1, 0]
1360-
output = sorter(input)
1361-
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
1362-
1363-
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
1364-
output = sorter(input)
1365-
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
1366-
1367-
input = list(reversed(range(100)))
1368-
output = sorter(input)
1369-
self.assertEqual(output, list(range(100)))
1370-
""" # noqa: PTH119
1371-
elif args.test_framework == "pytest":
1372-
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
1261+
# Always use pytest for tests
1262+
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
13731263
13741264
def test_sort():
13751265
input = [5, 4, 3, 2, 1, 0]

codeflash/code_utils/config_parser.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,6 @@ def parse_config_file(
149149
else:
150150
config[key] = []
151151

152-
if config.get("test-framework"):
153-
assert config["test-framework"] in {"pytest", "unittest"}, (
154-
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
155-
)
156152
# see if this is happening during GitHub actions setup
157153
if config.get("formatter-cmds") and len(config.get("formatter-cmds")) > 0 and not override_formatter_check:
158154
assert config.get("formatter-cmds")[0] != "your-formatter $file", (

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def __init__(
7070
self,
7171
function: FunctionToOptimize,
7272
module_path: str,
73-
test_framework: str,
7473
call_positions: list[CodePosition],
7574
mode: TestingMode = TestingMode.BEHAVIOR,
7675
) -> None:
@@ -79,7 +78,6 @@ def __init__(
7978
self.class_name = None
8079
self.only_function_name = function.function_name
8180
self.module_path = module_path
82-
self.test_framework = test_framework
8381
self.call_positions = call_positions
8482
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
8583
self.class_name = function.top_level_parent_name
@@ -475,7 +473,6 @@ def __init__(
475473
self,
476474
function: FunctionToOptimize,
477475
module_path: str,
478-
test_framework: str,
479476
call_positions: list[CodePosition],
480477
mode: TestingMode = TestingMode.BEHAVIOR,
481478
) -> None:
@@ -484,7 +481,6 @@ def __init__(
484481
self.class_name = None
485482
self.only_function_name = function.function_name
486483
self.module_path = module_path
487-
self.test_framework = test_framework
488484
self.call_positions = call_positions
489485
self.did_instrument = False
490486
# Track function call count per test function
@@ -639,7 +635,6 @@ def inject_async_profiling_into_existing_test(
639635
call_positions: list[CodePosition],
640636
function_to_optimize: FunctionToOptimize,
641637
tests_project_root: Path,
642-
test_framework: str,
643638
mode: TestingMode = TestingMode.BEHAVIOR,
644639
) -> tuple[bool, str | None]:
645640
"""Inject profiling for async function calls by setting environment variables before each call."""
@@ -657,7 +652,7 @@ def inject_async_profiling_into_existing_test(
657652
import_visitor.visit(tree)
658653
func = import_visitor.imported_as
659654

660-
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode)
655+
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, call_positions, mode=mode)
661656
tree = async_instrumenter.visit(tree)
662657

663658
if not async_instrumenter.did_instrument:
@@ -675,12 +670,11 @@ def inject_profiling_into_existing_test(
675670
call_positions: list[CodePosition],
676671
function_to_optimize: FunctionToOptimize,
677672
tests_project_root: Path,
678-
test_framework: str,
679673
mode: TestingMode = TestingMode.BEHAVIOR,
680674
) -> tuple[bool, str | None]:
681675
if function_to_optimize.is_async:
682676
return inject_async_profiling_into_existing_test(
683-
test_path, call_positions, function_to_optimize, tests_project_root, test_framework, mode
677+
test_path, call_positions, function_to_optimize, tests_project_root, mode
684678
)
685679

686680
with test_path.open(encoding="utf8") as f:
@@ -696,7 +690,7 @@ def inject_profiling_into_existing_test(
696690
import_visitor.visit(tree)
697691
func = import_visitor.imported_as
698692

699-
tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode).visit(tree)
693+
tree = InjectPerfOnly(func, test_module_path, call_positions, mode=mode).visit(tree)
700694
new_imports = [
701695
ast.Import(names=[ast.alias(name="time")]),
702696
ast.Import(names=[ast.alias(name="gc")]),

codeflash/lsp/beta.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def get_config_value(key: str, default: str = "") -> str:
225225
setup_info = VsCodeSetupInfo(
226226
module_root=module_root,
227227
tests_root=tests_root,
228-
test_framework=get_config_value("test_framework", "pytest"),
229228
formatter=get_formatter_cmds(get_config_value("formatter_cmds", "disabled")),
230229
)
231230

@@ -241,7 +240,6 @@ def get_config_value(key: str, default: str = "") -> str:
241240
def get_config_suggestions(_params: any) -> dict[str, any]:
242241
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
243242
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
244-
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
245243
formatter_suggestions, default_formatter = get_suggestions(CommonSections.formatter_cmds)
246244
get_valid_subdirs.cache_clear()
247245

@@ -276,7 +274,6 @@ def get_config_suggestions(_params: any) -> dict[str, any]:
276274
return {
277275
"module_root": {"choices": module_root_suggestions, "default": default_module_root},
278276
"tests_root": {"choices": tests_root_suggestions, "default": default_tests_root},
279-
"test_framework": {"choices": test_framework_suggestions, "default": default_test_framework},
280277
"formatter_cmds": {"choices": formatter_suggestions, "default": default_formatter},
281278
}
282279

codeflash/lsp/features/perform_optimization.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import concurrent.futures
34
import contextlib
5+
import contextvars
46
import os
57
from typing import TYPE_CHECKING
68

7-
from codeflash.cli_cmds.console import code_print
9+
from codeflash.cli_cmds.console import code_print, logger
810
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
911
from codeflash.either import is_successful
1012

@@ -44,24 +46,48 @@ def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: thr
4446
function_optimizer.function_to_tests = function_to_tests
4547

4648
abort_if_cancelled(cancel_event)
47-
test_setup_result = function_optimizer.generate_and_instrument_tests(
48-
code_context, should_run_experiment=should_run_experiment
49-
)
49+
50+
ctx_tests = contextvars.copy_context()
51+
ctx_opts = contextvars.copy_context()
52+
53+
def run_generate_tests(): # noqa: ANN202
54+
return function_optimizer.generate_and_instrument_tests(code_context)
55+
56+
def run_generate_optimizations(): # noqa: ANN202
57+
return function_optimizer.generate_optimizations(
58+
read_writable_code=code_context.read_writable_code,
59+
read_only_context_code=code_context.read_only_context_code,
60+
run_experiment=should_run_experiment,
61+
)
62+
63+
future_tests = function_optimizer.executor.submit(ctx_tests.run, run_generate_tests)
64+
future_optimizations = function_optimizer.executor.submit(ctx_opts.run, run_generate_optimizations)
65+
66+
logger.info(f"loading|Generating new tests and optimizations for function '{params.functionName}'...")
67+
concurrent.futures.wait([future_tests, future_optimizations])
68+
69+
test_setup_result = future_tests.result()
70+
optimization_result = future_optimizations.result()
71+
5072
abort_if_cancelled(cancel_event)
5173
if not is_successful(test_setup_result):
5274
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
75+
if not is_successful(optimization_result):
76+
return {"functionName": params.functionName, "status": "error", "message": optimization_result.failure()}
77+
5378
(
5479
generated_tests,
5580
function_to_concolic_tests,
5681
concolic_test_str,
57-
optimizations_set,
5882
generated_test_paths,
5983
generated_perf_test_paths,
6084
instrumented_unittests_created_for_function,
6185
original_conftest_content,
62-
function_references,
6386
) = test_setup_result.unwrap()
6487

88+
optimizations_set, function_references = optimization_result.unwrap()
89+
90+
logger.info(f"Generated '{len(optimizations_set.control)}' candidate optimizations.")
6591
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
6692
code_context=code_context,
6793
original_helper_code=original_helper_code,

0 commit comments

Comments
 (0)