Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/e2e-bubblesort-unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ jobs:
- name: Install dependencies (CLI)
run: |
uv sync
uv add timeout_decorator

- name: Run Codeflash to optimize code
id: optimize_code
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/e2e-futurehouse-structure.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 10
EXPECTED_IMPROVEMENT_PCT: 5
CODEFLASH_END_TO_END: 1
steps:
- name: 🛎️ Checkout
Expand Down
2 changes: 0 additions & 2 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def parse_args() -> Namespace:
parser.add_argument(
"--tests-root", type=str, help="Path to the test directory of the project, where all the tests are located."
)
parser.add_argument("--test-framework", choices=["pytest", "unittest"], default="pytest")
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
parser.add_argument(
Expand Down Expand Up @@ -172,7 +171,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
"module_root",
"tests_root",
"benchmarks_root",
"test_framework",
"ignore_paths",
"pytest_cmd",
"formatter_cmds",
Expand Down
114 changes: 2 additions & 112 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import ast
import os
import re
import subprocess
Expand Down Expand Up @@ -59,7 +58,6 @@ class CLISetupInfo:
module_root: str
tests_root: str
benchmarks_root: Union[str, None]
test_framework: str
ignore_paths: list[str]
formatter: Union[str, list[str]]
git_remote: str
Expand All @@ -70,7 +68,6 @@ class CLISetupInfo:
class VsCodeSetupInfo:
module_root: str
tests_root: str
test_framework: str
formatter: Union[str, list[str]]


Expand Down Expand Up @@ -256,7 +253,6 @@ def __init__(self) -> None:
class CommonSections(Enum):
module_root = "module_root"
tests_root = "tests_root"
test_framework = "test_framework"
formatter_cmds = "formatter_cmds"

def get_toml_key(self) -> str:
Expand Down Expand Up @@ -292,9 +288,6 @@ def get_suggestions(section: str) -> tuple[list[str], Optional[str]]:
if section == CommonSections.tests_root:
default = "tests" if "tests" in valid_subdirs else None
return valid_subdirs, default
if section == CommonSections.test_framework:
auto_detected = detect_test_framework_from_config_files(Path.cwd())
return ["pytest", "unittest"], auto_detected
if section == CommonSections.formatter_cmds:
return ["disabled", "ruff", "black"], "disabled"
msg = f"Unknown section: {section}"
Expand Down Expand Up @@ -455,43 +448,6 @@ def collect_setup_info() -> CLISetupInfo:

ph("cli-tests-root-provided")

test_framework_choices, detected_framework = get_suggestions(CommonSections.test_framework)
autodetected_test_framework = detected_framework or detect_test_framework_from_test_files(tests_root)

framework_message = "⚗️ Let's configure your test framework.\n\n"
if autodetected_test_framework:
framework_message += f"I detected that you're using {autodetected_test_framework}. "
framework_message += "Please confirm or select a different one."

framework_panel = Panel(Text(framework_message, style="blue"), title="⚗️ Test Framework", border_style="bright_blue")
console.print(framework_panel)
console.print()

framework_choices = []
# add icons based on the detected framework
for choice in test_framework_choices:
if choice == "pytest":
framework_choices.append(("🧪 pytest", "pytest"))
elif choice == "unittest":
framework_choices.append(("🐍 unittest", "unittest"))

framework_questions = [
inquirer.List(
"test_framework",
message="Which test framework do you use?",
choices=framework_choices,
default=autodetected_test_framework or "pytest",
carousel=True,
)
]

framework_answers = inquirer.prompt(framework_questions, theme=CodeflashTheme())
if not framework_answers:
apologize_and_exit()
test_framework = framework_answers["test_framework"]

ph("cli-test-framework-provided", {"test_framework": test_framework})

benchmarks_root = None

# TODO: Implement other benchmark framework options
Expand Down Expand Up @@ -588,60 +544,13 @@ def collect_setup_info() -> CLISetupInfo:
module_root=str(module_root),
tests_root=str(tests_root),
benchmarks_root=str(benchmarks_root) if benchmarks_root else None,
test_framework=cast("str", test_framework),
ignore_paths=ignore_paths,
formatter=cast("str", formatter),
git_remote=str(git_remote),
enable_telemetry=enable_telemetry,
)


