Skip to content

Commit 8bb3dea

Browse files
committed
follow up
1 parent 33437d3 commit 8bb3dea

File tree

5 files changed

+33
-134
lines changed

5 files changed

+33
-134
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import ast
43
import os
54
import re
65
import subprocess
@@ -59,7 +58,6 @@ class CLISetupInfo:
5958
module_root: str
6059
tests_root: str
6160
benchmarks_root: Union[str, None]
62-
test_framework: str
6361
ignore_paths: list[str]
6462
formatter: Union[str, list[str]]
6563
git_remote: str
@@ -70,7 +68,6 @@ class CLISetupInfo:
7068
class VsCodeSetupInfo:
7169
module_root: str
7270
tests_root: str
73-
test_framework: str
7471
formatter: Union[str, list[str]]
7572

7673

@@ -256,7 +253,6 @@ def __init__(self) -> None:
256253
class CommonSections(Enum):
257254
module_root = "module_root"
258255
tests_root = "tests_root"
259-
test_framework = "test_framework"
260256
formatter_cmds = "formatter_cmds"
261257

262258
def get_toml_key(self) -> str:
@@ -292,9 +288,6 @@ def get_suggestions(section: str) -> tuple[list[str], Optional[str]]:
292288
if section == CommonSections.tests_root:
293289
default = "tests" if "tests" in valid_subdirs else None
294290
return valid_subdirs, default
295-
if section == CommonSections.test_framework:
296-
auto_detected = detect_test_framework_from_config_files(Path.cwd())
297-
return ["pytest", "unittest"], auto_detected
298291
if section == CommonSections.formatter_cmds:
299292
return ["disabled", "ruff", "black"], "disabled"
300293
msg = f"Unknown section: {section}"
@@ -455,43 +448,6 @@ def collect_setup_info() -> CLISetupInfo:
455448

456449
ph("cli-tests-root-provided")
457450

458-
test_framework_choices, detected_framework = get_suggestions(CommonSections.test_framework)
459-
autodetected_test_framework = detected_framework or detect_test_framework_from_test_files(tests_root)
460-
461-
framework_message = "⚗️ Let's configure your test framework.\n\n"
462-
if autodetected_test_framework:
463-
framework_message += f"I detected that you're using {autodetected_test_framework}. "
464-
framework_message += "Please confirm or select a different one."
465-
466-
framework_panel = Panel(Text(framework_message, style="blue"), title="⚗️ Test Framework", border_style="bright_blue")
467-
console.print(framework_panel)
468-
console.print()
469-
470-
framework_choices = []
471-
# add icons based on the detected framework
472-
for choice in test_framework_choices:
473-
if choice == "pytest":
474-
framework_choices.append(("🧪 pytest", "pytest"))
475-
elif choice == "unittest":
476-
framework_choices.append(("🐍 unittest", "unittest"))
477-
478-
framework_questions = [
479-
inquirer.List(
480-
"test_framework",
481-
message="Which test framework do you use?",
482-
choices=framework_choices,
483-
default=autodetected_test_framework or "pytest",
484-
carousel=True,
485-
)
486-
]
487-
488-
framework_answers = inquirer.prompt(framework_questions, theme=CodeflashTheme())
489-
if not framework_answers:
490-
apologize_and_exit()
491-
test_framework = framework_answers["test_framework"]
492-
493-
ph("cli-test-framework-provided", {"test_framework": test_framework})
494-
495451
benchmarks_root = None
496452

497453
# TODO: Implement other benchmark framework options
@@ -588,60 +544,13 @@ def collect_setup_info() -> CLISetupInfo:
588544
module_root=str(module_root),
589545
tests_root=str(tests_root),
590546
benchmarks_root=str(benchmarks_root) if benchmarks_root else None,
591-
test_framework=cast("str", test_framework),
592547
ignore_paths=ignore_paths,
593548
formatter=cast("str", formatter),
594549
git_remote=str(git_remote),
595550
enable_telemetry=enable_telemetry,
596551
)
597552

