|
22 | 22 | from typing import Annotated, Optional, cast |
23 | 23 |
|
24 | 24 | from jedi.api.classes import Name |
25 | | -from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError |
| 25 | +from pydantic import AfterValidator, BaseModel, ConfigDict, Field, PrivateAttr, ValidationError |
26 | 26 | from pydantic.dataclasses import dataclass |
27 | 27 |
|
28 | 28 | from codeflash.cli_cmds.console import console, logger |
29 | | -from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code |
| 29 | +from codeflash.code_utils.code_utils import diff_length, module_name_from_file_path, validate_python_code |
30 | 30 | from codeflash.code_utils.env_utils import is_end_to_end |
31 | 31 | from codeflash.verification.comparator import comparator |
32 | 32 |
|
@@ -346,6 +346,73 @@ class OptimizationSet(BaseModel): |
346 | 346 | experiment: Optional[list[OptimizedCandidate]] |
347 | 347 |
|
348 | 348 |
|
| 349 | +@dataclass |
| 350 | +class CandidateEvaluationContext: |
| 351 | + """Holds tracking state during candidate evaluation in determine_best_candidate.""" |
| 352 | + |
| 353 | + speedup_ratios: dict[str, float | None] = Field(default_factory=dict) |
| 354 | + optimized_runtimes: dict[str, float | None] = Field(default_factory=dict) |
| 355 | + is_correct: dict[str, bool] = Field(default_factory=dict) |
| 356 | + optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict) |
| 357 | + ast_code_to_id: dict = Field(default_factory=dict) |
| 358 | + optimizations_post: dict[str, str] = Field(default_factory=dict) |
| 359 | + valid_optimizations: list = Field(default_factory=list) |
| 360 | + |
| 361 | + def record_failed_candidate(self, optimization_id: str) -> None: |
| 362 | + """Record results for a failed candidate.""" |
| 363 | + self.optimized_runtimes[optimization_id] = None |
| 364 | + self.is_correct[optimization_id] = False |
| 365 | + self.speedup_ratios[optimization_id] = None |
| 366 | + |
| 367 | + def record_successful_candidate(self, optimization_id: str, runtime: float, speedup: float) -> None: |
| 368 | + """Record results for a successful candidate.""" |
| 369 | + self.optimized_runtimes[optimization_id] = runtime |
| 370 | + self.is_correct[optimization_id] = True |
| 371 | + self.speedup_ratios[optimization_id] = speedup |
| 372 | + |
| 373 | + def record_line_profiler_result(self, optimization_id: str, result: str) -> None: |
| 374 | + """Record line profiler results for a candidate.""" |
| 375 | + self.optimized_line_profiler_results[optimization_id] = result |
| 376 | + |
| 377 | + def handle_duplicate_candidate( |
| 378 | + self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext |
| 379 | + ) -> None: |
| 380 | + """Handle a candidate that has been seen before.""" |
| 381 | + past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] |
| 382 | + |
| 383 | + # Copy results from the previous evaluation |
| 384 | + self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios[past_opt_id] |
| 385 | + self.is_correct[candidate.optimization_id] = self.is_correct[past_opt_id] |
| 386 | + self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes[past_opt_id] |
| 387 | + |
| 388 | + # Line profiler results only available for successful runs |
| 389 | + if past_opt_id in self.optimized_line_profiler_results: |
| 390 | + self.optimized_line_profiler_results[candidate.optimization_id] = self.optimized_line_profiler_results[ |
| 391 | + past_opt_id |
| 392 | + ] |
| 393 | + |
| 394 | + self.optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ |
| 395 | + "shorter_source_code" |
| 396 | + ].markdown |
| 397 | + self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown |
| 398 | + |
| 399 | + # Update to shorter code if this candidate has a shorter diff |
| 400 | + new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) |
| 401 | + if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]: |
| 402 | + self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code |
| 403 | + self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len |
| 404 | + |
| 405 | + def register_new_candidate( |
| 406 | + self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext |
| 407 | + ) -> None: |
| 408 | + """Register a new candidate that hasn't been seen before.""" |
| 409 | + self.ast_code_to_id[normalized_code] = { |
| 410 | + "optimization_id": candidate.optimization_id, |
| 411 | + "shorter_source_code": candidate.source_code, |
| 412 | + "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), |
| 413 | + } |
| 414 | + |
| 415 | + |
349 | 416 | @dataclass(frozen=True) |
350 | 417 | class TestsInFile: |
351 | 418 | test_file: Path |
|
0 commit comments