Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from typing import Annotated, Optional, cast

from jedi.api.classes import Name
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, PrivateAttr, ValidationError
from pydantic.dataclasses import dataclass

from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code
from codeflash.code_utils.code_utils import diff_length, module_name_from_file_path, validate_python_code
from codeflash.code_utils.env_utils import is_end_to_end
from codeflash.verification.comparator import comparator

Expand Down Expand Up @@ -346,6 +346,73 @@ class OptimizationSet(BaseModel):
experiment: Optional[list[OptimizedCandidate]]


@dataclass
class CandidateEvaluationContext:
"""Holds tracking state during candidate evaluation in determine_best_candidate."""

speedup_ratios: dict[str, float | None] = Field(default_factory=dict)
optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
is_correct: dict[str, bool] = Field(default_factory=dict)
optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
ast_code_to_id: dict = Field(default_factory=dict)
optimizations_post: dict[str, str] = Field(default_factory=dict)
valid_optimizations: list = Field(default_factory=list)

def record_failed_candidate(self, optimization_id: str) -> None:
"""Record results for a failed candidate."""
self.optimized_runtimes[optimization_id] = None
self.is_correct[optimization_id] = False
self.speedup_ratios[optimization_id] = None

def record_successful_candidate(self, optimization_id: str, runtime: float, speedup: float) -> None:
"""Record results for a successful candidate."""
self.optimized_runtimes[optimization_id] = runtime
self.is_correct[optimization_id] = True
self.speedup_ratios[optimization_id] = speedup

def record_line_profiler_result(self, optimization_id: str, result: str) -> None:
"""Record line profiler results for a candidate."""
self.optimized_line_profiler_results[optimization_id] = result

def handle_duplicate_candidate(
self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext
) -> None:
"""Handle a candidate that has been seen before."""
past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]

# Copy results from the previous evaluation
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios[past_opt_id]
self.is_correct[candidate.optimization_id] = self.is_correct[past_opt_id]
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes[past_opt_id]

# Line profiler results only available for successful runs
if past_opt_id in self.optimized_line_profiler_results:
self.optimized_line_profiler_results[candidate.optimization_id] = self.optimized_line_profiler_results[
past_opt_id
]

self.optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][
"shorter_source_code"
].markdown
self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown

# Update to shorter code if this candidate has a shorter diff
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len

def register_new_candidate(
self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext
) -> None:
"""Register a new candidate that hasn't been seen before."""
self.ast_code_to_id[normalized_code] = {
"optimization_id": candidate.optimization_id,
"shorter_source_code": candidate.source_code,
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
}


@dataclass(frozen=True)
class TestsInFile:
test_file: Path
Expand Down
Loading
Loading