598553

599-
def detect_test_framework_from_config_files(curdir: Path) -> Optional[str]:
600-
test_framework = None
601-
pytest_files = ["pytest.ini", "pyproject.toml", "tox.ini", "setup.cfg"]
602-
pytest_config_patterns = {
603-
"pytest.ini": "[pytest]",
604-
"pyproject.toml": "[tool.pytest.ini_options]",
605-
"tox.ini": "[pytest]",
606-
"setup.cfg": "[tool:pytest]",
607-
}
608-
for pytest_file in pytest_files:
609-
file_path = curdir / pytest_file
610-
if file_path.exists():
611-
with file_path.open(encoding="utf8") as file:
612-
contents = file.read()
613-
if pytest_config_patterns[pytest_file] in contents:
614-
test_framework = "pytest"
615-
break
616-
test_framework = "pytest"
617-
return test_framework
618-
619-
620-
def detect_test_framework_from_test_files(tests_root: Path) -> Optional[str]:
621-
test_framework = None
622-
# Check if any python files contain a class that inherits from unittest.TestCase
623-
for filename in tests_root.iterdir():
624-
if filename.suffix == ".py":
625-
with filename.open(encoding="utf8") as file:
626-
contents = file.read()
627-
try:
628-
node = ast.parse(contents)
629-
except SyntaxError:
630-
continue
631-
if any(
632-
isinstance(item, ast.ClassDef)
633-
and any(
634-
(isinstance(base, ast.Attribute) and base.attr == "TestCase")
635-
or (isinstance(base, ast.Name) and base.id == "TestCase")
636-
for base in item.bases
637-
)
638-
for item in node.body
639-
):
640-
test_framework = "unittest"
641-
break
642-
return test_framework
643-
644-
645554
def check_for_toml_or_setup_file() -> str | None:
646555
click.echo()
647556
click.echo("Checking for pyproject.toml or setup.py…\r", nl=False)
@@ -1060,7 +969,6 @@ def configure_pyproject_toml(
1060969
else:
1061970
codeflash_section["module-root"] = setup_info.module_root
1062971
codeflash_section["tests-root"] = setup_info.tests_root
1063-
codeflash_section["test-framework"] = setup_info.test_framework
1064972
codeflash_section["ignore-paths"] = setup_info.ignore_paths
1065973
if not setup_info.enable_telemetry:
1066974
codeflash_section["disable-telemetry"] = not setup_info.enable_telemetry

codeflash/code_utils/config_parser.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,6 @@ def parse_config_file(
149149
else:
150150
config[key] = []
151151

152-
if config.get("test-framework"):
153-
assert config["test-framework"] in {"pytest", "unittest"}, (
154-
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
155-
)
156152
# see if this is happening during GitHub actions setup
157153
if config.get("formatter-cmds") and len(config.get("formatter-cmds")) > 0 and not override_formatter_check:
158154
assert config.get("formatter-cmds")[0] != "your-formatter $file", (

codeflash/lsp/beta.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def write_config(params: WriteConfigParams) -> dict[str, any]:
187187
setup_info = VsCodeSetupInfo(
188188
module_root=getattr(cfg, "module_root", ""),
189189
tests_root=getattr(cfg, "tests_root", ""),
190-
test_framework=getattr(cfg, "test_framework", "pytest"),
191190
formatter=get_formatter_cmds(getattr(cfg, "formatter_cmds", "disabled")),
192191
)
193192

@@ -203,7 +202,6 @@ def write_config(params: WriteConfigParams) -> dict[str, any]:
203202
def get_config_suggestions(_params: any) -> dict[str, any]:
204203
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
205204
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
206-
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
207205
formatter_suggestions, default_formatter = get_suggestions(CommonSections.formatter_cmds)
208206
get_valid_subdirs.cache_clear()
209207

@@ -238,7 +236,6 @@ def get_config_suggestions(_params: any) -> dict[str, any]:
238236
return {
239237
"module_root": {"choices": module_root_suggestions, "default": default_module_root},
240238
"tests_root": {"choices": tests_root_suggestions, "default": default_tests_root},
241-
"test_framework": {"choices": test_framework_suggestions, "default": default_test_framework},
242239
"formatter_cmds": {"choices": formatter_suggestions, "default": default_formatter},
243240
}
244241

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,6 @@ __version__ = "{version}"
303303
module-root = "codeflash"
304304
tests-root = "tests"
305305
benchmarks-root = "tests/benchmarks"
306-
test-framework = "pytest"
307306
formatter-cmds = [
308307
"uvx ruff check --exit-zero --fix $file",
309308
"uvx ruff format $file",

tests/test_cmd_init.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
import pytest
1+
import os
22
import tempfile
33
from pathlib import Path
4+
5+
import pytest
6+
47
from codeflash.cli_cmds.cmd_init import (
5-
is_valid_pyproject_toml,
6-
configure_pyproject_toml,
78
CLISetupInfo,
8-
get_formatter_cmds,
99
VsCodeSetupInfo,
10+
configure_pyproject_toml,
11+
get_formatter_cmds,
1012
get_valid_subdirs,
13+
is_valid_pyproject_toml,
1114
)
12-
import os
1315

1416

1517
@pytest.fixture
@@ -29,11 +31,12 @@ def test_is_valid_pyproject_toml_with_empty_config(temp_dir: Path) -> None:
2931
assert not valid
3032
assert _message == "Missing required field: 'module_root'"
3133

34+
3235
def test_is_valid_pyproject_toml_with_incorrect_module_root(temp_dir: Path) -> None:
3336
with (temp_dir / "pyproject.toml").open(mode="w") as f:
3437
wrong_module_root = temp_dir / "invalid_directory"
3538
f.write(
36-
f"""[tool.codeflash]
39+
"""[tool.codeflash]
3740
module-root = "invalid_directory"
3841
"""
3942
)
@@ -47,7 +50,7 @@ def test_is_valid_pyproject_toml_with_incorrect_tests_root(temp_dir: Path) -> No
4750
with (temp_dir / "pyproject.toml").open(mode="w") as f:
4851
wrong_tests_root = temp_dir / "incorrect_tests_root"
4952
f.write(
50-
f"""[tool.codeflash]
53+
"""[tool.codeflash]
5154
module-root = "."
5255
tests-root = "incorrect_tests_root"
5356
"""
@@ -65,21 +68,21 @@ def test_is_valid_pyproject_toml_with_valid_config(temp_dir: Path) -> None:
6568
"""[tool.codeflash]
6669
module-root = "."
6770
tests-root = "tests"
68-
test-framework = "pytest"
6971
"""
7072
)
7173
f.flush()
7274
valid, config, _message = is_valid_pyproject_toml(temp_dir / "pyproject.toml")
7375
assert valid
7476

77+
7578
def test_get_formatter_cmd(temp_dir: Path) -> None:
7679
assert get_formatter_cmds("black") == ["black $file"]
7780
assert get_formatter_cmds("ruff") == ["ruff check --exit-zero --fix $file", "ruff format $file"]
7881
assert get_formatter_cmds("disabled") == ["disabled"]
7982
assert get_formatter_cmds("don't use a formatter") == ["disabled"]
8083

81-
def test_configure_pyproject_toml_for_cli(temp_dir: Path) -> None:
8284

85+
def test_configure_pyproject_toml_for_cli(temp_dir: Path) -> None:
8386
pyproject_path = temp_dir / "pyproject.toml"
8487

8588
with (pyproject_path).open(mode="w") as f:
@@ -90,7 +93,6 @@ def test_configure_pyproject_toml_for_cli(temp_dir: Path) -> None:
9093
module_root=".",
9194
tests_root="tests",
9295
benchmarks_root=None,
93-
test_framework="pytest",
9496
ignore_paths=[],
9597
formatter="black",
9698
git_remote="origin",
@@ -101,81 +103,78 @@ def test_configure_pyproject_toml_for_cli(temp_dir: Path) -> None:
101103
assert success
102104

103105
config_content = pyproject_path.read_text()
104-
assert """[tool.codeflash]
106+
assert (
107+
config_content
108+
== """[tool.codeflash]
105109
# All paths are relative to this pyproject.toml's directory.
106110
module-root = "."
107111
tests-root = "tests"
108-
test-framework = "pytest"
109112
ignore-paths = []
110113
disable-telemetry = true
111114
formatter-cmds = ["black $file"]
112-
""" == config_content
115+
"""
116+
)
113117
valid, _, _ = is_valid_pyproject_toml(pyproject_path)
114118
assert valid
115119

116-
def test_configure_pyproject_toml_for_vscode_with_empty_config(temp_dir: Path) -> None:
117120

121+
def test_configure_pyproject_toml_for_vscode_with_empty_config(temp_dir: Path) -> None:
118122
pyproject_path = temp_dir / "pyproject.toml"
119123

120124
with (pyproject_path).open(mode="w") as f:
121125
f.write("")
122126
f.flush()
123127
os.mkdir(temp_dir / "tests")
124-
config = VsCodeSetupInfo(
125-
module_root=".",
126-
tests_root="tests",
127-
test_framework="pytest",
128-
formatter="black",
129-
)
128+
config = VsCodeSetupInfo(module_root=".", tests_root="tests", formatter="black")
130129

131130
success = configure_pyproject_toml(config, pyproject_path)
132131
assert success
133132

134133
config_content = pyproject_path.read_text()
135-
assert """[tool.codeflash]
134+
assert (
135+
config_content
136+
== """[tool.codeflash]
136137
module-root = "."
137138
tests-root = "tests"
138-
test-framework = "pytest"
139139
formatter-cmds = ["black $file"]
140-
""" == config_content
140+
"""
141+
)
141142
valid, _, _ = is_valid_pyproject_toml(pyproject_path)
142143
assert valid
143144

145+
144146
def test_configure_pyproject_toml_for_vscode_with_existing_config(temp_dir: Path) -> None:
145147
pyproject_path = temp_dir / "pyproject.toml"
146-
148+
147149
with (pyproject_path).open(mode="w") as f:
148150
f.write("""[tool.codeflash]
149151
module-root = "codeflash"
150152
tests-root = "tests"
151153
benchmarks-root = "tests/benchmarks"
152-
test-framework = "pytest"
153154
formatter-cmds = ["disabled"]
154155
""")
155156
f.flush()
156157
os.mkdir(temp_dir / "tests")
157-
config = VsCodeSetupInfo(
158-
module_root=".",
159-
tests_root="tests",
160-
test_framework="pytest",
161-
formatter="disabled",
162-
)
158+
config = VsCodeSetupInfo(module_root=".", tests_root="tests", formatter="disabled")
163159

164160
success = configure_pyproject_toml(config, pyproject_path)
165161
assert success
166162

167163
config_content = pyproject_path.read_text()
168164
# the benchmarks-root shouldn't get overwritten
169-
assert """[tool.codeflash]
165+
assert (
166+
config_content
167+
== """[tool.codeflash]
170168
module-root = "."
171169
tests-root = "tests"
172170
benchmarks-root = "tests/benchmarks"
173-
test-framework = "pytest"
174171
formatter-cmds = ["disabled"]
175-
""" == config_content
172+
"""
173+
)
176174
valid, _, _ = is_valid_pyproject_toml(pyproject_path)
177175
assert valid
178176

177+
179178
def test_get_valid_subdirs(temp_dir: Path) -> None:
180179
os.mkdir(temp_dir / "dir1")
181180
os.mkdir(temp_dir / "dir2")

0 commit comments

Comments
 (0)