def detect_test_framework_from_config_files(curdir: Path) -> Optional[str]:
test_framework = None
pytest_files = ["pytest.ini", "pyproject.toml", "tox.ini", "setup.cfg"]
pytest_config_patterns = {
"pytest.ini": "[pytest]",
"pyproject.toml": "[tool.pytest.ini_options]",
"tox.ini": "[pytest]",
"setup.cfg": "[tool:pytest]",
}
for pytest_file in pytest_files:
file_path = curdir / pytest_file
if file_path.exists():
with file_path.open(encoding="utf8") as file:
contents = file.read()
if pytest_config_patterns[pytest_file] in contents:
test_framework = "pytest"
break
test_framework = "pytest"
return test_framework


def detect_test_framework_from_test_files(tests_root: Path) -> Optional[str]:
test_framework = None
# Check if any python files contain a class that inherits from unittest.TestCase
for filename in tests_root.iterdir():
if filename.suffix == ".py":
with filename.open(encoding="utf8") as file:
contents = file.read()
try:
node = ast.parse(contents)
except SyntaxError:
continue
if any(
isinstance(item, ast.ClassDef)
and any(
(isinstance(base, ast.Attribute) and base.attr == "TestCase")
or (isinstance(base, ast.Name) and base.id == "TestCase")
for base in item.bases
)
for item in node.body
):
test_framework = "unittest"
break
return test_framework


def check_for_toml_or_setup_file() -> str | None:
click.echo()
click.echo("Checking for pyproject.toml or setup.py…\r", nl=False)
Expand Down Expand Up @@ -1060,7 +969,6 @@ def configure_pyproject_toml(
else:
codeflash_section["module-root"] = setup_info.module_root
codeflash_section["tests-root"] = setup_info.tests_root
codeflash_section["test-framework"] = setup_info.test_framework
codeflash_section["ignore-paths"] = setup_info.ignore_paths
if not setup_info.enable_telemetry:
codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry
Expand Down Expand Up @@ -1325,26 +1233,8 @@ def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
arr[j + 1] = temp
return arr
"""
if args.test_framework == "unittest":
bubble_sort_test_content = f"""import unittest
from {os.path.basename(args.module_root)}.bubble_sort import sorter # Keep usage of os.path.basename to avoid pathlib potential incompatibility https://github.com/codeflash-ai/codeflash/pull/1066#discussion_r1801628022

class TestBubbleSort(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
output = sorter(input)
self.assertEqual(output, [0, 1, 2, 3, 4, 5])

input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = sorter(input)
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])

