Skip to content

Commit a6a5578

Browse files
fixes
1 parent 41de7be commit a6a5578

File tree

3 files changed

+14
-23
lines changed

3 files changed

+14
-23
lines changed

codeflash/api/aiservice.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,15 +296,15 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
296296
console.rule()
297297
return []
298298

299-
def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None:
300-
"""Optimize the given python code for performance by making a request to the Django endpoint.
299+
def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None:
300+
"""Repair the optimization candidate that is not matching the test result of the original code.
301301
302302
Args:
303-
request: optimization candidate details for refinement
303+
request: candidate details for repair
304304
305305
Returns:
306306
-------
307-
- OptimizationCandidate: new fixed candidate.
307+
- OptimizedCandidate: new fixed candidate.
308308
309309
"""
310310
console.rule()

codeflash/optimization/function_optimizer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,7 @@ def determine_best_candidate(
872872
eval_ctx = CandidateEvaluationContext()
873873
self.future_all_refinements.clear()
874874
self.future_all_code_repair.clear()
875+
self.repair_counter = 0
875876

876877
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
877878
assert ai_service_client is not None, "AI service client must be set for optimization"
@@ -992,7 +993,7 @@ def repair_optimization(
992993
test_diffs=test_diffs,
993994
trace_id=trace_id,
994995
)
995-
return executor.submit(ai_service_client.optimize_python_code_repair, request=request)
996+
return executor.submit(ai_service_client.code_repair, request=request)
996997

997998
def log_successful_optimization(
998999
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
@@ -1903,8 +1904,9 @@ def repair_if_possible(
19031904
if self.repair_counter >= MAX_REPAIRS_PER_TRACE:
19041905
logger.debug(f"Repair counter reached {MAX_REPAIRS_PER_TRACE}, skipping repair")
19051906
return
1906-
if candidate.source == OptimizedCandidateSource.REPAIR:
1907-
logger.debug("Candidate is already a repair, skipping repair")
1907+
if candidate.source not in (OptimizedCandidateSource.OPTIMIZE, OptimizedCandidateSource.OPTIMIZE_LP):
1908+
# only repair the first pass of the candidates for now
1909+
logger.debug(f"Candidate is a result of {candidate.source.value}, skipping repair")
19081910
return
19091911
if not diffs:
19101912
logger.debug("No diffs found, skipping repair")

codeflash/verification/equivalence.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
INCREASED_RECURSION_LIMIT = 5000
1616

17+
reprlib_repr = reprlib.Repr(maxstring=1500)
18+
test_diff_repr = reprlib_repr.repr
19+
1720

1821
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
1922
# This is meant to be only called with test results for the first loop index
@@ -68,27 +71,13 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
6871
)
6972
if original_pytest_error:
7073
original_pytest_error = shorten_pytest_error(original_pytest_error)
71-
test_src_code = original_test_result.id.get_src_code(original_test_result.file_name)
72-
test_diff = TestDiff(
73-
scope=TestDiffScope.RETURN_VALUE,
74-
original_value=reprlib.repr(original_test_result.return_value),
75-
candidate_value=reprlib.repr(cdd_test_result.return_value),
76-
test_src_code=test_src_code,
77-
candidate_pytest_error=cdd_pytest_error,
78-
original_pass=original_test_result.did_pass,
79-
candidate_pass=cdd_test_result.did_pass,
80-
original_pytest_error=original_pytest_error,
81-
)
82-
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
83-
test_diff.scope = TestDiffScope.RETURN_VALUE
84-
test_diffs.append(test_diff)
8574

8675
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
8776
test_diffs.append(
8877
TestDiff(
8978
scope=TestDiffScope.RETURN_VALUE,
90-
original_value=repr(original_test_result.return_value),
91-
candidate_value=repr(cdd_test_result.return_value),
79+
original_value=test_diff_repr(repr(original_test_result.return_value)),
80+
candidate_value=test_diff_repr(repr(cdd_test_result.return_value)),
9281
test_src_code=original_test_result.id.get_src_code(original_test_result.file_name),
9382
candidate_pytest_error=cdd_pytest_error,
9483
original_pass=original_test_result.did_pass,

0 commit comments

Comments
 (0)