@@ -1053,7 +1053,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
10531053 call_positions = [test .position for test in tests_in_file_list ],
10541054 function_to_optimize = self .function_to_optimize ,
10551055 tests_project_root = self .test_cfg .tests_project_rootdir ,
1056- test_framework = self . args . test_framework ,
1056+ test_framework = "pytest" ,
10571057 )
10581058 if not success :
10591059 continue
@@ -1063,7 +1063,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
10631063 call_positions = [test .position for test in tests_in_file_list ],
10641064 function_to_optimize = self .function_to_optimize ,
10651065 tests_project_root = self .test_cfg .tests_project_rootdir ,
1066- test_framework = self . args . test_framework ,
1066+ test_framework = "pytest" ,
10671067 )
10681068 if not success :
10691069 continue
@@ -1271,7 +1271,7 @@ def setup_and_establish_baseline(
12711271
12721272 original_code_baseline , test_functions_to_remove = baseline_result .unwrap ()
12731273 if isinstance (original_code_baseline , OriginalCodeBaseline ) and (
1274- not coverage_critic (original_code_baseline .coverage_results , self . args . test_framework )
1274+ not coverage_critic (original_code_baseline .coverage_results , "pytest" )
12751275 or not quantity_of_tests_critic (original_code_baseline )
12761276 ):
12771277 if self .args .override_fixtures :
@@ -1593,7 +1593,7 @@ def establish_original_code_baseline(
15931593 ) -> Result [tuple [OriginalCodeBaseline , list [str ]], str ]:
15941594 line_profile_results = {"timings" : {}, "unit" : 0 , "str_out" : "" }
15951595 # For the original function - run the tests and get the runtime, plus coverage
1596- assert ( test_framework := self . args . test_framework ) in { "pytest" , "unittest" } # noqa: RUF018
1596+ test_framework = "pytest" # Always use pytest for all tests
15971597 success = True
15981598
15991599 test_env = self .get_test_env (codeflash_loop_index = 0 , codeflash_test_iteration = 0 , codeflash_tracer_disable = 1 )
@@ -1618,7 +1618,7 @@ def establish_original_code_baseline(
16181618 test_files = self .test_files ,
16191619 optimization_iteration = 0 ,
16201620 testing_time = total_looping_time ,
1621- enable_coverage = test_framework == " pytest" ,
1621+ enable_coverage = True , # Always enable coverage with pytest
16221622 code_context = code_context ,
16231623 )
16241624 finally :
@@ -1632,7 +1632,7 @@ def establish_original_code_baseline(
16321632 )
16331633 console .rule ()
16341634 return Failure ("Failed to establish a baseline for the original code - bevhavioral tests failed." )
1635- if not coverage_critic (coverage_results , self . args . test_framework ):
1635+ if not coverage_critic (coverage_results , "pytest" ):
16361636 did_pass_all_tests = all (result .did_pass for result in behavioral_results )
16371637 if not did_pass_all_tests :
16381638 return Failure ("Tests failed to pass for the original code." )
0 commit comments