Skip to content

Commit d165a15

Browse files
committed
Merge remote-tracking branch 'origin/main' into feat/feedback-loop-for-unmatched-test-results
2 parents 46522d8 + 872ec28 commit d165a15

File tree

5 files changed

+132
-76
lines changed

5 files changed

+132
-76
lines changed

codeflash/api/aiservice.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,20 +259,18 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
259259
"original_source_code": opt.original_source_code,
260260
"read_only_dependency_code": opt.read_only_dependency_code,
261261
"original_line_profiler_results": opt.original_line_profiler_results,
262-
"original_code_runtime": opt.original_code_runtime,
262+
"original_code_runtime": humanize_runtime(opt.original_code_runtime),
263263
"optimized_source_code": opt.optimized_source_code,
264264
"optimized_explanation": opt.optimized_explanation,
265265
"optimized_line_profiler_results": opt.optimized_line_profiler_results,
266-
"optimized_code_runtime": opt.optimized_code_runtime,
266+
"optimized_code_runtime": humanize_runtime(opt.optimized_code_runtime),
267267
"speedup": opt.speedup,
268268
"trace_id": opt.trace_id,
269269
"function_references": opt.function_references,
270270
"python_version": platform.python_version(),
271271
}
272272
for opt in request
273273
]
274-
logger.debug(f"Refining {len(request)} optimizations…")
275-
console.rule()
276274
try:
277275
response = self.make_ai_service_request("/refinement", payload=payload, timeout=120)
278276
except requests.exceptions.RequestException as e:
@@ -282,8 +280,6 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
282280

283281
if response.status_code == 200:
284282
refined_optimizations = response.json()["refinements"]
285-
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
286-
console.rule()
287283

288284
return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE)
289285

codeflash/code_utils/code_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,63 @@ def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tof
4545
return "".join(diff)
4646

4747

