11from __future__ import annotations
22
3- import ast
43import os
54import re
65import subprocess
@@ -60,7 +59,6 @@ class CLISetupInfo:
6059 module_root : str
6160 tests_root : str
6261 benchmarks_root : Union [str , None ]
63- test_framework : str
6462 ignore_paths : list [str ]
6563 formatter : Union [str , list [str ]]
6664 git_remote : str
@@ -71,7 +69,6 @@ class CLISetupInfo:
7169class VsCodeSetupInfo :
7270 module_root : str
7371 tests_root : str
74- test_framework : str
7572 formatter : Union [str , list [str ]]
7673
7774
@@ -257,7 +254,6 @@ def __init__(self) -> None:
257254class CommonSections (Enum ):
258255 module_root = "module_root"
259256 tests_root = "tests_root"
260- test_framework = "test_framework"
261257 formatter_cmds = "formatter_cmds"
262258
263259 def get_toml_key (self ) -> str :
@@ -293,9 +289,6 @@ def get_suggestions(section: str) -> tuple[list[str], Optional[str]]:
293289 if section == CommonSections .tests_root :
294290 default = "tests" if "tests" in valid_subdirs else None
295291 return valid_subdirs , default
296- if section == CommonSections .test_framework :
297- auto_detected = detect_test_framework_from_config_files (Path .cwd ())
298- return ["pytest" , "unittest" ], auto_detected
299292 if section == CommonSections .formatter_cmds :
300293 return ["disabled" , "ruff" , "black" ], "disabled"
301294 msg = f"Unknown section: { section } "
@@ -480,43 +473,6 @@ def collect_setup_info() -> CLISetupInfo:
480473
481474 ph ("cli-tests-root-provided" )
482475
483- test_framework_choices , detected_framework = get_suggestions (CommonSections .test_framework )
484- autodetected_test_framework = detected_framework or detect_test_framework_from_test_files (tests_root )
485-
486- framework_message = "⚗️ Let's configure your test framework.\n \n "
487- if autodetected_test_framework :
488- framework_message += f"I detected that you're using { autodetected_test_framework } . "
489- framework_message += "Please confirm or select a different one."
490-
491- framework_panel = Panel (Text (framework_message , style = "blue" ), title = "⚗️ Test Framework" , border_style = "bright_blue" )
492- console .print (framework_panel )
493- console .print ()
494-
495- framework_choices = []
496- # add icons based on the detected framework
497- for choice in test_framework_choices :
498- if choice == "pytest" :
499- framework_choices .append (("🧪 pytest" , "pytest" ))
500- elif choice == "unittest" :
501- framework_choices .append (("🐍 unittest" , "unittest" ))
502-
503- framework_questions = [
504- inquirer .List (
505- "test_framework" ,
506- message = "Which test framework do you use?" ,
507- choices = framework_choices ,
508- default = autodetected_test_framework or "pytest" ,
509- carousel = True ,
510- )
511- ]
512-
513- framework_answers = inquirer .prompt (framework_questions , theme = CodeflashTheme ())
514- if not framework_answers :
515- apologize_and_exit ()
516- test_framework = framework_answers ["test_framework" ]
517-
518- ph ("cli-test-framework-provided" , {"test_framework" : test_framework })
519-
520476 benchmarks_root = None
521477
522478 # TODO: Implement other benchmark framework options
@@ -613,60 +569,13 @@ def collect_setup_info() -> CLISetupInfo:
613569 module_root = str (module_root ),
614570 tests_root = str (tests_root ),
615571 benchmarks_root = str (benchmarks_root ) if benchmarks_root else None ,
616- test_framework = cast ("str" , test_framework ),
617572 ignore_paths = ignore_paths ,
618573 formatter = cast ("str" , formatter ),
619574 git_remote = str (git_remote ),
620575 enable_telemetry = enable_telemetry ,
621576 )
622577
623578
624- def detect_test_framework_from_config_files (curdir : Path ) -> Optional [str ]:
625- test_framework = None
626- pytest_files = ["pytest.ini" , "pyproject.toml" , "tox.ini" , "setup.cfg" ]
627- pytest_config_patterns = {
628- "pytest.ini" : "[pytest]" ,
629- "pyproject.toml" : "[tool.pytest.ini_options]" ,
630- "tox.ini" : "[pytest]" ,
631- "setup.cfg" : "[tool:pytest]" ,
632- }
633- for pytest_file in pytest_files :
634- file_path = curdir / pytest_file
635- if file_path .exists ():
636- with file_path .open (encoding = "utf8" ) as file :
637- contents = file .read ()
638- if pytest_config_patterns [pytest_file ] in contents :
639- test_framework = "pytest"
640- break
641- test_framework = "pytest"
642- return test_framework
643-
644-
645- def detect_test_framework_from_test_files (tests_root : Path ) -> Optional [str ]:
646- test_framework = None
647- # Check if any python files contain a class that inherits from unittest.TestCase
648- for filename in tests_root .iterdir ():
649- if filename .suffix == ".py" :
650- with filename .open (encoding = "utf8" ) as file :
651- contents = file .read ()
652- try :
653- node = ast .parse (contents )
654- except SyntaxError :
655- continue
656- if any (
657- isinstance (item , ast .ClassDef )
658- and any (
659- (isinstance (base , ast .Attribute ) and base .attr == "TestCase" )
660- or (isinstance (base , ast .Name ) and base .id == "TestCase" )
661- for base in item .bases
662- )
663- for item in node .body
664- ):
665- test_framework = "unittest"
666- break
667- return test_framework
668-
669-
670579def check_for_toml_or_setup_file () -> str | None :
671580 click .echo ()
672581 click .echo ("Checking for pyproject.toml or setup.py…\r " , nl = False )
@@ -1085,7 +994,6 @@ def configure_pyproject_toml(
1085994 else :
1086995 codeflash_section ["module-root" ] = setup_info .module_root
1087996 codeflash_section ["tests-root" ] = setup_info .tests_root
1088- codeflash_section ["test-framework" ] = setup_info .test_framework
1089997 codeflash_section ["ignore-paths" ] = setup_info .ignore_paths
1090998 if not setup_info .enable_telemetry :
1091999 codeflash_section ["disable-telemetry" ] = not setup_info .enable_telemetry
@@ -1350,26 +1258,8 @@ def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
13501258 arr[j + 1] = temp
13511259 return arr
13521260"""
1353- if args .test_framework == "unittest" :
1354- bubble_sort_test_content = f"""import unittest
1355- from { os .path .basename (args .module_root )} .bubble_sort import sorter # Keep usage of os.path.basename to avoid pathlib potential incompatibility https://github.com/codeflash-ai/codeflash/pull/1066#discussion_r1801628022
1356-
1357- class TestBubbleSort(unittest.TestCase):
1358- def test_sort(self):
1359- input = [5, 4, 3, 2, 1, 0]
1360- output = sorter(input)
1361- self.assertEqual(output, [0, 1, 2, 3, 4, 5])
1362-
1363- input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
1364- output = sorter(input)
1365- self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
1366-
1367- input = list(reversed(range(100)))
1368- output = sorter(input)
1369- self.assertEqual(output, list(range(100)))
1370- """ # noqa: PTH119
1371- elif args .test_framework == "pytest" :
1372- bubble_sort_test_content = f"""from { Path (args .module_root ).name } .bubble_sort import sorter
1261+ # Always use pytest for tests
1262+ bubble_sort_test_content = f"""from { Path (args .module_root ).name } .bubble_sort import sorter
13731263
13741264def test_sort():
13751265 input = [5, 4, 3, 2, 1, 0]
0 commit comments