Skip to content

Conversation

@KRRT7
Copy link
Contributor

@KRRT7 KRRT7 commented Dec 9, 2025

PR Type

Enhancement


Description

  • Refactor candidate evaluation into helper methods

  • Add CandidateEvaluationContext dataclass

  • Improve trace ID handling utility

  • Streamline ranking and logging flow


Diagram Walkthrough

flowchart LR
  A["determine_best_candidate"] -- "initializes" --> B["CandidateEvaluationContext"]
  A -- "submits" --> C["line profiler future"]
  A -- "processes via" --> D["process_single_candidate"]
  D -- "on success" --> E["handle_successful_candidate"]
  E -- "collects" --> F["valid_optimizations"]
  A -- "selects best" --> G["select_best_optimization"]
  A -- "logs results" --> H["log_evaluation_results"]
Loading

File Walkthrough

Relevant files
Enhancement
models.py
Add evaluation context for candidate processing                   

codeflash/models/models.py

  • Import Field and diff_length
  • Add CandidateEvaluationContext dataclass
  • Provide methods for candidate tracking and deduping
  • Use diff_length for shorter-code selection
+79/-2   
function_optimizer.py
Refactor candidate evaluation into modular helpers             

codeflash/optimization/function_optimizer.py

  • Extract helper methods for evaluation flow
  • Replace inline logic with process_single_candidate
  • Add get_trace_id, logging and selection helpers
  • Simplify ranking trace_id usage and output trees
+376/-276

@github-actions
Copy link

github-actions bot commented Dec 9, 2025

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 No relevant tests
🔒 No security concerns identified
⚡ Recommended focus areas for review

Dataclass/Pydantic Mixing

Fields in the dataclass use pydantic Field and type hints with dict/list but no type parameters for some entries; ensure runtime behavior is intended and serialization/validation works given pydantic.dataclasses plus BaseModel elsewhere.

@dataclass
class CandidateEvaluationContext:
    """Holds tracking state during candidate evaluation in determine_best_candidate."""

    speedup_ratios: dict[str, float | None] = Field(default_factory=dict)
    optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
    is_correct: dict[str, bool] = Field(default_factory=dict)
    optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
    ast_code_to_id: dict = Field(default_factory=dict)
    optimizations_post: dict[str, str] = Field(default_factory=dict)
    valid_optimizations: list = Field(default_factory=list)

    def record_failed_candidate(self, optimization_id: str) -> None:
        """Record results for a failed candidate."""
        self.optimized_runtimes[optimization_id] = None
        self.is_correct[optimization_id] = False
        self.speedup_ratios[optimization_id] = None

    def record_successful_candidate(
        self, optimization_id: str, runtime: float, speedup: float
    ) -> None:
        """Record results for a successful candidate."""
        self.optimized_runtimes[optimization_id] = runtime
        self.is_correct[optimization_id] = True
        self.speedup_ratios[optimization_id] = speedup

    def record_line_profiler_result(self, optimization_id: str, result: str) -> None:
        """Record line profiler results for a candidate."""
        self.optimized_line_profiler_results[optimization_id] = result

    def handle_duplicate_candidate(
        self,
        candidate: OptimizedCandidate,
        normalized_code: str,
        code_context: CodeOptimizationContext,
    ) -> None:
        """Handle a candidate that has been seen before."""
        past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]

        # Copy results from the previous evaluation
        self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios[past_opt_id]
        self.is_correct[candidate.optimization_id] = self.is_correct[past_opt_id]
        self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes[past_opt_id]

        # Line profiler results only available for successful runs
        if past_opt_id in self.optimized_line_profiler_results:
            self.optimized_line_profiler_results[candidate.optimization_id] = (
                self.optimized_line_profiler_results[past_opt_id]
            )

        self.optimizations_post[candidate.optimization_id] = (
            self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown
        )
        self.optimizations_post[past_opt_id] = (
            self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown
        )

        # Update to shorter code if this candidate has a shorter diff
        new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
        if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
            self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
            self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len

    def register_new_candidate(
        self,
        normalized_code: str,
        candidate: OptimizedCandidate,
        code_context: CodeOptimizationContext,
    ) -> None:
        """Register a new candidate that hasn't been seen before."""
        self.ast_code_to_id[normalized_code] = {
            "optimization_id": candidate.optimization_id,
            "shorter_source_code": candidate.source_code,
            "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
        }
