Skip to content

Commit 9cdeec8

Browse files
authored
extract logic for processing a candidate into a method for separation of concerns (#959)
* first pass at refactor * Update models.py * extract logic for processing a candidate * formatting * formatting
1 parent 2e34d83 commit 9cdeec8

File tree

2 files changed

+440
-277
lines changed

2 files changed

+440
-277
lines changed

codeflash/models/models.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
from typing import Annotated, Optional, cast
2323

2424
from jedi.api.classes import Name
25-
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
25+
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, PrivateAttr, ValidationError
2626
from pydantic.dataclasses import dataclass
2727

2828
from codeflash.cli_cmds.console import console, logger
29-
from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code
29+
from codeflash.code_utils.code_utils import diff_length, module_name_from_file_path, validate_python_code
3030
from codeflash.code_utils.env_utils import is_end_to_end
3131
from codeflash.verification.comparator import comparator
3232

@@ -346,6 +346,73 @@ class OptimizationSet(BaseModel):
346346
experiment: Optional[list[OptimizedCandidate]]
347347

348348

349+
@dataclass
350+
class CandidateEvaluationContext:
351+
"""Holds tracking state during candidate evaluation in determine_best_candidate."""
352+
353+
speedup_ratios: dict[str, float | None] = Field(default_factory=dict)
354+
optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
355+
is_correct: dict[str, bool] = Field(default_factory=dict)
356+
optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
357+
ast_code_to_id: dict = Field(default_factory=dict)
358+
optimizations_post: dict[str, str] = Field(default_factory=dict)
359+
valid_optimizations: list = Field(default_factory=list)
360+
361+
def record_failed_candidate(self, optimization_id: str) -> None:
362+
"""Record results for a failed candidate."""
363+
self.optimized_runtimes[optimization_id] = None
364+
self.is_correct[optimization_id] = False
365+
self.speedup_ratios[optimization_id] = None
366+
367+
def record_successful_candidate(self, optimization_id: str, runtime: float, speedup: float) -> None:
368+
"""Record results for a successful candidate."""
369+
self.optimized_runtimes[optimization_id] = runtime
370+
self.is_correct[optimization_id] = True
371+
self.speedup_ratios[optimization_id] = speedup
372+
373+
def record_line_profiler_result(self, optimization_id: str, result: str) -> None:
374+
"""Record line profiler results for a candidate."""
375+
self.optimized_line_profiler_results[optimization_id] = result
376+
377+
def handle_duplicate_candidate(
378+
self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext
379+
) -> None:
380+
"""Handle a candidate that has been seen before."""
381+
past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]
382+
383+
# Copy results from the previous evaluation
384+
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios[past_opt_id]
385+
self.is_correct[candidate.optimization_id] = self.is_correct[past_opt_id]
386+
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes[past_opt_id]
387+
388+
# Line profiler results only available for successful runs
389+
if past_opt_id in self.optimized_line_profiler_results:
390+
self.optimized_line_profiler_results[candidate.optimization_id] = self.optimized_line_profiler_results[
391+
past_opt_id
392+
]
393+
394+
self.optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][
395+
"shorter_source_code"
396+
].markdown
397+
self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown
398+
399+
# Update to shorter code if this candidate has a shorter diff
400+
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
401+
if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
402+
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
403+
self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
404+
405+
def register_new_candidate(
406+
self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext
407+
) -> None:
408+
"""Register a new candidate that hasn't been seen before."""
409+
self.ast_code_to_id[normalized_code] = {
410+
"optimization_id": candidate.optimization_id,
411+
"shorter_source_code": candidate.source_code,
412+
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
413+
}
414+
415+
349416
@dataclass(frozen=True)
350417
class TestsInFile:
351418
test_file: Path

0 commit comments

Comments
 (0)