2929 replace_function_definitions_in_module ,
3030)
3131from codeflash .code_utils .code_utils import (
32+ choose_weights ,
3233 cleanup_paths ,
3334 create_rank_dictionary_compact ,
35+ create_score_dictionary_from_metrics ,
3436 diff_length ,
3537 extract_unique_errors ,
3638 file_name_from_test_module_name ,
3739 get_run_tmp_file ,
3840 module_name_from_file_path ,
41+ normalize_by_max ,
3942 restore_conftest ,
4043 unified_diff_strings ,
4144)
4649 N_CANDIDATES_EFFECTIVE ,
4750 N_CANDIDATES_LP_EFFECTIVE ,
4851 N_TESTS_TO_GENERATE_EFFECTIVE ,
52+ REFINE_ALL_THRESHOLD ,
53+ REFINED_CANDIDATE_RANKING_WEIGHTS ,
4954 REPAIR_UNMATCHED_PERCENTAGE_LIMIT ,
5055 REPEAT_OPTIMIZATION_PROBABILITY ,
56+ TOP_N_REFINEMENTS ,
5157 TOTAL_LOOPING_TIME_EFFECTIVE ,
5258)
5359from codeflash .code_utils .deduplicate_code import normalize_code
@@ -129,20 +135,24 @@ def __init__(
129135 self ,
130136 initial_candidates : list ,
131137 future_line_profile_results : concurrent .futures .Future ,
132- future_all_refinements : list [concurrent .futures .Future ],
138+ all_refinements_data : list [AIServiceRefinerRequest ],
139+ ai_service_client : AiServiceClient ,
140+ executor : concurrent .futures .ThreadPoolExecutor ,
133141 future_all_code_repair : list [concurrent .futures .Future ],
134142 ) -> None :
135143 self .candidate_queue = queue .Queue ()
136144 self .line_profiler_done = False
137145 self .refinement_done = False
138146 self .candidate_len = len (initial_candidates )
147+ self .ai_service_client = ai_service_client
148+ self .executor = executor
139149
140150 # Initialize queue with initial candidates
141151 for candidate in initial_candidates :
142152 self .candidate_queue .put (candidate )
143153
144154 self .future_line_profile_results = future_line_profile_results
145- self .future_all_refinements = future_all_refinements
155+ self .all_refinements_data = all_refinements_data
146156 self .future_all_code_repair = future_all_code_repair
147157
148158 def get_next_candidate (self ) -> OptimizedCandidate | None :
@@ -177,15 +187,45 @@ def _process_line_profiler_results(self) -> OptimizedCandidate | None:
177187
178188 return self .get_next_candidate ()
179189
190+ def refine_optimizations (self , request : list [AIServiceRefinerRequest ]) -> concurrent .futures .Future :
191+ return self .executor .submit (self .ai_service_client .optimize_python_code_refinement , request = request )
192+
180193 def _process_refinement_results (self ) -> OptimizedCandidate | None :
181- """Process refinement results and add to queue."""
182- if self .future_all_refinements :
194+ """Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined."""
195+ future_refinements : list [concurrent .futures .Future ] = []
196+
197+ if len (self .all_refinements_data ) <= REFINE_ALL_THRESHOLD :
198+ for data in self .all_refinements_data :
199+ future_refinements .append (self .refine_optimizations ([data ])) # noqa: PERF401
200+ else :
201+ diff_lens_list = []
202+ runtimes_list = []
203+ for c in self .all_refinements_data :
204+ diff_lens_list .append (diff_length (c .original_source_code , c .optimized_source_code ))
205+ runtimes_list .append (c .optimized_code_runtime )
206+
207+ runtime_w , diff_w = REFINED_CANDIDATE_RANKING_WEIGHTS
208+ weights = choose_weights (runtime = runtime_w , diff = diff_w )
209+
210+ runtime_norm = normalize_by_max (runtimes_list )
211+ diffs_norm = normalize_by_max (diff_lens_list )
212+ # the lower the better
213+ score_dict = create_score_dictionary_from_metrics (weights , runtime_norm , diffs_norm )
214+ top_n_candidates = int ((TOP_N_REFINEMENTS * len (runtimes_list )) + 0.5 )
215+ top_indecies = sorted (score_dict , key = score_dict .get )[:top_n_candidates ]
216+
217+ for idx in top_indecies :
218+ data = self .all_refinements_data [idx ]
219+ future_refinements .append (self .refine_optimizations ([data ]))
220+
221+ if future_refinements :
183222 logger .info ("loading|Refining generated code for improved quality and performance..." )
184- concurrent .futures .wait (self .future_all_refinements )
223+
224+ concurrent .futures .wait (future_refinements )
185225 refinement_response = []
186226
187- for future_refinement in self . future_all_refinements :
188- possible_refinement = future_refinement .result ()
227+ for f in future_refinements :
228+ possible_refinement = f .result ()
189229 if len (possible_refinement ) > 0 :
190230 refinement_response .append (possible_refinement [0 ])
191231
@@ -197,7 +237,6 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
197237 logger .info (
198238 f"Added { len (refinement_response )} candidates from refinement, total candidates now: { self .candidate_len } "
199239 )
200- self .future_all_refinements = []
201240 self .refinement_done = True
202241
203242 return self .get_next_candidate ()
@@ -278,7 +317,6 @@ def __init__(
278317 max_workers = n_tests + 3 if self .experiment_id is None else n_tests + 4
279318 )
280319 self .optimization_review = ""
281- self .future_all_refinements : list [concurrent .futures .Future ] = []
282320 self .future_all_code_repair : list [concurrent .futures .Future ] = []
283321 self .repair_counter = 0 # track how many repairs we did for each function
284322
@@ -724,14 +762,15 @@ def process_single_candidate(
724762 original_helper_code : dict [Path , str ],
725763 file_path_to_helper_classes : dict [Path , set [str ]],
726764 eval_ctx : CandidateEvaluationContext ,
765+ all_refinements_data : list [AIServiceRefinerRequest ],
727766 ai_service_client : AiServiceClient ,
728767 exp_type : str ,
729768 function_references : str ,
730769 ) -> BestOptimization | None :
731770 """Process a single optimization candidate.
732771
733772 Returns the BestOptimization if the candidate is successful, None otherwise.
734- Updates eval_ctx with results and may append to future_all_refinements .
773+ Updates eval_ctx with results and may append to all_refinements_data .
735774 """
736775 # Cleanup temp files
737776 get_run_tmp_file (Path (f"test_return_values_{ candidate_index } .bin" )).unlink (missing_ok = True )
@@ -829,17 +868,23 @@ def process_single_candidate(
829868
830869 # Queue refinement for non-refined candidates
831870 if candidate .source != OptimizedCandidateSource .REFINE :
832- self .future_all_refinements .append (
833- self .refine_optimizations (
834- valid_optimizations = [best_optimization ],
835- original_code_baseline = original_code_baseline ,
836- code_context = code_context ,
871+ all_refinements_data .append (
872+ AIServiceRefinerRequest (
873+ optimization_id = best_optimization .candidate .optimization_id ,
874+ original_source_code = code_context .read_writable_code .markdown ,
875+ read_only_dependency_code = code_context .read_only_context_code ,
876+ original_code_runtime = original_code_baseline .runtime ,
877+ optimized_source_code = best_optimization .candidate .source_code .markdown ,
878+ optimized_explanation = best_optimization .candidate .explanation ,
879+ optimized_code_runtime = best_optimization .runtime ,
880+ speedup = f"{ int (performance_gain (original_runtime_ns = original_code_baseline .runtime , optimized_runtime_ns = best_optimization .runtime ) * 100 )} %" ,
837881 trace_id = self .get_trace_id (exp_type ),
838- ai_service_client = ai_service_client ,
839- executor = self . executor ,
882+ original_line_profiler_results = original_code_baseline . line_profile_results [ "str_out" ] ,
883+ optimized_line_profiler_results = best_optimization . line_profiler_test_results [ "str_out" ] ,
840884 function_references = function_references ,
841885 )
842- )
886+ original_line_profiler_results = original_code_baseline .line_profile_results ["str_out" ],
887+ optimized_line_profiler_results = best_optimization .line_profiler_test_results ["str_out" ],
843888
844889 # Display runtime information
845890 if is_LSP_enabled ():
@@ -872,7 +917,7 @@ def determine_best_candidate(
872917
873918 # Initialize evaluation context and async tasks
874919 eval_ctx = CandidateEvaluationContext ()
875- self . future_all_refinements . clear ()
920+ all_refinements_data : list [ AIServiceRefinerRequest ] = []
876921 self .future_all_code_repair .clear ()
877922 self .repair_counter = 0
878923
@@ -894,7 +939,7 @@ def determine_best_candidate(
894939 )
895940
896941 processor = CandidateProcessor (
897- candidates , future_line_profile_results , self .future_all_refinements , self .future_all_code_repair
942+ candidates , future_line_profile_results , all_refinements_data , self .future_all_code_repair , self .aiservice_client , self . executor
898943 )
899944 candidate_index = 0
900945
@@ -916,6 +961,7 @@ def determine_best_candidate(
916961 original_helper_code = original_helper_code ,
917962 file_path_to_helper_classes = file_path_to_helper_classes ,
918963 eval_ctx = eval_ctx ,
964+ all_refinements_data = all_refinements_data ,
919965 ai_service_client = ai_service_client ,
920966 exp_type = exp_type ,
921967 function_references = function_references ,
@@ -949,54 +995,6 @@ def determine_best_candidate(
949995
950996 return best_optimization
951997
952- def refine_optimizations (
953- self ,
954- valid_optimizations : list [BestOptimization ],
955- original_code_baseline : OriginalCodeBaseline ,
956- code_context : CodeOptimizationContext ,
957- trace_id : str ,
958- ai_service_client : AiServiceClient ,
959- executor : concurrent .futures .ThreadPoolExecutor ,
960- function_references : str | None = None ,
961- ) -> concurrent .futures .Future :
962- request = [
963- AIServiceRefinerRequest (
964- optimization_id = opt .candidate .optimization_id ,
965- original_source_code = code_context .read_writable_code .markdown ,
966- read_only_dependency_code = code_context .read_only_context_code ,
967- original_code_runtime = humanize_runtime (original_code_baseline .runtime ),
968- optimized_source_code = opt .candidate .source_code .markdown ,
969- optimized_explanation = opt .candidate .explanation ,
970- optimized_code_runtime = humanize_runtime (opt .runtime ),
971- speedup = f"{ int (performance_gain (original_runtime_ns = original_code_baseline .runtime , optimized_runtime_ns = opt .runtime ) * 100 )} %" ,
972- trace_id = trace_id ,
973- original_line_profiler_results = original_code_baseline .line_profile_results ["str_out" ],
974- optimized_line_profiler_results = opt .line_profiler_test_results ["str_out" ],
975- function_references = function_references ,
976- )
977- for opt in valid_optimizations
978- ]
979- return executor .submit (ai_service_client .optimize_python_code_refinement , request = request )
980-
981- def repair_optimization (
982- self ,
983- original_source_code : str ,
984- modified_source_code : str ,
985- test_diffs : list [TestDiff ],
986- trace_id : str ,
987- optimization_id : str ,
988- ai_service_client : AiServiceClient ,
989- executor : concurrent .futures .ThreadPoolExecutor ,
990- ) -> concurrent .futures .Future [OptimizedCandidate | None ]:
991- request = AIServiceCodeRepairRequest (
992- optimization_id = optimization_id ,
993- original_source_code = original_source_code ,
994- modified_source_code = modified_source_code ,
995- test_diffs = test_diffs ,
996- trace_id = trace_id ,
997- )
998- return executor .submit (ai_service_client .code_repair , request = request )
999-
1000998 def log_successful_optimization (
1001999 self , explanation : Explanation , generated_tests : GeneratedTestsList , exp_type : str
10021000 ) -> None :
0 commit comments