48+
def choose_weights(**importance: float) -> list[float]:
49+
"""Choose normalized weights from relative importance values.
50+
51+
Example:
52+
choose_weights(runtime=3, diff=1)
53+
-> [0.75, 0.25]
54+
55+
Args:
56+
**importance: keyword args of metric=importance (relative numbers).
57+
58+
Returns:
59+
A list of weights in the same order as the arguments.
60+
61+
"""
62+
total = sum(importance.values())
63+
if total == 0:
64+
raise ValueError("At least one importance value must be > 0")
65+
66+
return [v / total for v in importance.values()]
67+
68+
69+
def normalize_by_max(values: list[float]) -> list[float]:
70+
mx = max(values)
71+
if mx == 0:
72+
return [0.0] * len(values)
73+
return [v / mx for v in values]
74+
75+
76+
def create_score_dictionary_from_metrics(weights: list[float], *metrics: list[float]) -> dict[int, int]:
77+
"""Combine multiple metrics into a single weighted score dictionary.
78+
79+
Each metric is a list of values (smaller = better).
80+
The total score for each index is the weighted sum of its values
81+
across all metrics:
82+
83+
score[index] = Σ (value * weight)
84+
85+
Args:
86+
weights: A list of weights, one per metric. Larger weight = more influence.
87+
*metrics: Lists of values (one list per metric, aligned by index).
88+
89+
Returns:
90+
A dictionary mapping each index to its combined weighted score.
91+
92+
"""
93+
if len(weights) != len(metrics):
94+
raise ValueError("Number of weights must match number of metrics")
95+
96+
combined: dict[int, float] = {}
97+
98+
for weight, metric in zip(weights, metrics):
99+
for idx, value in enumerate(metric):
100+
combined[idx] = combined.get(idx, 0) + value * weight
101+
102+
return combined
103+
104+
48105
def diff_length(a: str, b: str) -> int:
49106
"""Compute the length (in characters) of the unified diff between two strings.
50107

codeflash/code_utils/config_consts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
1515
N_CANDIDATES_LP = 6
1616

17+
# Refinement
18+
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations
19+
REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2
20+
TOP_N_REFINEMENTS = 0.45 # top 45% of valid optimizations (based on the weighted score) are refined
21+
1722
# LSP-specific
1823
N_CANDIDATES_LSP = 3
1924
N_TESTS_TO_GENERATE_LSP = 2

codeflash/models/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ class AIServiceRefinerRequest:
3737
optimization_id: str
3838
original_source_code: str
3939
read_only_dependency_code: str
40-
original_code_runtime: str
40+
original_code_runtime: int
4141
optimized_source_code: str
4242
optimized_explanation: str
43-
optimized_code_runtime: str
43+
optimized_code_runtime: int
4444
speedup: str
4545
trace_id: str
4646
original_line_profiler_results: str

codeflash/optimization/function_optimizer.py

Lines changed: 66 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@
2929
replace_function_definitions_in_module,
3030
)
3131
from codeflash.code_utils.code_utils import (
32+
choose_weights,
3233
cleanup_paths,
3334
create_rank_dictionary_compact,
35+
create_score_dictionary_from_metrics,
3436
diff_length,
3537
extract_unique_errors,
3638
file_name_from_test_module_name,
3739
get_run_tmp_file,
3840
module_name_from_file_path,
41+
normalize_by_max,
3942
restore_conftest,
4043
unified_diff_strings,
4144
)
@@ -46,8 +49,11 @@
4649
N_CANDIDATES_EFFECTIVE,
4750
N_CANDIDATES_LP_EFFECTIVE,
4851
N_TESTS_TO_GENERATE_EFFECTIVE,
52+
REFINE_ALL_THRESHOLD,
53+
REFINED_CANDIDATE_RANKING_WEIGHTS,
4954
REPAIR_UNMATCHED_PERCENTAGE_LIMIT,
5055
REPEAT_OPTIMIZATION_PROBABILITY,
56+
TOP_N_REFINEMENTS,
5157
TOTAL_LOOPING_TIME_EFFECTIVE,
5258
)
5359
from codeflash.code_utils.deduplicate_code import normalize_code
@@ -129,20 +135,24 @@ def __init__(
129135
self,
130136
initial_candidates: list,
131137
future_line_profile_results: concurrent.futures.Future,
132-
future_all_refinements: list[concurrent.futures.Future],
138+
all_refinements_data: list[AIServiceRefinerRequest],
139+
ai_service_client: AiServiceClient,
140+
executor: concurrent.futures.ThreadPoolExecutor,
133141
future_all_code_repair: list[concurrent.futures.Future],
134142
) -> None:
135143
self.candidate_queue = queue.Queue()
136144
self.line_profiler_done = False
137145
self.refinement_done = False
138146
self.candidate_len = len(initial_candidates)
147+
self.ai_service_client = ai_service_client
148+
self.executor = executor
139149

140150
# Initialize queue with initial candidates
141151
for candidate in initial_candidates:
142152
self.candidate_queue.put(candidate)
143153

144154
self.future_line_profile_results = future_line_profile_results
145-
self.future_all_refinements = future_all_refinements
155+
self.all_refinements_data = all_refinements_data
146156
self.future_all_code_repair = future_all_code_repair
147157

148158
def get_next_candidate(self) -> OptimizedCandidate | None:
@@ -177,15 +187,45 @@ def _process_line_profiler_results(self) -> OptimizedCandidate | None:
177187

178188
return self.get_next_candidate()
179189

190+
def refine_optimizations(self, request: list[AIServiceRefinerRequest]) -> concurrent.futures.Future:
191+
return self.executor.submit(self.ai_service_client.optimize_python_code_refinement, request=request)
192+
180193
def _process_refinement_results(self) -> OptimizedCandidate | None:
181-
"""Process refinement results and add to queue."""
182-
if self.future_all_refinements:
194+
"""Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined."""
195+
future_refinements: list[concurrent.futures.Future] = []
196+
197+
if len(self.all_refinements_data) <= REFINE_ALL_THRESHOLD:
198+
for data in self.all_refinements_data:
199+
future_refinements.append(self.refine_optimizations([data])) # noqa: PERF401
200+
else:
201+
diff_lens_list = []
202+
runtimes_list = []
203+
for c in self.all_refinements_data:
204+
diff_lens_list.append(diff_length(c.original_source_code, c.optimized_source_code))
205+
runtimes_list.append(c.optimized_code_runtime)
206+
207+
runtime_w, diff_w = REFINED_CANDIDATE_RANKING_WEIGHTS
208+
weights = choose_weights(runtime=runtime_w, diff=diff_w)
209+
210+
runtime_norm = normalize_by_max(runtimes_list)
211+
diffs_norm = normalize_by_max(diff_lens_list)
212+
# the lower the better
213+
score_dict = create_score_dictionary_from_metrics(weights, runtime_norm, diffs_norm)
214+
top_n_candidates = int((TOP_N_REFINEMENTS * len(runtimes_list)) + 0.5)
215+
top_indecies = sorted(score_dict, key=score_dict.get)[:top_n_candidates]
216+
217+
for idx in top_indecies:
218+
data = self.all_refinements_data[idx]
219+
future_refinements.append(self.refine_optimizations([data]))
220+
221+
if future_refinements:
183222
logger.info("loading|Refining generated code for improved quality and performance...")
184-
concurrent.futures.wait(self.future_all_refinements)
223+
224+
concurrent.futures.wait(future_refinements)
185225
refinement_response = []
186226

187-
for future_refinement in self.future_all_refinements:
188-
possible_refinement = future_refinement.result()
227+
for f in future_refinements:
228+
possible_refinement = f.result()
189229
if len(possible_refinement) > 0:
190230
refinement_response.append(possible_refinement[0])
191231

@@ -197,7 +237,6 @@ def _process_refinement_results(self) -> OptimizedCandidate | None:
197237
logger.info(
198238
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
199239
)
200-
self.future_all_refinements = []
201240
self.refinement_done = True
202241

203242
return self.get_next_candidate()
@@ -278,7 +317,6 @@ def __init__(
278317
max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4
279318
)
280319
self.optimization_review = ""
281-
self.future_all_refinements: list[concurrent.futures.Future] = []
282320
self.future_all_code_repair: list[concurrent.futures.Future] = []
283321
self.repair_counter = 0 # track how many repairs we did for each function
284322

@@ -724,14 +762,15 @@ def process_single_candidate(
724762
original_helper_code: dict[Path, str],
725763
file_path_to_helper_classes: dict[Path, set[str]],
726764
eval_ctx: CandidateEvaluationContext,
765+
all_refinements_data: list[AIServiceRefinerRequest],
727766
ai_service_client: AiServiceClient,
728767
exp_type: str,
729768
function_references: str,
730769
) -> BestOptimization | None:
731770
"""Process a single optimization candidate.
732771
733772
Returns the BestOptimization if the candidate is successful, None otherwise.
734-
Updates eval_ctx with results and may append to future_all_refinements.
773+
Updates eval_ctx with results and may append to all_refinements_data.
735774
"""
736775
# Cleanup temp files
737776
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
@@ -829,17 +868,23 @@ def process_single_candidate(
829868

830869
# Queue refinement for non-refined candidates
831870
if candidate.source != OptimizedCandidateSource.REFINE:
832-
self.future_all_refinements.append(
833-
self.refine_optimizations(
834-
valid_optimizations=[best_optimization],
835-
original_code_baseline=original_code_baseline,
836-
code_context=code_context,
871+
all_refinements_data.append(
872+
AIServiceRefinerRequest(
873+
optimization_id=best_optimization.candidate.optimization_id,
874+
original_source_code=code_context.read_writable_code.markdown,
875+
read_only_dependency_code=code_context.read_only_context_code,
876+
original_code_runtime=original_code_baseline.runtime,
877+
optimized_source_code=best_optimization.candidate.source_code.markdown,
878+
optimized_explanation=best_optimization.candidate.explanation,
879+
optimized_code_runtime=best_optimization.runtime,
880+
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%",
837881
trace_id=self.get_trace_id(exp_type),
838-
ai_service_client=ai_service_client,
839-
executor=self.executor,
882+
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
883+
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
840884
function_references=function_references,
841885
)
842-
)
886+
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
887+
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
843888

844889
# Display runtime information
845890
if is_LSP_enabled():
@@ -872,7 +917,7 @@ def determine_best_candidate(
872917

873918
# Initialize evaluation context and async tasks
874919
eval_ctx = CandidateEvaluationContext()
875-
self.future_all_refinements.clear()
920+
all_refinements_data: list[AIServiceRefinerRequest] = []
876921
self.future_all_code_repair.clear()
877922
self.repair_counter = 0
878923

@@ -894,7 +939,7 @@ def determine_best_candidate(
894939
)
895940

896941
processor = CandidateProcessor(
897-
candidates, future_line_profile_results, self.future_all_refinements, self.future_all_code_repair
942+
candidates, future_line_profile_results, all_refinements_data, self.future_all_code_repair, self.aiservice_client, self.executor
898943
)
899944
candidate_index = 0
900945

@@ -916,6 +961,7 @@ def determine_best_candidate(
916961
original_helper_code=original_helper_code,
917962
file_path_to_helper_classes=file_path_to_helper_classes,
918963
eval_ctx=eval_ctx,
964+
all_refinements_data=all_refinements_data,
919965
ai_service_client=ai_service_client,
920966
exp_type=exp_type,
921967
function_references=function_references,
@@ -949,54 +995,6 @@ def determine_best_candidate(
949995

950996
return best_optimization
951997

952-
def refine_optimizations(
953-
self,
954-
valid_optimizations: list[BestOptimization],
955-
original_code_baseline: OriginalCodeBaseline,
956-
code_context: CodeOptimizationContext,
957-
trace_id: str,
958-
ai_service_client: AiServiceClient,
959-
executor: concurrent.futures.ThreadPoolExecutor,
960-
function_references: str | None = None,
961-
) -> concurrent.futures.Future:
962-
request = [
963-
AIServiceRefinerRequest(
964-
optimization_id=opt.candidate.optimization_id,
965-
original_source_code=code_context.read_writable_code.markdown,
966-
read_only_dependency_code=code_context.read_only_context_code,
967-
original_code_runtime=humanize_runtime(original_code_baseline.runtime),
968-
optimized_source_code=opt.candidate.source_code.markdown,
969-
optimized_explanation=opt.candidate.explanation,
970-
optimized_code_runtime=humanize_runtime(opt.runtime),
971-
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=opt.runtime) * 100)}%",
972-
trace_id=trace_id,
973-
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
974-
optimized_line_profiler_results=opt.line_profiler_test_results["str_out"],
975-
function_references=function_references,
976-
)
977-
for opt in valid_optimizations
978-
]
979-
return executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
980-
981-
def repair_optimization(
982-
self,
983-
original_source_code: str,
984-
modified_source_code: str,
985-
test_diffs: list[TestDiff],
986-
trace_id: str,
987-
optimization_id: str,
988-
ai_service_client: AiServiceClient,
989-
executor: concurrent.futures.ThreadPoolExecutor,
990-
) -> concurrent.futures.Future[OptimizedCandidate | None]:
991-
request = AIServiceCodeRepairRequest(
992-
optimization_id=optimization_id,
993-
original_source_code=original_source_code,
994-
modified_source_code=modified_source_code,
995-
test_diffs=test_diffs,
996-
trace_id=trace_id,
997-
)
998-
return executor.submit(ai_service_client.code_repair, request=request)
999-
1000998
def log_successful_optimization(
1001999
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
10021000
) -> None:

0 commit comments

Comments
 (0)