diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 744f76087..0ea380059 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -22,11 +22,11 @@ from typing import Annotated, Optional, cast from jedi.api.classes import Name -from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError +from pydantic import AfterValidator, BaseModel, ConfigDict, Field, PrivateAttr, ValidationError from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger -from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code +from codeflash.code_utils.code_utils import diff_length, module_name_from_file_path, validate_python_code from codeflash.code_utils.env_utils import is_end_to_end from codeflash.verification.comparator import comparator @@ -346,6 +346,73 @@ class OptimizationSet(BaseModel): experiment: Optional[list[OptimizedCandidate]] +@dataclass +class CandidateEvaluationContext: + """Holds tracking state during candidate evaluation in determine_best_candidate.""" + + speedup_ratios: dict[str, float | None] = Field(default_factory=dict) + optimized_runtimes: dict[str, float | None] = Field(default_factory=dict) + is_correct: dict[str, bool] = Field(default_factory=dict) + optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict) + ast_code_to_id: dict = Field(default_factory=dict) + optimizations_post: dict[str, str] = Field(default_factory=dict) + valid_optimizations: list = Field(default_factory=list) + + def record_failed_candidate(self, optimization_id: str) -> None: + """Record results for a failed candidate.""" + self.optimized_runtimes[optimization_id] = None + self.is_correct[optimization_id] = False + self.speedup_ratios[optimization_id] = None + + def record_successful_candidate(self, optimization_id: str, runtime: float, speedup: float) -> None: + """Record results for a successful candidate.""" + self.optimized_runtimes[optimization_id] = runtime + self.is_correct[optimization_id] = True + self.speedup_ratios[optimization_id] = speedup + + def record_line_profiler_result(self, optimization_id: str, result: str) -> None: + """Record line profiler results for a candidate.""" + self.optimized_line_profiler_results[optimization_id] = result + + def handle_duplicate_candidate( + self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext + ) -> None: + """Handle a candidate that has been seen before.""" + past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] + + # Copy results from the previous evaluation + self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios[past_opt_id] + self.is_correct[candidate.optimization_id] = self.is_correct[past_opt_id] + self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes[past_opt_id] + + # Line profiler results only available for successful runs + if past_opt_id in self.optimized_line_profiler_results: + self.optimized_line_profiler_results[candidate.optimization_id] = self.optimized_line_profiler_results[ + past_opt_id + ] + + self.optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ + "shorter_source_code" + ].markdown + self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown + + # Update to shorter code if this candidate has a shorter diff + new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) + if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]: + self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code + self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + + def register_new_candidate( + self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext + ) -> None: + """Register a new candidate that hasn't been seen before.""" + self.ast_code_to_id[normalized_code] = { + "optimization_id": candidate.optimization_id, + "shorter_source_code": candidate.source_code, + "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), + } + + @dataclass(frozen=True) class TestsInFile: test_file: Path diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index a95d0920e..4e659a73b 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -69,6 +69,7 @@ from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( BestOptimization, + CandidateEvaluationContext, CodeOptimizationContext, GeneratedTests, GeneratedTestsList, @@ -456,290 +457,147 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) - def determine_best_candidate( + def get_trace_id(self, exp_type: str) -> str: + """Get the trace ID for the current experiment type.""" + if self.experiment_id: + return self.function_trace_id[:-4] + exp_type + return self.function_trace_id + + def build_runtime_info_tree( self, + candidate_index: int, + candidate_result: OptimizedCandidateResult, + original_code_baseline: OriginalCodeBaseline, + perf_gain: float, *, - candidates: list[OptimizedCandidate], + is_successful_candidate: bool, + ) -> Tree: + """Build a Tree display for runtime information of a candidate.""" + tree = Tree(f"Candidate #{candidate_index} - Runtime Information ⌛") + + is_async = original_code_baseline.async_throughput is not None and candidate_result.async_throughput is not None + + if is_successful_candidate: + if is_async: + throughput_gain_value = throughput_gain( + original_throughput=original_code_baseline.async_throughput, + optimized_throughput=candidate_result.async_throughput, + ) + tree.add("This candidate has better async throughput than the original code. 🚀") + tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions") + tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions") + tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%") + tree.add(f"Throughput ratio: {throughput_gain_value + 1:.3f}X") + else: + tree.add("This candidate is faster than the original code. 🚀") + tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") + tree.add( + f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " + f"(measured over {candidate_result.max_loop_count} " + f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" + ) + tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") + tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + # Not a successful optimization candidate + elif is_async: + throughput_gain_value = throughput_gain( + original_throughput=original_code_baseline.async_throughput, + optimized_throughput=candidate_result.async_throughput, + ) + tree.add(f"Async throughput: {candidate_result.async_throughput} executions") + tree.add(f"Throughput change: {throughput_gain_value * 100:.1f}%") + tree.add( + f"(Runtime for reference: {humanize_runtime(candidate_result.best_test_runtime)} over " + f"{candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})" + ) + else: + tree.add( + f"Summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " + f"(measured over {candidate_result.max_loop_count} " + f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" + ) + tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") + tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + + return tree + + def handle_successful_candidate( + self, + candidate: OptimizedCandidate, + candidate_result: OptimizedCandidateResult, code_context: CodeOptimizationContext, original_code_baseline: OriginalCodeBaseline, original_helper_code: dict[Path, str], - file_path_to_helper_classes: dict[Path, set[str]], - exp_type: str, - function_references: str, - ) -> BestOptimization | None: - best_optimization: BestOptimization | None = None - _best_runtime_until_now = original_code_baseline.runtime - - speedup_ratios: dict[str, float | None] = {} - optimized_runtimes: dict[str, float | None] = {} - is_correct = {} - optimized_line_profiler_results: dict[str, str] = {} - - logger.info( - f"Determining best optimization candidate (out of {len(candidates)}) for " - f"{self.function_to_optimize.qualified_name}…" + candidate_index: int, + eval_ctx: CandidateEvaluationContext, + ) -> tuple[BestOptimization, Tree | None]: + """Handle a successful optimization candidate. + + Returns the BestOptimization and optional benchmark tree. + """ + line_profile_test_results = self.line_profiler_step( + code_context=code_context, original_helper_code=original_helper_code, candidate_index=candidate_index ) - console.rule() + eval_ctx.record_line_profiler_result(candidate.optimization_id, line_profile_test_results["str_out"]) - future_all_refinements: list[concurrent.futures.Future] = [] - ast_code_to_id = {} - valid_optimizations = [] - optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated + replay_perf_gain = {} + benchmark_tree = None - # Start a new thread for AI service request - ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client - future_line_profile_results = self.executor.submit( - ai_service_client.optimize_python_code_line_profiler, - source_code=code_context.read_writable_code.markdown, - dependency_code=code_context.read_only_context_code, - trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, - line_profiler_results=original_code_baseline.line_profile_results["str_out"], - num_candidates=N_CANDIDATES_LP_EFFECTIVE, - experiment_metadata=ExperimentMetadata( - id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment" + if self.args.benchmark: + test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks( + self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root ) - if self.experiment_id - else None, - ) - - # Initialize candidate processor - processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements) - candidate_index = 0 - - # Process candidates using queue-based approach - while not processor.is_done(): - candidate = processor.get_next_candidate() - if candidate is None: - logger.debug("everything done, exiting") - break - - try: - candidate_index += 1 - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) - logger.info(f"h3|Optimization candidate {candidate_index}/{processor.candidate_len}:") - code_print( - candidate.source_code.flat, - file_name=f"candidate_{candidate_index}.py", - lsp_message_id=LSPMessageId.CANDIDATE.value, + if len(test_results_by_benchmark) > 0: + benchmark_tree = Tree("Speedup percentage on benchmarks:") + for benchmark_key, candidate_test_results in test_results_by_benchmark.items(): + original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[ + benchmark_key + ].total_passed_runtime() + candidate_replay_runtime = candidate_test_results.total_passed_runtime() + replay_perf_gain[benchmark_key] = performance_gain( + original_runtime_ns=original_code_replay_runtime, optimized_runtime_ns=candidate_replay_runtime ) - # map ast normalized code to diff len, unnormalized code - # map opt id to the shortest unnormalized code - try: - did_update = self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, - optimized_code=candidate.source_code, - original_helper_code=original_helper_code, - ) - if not did_update: - logger.warning( - "force_lsp|No functions were replaced in the optimized code. Skipping optimization candidate." - ) - console.rule() - continue - except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: - logger.error(e) - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) - continue - # check if this code has been evaluated before by checking the ast normalized code string - normalized_code = normalize_code(candidate.source_code.flat.strip()) - if normalized_code in ast_code_to_id: - logger.info( - "Current candidate has been encountered before in testing, Skipping optimization candidate." - ) - past_opt_id = ast_code_to_id[normalized_code]["optimization_id"] - # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes - speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] - is_correct[candidate.optimization_id] = is_correct[past_opt_id] - optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] - # line profiler results only available for successful runs - if past_opt_id in optimized_line_profiler_results: - optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[ - past_opt_id - ] - optimizations_post[candidate.optimization_id] = ast_code_to_id[normalized_code][ - "shorter_source_code" - ].markdown - optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown - new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) - if ( - new_diff_len < ast_code_to_id[normalized_code]["diff_len"] - ): # new candidate has a shorter diff than the previously encountered one - ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code - ast_code_to_id[normalized_code]["diff_len"] = new_diff_len - continue - ast_code_to_id[normalized_code] = { - "optimization_id": candidate.optimization_id, - "shorter_source_code": candidate.source_code, - "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), - } - run_results = self.run_optimized_candidate( - optimization_candidate_index=candidate_index, - baseline_results=original_code_baseline, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, - ) - console.rule() - if not is_successful(run_results): - optimized_runtimes[candidate.optimization_id] = None - is_correct[candidate.optimization_id] = False - speedup_ratios[candidate.optimization_id] = None - else: - candidate_result: OptimizedCandidateResult = run_results.unwrap() - best_test_runtime = candidate_result.best_test_runtime - optimized_runtimes[candidate.optimization_id] = best_test_runtime - is_correct[candidate.optimization_id] = True - perf_gain = performance_gain( - original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime - ) - speedup_ratios[candidate.optimization_id] = perf_gain - - tree = Tree(f"Candidate #{candidate_index} - Runtime Information ⌛") - benchmark_tree = None - if speedup_critic( - candidate_result, - original_code_baseline.runtime, - best_runtime_until_now=None, - original_async_throughput=original_code_baseline.async_throughput, - best_throughput_until_now=None, - ) and quantity_of_tests_critic(candidate_result): - # For async functions, prioritize throughput metrics over runtime - is_async = ( - original_code_baseline.async_throughput is not None - and candidate_result.async_throughput is not None - ) + benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%") - if is_async: - throughput_gain_value = throughput_gain( - original_throughput=original_code_baseline.async_throughput, - optimized_throughput=candidate_result.async_throughput, - ) - tree.add("This candidate has better async throughput than the original code. 🚀") - tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions") - tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions") - tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%") - tree.add(f"Throughput ratio: {throughput_gain_value + 1:.3f}X") - else: - tree.add("This candidate is faster than the original code. 🚀") - tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") - tree.add( - f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " - f"(measured over {candidate_result.max_loop_count} " - f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" - ) - tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") - tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") - line_profile_test_results = self.line_profiler_step( - code_context=code_context, - original_helper_code=original_helper_code, - candidate_index=candidate_index, - ) - optimized_line_profiler_results[candidate.optimization_id] = line_profile_test_results[ - "str_out" - ] - replay_perf_gain = {} - if self.args.benchmark: - test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks( - self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root - ) - if len(test_results_by_benchmark) > 0: - benchmark_tree = Tree("Speedup percentage on benchmarks:") - for benchmark_key, candidate_test_results in test_results_by_benchmark.items(): - original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[ - benchmark_key - ].total_passed_runtime() - candidate_replay_runtime = candidate_test_results.total_passed_runtime() - replay_perf_gain[benchmark_key] = performance_gain( - original_runtime_ns=original_code_replay_runtime, - optimized_runtime_ns=candidate_replay_runtime, - ) - benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%") - best_optimization = BestOptimization( - candidate=candidate, - helper_functions=code_context.helper_functions, - code_context=code_context, - runtime=best_test_runtime, - line_profiler_test_results=line_profile_test_results, - winning_behavior_test_results=candidate_result.behavior_test_results, - replay_performance_gain=replay_perf_gain if self.args.benchmark else None, - winning_benchmarking_test_results=candidate_result.benchmarking_test_results, - winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results, - async_throughput=candidate_result.async_throughput, - ) - valid_optimizations.append(best_optimization) - # queue corresponding refined optimization for best optimization - if not candidate.optimization_id.endswith("refi"): - future_all_refinements.append( - self.refine_optimizations( - valid_optimizations=[best_optimization], - original_code_baseline=original_code_baseline, - code_context=code_context, - trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - ai_service_client=ai_service_client, - executor=self.executor, - function_references=function_references, - ) - ) - else: - # For async functions, prioritize throughput metrics over runtime even for slow candidates - is_async = ( - original_code_baseline.async_throughput is not None - and candidate_result.async_throughput is not None - ) + best_optimization = BestOptimization( + candidate=candidate, + helper_functions=code_context.helper_functions, + code_context=code_context, + runtime=candidate_result.best_test_runtime, + line_profiler_test_results=line_profile_test_results, + winning_behavior_test_results=candidate_result.behavior_test_results, + replay_performance_gain=replay_perf_gain if self.args.benchmark else None, + winning_benchmarking_test_results=candidate_result.benchmarking_test_results, + winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results, + async_throughput=candidate_result.async_throughput, + ) - if is_async: - throughput_gain_value = throughput_gain( - original_throughput=original_code_baseline.async_throughput, - optimized_throughput=candidate_result.async_throughput, - ) - tree.add(f"Async throughput: {candidate_result.async_throughput} executions") - tree.add(f"Throughput change: {throughput_gain_value * 100:.1f}%") - tree.add( - f"(Runtime for reference: {humanize_runtime(best_test_runtime)} over " - f"{candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})" - ) - else: - tree.add( - f"Summed runtime: {humanize_runtime(best_test_runtime)} " - f"(measured over {candidate_result.max_loop_count} " - f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" - ) - tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") - tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + return best_optimization, benchmark_tree - if is_LSP_enabled(): - lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree))) - else: - console.print(tree) - if self.args.benchmark and benchmark_tree: - console.print(benchmark_tree) - console.rule() - except KeyboardInterrupt as e: - logger.exception(f"Optimization interrupted: {e}") - raise - finally: - # reset for the next candidate - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) - if not valid_optimizations: + def select_best_optimization( + self, + eval_ctx: CandidateEvaluationContext, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + ai_service_client: AiServiceClient, + exp_type: str, + function_references: str, + ) -> BestOptimization | None: + """Select the best optimization from valid candidates.""" + if not eval_ctx.valid_optimizations: return None - # need to figure out the best candidate here before we return best_optimization - # reassign the shorter code here + valid_candidates_with_shorter_code = [] diff_lens_list = [] # character level diff speedups_list = [] optimization_ids = [] diff_strs = [] runtimes_list = [] - for valid_opt in valid_optimizations: + + for valid_opt in eval_ctx.valid_optimizations: valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip()) new_candidate_with_shorter_code = OptimizedCandidate( - source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], + source_code=eval_ctx.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], optimization_id=valid_opt.candidate.optimization_id, explanation=valid_opt.candidate.explanation, ) @@ -758,7 +616,7 @@ def determine_best_candidate( valid_candidates_with_shorter_code.append(new_best_opt) diff_lens_list.append( diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat) - ) # char level diff + ) diff_strs.append( unified_diff_strings(code_context.read_writable_code.flat, new_best_opt.candidate.source_code.flat) ) @@ -770,13 +628,14 @@ def determine_best_candidate( ) optimization_ids.append(new_best_opt.candidate.optimization_id) runtimes_list.append(new_best_opt.runtime) + if len(optimization_ids) > 1: future_ranking = self.executor.submit( ai_service_client.generate_ranking, diffs=diff_strs, optimization_ids=optimization_ids, speedups=speedups_list, - trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + trace_id=self.get_trace_id(exp_type), function_references=function_references, ) concurrent.futures.wait([future_ranking]) @@ -786,25 +645,262 @@ def determine_best_candidate( else: diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list) runtimes_ranking = create_rank_dictionary_compact(runtimes_list) - # TODO: better way to resolve conflicts with same min ranking overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking} min_key = min(overall_ranking, key=overall_ranking.get) elif len(optimization_ids) == 1: - min_key = 0 # only one candidate in valid _opts, already returns if there are no valid candidates - else: # 0? shouldn't happen, but it's there to escape potential bugs + min_key = 0 + else: return None - best_optimization = valid_candidates_with_shorter_code[min_key] - # reassign code string which is the shortest + + return valid_candidates_with_shorter_code[min_key] + + def log_evaluation_results( + self, + eval_ctx: CandidateEvaluationContext, + best_optimization: BestOptimization, + original_code_baseline: OriginalCodeBaseline, + ai_service_client: AiServiceClient, + exp_type: str, + ) -> None: + """Log evaluation results to the AI service.""" ai_service_client.log_results( - function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, - speedup_ratio=speedup_ratios, + function_trace_id=self.get_trace_id(exp_type), + speedup_ratio=eval_ctx.speedup_ratios, original_runtime=original_code_baseline.runtime, - optimized_runtime=optimized_runtimes, - is_correct=is_correct, - optimized_line_profiler_results=optimized_line_profiler_results, - optimizations_post=optimizations_post, + optimized_runtime=eval_ctx.optimized_runtimes, + is_correct=eval_ctx.is_correct, + optimized_line_profiler_results=eval_ctx.optimized_line_profiler_results, + optimizations_post=eval_ctx.optimizations_post, metadata={"best_optimization_id": best_optimization.candidate.optimization_id}, ) + + def process_single_candidate( + self, + candidate: OptimizedCandidate, + candidate_index: int, + total_candidates: int, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + eval_ctx: CandidateEvaluationContext, + future_all_refinements: list[concurrent.futures.Future], + ai_service_client: AiServiceClient, + exp_type: str, + function_references: str, + ) -> BestOptimization | None: + """Process a single optimization candidate. + + Returns the BestOptimization if the candidate is successful, None otherwise. + Updates eval_ctx with results and may append to future_all_refinements. + """ + # Cleanup temp files + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) + + logger.info(f"h3|Optimization candidate {candidate_index}/{total_candidates}:") + code_print( + candidate.source_code.flat, + file_name=f"candidate_{candidate_index}.py", + lsp_message_id=LSPMessageId.CANDIDATE.value, + ) + + # Try to replace function with optimized code + try: + did_update = self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=candidate.source_code, + original_helper_code=original_helper_code, + ) + if not did_update: + logger.warning( + "force_lsp|No functions were replaced in the optimized code. Skipping optimization candidate." + ) + console.rule() + return None + except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + logger.error(e) + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + return None + + # Check for duplicate candidates + normalized_code = normalize_code(candidate.source_code.flat.strip()) + if normalized_code in eval_ctx.ast_code_to_id: + logger.info("Current candidate has been encountered before in testing, Skipping optimization candidate.") + eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context) + return None + + eval_ctx.register_new_candidate(normalized_code, candidate, code_context) + + # Run the optimized candidate + run_results = self.run_optimized_candidate( + optimization_candidate_index=candidate_index, + baseline_results=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + ) + console.rule() + + if not is_successful(run_results): + eval_ctx.record_failed_candidate(candidate.optimization_id) + return None + + candidate_result: OptimizedCandidateResult = run_results.unwrap() + perf_gain = performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=candidate_result.best_test_runtime + ) + eval_ctx.record_successful_candidate(candidate.optimization_id, candidate_result.best_test_runtime, perf_gain) + + # Check if this is a successful optimization + is_successful_opt = speedup_critic( + candidate_result, + original_code_baseline.runtime, + best_runtime_until_now=None, + original_async_throughput=original_code_baseline.async_throughput, + best_throughput_until_now=None, + ) and quantity_of_tests_critic(candidate_result) + + tree = self.build_runtime_info_tree( + candidate_index=candidate_index, + candidate_result=candidate_result, + original_code_baseline=original_code_baseline, + perf_gain=perf_gain, + is_successful_candidate=is_successful_opt, + ) + + best_optimization = None + benchmark_tree = None + + if is_successful_opt: + best_optimization, benchmark_tree = self.handle_successful_candidate( + candidate=candidate, + candidate_result=candidate_result, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + candidate_index=candidate_index, + eval_ctx=eval_ctx, + ) + eval_ctx.valid_optimizations.append(best_optimization) + + # Queue refinement for non-refined candidates + if not candidate.optimization_id.endswith("refi"): + future_all_refinements.append( + self.refine_optimizations( + valid_optimizations=[best_optimization], + original_code_baseline=original_code_baseline, + code_context=code_context, + trace_id=self.get_trace_id(exp_type), + ai_service_client=ai_service_client, + executor=self.executor, + function_references=function_references, + ) + ) + + # Display runtime information + if is_LSP_enabled(): + lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree))) + else: + console.print(tree) + if self.args.benchmark and benchmark_tree: + console.print(benchmark_tree) + console.rule() + + return best_optimization + + def determine_best_candidate( + self, + *, + candidates: list[OptimizedCandidate], + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + exp_type: str, + function_references: str, + ) -> BestOptimization | None: + """Determine the best optimization candidate from a list of candidates.""" + logger.info( + f"Determining best optimization candidate (out of {len(candidates)}) for " + f"{self.function_to_optimize.qualified_name}…" + ) + console.rule() + + # Initialize evaluation context and async tasks + eval_ctx = CandidateEvaluationContext() + future_all_refinements: list[concurrent.futures.Future] = [] + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + assert ai_service_client is not None, "AI service client must be set for optimization" + + future_line_profile_results = self.executor.submit( + ai_service_client.optimize_python_code_line_profiler, + source_code=code_context.read_writable_code.markdown, + dependency_code=code_context.read_only_context_code, + trace_id=self.get_trace_id(exp_type), + line_profiler_results=original_code_baseline.line_profile_results["str_out"], + num_candidates=N_CANDIDATES_LP_EFFECTIVE, + experiment_metadata=ExperimentMetadata( + id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment" + ) + if self.experiment_id + else None, + ) + + processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements) + candidate_index = 0 + + # Process candidates using queue-based approach + while not processor.is_done(): + candidate = processor.get_next_candidate() + if candidate is None: + logger.debug("everything done, exiting") + break + + try: + candidate_index += 1 + self.process_single_candidate( + candidate=candidate, + candidate_index=candidate_index, + total_candidates=processor.candidate_len, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + eval_ctx=eval_ctx, + future_all_refinements=future_all_refinements, + ai_service_client=ai_service_client, + exp_type=exp_type, + function_references=function_references, + ) + except KeyboardInterrupt as e: + logger.exception(f"Optimization interrupted: {e}") + raise + finally: + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + + # Select and return the best optimization + best_optimization = self.select_best_optimization( + eval_ctx=eval_ctx, + code_context=code_context, + original_code_baseline=original_code_baseline, + ai_service_client=ai_service_client, + exp_type=exp_type, + function_references=function_references, + ) + + if best_optimization: + self.log_evaluation_results( + eval_ctx=eval_ctx, + best_optimization=best_optimization, + original_code_baseline=original_code_baseline, + ai_service_client=ai_service_client, + exp_type=exp_type, + ) + return best_optimization def refine_optimizations(