Trace ID Logic

get_trace_id trims last 4 chars of function_trace_id; verify all IDs follow that suffix pattern to avoid malformed IDs and collisions.

def get_trace_id(self, exp_type: str) -> str:
    """Get the trace ID for the current experiment type."""
    if self.experiment_id:
        return self.function_trace_id[:-4] + exp_type
    return self.function_trace_id
None Return Paths

select_best_optimization can return None silently when there are 0 candidates or ranking paths fail; ensure callers handle None and that logging communicates why no selection was made.

def select_best_optimization(
    self,
    eval_ctx: CandidateEvaluationContext,
    code_context: CodeOptimizationContext,
    original_code_baseline: OriginalCodeBaseline,
    ai_service_client: AiServiceClient,
    exp_type: str,
    function_references: str,
) -> BestOptimization | None:
    """Select the best optimization from valid candidates."""
    if not eval_ctx.valid_optimizations:
        return None

    valid_candidates_with_shorter_code = []
    diff_lens_list = []  # character level diff
    speedups_list = []
    optimization_ids = []
    diff_strs = []
    runtimes_list = []

    for valid_opt in eval_ctx.valid_optimizations:
        valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip())
        new_candidate_with_shorter_code = OptimizedCandidate(
            source_code=eval_ctx.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"],
            optimization_id=valid_opt.candidate.optimization_id,
            explanation=valid_opt.candidate.explanation,
        )
        new_best_opt = BestOptimization(
            candidate=new_candidate_with_shorter_code,
            helper_functions=valid_opt.helper_functions,
            code_context=valid_opt.code_context,
            runtime=valid_opt.runtime,
            line_profiler_test_results=valid_opt.line_profiler_test_results,
            winning_behavior_test_results=valid_opt.winning_behavior_test_results,
            replay_performance_gain=valid_opt.replay_performance_gain,
            winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
            winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
            async_throughput=valid_opt.async_throughput,
        )
        valid_candidates_with_shorter_code.append(new_best_opt)
        diff_lens_list.append(
            diff_length(new_best_opt.candidate.source_code.flat, code_context.read_writable_code.flat)
        )
        diff_strs.append(
            unified_diff_strings(code_context.read_writable_code.flat, new_best_opt.candidate.source_code.flat)
        )
        speedups_list.append(
            1
            + performance_gain(
                original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=new_best_opt.runtime
            )
        )
        optimization_ids.append(new_best_opt.candidate.optimization_id)
        runtimes_list.append(new_best_opt.runtime)

    if len(optimization_ids) > 1:
        future_ranking = self.executor.submit(
            ai_service_client.generate_ranking,
            diffs=diff_strs,
            optimization_ids=optimization_ids,
            speedups=speedups_list,
            trace_id=self.get_trace_id(exp_type),
            function_references=function_references,
        )
        concurrent.futures.wait([future_ranking])
        ranking = future_ranking.result()
        if ranking:
            min_key = ranking[0]
        else:
            diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
            runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
            overall_ranking = {key: diff_lens_ranking[key] + runtimes_ranking[key] for key in diff_lens_ranking}
            min_key = min(overall_ranking, key=overall_ranking.get)
    elif len(optimization_ids) == 1:
        min_key = 0
    else:
        return None

    return valid_candidates_with_shorter_code[min_key]

@github-actions
Copy link

github-actions bot commented Dec 9, 2025

PR Code Suggestions ✨

Explore these optional code suggestions:

CategorySuggestion                                                                                                                                    Impact
Possible issue
Use proper dataclass default factories

Remove the use of pydantic.Field in a @dataclass class; Field won't apply and can
mislead defaults. Use plain dataclass field defaults via
dataclasses.field(default_factory=...) to ensure correct initialization. This
prevents subtle bugs where attributes may not be initialized as intended.

codeflash/models/models.py [349-360]

+from dataclasses import dataclass, field
+
 @dataclass
 class CandidateEvaluationContext:
     """Holds tracking state during candidate evaluation in determine_best_candidate."""
 
