Skip to content

Commit ae080d0

Browse files
enhancements
1 parent 4976d5d commit ae080d0

File tree

3 files changed

+78
-52
lines changed

3 files changed

+78
-52
lines changed

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
N_CANDIDATES_LP_LSP = 3
2222

2323
# Code repair
24-
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.35 # if the percentage of unmatched tests is greater than this, we won't fix it
24+
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.4 # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted)
2525
MAX_REPAIRS_PER_TRACE = 3 # maximum number of repairs we will do for each function
2626

2727
MAX_N_CANDIDATES = 5

codeflash/optimization/function_optimizer.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,6 +1854,46 @@ def get_results_not_matched_error(self) -> Failure:
18541854
console.rule()
18551855
return Failure("Test results did not match the test results of the original code.")
18561856

1857+
def repair_if_possible(
1858+
self,
1859+
candidate: OptimizedCandidate,
1860+
diffs: list[TestDiff],
1861+
code_context: CodeOptimizationContext,
1862+
test_results_count: int,
1863+
exp_type: str,
1864+
) -> None:
1865+
if self.repair_counter >= MAX_REPAIRS_PER_TRACE:
1866+
logger.debug(f"Repair counter reached {MAX_REPAIRS_PER_TRACE}, skipping repair")
1867+
return
1868+
if candidate.source == OptimizedCandidateSource.REPAIR:
1869+
logger.debug("Candidate is already a repair, skipping repair")
1870+
return
1871+
if not diffs:
1872+
logger.debug("No diffs found, skipping repair")
1873+
return
1874+
result_unmatched_perc = len(diffs) / test_results_count
1875+
if result_unmatched_perc > REPAIR_UNMATCHED_PERCENTAGE_LIMIT:
1876+
logger.debug(f"Result unmatched percentage is {result_unmatched_perc * 100}%, skipping repair")
1877+
return
1878+
1879+
logger.debug(
1880+
f"Adding a candidate for repair, with {len(diffs)} diffs, ({result_unmatched_perc * 100}% unmatched)"
1881+
)
1882+
# start repairing
1883+
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
1884+
self.repair_counter += 1
1885+
self.future_all_code_repair.append(
1886+
self.repair_optimization(
1887+
original_source_code=code_context.read_writable_code.markdown,
1888+
modified_source_code=candidate.source_code.markdown,
1889+
test_diffs=diffs,
1890+
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
1891+
ai_service_client=ai_service_client,
1892+
optimization_id=candidate.optimization_id,
1893+
executor=self.executor,
1894+
)
1895+
)
1896+
18571897
def run_optimized_candidate(
18581898
self,
18591899
*,
@@ -1919,36 +1959,7 @@ def run_optimized_candidate(
19191959
logger.info("h3|Test results matched ✅")
19201960
console.rule()
19211961
else:
1922-
1923-
def repair_if_possible() -> None:
1924-
if self.repair_counter >= MAX_REPAIRS_PER_TRACE:
1925-
return
1926-
1927-
result_unmatched_perc = len(diffs) / len(candidate_behavior_results)
1928-
if (
1929-
candidate.source == OptimizedCandidateSource.REPAIR
1930-
or result_unmatched_perc > REPAIR_UNMATCHED_PERCENTAGE_LIMIT
1931-
):
1932-
return
1933-
1934-
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
1935-
logger.info("Adding this to the repair queue")
1936-
self.repair_counter += 1
1937-
self.future_all_code_repair.append(
1938-
self.repair_optimization(
1939-
original_source_code=code_context.read_writable_code.markdown,
1940-
modified_source_code=candidate.source_code.markdown,
1941-
test_diffs=diffs,
1942-
trace_id=self.function_trace_id[:-4] + exp_type
1943-
if self.experiment_id
1944-
else self.function_trace_id,
1945-
ai_service_client=ai_service_client,
1946-
optimization_id=candidate.optimization_id,
1947-
executor=self.executor,
1948-
)
1949-
)
1950-
1951-
repair_if_possible()
1962+
self.repair_if_possible(candidate, diffs, code_context, len(candidate_behavior_results), exp_type)
19521963
return self.get_results_not_matched_error()
19531964

19541965
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")

codeflash/verification/equivalence.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,19 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
6363
else ""
6464
)
6565

66-
test_src_code = original_test_result.id.get_src_code(original_test_result.file_name)
67-
test_diff = TestDiff(
68-
scope=TestDiffScope.RETURN_VALUE,
69-
original_value=repr(original_test_result.return_value),
70-
candidate_value=repr(cdd_test_result.return_value),
71-
test_src_code=test_src_code,
72-
candidate_pytest_error=cdd_pytest_error,
73-
original_pass=original_test_result.did_pass,
74-
candidate_pass=cdd_test_result.did_pass,
75-
original_pytest_error=original_pytest_error,
76-
)
7766
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
78-
test_diff.scope = TestDiffScope.RETURN_VALUE
79-
test_diffs.append(test_diff)
67+
test_diffs.append(
68+
TestDiff(
69+
scope=TestDiffScope.RETURN_VALUE,
70+
original_value=repr(original_test_result.return_value),
71+
candidate_value=repr(cdd_test_result.return_value),
72+
test_src_code=original_test_result.id.get_src_code(original_test_result.file_name),
73+
candidate_pytest_error=cdd_pytest_error,
74+
original_pass=original_test_result.did_pass,
75+
candidate_pass=cdd_test_result.did_pass,
76+
original_pytest_error=original_pytest_error,
77+
)
78+
)
8079

8180
try:
8281
logger.debug(
@@ -92,21 +91,37 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
9291
elif (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
9392
original_test_result.stdout, cdd_test_result.stdout
9493
):
95-
test_diff.scope = TestDiffScope.STDOUT
96-
test_diff.original_value = str(original_test_result.stdout)
97-
test_diff.candidate_value = str(cdd_test_result.stdout)
98-
test_diffs.append(test_diff)
94+
test_diffs.append(
95+
TestDiff(
96+
scope=TestDiffScope.STDOUT,
97+
original_value=str(original_test_result.stdout),
98+
candidate_value=str(cdd_test_result.stdout),
99+
test_src_code=original_test_result.id.get_src_code(original_test_result.file_name),
100+
candidate_pytest_error=cdd_pytest_error,
101+
original_pass=original_test_result.did_pass,
102+
candidate_pass=cdd_test_result.did_pass,
103+
original_pytest_error=original_pytest_error,
104+
)
105+
)
99106

100107
elif original_test_result.test_type in {
101108
TestType.EXISTING_UNIT_TEST,
102109
TestType.CONCOLIC_COVERAGE_TEST,
103110
TestType.GENERATED_REGRESSION,
104111
TestType.REPLAY_TEST,
105112
} and (cdd_test_result.did_pass != original_test_result.did_pass):
106-
test_diff.scope = TestDiffScope.DID_PASS
107-
test_diff.original_value = str(original_test_result.did_pass)
108-
test_diff.candidate_value = str(cdd_test_result.did_pass)
109-
test_diffs.append(test_diff)
113+
test_diffs.append(
114+
TestDiff(
115+
scope=TestDiffScope.DID_PASS,
116+
original_value=str(original_test_result.did_pass),
117+
candidate_value=str(cdd_test_result.did_pass),
118+
test_src_code=original_test_result.id.get_src_code(original_test_result.file_name),
119+
candidate_pytest_error=cdd_pytest_error,
120+
original_pass=original_test_result.did_pass,
121+
candidate_pass=cdd_test_result.did_pass,
122+
original_pytest_error=original_pytest_error,
123+
)
124+
)
110125

111126
sys.setrecursionlimit(original_recursion_limit)
112127
if did_all_timeout:

0 commit comments

Comments
 (0)