From 80be1413a1cd2dc5a68a0e72c9ae2c10b169d8ab Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 9 Dec 2025 01:49:35 -0600 Subject: [PATCH 1/5] first pass at refactor --- codeflash/models/models.py | 81 ++- codeflash/optimization/function_optimizer.py | 496 +++++++++++-------- 2 files changed, 355 insertions(+), 222 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 744f76087..8b7a8acd5 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -22,11 +22,11 @@ from typing import Annotated, Optional, cast from jedi.api.classes import Name -from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError +from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError, field from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger -from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code +from codeflash.code_utils.code_utils import diff_length, module_name_from_file_path, validate_python_code from codeflash.code_utils.env_utils import is_end_to_end from codeflash.verification.comparator import comparator @@ -346,6 +346,83 @@ class OptimizationSet(BaseModel): experiment: Optional[list[OptimizedCandidate]] +@dataclass +class CandidateEvaluationContext: + """Holds tracking state during candidate evaluation in determine_best_candidate.""" + + speedup_ratios: dict[str, float | None] = field(default_factory=dict) + optimized_runtimes: dict[str, float | None] = field(default_factory=dict) + is_correct: dict[str, bool] = field(default_factory=dict) + optimized_line_profiler_results: dict[str, str] = field(default_factory=dict) + ast_code_to_id: dict = field(default_factory=dict) + optimizations_post: dict[str, str] = field(default_factory=dict) + valid_optimizations: list = field(default_factory=list) + + def record_failed_candidate(self, optimization_id: str) -> None: + """Record results for a failed candidate.""" + self.optimized_runtimes[optimization_id] = None + self.is_correct[optimization_id] = False + self.speedup_ratios[optimization_id] = None + + def record_successful_candidate( + self, optimization_id: str, runtime: float, speedup: float + ) -> None: + """Record results for a successful candidate.""" + self.optimized_runtimes[optimization_id] = runtime + self.is_correct[optimization_id] = True + self.speedup_ratios[optimization_id] = speedup + + def record_line_profiler_result(self, optimization_id: str, result: str) -> None: + """Record line profiler results for a candidate.""" + self.optimized_line_profiler_results[optimization_id] = result + + def handle_duplicate_candidate( + self, + candidate: OptimizedCandidate, + normalized_code: str, + code_context: CodeOptimizationContext, + ) -> None: + """Handle a candidate that has been seen before.""" + past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] + + # Copy results from the previous evaluation + self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios[past_opt_id] + self.is_correct[candidate.optimization_id] = self.is_correct[past_opt_id] + self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes[past_opt_id] + + # Line profiler results only available for successful runs + if past_opt_id in self.optimized_line_profiler_results: + self.optimized_line_profiler_results[candidate.optimization_id] = ( + self.optimized_line_profiler_results[past_opt_id] + ) + + self.optimizations_post[candidate.optimization_id] = ( + self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown + ) + self.optimizations_post[past_opt_id] = ( + self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown + ) + + # Update to shorter code if this candidate has a shorter diff + new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) + if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]: + self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code + self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + + def register_new_candidate( + self, + normalized_code: str, + candidate: OptimizedCandidate, + code_context: CodeOptimizationContext, + ) -> None: + """Register a new candidate that hasn't been seen before.""" + self.ast_code_to_id[normalized_code] = { + "optimization_id": candidate.optimization_id, + "shorter_source_code": candidate.source_code, + "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), + } + + @dataclass(frozen=True) class TestsInFile: test_file: Path diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 058c84dfc..4eef8f674 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -8,6 +8,7 @@ import subprocess import uuid from collections import defaultdict + from pathlib import Path from typing import TYPE_CHECKING @@ -69,6 +70,7 @@ from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( BestOptimization, + CandidateEvaluationContext, CodeOptimizationContext, GeneratedTests, GeneratedTestsList, @@ -456,6 +458,217 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) + def build_runtime_info_tree( + self, + candidate_index: int, + candidate_result: OptimizedCandidateResult, + original_code_baseline: OriginalCodeBaseline, + perf_gain: float, + is_successful_candidate: bool, + ) -> Tree: + """Build a Tree display for runtime information of a candidate.""" + tree = Tree(f"Candidate #{candidate_index} - Runtime Information ⌛") + + is_async = original_code_baseline.async_throughput is not None and candidate_result.async_throughput is not None + + if is_successful_candidate: + if is_async: + throughput_gain_value = throughput_gain( + original_throughput=original_code_baseline.async_throughput, + optimized_throughput=candidate_result.async_throughput, + ) + tree.add("This candidate has better async throughput than the original code. 🚀") + tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions") + tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions") + tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%") + tree.add(f"Throughput ratio: {throughput_gain_value + 1:.3f}X") + else: + tree.add("This candidate is faster than the original code. 🚀") + tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") + tree.add( + f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " + f"(measured over {candidate_result.max_loop_count} " + f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" + ) + tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") + tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + # Not a successful optimization candidate + elif is_async: + throughput_gain_value = throughput_gain( + original_throughput=original_code_baseline.async_throughput, + optimized_throughput=candidate_result.async_throughput, + ) + tree.add(f"Async throughput: {candidate_result.async_throughput} executions") + tree.add(f"Throughput change: {throughput_gain_value * 100:.1f}%") + tree.add( + f"(Runtime for reference: {humanize_runtime(candidate_result.best_test_runtime)} over " + f"{candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})" + ) + else: + tree.add( + f"Summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " + f"(measured over {candidate_result.max_loop_count} " + f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" + ) + tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") + tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + + return tree + + def handle_successful_candidate( + self, + candidate: OptimizedCandidate, + candidate_result: OptimizedCandidateResult, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + candidate_index: int, + eval_ctx: CandidateEvaluationContext, + ) -> tuple[BestOptimization, Tree | None]: + """ + Handle a successful optimization candidate. + + Returns the BestOptimization and optional benchmark tree. + """ + line_profile_test_results = self.line_profiler_step( + code_context=code_context, original_helper_code=original_helper_code, candidate_index=candidate_index + ) + eval_ctx.record_line_profiler_result(candidate.optimization_id, line_profile_test_results["str_out"]) + + replay_perf_gain = {} + benchmark_tree = None + + if self.args.benchmark: + test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks( + self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root + ) + if len(test_results_by_benchmark) > 0: + benchmark_tree = Tree("Speedup percentage on benchmarks:") + for benchmark_key, candidate_test_results in test_results_by_benchmark.items(): + original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[ + benchmark_key + ].total_passed_runtime() + candidate_replay_runtime = candidate_test_results.total_passed_runtime() + replay_perf_gain[benchmark_key] = performance_gain( + original_runtime_ns=original_code_replay_runtime, optimized_runtime_ns=candidate_replay_runtime + ) + benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%") + + best_optimization = BestOptimization( + candidate=candidate, + helper_functions=code_context.helper_functions, + code_context=code_context, + runtime=candidate_result.best_test_runtime, + line_profiler_test_results=line_profile_test_results, + winning_behavior_test_results=candidate_result.behavior_test_results, + replay_performance_gain=replay_perf_gain if self.args.benchmark else None, + winning_benchmarking_test_results=candidate_result.benchmarking_test_results, + winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results, + async_throughput=candidate_result.async_throughput, + ) + + return best_optimization, benchmark_tree + + def select_best_optimization( + self, + eval_ctx: CandidateEvaluationContext, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + ai_service_client: AiServiceClient, + exp_type: str, + function_references: str, + ) -> BestOptimization | None: + """Select the best optimization from valid candidates.""" + if not eval_ctx.valid_optimizations: + return None + + valid_candidates_with_shorter_code = [] + diff_lens_list = [] # character level diff + speedups_list = [] + optimization_ids = [] + diff_strs = [] + runtimes_list = [] + + for valid_opt in eval_ctx.valid_optimizations: + valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip()) + new_candidate_with_shorter_code = OptimizedCandidate( + source_code=eval_ctx.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], + optimization_id=valid_opt.candidate.optimization_id, + explanation=valid_opt.candidate.explanation, + ) + new_best_opt = BestOptimization( + candidate=new_candidate_with_shorter_code, + helper_functions=valid_opt.helper_functions, + code_context=valid_opt.code_context, + runtime=valid_opt.runtime, + line_profiler_test_results=valid_opt.line_profiler_test_results, + winning_behavior_test_results=valid_opt.winning_behavior_test_results, + replay_performance_gain=valid_opt.replay_performance_gain, + winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results, + winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results, + async_throughput=valid_opt.async_throughput, + ) + valid_candidates_with_shorter_code.append(new_best_opt) + diff_lens_list.append( + diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat) + ) + diff_strs.append( + unified_diff_strings(code_context.read_writable_code.flat, new_best_opt.candidate.source_code.flat) + ) + speedups_list.append( + 1 + + performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=new_best_opt.runtime + ) + ) + optimization_ids.append(new_best_opt.candidate.optimization_id) + runtimes_list.append(new_best_opt.runtime) + + if len(optimization_ids) > 1: + future_ranking = self.executor.submit( + ai_service_client.generate_ranking, + diffs=diff_strs, + optimization_ids=optimization_ids, + speedups=speedups_list, + trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + function_references=function_references, + ) + concurrent.futures.wait([future_ranking]) + ranking = future_ranking.result() + if ranking: + min_key = ranking[0] + else: + diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list) + runtimes_ranking = create_rank_dictionary_compact(runtimes_list) + overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking} + min_key = min(overall_ranking, key=overall_ranking.get) + elif len(optimization_ids) == 1: + min_key = 0 + else: + return None + + return valid_candidates_with_shorter_code[min_key] + + def log_evaluation_results( + self, + eval_ctx: CandidateEvaluationContext, + best_optimization: BestOptimization, + original_code_baseline: OriginalCodeBaseline, + ai_service_client: AiServiceClient, + exp_type: str, + ) -> None: + """Log evaluation results to the AI service.""" + ai_service_client.log_results( + function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + speedup_ratio=eval_ctx.speedup_ratios, + original_runtime=original_code_baseline.runtime, + optimized_runtime=eval_ctx.optimized_runtimes, + is_correct=eval_ctx.is_correct, + optimized_line_profiler_results=eval_ctx.optimized_line_profiler_results, + optimizations_post=eval_ctx.optimizations_post, + metadata={"best_optimization_id": best_optimization.candidate.optimization_id}, + ) + def determine_best_candidate( self, *, @@ -467,27 +680,18 @@ def determine_best_candidate( exp_type: str, function_references: str, ) -> BestOptimization | None: - best_optimization: BestOptimization | None = None - _best_runtime_until_now = original_code_baseline.runtime - - speedup_ratios: dict[str, float | None] = {} - optimized_runtimes: dict[str, float | None] = {} - is_correct = {} - optimized_line_profiler_results: dict[str, str] = {} - + """Determine the best optimization candidate from a list of candidates.""" logger.info( f"Determining best optimization candidate (out of {len(candidates)}) for " f"{self.function_to_optimize.qualified_name}…" ) console.rule() + # Initialize evaluation context and async tasks + eval_ctx = CandidateEvaluationContext() future_all_refinements: list[concurrent.futures.Future] = [] - ast_code_to_id = {} - valid_optimizations = [] - optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated - - # Start a new thread for AI service request ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + future_line_profile_results = self.executor.submit( ai_service_client.optimize_python_code_line_profiler, source_code=code_context.read_writable_code.markdown, @@ -502,7 +706,6 @@ def determine_best_candidate( else None, ) - # Initialize candidate processor processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements) candidate_index = 0 @@ -523,8 +726,8 @@ def determine_best_candidate( file_name=f"candidate_{candidate_index}.py", lsp_message_id=LSPMessageId.CANDIDATE.value, ) - # map ast normalized code to diff len, unnormalized code - # map opt id to the shortest unnormalized code + + # Try to replace function with optimized code try: did_update = self.replace_function_and_helpers_with_optimized_code( code_context=code_context, @@ -543,38 +746,19 @@ def determine_best_candidate( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) continue - # check if this code has been evaluated before by checking the ast normalized code string + + # Check for duplicate candidates normalized_code = normalize_code(candidate.source_code.flat.strip()) - if normalized_code in ast_code_to_id: + if normalized_code in eval_ctx.ast_code_to_id: logger.info( "Current candidate has been encountered before in testing, Skipping optimization candidate." ) - past_opt_id = ast_code_to_id[normalized_code]["optimization_id"] - # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes - speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] - is_correct[candidate.optimization_id] = is_correct[past_opt_id] - optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] - # line profiler results only available for successful runs - if past_opt_id in optimized_line_profiler_results: - optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[ - past_opt_id - ] - optimizations_post[candidate.optimization_id] = ast_code_to_id[normalized_code][ - "shorter_source_code" - ].markdown - optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown - new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) - if ( - new_diff_len < ast_code_to_id[normalized_code]["diff_len"] - ): # new candidate has a shorter diff than the previously encountered one - ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code - ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context) continue - ast_code_to_id[normalized_code] = { - "optimization_id": candidate.optimization_id, - "shorter_source_code": candidate.source_code, - "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), - } + + eval_ctx.register_new_candidate(normalized_code, candidate, code_context) + + # Run the optimized candidate run_results = self.run_optimized_candidate( optimization_candidate_index=candidate_index, baseline_results=original_code_baseline, @@ -582,94 +766,50 @@ def determine_best_candidate( file_path_to_helper_classes=file_path_to_helper_classes, ) console.rule() + if not is_successful(run_results): - optimized_runtimes[candidate.optimization_id] = None - is_correct[candidate.optimization_id] = False - speedup_ratios[candidate.optimization_id] = None + eval_ctx.record_failed_candidate(candidate.optimization_id) else: candidate_result: OptimizedCandidateResult = run_results.unwrap() - best_test_runtime = candidate_result.best_test_runtime - optimized_runtimes[candidate.optimization_id] = best_test_runtime - is_correct[candidate.optimization_id] = True perf_gain = performance_gain( - original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime + original_runtime_ns=original_code_baseline.runtime, + optimized_runtime_ns=candidate_result.best_test_runtime, + ) + eval_ctx.record_successful_candidate( + candidate.optimization_id, candidate_result.best_test_runtime, perf_gain ) - speedup_ratios[candidate.optimization_id] = perf_gain - tree = Tree(f"Candidate #{candidate_index} - Runtime Information ⌛") - benchmark_tree = None - if speedup_critic( + # Check if this is a successful optimization + is_successful_opt = speedup_critic( candidate_result, original_code_baseline.runtime, best_runtime_until_now=None, original_async_throughput=original_code_baseline.async_throughput, best_throughput_until_now=None, - ) and quantity_of_tests_critic(candidate_result): - # For async functions, prioritize throughput metrics over runtime - is_async = ( - original_code_baseline.async_throughput is not None - and candidate_result.async_throughput is not None - ) + ) and quantity_of_tests_critic(candidate_result) + + tree = self.build_runtime_info_tree( + candidate_index=candidate_index, + candidate_result=candidate_result, + original_code_baseline=original_code_baseline, + perf_gain=perf_gain, + is_successful_candidate=is_successful_opt, + ) - if is_async: - throughput_gain_value = throughput_gain( - original_throughput=original_code_baseline.async_throughput, - optimized_throughput=candidate_result.async_throughput, - ) - tree.add("This candidate has better async throughput than the original code. 🚀") - tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions") - tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions") - tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%") - tree.add(f"Throughput ratio: {throughput_gain_value + 1:.3f}X") - else: - tree.add("This candidate is faster than the original code. 🚀") - tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") - tree.add( - f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " - f"(measured over {candidate_result.max_loop_count} " - f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" - ) - tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") - tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") - line_profile_test_results = self.line_profiler_step( + benchmark_tree = None + if is_successful_opt: + best_optimization, benchmark_tree = self.handle_successful_candidate( + candidate=candidate, + candidate_result=candidate_result, code_context=code_context, + original_code_baseline=original_code_baseline, original_helper_code=original_helper_code, candidate_index=candidate_index, + eval_ctx=eval_ctx, ) - optimized_line_profiler_results[candidate.optimization_id] = line_profile_test_results[ - "str_out" - ] - replay_perf_gain = {} - if self.args.benchmark: - test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks( - self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root - ) - if len(test_results_by_benchmark) > 0: - benchmark_tree = Tree("Speedup percentage on benchmarks:") - for benchmark_key, candidate_test_results in test_results_by_benchmark.items(): - original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[ - benchmark_key - ].total_passed_runtime() - candidate_replay_runtime = candidate_test_results.total_passed_runtime() - replay_perf_gain[benchmark_key] = performance_gain( - original_runtime_ns=original_code_replay_runtime, - optimized_runtime_ns=candidate_replay_runtime, - ) - benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%") - best_optimization = BestOptimization( - candidate=candidate, - helper_functions=code_context.helper_functions, - code_context=code_context, - runtime=best_test_runtime, - line_profiler_test_results=line_profile_test_results, - winning_behavior_test_results=candidate_result.behavior_test_results, - replay_performance_gain=replay_perf_gain if self.args.benchmark else None, - winning_benchmarking_test_results=candidate_result.benchmarking_test_results, - winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results, - async_throughput=candidate_result.async_throughput, - ) - valid_optimizations.append(best_optimization) - # queue corresponding refined optimization for best optimization + eval_ctx.valid_optimizations.append(best_optimization) + + # Queue refinement for non-refined candidates if not candidate.optimization_id.endswith("refi"): future_all_refinements.append( self.refine_optimizations( @@ -684,33 +824,8 @@ def determine_best_candidate( function_references=function_references, ) ) - else: - # For async functions, prioritize throughput metrics over runtime even for slow candidates - is_async = ( - original_code_baseline.async_throughput is not None - and candidate_result.async_throughput is not None - ) - - if is_async: - throughput_gain_value = throughput_gain( - original_throughput=original_code_baseline.async_throughput, - optimized_throughput=candidate_result.async_throughput, - ) - tree.add(f"Async throughput: {candidate_result.async_throughput} executions") - tree.add(f"Throughput change: {throughput_gain_value * 100:.1f}%") - tree.add( - f"(Runtime for reference: {humanize_runtime(best_test_runtime)} over " - f"{candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})" - ) - else: - tree.add( - f"Summed runtime: {humanize_runtime(best_test_runtime)} " - f"(measured over {candidate_result.max_loop_count} " - f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" - ) - tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") - tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + # Display runtime information if is_LSP_enabled(): lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree))) else: @@ -718,93 +833,34 @@ def determine_best_candidate( if self.args.benchmark and benchmark_tree: console.print(benchmark_tree) console.rule() + except KeyboardInterrupt as e: logger.exception(f"Optimization interrupted: {e}") raise finally: - # reset for the next candidate self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) - if not valid_optimizations: - return None - # need to figure out the best candidate here before we return best_optimization - # reassign the shorter code here - valid_candidates_with_shorter_code = [] - diff_lens_list = [] # character level diff - speedups_list = [] - optimization_ids = [] - diff_strs = [] - runtimes_list = [] - for valid_opt in valid_optimizations: - valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip()) - new_candidate_with_shorter_code = OptimizedCandidate( - source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], - optimization_id=valid_opt.candidate.optimization_id, - explanation=valid_opt.candidate.explanation, - ) - new_best_opt = BestOptimization( - candidate=new_candidate_with_shorter_code, - helper_functions=valid_opt.helper_functions, - code_context=valid_opt.code_context, - runtime=valid_opt.runtime, - line_profiler_test_results=valid_opt.line_profiler_test_results, - winning_behavior_test_results=valid_opt.winning_behavior_test_results, - replay_performance_gain=valid_opt.replay_performance_gain, - winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results, - winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results, - async_throughput=valid_opt.async_throughput, - ) - valid_candidates_with_shorter_code.append(new_best_opt) - diff_lens_list.append( - diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat) - ) # char level diff - diff_strs.append( - unified_diff_strings(code_context.read_writable_code.flat, new_best_opt.candidate.source_code.flat) - ) - speedups_list.append( - 1 - + performance_gain( - original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=new_best_opt.runtime - ) - ) - optimization_ids.append(new_best_opt.candidate.optimization_id) - runtimes_list.append(new_best_opt.runtime) - if len(optimization_ids) > 1: - future_ranking = self.executor.submit( - ai_service_client.generate_ranking, - diffs=diff_strs, - optimization_ids=optimization_ids, - speedups=speedups_list, - trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, - function_references=function_references, - ) - concurrent.futures.wait([future_ranking]) - ranking = future_ranking.result() - if ranking: - min_key = ranking[0] - else: - diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list) - runtimes_ranking = create_rank_dictionary_compact(runtimes_list) - # TODO: better way to resolve conflicts with same min ranking - overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking} - min_key = min(overall_ranking, key=overall_ranking.get) - elif len(optimization_ids) == 1: - min_key = 0 # only one candidate in valid _opts, already returns if there are no valid candidates - else: # 0? shouldn't happen, but it's there to escape potential bugs - return None - best_optimization = valid_candidates_with_shorter_code[min_key] - # reassign code string which is the shortest - ai_service_client.log_results( - function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, - speedup_ratio=speedup_ratios, - original_runtime=original_code_baseline.runtime, - optimized_runtime=optimized_runtimes, - is_correct=is_correct, - optimized_line_profiler_results=optimized_line_profiler_results, - optimizations_post=optimizations_post, - metadata={"best_optimization_id": best_optimization.candidate.optimization_id}, + + # Select and return the best optimization + best_optimization = self.select_best_optimization( + eval_ctx=eval_ctx, + code_context=code_context, + original_code_baseline=original_code_baseline, + ai_service_client=ai_service_client, + exp_type=exp_type, + function_references=function_references, ) + + if best_optimization: + self.log_evaluation_results( + eval_ctx=eval_ctx, + best_optimization=best_optimization, + original_code_baseline=original_code_baseline, + ai_service_client=ai_service_client, + exp_type=exp_type, + ) + return best_optimization def refine_optimizations( From a497599893d4e1df617e717e4d5380e0e73b25f4 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 9 Dec 2025 02:01:52 -0600 Subject: [PATCH 2/5] Update models.py --- codeflash/models/models.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 8b7a8acd5..d88996ee2 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -22,7 +22,7 @@ from typing import Annotated, Optional, cast from jedi.api.classes import Name -from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError, field +from pydantic import AfterValidator, BaseModel, ConfigDict, Field, PrivateAttr, ValidationError from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger @@ -350,13 +350,13 @@ class OptimizationSet(BaseModel): class CandidateEvaluationContext: """Holds tracking state during candidate evaluation in determine_best_candidate.""" - speedup_ratios: dict[str, float | None] = field(default_factory=dict) - optimized_runtimes: dict[str, float | None] = field(default_factory=dict) - is_correct: dict[str, bool] = field(default_factory=dict) - optimized_line_profiler_results: dict[str, str] = field(default_factory=dict) - ast_code_to_id: dict = field(default_factory=dict) - optimizations_post: dict[str, str] = field(default_factory=dict) - valid_optimizations: list = field(default_factory=list) + speedup_ratios: dict[str, float | None] = Field(default_factory=dict) + optimized_runtimes: dict[str, float | None] = Field(default_factory=dict) + is_correct: dict[str, bool] = Field(default_factory=dict) + optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict) + ast_code_to_id: dict = Field(default_factory=dict) + optimizations_post: dict[str, str] = Field(default_factory=dict) + valid_optimizations: list = Field(default_factory=list) def record_failed_candidate(self, optimization_id: str) -> None: """Record results for a failed candidate.""" From 76e7253563e95d513077d8263124aaed192ccfe2 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 9 Dec 2025 02:12:49 -0600 Subject: [PATCH 3/5] extract logic for processing a candidate --- codeflash/optimization/function_optimizer.py | 282 +++++++++++-------- 1 file changed, 163 insertions(+), 119 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 4eef8f674..59e03a7fc 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -8,7 +8,6 @@ import subprocess import uuid from collections import defaultdict - from pathlib import Path from typing import TYPE_CHECKING @@ -458,6 +457,12 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) + def get_trace_id(self, exp_type: str) -> str: + """Get the trace ID for the current experiment type.""" + if self.experiment_id: + return self.function_trace_id[:-4] + exp_type + return self.function_trace_id + def build_runtime_info_tree( self, candidate_index: int, @@ -525,8 +530,7 @@ def handle_successful_candidate( candidate_index: int, eval_ctx: CandidateEvaluationContext, ) -> tuple[BestOptimization, Tree | None]: - """ - Handle a successful optimization candidate. + """Handle a successful optimization candidate. Returns the BestOptimization and optional benchmark tree. """ @@ -630,7 +634,7 @@ def select_best_optimization( diffs=diff_strs, optimization_ids=optimization_ids, speedups=speedups_list, - trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + trace_id=self.get_trace_id(exp_type), function_references=function_references, ) concurrent.futures.wait([future_ranking]) @@ -659,7 +663,7 @@ def log_evaluation_results( ) -> None: """Log evaluation results to the AI service.""" ai_service_client.log_results( - function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + function_trace_id=self.get_trace_id(exp_type), speedup_ratio=eval_ctx.speedup_ratios, original_runtime=original_code_baseline.runtime, optimized_runtime=eval_ctx.optimized_runtimes, @@ -669,6 +673,147 @@ def log_evaluation_results( metadata={"best_optimization_id": best_optimization.candidate.optimization_id}, ) + def process_single_candidate( + self, + candidate: OptimizedCandidate, + candidate_index: int, + total_candidates: int, + code_context: CodeOptimizationContext, + original_code_baseline: OriginalCodeBaseline, + original_helper_code: dict[Path, str], + file_path_to_helper_classes: dict[Path, set[str]], + eval_ctx: CandidateEvaluationContext, + future_all_refinements: list[concurrent.futures.Future], + ai_service_client: AiServiceClient, + exp_type: str, + function_references: str, + ) -> BestOptimization | None: + """Process a single optimization candidate. + + Returns the BestOptimization if the candidate is successful, None otherwise. + Updates eval_ctx with results and may append to future_all_refinements. + """ + # Cleanup temp files + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) + + logger.info(f"h3|Optimization candidate {candidate_index}/{total_candidates}:") + code_print( + candidate.source_code.flat, + file_name=f"candidate_{candidate_index}.py", + lsp_message_id=LSPMessageId.CANDIDATE.value, + ) + + # Try to replace function with optimized code + try: + did_update = self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=candidate.source_code, + original_helper_code=original_helper_code, + ) + if not did_update: + logger.warning( + "force_lsp|No functions were replaced in the optimized code. Skipping optimization candidate." + ) + console.rule() + return None + except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + logger.error(e) + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + return None + + # Check for duplicate candidates + normalized_code = normalize_code(candidate.source_code.flat.strip()) + if normalized_code in eval_ctx.ast_code_to_id: + logger.info( + "Current candidate has been encountered before in testing, Skipping optimization candidate." + ) + eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context) + return None + + eval_ctx.register_new_candidate(normalized_code, candidate, code_context) + + # Run the optimized candidate + run_results = self.run_optimized_candidate( + optimization_candidate_index=candidate_index, + baseline_results=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + ) + console.rule() + + if not is_successful(run_results): + eval_ctx.record_failed_candidate(candidate.optimization_id) + return None + + candidate_result: OptimizedCandidateResult = run_results.unwrap() + perf_gain = performance_gain( + original_runtime_ns=original_code_baseline.runtime, + optimized_runtime_ns=candidate_result.best_test_runtime, + ) + eval_ctx.record_successful_candidate( + candidate.optimization_id, candidate_result.best_test_runtime, perf_gain + ) + + # Check if this is a successful optimization + is_successful_opt = speedup_critic( + candidate_result, + original_code_baseline.runtime, + best_runtime_until_now=None, + original_async_throughput=original_code_baseline.async_throughput, + best_throughput_until_now=None, + ) and quantity_of_tests_critic(candidate_result) + + tree = self.build_runtime_info_tree( + candidate_index=candidate_index, + candidate_result=candidate_result, + original_code_baseline=original_code_baseline, + perf_gain=perf_gain, + is_successful_candidate=is_successful_opt, + ) + + best_optimization = None + benchmark_tree = None + + if is_successful_opt: + best_optimization, benchmark_tree = self.handle_successful_candidate( + candidate=candidate, + candidate_result=candidate_result, + code_context=code_context, + original_code_baseline=original_code_baseline, + original_helper_code=original_helper_code, + candidate_index=candidate_index, + eval_ctx=eval_ctx, + ) + eval_ctx.valid_optimizations.append(best_optimization) + + # Queue refinement for non-refined candidates + if not candidate.optimization_id.endswith("refi"): + future_all_refinements.append( + self.refine_optimizations( + valid_optimizations=[best_optimization], + original_code_baseline=original_code_baseline, + code_context=code_context, + trace_id=self.get_trace_id(exp_type), + ai_service_client=ai_service_client, + executor=self.executor, + function_references=function_references, + ) + ) + + # Display runtime information + if is_LSP_enabled(): + lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree))) + else: + console.print(tree) + if self.args.benchmark and benchmark_tree: + console.print(benchmark_tree) + console.rule() + + return best_optimization + def determine_best_candidate( self, *, @@ -691,12 +836,13 @@ def determine_best_candidate( eval_ctx = CandidateEvaluationContext() future_all_refinements: list[concurrent.futures.Future] = [] ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + assert ai_service_client is not None, "AI service client must be set for optimization" future_line_profile_results = self.executor.submit( ai_service_client.optimize_python_code_line_profiler, source_code=code_context.read_writable_code.markdown, dependency_code=code_context.read_only_context_code, - trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + trace_id=self.get_trace_id(exp_type), line_profiler_results=original_code_baseline.line_profile_results["str_out"], num_candidates=N_CANDIDATES_LP_EFFECTIVE, experiment_metadata=ExperimentMetadata( @@ -718,122 +864,20 @@ def determine_best_candidate( try: candidate_index += 1 - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) - logger.info(f"h3|Optimization candidate {candidate_index}/{processor.candidate_len}:") - code_print( - candidate.source_code.flat, - file_name=f"candidate_{candidate_index}.py", - lsp_message_id=LSPMessageId.CANDIDATE.value, - ) - - # Try to replace function with optimized code - try: - did_update = self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, - optimized_code=candidate.source_code, - original_helper_code=original_helper_code, - ) - if not did_update: - logger.warning( - "force_lsp|No functions were replaced in the optimized code. Skipping optimization candidate." - ) - console.rule() - continue - except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: - logger.error(e) - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) - continue - - # Check for duplicate candidates - normalized_code = normalize_code(candidate.source_code.flat.strip()) - if normalized_code in eval_ctx.ast_code_to_id: - logger.info( - "Current candidate has been encountered before in testing, Skipping optimization candidate." - ) - eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context) - continue - - eval_ctx.register_new_candidate(normalized_code, candidate, code_context) - - # Run the optimized candidate - run_results = self.run_optimized_candidate( - optimization_candidate_index=candidate_index, - baseline_results=original_code_baseline, + self.process_single_candidate( + candidate=candidate, + candidate_index=candidate_index, + total_candidates=processor.candidate_len, + code_context=code_context, + original_code_baseline=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, + eval_ctx=eval_ctx, + future_all_refinements=future_all_refinements, + ai_service_client=ai_service_client, + exp_type=exp_type, + function_references=function_references, ) - console.rule() - - if not is_successful(run_results): - eval_ctx.record_failed_candidate(candidate.optimization_id) - else: - candidate_result: OptimizedCandidateResult = run_results.unwrap() - perf_gain = performance_gain( - original_runtime_ns=original_code_baseline.runtime, - optimized_runtime_ns=candidate_result.best_test_runtime, - ) - eval_ctx.record_successful_candidate( - candidate.optimization_id, candidate_result.best_test_runtime, perf_gain - ) - - # Check if this is a successful optimization - is_successful_opt = speedup_critic( - candidate_result, - original_code_baseline.runtime, - best_runtime_until_now=None, - original_async_throughput=original_code_baseline.async_throughput, - best_throughput_until_now=None, - ) and quantity_of_tests_critic(candidate_result) - - tree = self.build_runtime_info_tree( - candidate_index=candidate_index, - candidate_result=candidate_result, - original_code_baseline=original_code_baseline, - perf_gain=perf_gain, - is_successful_candidate=is_successful_opt, - ) - - benchmark_tree = None - if is_successful_opt: - best_optimization, benchmark_tree = self.handle_successful_candidate( - candidate=candidate, - candidate_result=candidate_result, - code_context=code_context, - original_code_baseline=original_code_baseline, - original_helper_code=original_helper_code, - candidate_index=candidate_index, - eval_ctx=eval_ctx, - ) - eval_ctx.valid_optimizations.append(best_optimization) - - # Queue refinement for non-refined candidates - if not candidate.optimization_id.endswith("refi"): - future_all_refinements.append( - self.refine_optimizations( - valid_optimizations=[best_optimization], - original_code_baseline=original_code_baseline, - code_context=code_context, - trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - ai_service_client=ai_service_client, - executor=self.executor, - function_references=function_references, - ) - ) - - # Display runtime information - if is_LSP_enabled(): - lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree))) - else: - console.print(tree) - if self.args.benchmark and benchmark_tree: - console.print(benchmark_tree) - console.rule() - except KeyboardInterrupt as e: logger.exception(f"Optimization interrupted: {e}") raise From 6345119840538b828431fa772d7c30bbd902703d Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 9 Dec 2025 02:15:12 -0600 Subject: [PATCH 4/5] formatting --- codeflash/models/models.py | 30 +++++++------------- codeflash/optimization/function_optimizer.py | 11 ++----- 2 files changed, 13 insertions(+), 28 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index d88996ee2..0ea380059 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -364,9 +364,7 @@ def record_failed_candidate(self, optimization_id: str) -> None: self.is_correct[optimization_id] = False self.speedup_ratios[optimization_id] = None - def record_successful_candidate( - self, optimization_id: str, runtime: float, speedup: float - ) -> None: + def record_successful_candidate(self, optimization_id: str, runtime: float, speedup: float) -> None: """Record results for a successful candidate.""" self.optimized_runtimes[optimization_id] = runtime self.is_correct[optimization_id] = True @@ -377,10 +375,7 @@ def record_line_profiler_result(self, optimization_id: str, result: str) -> None self.optimized_line_profiler_results[optimization_id] = result def handle_duplicate_candidate( - self, - candidate: OptimizedCandidate, - normalized_code: str, - code_context: CodeOptimizationContext, + self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext ) -> None: """Handle a candidate that has been seen before.""" past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] @@ -392,16 +387,14 @@ def handle_duplicate_candidate( # Line profiler results only available for successful runs if past_opt_id in self.optimized_line_profiler_results: - self.optimized_line_profiler_results[candidate.optimization_id] = ( - self.optimized_line_profiler_results[past_opt_id] - ) + self.optimized_line_profiler_results[candidate.optimization_id] = self.optimized_line_profiler_results[ + past_opt_id + ] - self.optimizations_post[candidate.optimization_id] = ( - self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown - ) - self.optimizations_post[past_opt_id] = ( - self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown - ) + self.optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ + "shorter_source_code" + ].markdown + self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown # Update to shorter code if this candidate has a shorter diff new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) @@ -410,10 +403,7 @@ def handle_duplicate_candidate( self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len def register_new_candidate( - self, - normalized_code: str, - candidate: OptimizedCandidate, - code_context: CodeOptimizationContext, + self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext ) -> None: """Register a new candidate that hasn't been seen before.""" self.ast_code_to_id[normalized_code] = { diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 59e03a7fc..a2bbf600b 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -727,9 +727,7 @@ def process_single_candidate( # Check for duplicate candidates normalized_code = normalize_code(candidate.source_code.flat.strip()) if normalized_code in eval_ctx.ast_code_to_id: - logger.info( - "Current candidate has been encountered before in testing, Skipping optimization candidate." - ) + logger.info("Current candidate has been encountered before in testing, Skipping optimization candidate.") eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context) return None @@ -750,12 +748,9 @@ def process_single_candidate( candidate_result: OptimizedCandidateResult = run_results.unwrap() perf_gain = performance_gain( - original_runtime_ns=original_code_baseline.runtime, - optimized_runtime_ns=candidate_result.best_test_runtime, - ) - eval_ctx.record_successful_candidate( - candidate.optimization_id, candidate_result.best_test_runtime, perf_gain + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=candidate_result.best_test_runtime ) + eval_ctx.record_successful_candidate(candidate.optimization_id, candidate_result.best_test_runtime, perf_gain) # Check if this is a successful optimization is_successful_opt = speedup_critic( From 2be5af5ffe12558ee1f424f161f8a7abd0cb6986 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Tue, 9 Dec 2025 02:17:39 -0600 Subject: [PATCH 5/5] formatting --- codeflash/optimization/function_optimizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index a2bbf600b..6459b7f39 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -469,6 +469,7 @@ def build_runtime_info_tree( candidate_result: OptimizedCandidateResult, original_code_baseline: OriginalCodeBaseline, perf_gain: float, + *, is_successful_candidate: bool, ) -> Tree: """Build a Tree display for runtime information of a candidate."""