-    speedup_ratios: dict[str, float | None] = Field(default_factory=dict)
-    optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
-    is_correct: dict[str, bool] = Field(default_factory=dict)
-    optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
-    ast_code_to_id: dict = Field(default_factory=dict)
-    optimizations_post: dict[str, str] = Field(default_factory=dict)
-    valid_optimizations: list = Field(default_factory=list)
+    speedup_ratios: dict[str, float | None] = field(default_factory=dict)
+    optimized_runtimes: dict[str, float | None] = field(default_factory=dict)
+    is_correct: dict[str, bool] = field(default_factory=dict)
+    optimized_line_profiler_results: dict[str, str] = field(default_factory=dict)
+    ast_code_to_id: dict = field(default_factory=dict)
+    optimizations_post: dict[str, str] = field(default_factory=dict)
+    valid_optimizations: list = field(default_factory=list)
Suggestion importance[1-10]: 7

__

Why: Correct: pydantic.Field has no effect in a standard @dataclass; using dataclasses.field avoids misleading defaults and potential init bugs. Moderate impact as it improves correctness and maintainability of state tracking.

Medium
General
Avoid hard assert crash path

Replace the hard assert with a controlled early return and log to avoid crashing
production code. If the client is unexpectedly None, log an error and skip line
profiler submission, preventing an exception that aborts the entire optimization
flow.

codeflash/optimization/function_optimizer.py [838-853]

-assert ai_service_client is not None, "AI service client must be set for optimization"
+if ai_service_client is None:
+    logger.error("AI service client is not available; skipping line profiler optimization submission.")
+    return None
 
 future_line_profile_results = self.executor.submit(
     ai_service_client.optimize_python_code_line_profiler,
     source_code=code_context.read_writable_code.markdown,
     dependency_code=code_context.read_only_context_code,
     trace_id=self.get_trace_id(exp_type),
     line_profiler_results=original_code_baseline.line_profile_results["str_out"],
     num_candidates=N_CANDIDATES_LP_EFFECTIVE,
     experiment_metadata=ExperimentMetadata(
         id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment"
     )
     if self.experiment_id
     else None,
 )
Suggestion importance[1-10]: 6

__

Why: Replacing a hard assert with a logged early return improves resiliency in production flows; however, it changes failure semantics and may hide configuration errors, so impact is moderate.

Low
Make trace ID slicing safe

Guard against None or short function_trace_id when slicing;
self.function_trace_id[:-4] can produce incorrect IDs or empty strings. Use a safer
conditional that only slices when the suffix exists and fall back otherwise. This
avoids malformed trace IDs in logs and downstream services.

codeflash/optimization/function_optimizer.py [460-465]

 def get_trace_id(self, exp_type: str) -> str:
     """Get the trace ID for the current experiment type."""
-    if self.experiment_id:
-        return self.function_trace_id[:-4] + exp_type
-    return self.function_trace_id
+    base = self.function_trace_id or ""
+    if self.experiment_id and len(base) >= 4:
+        return base[:-4] + exp_type
+    return base
Suggestion importance[1-10]: 5

__

Why: Reasonable defensive change; while the code likely expects a suffix, guarding against short or None IDs prevents malformed trace IDs. It's a minor robustness improvement.

Low

@KRRT7 KRRT7 requested review from misrasaurabh1 and mohammedahmed18 and removed request for mohammedahmed18 December 9, 2025 08:20
min_key = ranking[0]
else:
diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list)
runtimes_ranking = create_rank_dictionary_compact(runtimes_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not remove comments unless have a good reason

@misrasaurabh1
Copy link
Contributor

its a large diff - the logic seems sound. but ensure that you change no behavior what so ever, otherwise we won't catch it until 2 months later. verify functional equivalence deeply by reading the previous and new and comparing them

@KRRT7
Copy link
Contributor Author

KRRT7 commented Dec 9, 2025

its a large diff - the logic seems sound. but ensure that you change no behavior what so ever, otherwise we won't catch it until 2 months later. verify functional equivalence deeply by reading the previous and new and comparing them

yeah agreed, I'm only doing these refactors in anticipation of the agentic workflows, the currently implementations makes it very hard

@KRRT7 KRRT7 merged commit 9cdeec8 into main Dec 9, 2025
21 of 23 checks passed
@KRRT7 KRRT7 deleted the extract-running-candidates branch December 9, 2025 11:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants