11from __future__ import annotations
22
3- import ast
43import os
54import re
65import subprocess
@@ -59,7 +58,6 @@ class CLISetupInfo:
5958 module_root : str
6059 tests_root : str
6160 benchmarks_root : Union [str , None ]
62- test_framework : str
6361 ignore_paths : list [str ]
6462 formatter : Union [str , list [str ]]
6563 git_remote : str
@@ -70,7 +68,6 @@ class CLISetupInfo:
7068class VsCodeSetupInfo :
7169 module_root : str
7270 tests_root : str
73- test_framework : str
7471 formatter : Union [str , list [str ]]
7572
7673
@@ -256,7 +253,6 @@ def __init__(self) -> None:
256253class CommonSections (Enum ):
257254 module_root = "module_root"
258255 tests_root = "tests_root"
259- test_framework = "test_framework"
260256 formatter_cmds = "formatter_cmds"
261257
262258 def get_toml_key (self ) -> str :
@@ -292,9 +288,6 @@ def get_suggestions(section: str) -> tuple[list[str], Optional[str]]:
292288 if section == CommonSections .tests_root :
293289 default = "tests" if "tests" in valid_subdirs else None
294290 return valid_subdirs , default
295- if section == CommonSections .test_framework :
296- auto_detected = detect_test_framework_from_config_files (Path .cwd ())
297- return ["pytest" , "unittest" ], auto_detected
298291 if section == CommonSections .formatter_cmds :
299292 return ["disabled" , "ruff" , "black" ], "disabled"
300293 msg = f"Unknown section: { section } "
@@ -455,43 +448,6 @@ def collect_setup_info() -> CLISetupInfo:
455448
456449 ph ("cli-tests-root-provided" )
457450
458- test_framework_choices , detected_framework = get_suggestions (CommonSections .test_framework )
459- autodetected_test_framework = detected_framework or detect_test_framework_from_test_files (tests_root )
460-
461- framework_message = "⚗️ Let's configure your test framework.\n \n "
462- if autodetected_test_framework :
463- framework_message += f"I detected that you're using { autodetected_test_framework } . "
464- framework_message += "Please confirm or select a different one."
465-
466- framework_panel = Panel (Text (framework_message , style = "blue" ), title = "⚗️ Test Framework" , border_style = "bright_blue" )
467- console .print (framework_panel )
468- console .print ()
469-
470- framework_choices = []
471- # add icons based on the detected framework
472- for choice in test_framework_choices :
473- if choice == "pytest" :
474- framework_choices .append (("🧪 pytest" , "pytest" ))
475- elif choice == "unittest" :
476- framework_choices .append (("🐍 unittest" , "unittest" ))
477-
478- framework_questions = [
479- inquirer .List (
480- "test_framework" ,
481- message = "Which test framework do you use?" ,
482- choices = framework_choices ,
483- default = autodetected_test_framework or "pytest" ,
484- carousel = True ,
485- )
486- ]
487-
488- framework_answers = inquirer .prompt (framework_questions , theme = CodeflashTheme ())
489- if not framework_answers :
490- apologize_and_exit ()
491- test_framework = framework_answers ["test_framework" ]
492-
493- ph ("cli-test-framework-provided" , {"test_framework" : test_framework })
494-
495451 benchmarks_root = None
496452
497453 # TODO: Implement other benchmark framework options
@@ -588,60 +544,13 @@ def collect_setup_info() -> CLISetupInfo:
588544 module_root = str (module_root ),
589545 tests_root = str (tests_root ),
590546 benchmarks_root = str (benchmarks_root ) if benchmarks_root else None ,
591- test_framework = cast ("str" , test_framework ),
592547 ignore_paths = ignore_paths ,
593548 formatter = cast ("str" , formatter ),
594549 git_remote = str (git_remote ),
595550 enable_telemetry = enable_telemetry ,
596551 )
597552
598553
599- def detect_test_framework_from_config_files (curdir : Path ) -> Optional [str ]:
600- test_framework = None
601- pytest_files = ["pytest.ini" , "pyproject.toml" , "tox.ini" , "setup.cfg" ]
602- pytest_config_patterns = {
603- "pytest.ini" : "[pytest]" ,
604- "pyproject.toml" : "[tool.pytest.ini_options]" ,
605- "tox.ini" : "[pytest]" ,
606- "setup.cfg" : "[tool:pytest]" ,
607- }
608- for pytest_file in pytest_files :
609- file_path = curdir / pytest_file
610- if file_path .exists ():
611- with file_path .open (encoding = "utf8" ) as file :
612- contents = file .read ()
613- if pytest_config_patterns [pytest_file ] in contents :
614- test_framework = "pytest"
615- break
616- test_framework = "pytest"
617- return test_framework
618-
619-
620- def detect_test_framework_from_test_files (tests_root : Path ) -> Optional [str ]:
621- test_framework = None
622- # Check if any python files contain a class that inherits from unittest.TestCase
623- for filename in tests_root .iterdir ():
624- if filename .suffix == ".py" :
625- with filename .open (encoding = "utf8" ) as file :
626- contents = file .read ()
627- try :
628- node = ast .parse (contents )
629- except SyntaxError :
630- continue
631- if any (
632- isinstance (item , ast .ClassDef )
633- and any (
634- (isinstance (base , ast .Attribute ) and base .attr == "TestCase" )
635- or (isinstance (base , ast .Name ) and base .id == "TestCase" )
636- for base in item .bases
637- )
638- for item in node .body
639- ):
640- test_framework = "unittest"
641- break
642- return test_framework
643-
644-
645554def check_for_toml_or_setup_file () -> str | None :
646555 click .echo ()
647556 click .echo ("Checking for pyproject.toml or setup.py…\r " , nl = False )
@@ -1060,7 +969,6 @@ def configure_pyproject_toml(
1060969 else :
1061970 codeflash_section ["module-root" ] = setup_info .module_root
1062971 codeflash_section ["tests-root" ] = setup_info .tests_root
1063- codeflash_section ["test-framework" ] = setup_info .test_framework
1064972 codeflash_section ["ignore-paths" ] = setup_info .ignore_paths
1065973 if not setup_info .enable_telemetry :
1066974 codeflash_section ["disable-telemetry" ] = not setup_info .enable_telemetry
@@ -1325,26 +1233,8 @@ def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
13251233 arr[j + 1] = temp
13261234 return arr
13271235"""
1328- if args .test_framework == "unittest" :
1329- bubble_sort_test_content = f"""import unittest
1330- from { os .path .basename (args .module_root )} .bubble_sort import sorter # Keep usage of os.path.basename to avoid pathlib potential incompatibility https://github.com/codeflash-ai/codeflash/pull/1066#discussion_r1801628022
1331-
1332- class TestBubbleSort(unittest.TestCase):
1333- def test_sort(self):
1334- input = [5, 4, 3, 2, 1, 0]
1335- output = sorter(input)
1336- self.assertEqual(output, [0, 1, 2, 3, 4, 5])
1337-
1338- input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
1339- output = sorter(input)
1340- self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
1341-
1342- input = list(reversed(range(100)))
1343- output = sorter(input)
1344- self.assertEqual(output, list(range(100)))
1345- """ # noqa: PTH119
1346- elif args .test_framework == "pytest" :
1347- bubble_sort_test_content = f"""from { Path (args .module_root ).name } .bubble_sort import sorter
1236+ # Always use pytest for tests
1237+ bubble_sort_test_content = f"""from { Path (args .module_root ).name } .bubble_sort import sorter
13481238
13491239def test_sort():
13501240 input = [5, 4, 3, 2, 1, 0]
0 commit comments