input = list(reversed(range(100)))
output = sorter(input)
self.assertEqual(output, list(range(100)))
""" # noqa: PTH119
elif args.test_framework == "pytest":
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
# Always use pytest for tests
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter

def test_sort():
input = [5, 4, 3, 2, 1, 0]
Expand Down
4 changes: 0 additions & 4 deletions codeflash/code_utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ def parse_config_file(
else:
config[key] = []

if config.get("test-framework"):
assert config["test-framework"] in {"pytest", "unittest"}, (
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
)
# see if this is happening during GitHub actions setup
if config.get("formatter-cmds") and len(config.get("formatter-cmds")) > 0 and not override_formatter_check:
assert config.get("formatter-cmds")[0] != "your-formatter $file", (
Expand Down
12 changes: 3 additions & 9 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(
self,
function: FunctionToOptimize,
module_path: str,
test_framework: str,
call_positions: list[CodePosition],
mode: TestingMode = TestingMode.BEHAVIOR,
) -> None:
Expand All @@ -79,7 +78,6 @@ def __init__(
self.class_name = None
self.only_function_name = function.function_name
self.module_path = module_path
self.test_framework = test_framework
self.call_positions = call_positions
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
self.class_name = function.top_level_parent_name
Expand Down Expand Up @@ -475,7 +473,6 @@ def __init__(
self,
function: FunctionToOptimize,
module_path: str,
test_framework: str,
call_positions: list[CodePosition],
mode: TestingMode = TestingMode.BEHAVIOR,
) -> None:
Expand All @@ -484,7 +481,6 @@ def __init__(
self.class_name = None
self.only_function_name = function.function_name
self.module_path = module_path
self.test_framework = test_framework
self.call_positions = call_positions
self.did_instrument = False
# Track function call count per test function
Expand Down Expand Up @@ -639,7 +635,6 @@ def inject_async_profiling_into_existing_test(
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
tests_project_root: Path,
test_framework: str,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> tuple[bool, str | None]:
"""Inject profiling for async function calls by setting environment variables before each call."""
Expand All @@ -657,7 +652,7 @@ def inject_async_profiling_into_existing_test(
import_visitor.visit(tree)
func = import_visitor.imported_as

async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode)
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, call_positions, mode=mode)
tree = async_instrumenter.visit(tree)

if not async_instrumenter.did_instrument:
Expand All @@ -675,12 +670,11 @@ def inject_profiling_into_existing_test(
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
tests_project_root: Path,
test_framework: str,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> tuple[bool, str | None]:
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
test_path, call_positions, function_to_optimize, tests_project_root, test_framework, mode
test_path, call_positions, function_to_optimize, tests_project_root, mode
)

with test_path.open(encoding="utf8") as f:
Expand All @@ -696,7 +690,7 @@ def inject_profiling_into_existing_test(
import_visitor.visit(tree)
func = import_visitor.imported_as

tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode).visit(tree)
tree = InjectPerfOnly(func, test_module_path, call_positions, mode=mode).visit(tree)
new_imports = [
ast.Import(names=[ast.alias(name="time")]),
ast.Import(names=[ast.alias(name="gc")]),
Expand Down
3 changes: 0 additions & 3 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def write_config(params: WriteConfigParams) -> dict[str, any]:
setup_info = VsCodeSetupInfo(
module_root=getattr(cfg, "module_root", ""),
tests_root=getattr(cfg, "tests_root", ""),
test_framework=getattr(cfg, "test_framework", "pytest"),
formatter=get_formatter_cmds(getattr(cfg, "formatter_cmds", "disabled")),
)

Expand All @@ -203,7 +202,6 @@ def write_config(params: WriteConfigParams) -> dict[str, any]:
def get_config_suggestions(_params: any) -> dict[str, any]:
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
formatter_suggestions, default_formatter = get_suggestions(CommonSections.formatter_cmds)
get_valid_subdirs.cache_clear()

Expand Down Expand Up @@ -238,7 +236,6 @@ def get_config_suggestions(_params: any) -> dict[str, any]:
return {
"module_root": {"choices": module_root_suggestions, "default": default_module_root},
"tests_root": {"choices": tests_root_suggestions, "default": default_tests_root},
"test_framework": {"choices": test_framework_suggestions, "default": default_test_framework},
"formatter_cmds": {"choices": formatter_suggestions, "default": default_formatter},
}

Expand Down
9 changes: 3 additions & 6 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
test_framework=self.args.test_framework,
)
if not success:
continue
Expand All @@ -1061,7 +1060,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
test_framework=self.args.test_framework,
)
if not success:
continue
Expand Down Expand Up @@ -1275,7 +1273,7 @@ def setup_and_establish_baseline(

original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
if isinstance(original_code_baseline, OriginalCodeBaseline) and (
not coverage_critic(original_code_baseline.coverage_results, self.args.test_framework)
not coverage_critic(original_code_baseline.coverage_results)
or not quantity_of_tests_critic(original_code_baseline)
):
if self.args.override_fixtures:
Expand Down Expand Up @@ -1597,7 +1595,6 @@ def establish_original_code_baseline(
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
# For the original function - run the tests and get the runtime, plus coverage
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018
success = True

test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
Expand All @@ -1622,7 +1619,7 @@ def establish_original_code_baseline(
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=test_framework == "pytest",
enable_coverage=True,
code_context=code_context,
)
finally:
Expand All @@ -1636,7 +1633,7 @@ def establish_original_code_baseline(
)
console.rule()
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
if not coverage_critic(coverage_results, self.args.test_framework):
if not coverage_critic(coverage_results):
did_pass_all_tests = all(result.did_pass for result in behavioral_results)
if not did_pass_all_tests:
return Failure("Tests failed to pass for the original code.")
Expand Down
10 changes: 4 additions & 6 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __init__(self, args: Namespace) -> None:
tests_root=args.tests_root,
tests_project_rootdir=args.test_project_root,
project_root_path=args.project_root,
test_framework=args.test_framework,
pytest_cmd=args.pytest_cmd,
pytest_cmd=args.pytest_cmd if hasattr(args, "pytest_cmd") and args.pytest_cmd else "pytest",
benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None,
)

Expand Down Expand Up @@ -285,10 +284,9 @@ def run(self) -> None:
file_to_funcs_to_optimize, num_optimizable_functions
)
optimizations_found: int = 0
if self.args.test_framework == "pytest":
self.test_cfg.concolic_test_root_dir = Path(
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
)
self.test_cfg.concolic_test_root_dir = Path(
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
)
try:
ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions})
if num_optimizable_functions == 0:
Expand Down
Loading
Loading