Skip to content

Commit 872ec28

Browse files
Merge pull request #962 from codeflash-ai/limit-refined-candidates
[Enhancement] Use weighted ranking to cap refinement candidates (CF-931)
2 parents 3f416af + 0a5649f commit 872ec28

File tree

5 files changed

+132
-58
lines changed

5 files changed

+132
-58
lines changed

codeflash/api/aiservice.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,20 +248,18 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
248248
"original_source_code": opt.original_source_code,
249249
"read_only_dependency_code": opt.read_only_dependency_code,
250250
"original_line_profiler_results": opt.original_line_profiler_results,
251-
"original_code_runtime": opt.original_code_runtime,
251+
"original_code_runtime": humanize_runtime(opt.original_code_runtime),
252252
"optimized_source_code": opt.optimized_source_code,
253253
"optimized_explanation": opt.optimized_explanation,
254254
"optimized_line_profiler_results": opt.optimized_line_profiler_results,
255-
"optimized_code_runtime": opt.optimized_code_runtime,
255+
"optimized_code_runtime": humanize_runtime(opt.optimized_code_runtime),
256256
"speedup": opt.speedup,
257257
"trace_id": opt.trace_id,
258258
"function_references": opt.function_references,
259259
"python_version": platform.python_version(),
260260
}
261261
for opt in request
262262
]
263-
logger.debug(f"Refining {len(request)} optimizations…")
264-
console.rule()
265263
try:
266264
response = self.make_ai_service_request("/refinement", payload=payload, timeout=120)
267265
except requests.exceptions.RequestException as e:
@@ -271,8 +269,6 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
271269

272270
if response.status_code == 200:
273271
refined_optimizations = response.json()["refinements"]
274-
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
275-
console.rule()
276272

277273
refinements = self._get_valid_candidates(refined_optimizations)
278274
return [

codeflash/code_utils/code_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,63 @@ def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tof
4545
return "".join(diff)
4646

4747

48+
def choose_weights(**importance: float) -> list[float]:
49+
"""Choose normalized weights from relative importance values.
50+
51+
Example:
52+
choose_weights(runtime=3, diff=1)
53+
-> [0.75, 0.25]
54+
55+
Args:
56+
**importance: keyword args of metric=importance (relative numbers).
57+
58+
Returns:
59+
A list of weights in the same order as the arguments.
60+
61+
"""
62+
total = sum(importance.values())
63+
if total == 0:
64+
raise ValueError("At least one importance value must be > 0")
65+
66+
return [v / total for v in importance.values()]
67+
68+
69+
def normalize_by_max(values: list[float]) -> list[float]:
70+
mx = max(values)
71+
if mx == 0:
72+
return [0.0] * len(values)
73+
return [v / mx for v in values]
74+
75+
76+
def create_score_dictionary_from_metrics(weights: list[float], *metrics: list[float]) -> dict[int, int]:
77+
"""Combine multiple metrics into a single weighted score dictionary.
78+
79+
Each metric is a list of values (smaller = better).
80+
The total score for each index is the weighted sum of its values
81+
across all metrics:
82+
83+
score[index] = Σ (value * weight)
84+
85+
Args:
86+
weights: A list of weights, one per metric. Larger weight = more influence.
87+
*metrics: Lists of values (one list per metric, aligned by index).
88+
89+
Returns:
90+
A dictionary mapping each index to its combined weighted score.
91+
92+
"""
93+
if len(weights) != len(metrics):
94+
raise ValueError("Number of weights must match number of metrics")
95+
96+
combined: dict[int, float] = {}
97+
98+
for weight, metric in zip(weights, metrics):
99+
for idx, value in enumerate(metric):
100+
combined[idx] = combined.get(idx, 0) + value * weight
101+
102+
return combined
103+
104+
48105
def diff_length(a: str, b: str) -> int:
49106
"""Compute the length (in characters) of the unified diff between two strings.
50107

codeflash/code_utils/config_consts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
1515
N_CANDIDATES_LP = 6
1616

17+
# Refinement
18+
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations
19+
REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2
20+
TOP_N_REFINEMENTS = 0.45 # top 45% of valid optimizations (based on the weighted score) are refined
21+
1722
# LSP-specific
1823
N_CANDIDATES_LSP = 3
1924
N_TESTS_TO_GENERATE_LSP = 2

codeflash/models/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ class AIServiceRefinerRequest:
3636
optimization_id: str
3737
original_source_code: str
3838
read_only_dependency_code: str
39-
original_code_runtime: str
39+
original_code_runtime: int
4040
optimized_source_code: str
4141
optimized_explanation: str
42-
optimized_code_runtime: str
42+
optimized_code_runtime: int
4343
speedup: str
4444
trace_id: str
4545
original_line_profiler_results: str

codeflash/optimization/function_optimizer.py

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,16 @@
2929
replace_function_definitions_in_module,
3030
)
3131
from codeflash.code_utils.code_utils import (
32+
choose_weights,
3233
cleanup_paths,
3334
create_rank_dictionary_compact,
35+
create_score_dictionary_from_metrics,
3436
diff_length,
3537
extract_unique_errors,
3638
file_name_from_test_module_name,
3739
get_run_tmp_file,
3840
module_name_from_file_path,
41+
normalize_by_max,
3942
restore_conftest,
4043
unified_diff_strings,
4144
)
@@ -45,7 +48,10 @@
4548
N_CANDIDATES_EFFECTIVE,
4649
N_CANDIDATES_LP_EFFECTIVE,
4750
N_TESTS_TO_GENERATE_EFFECTIVE,
51+
REFINE_ALL_THRESHOLD,
52+
REFINED_CANDIDATE_RANKING_WEIGHTS,
4853
REPEAT_OPTIMIZATION_PROBABILITY,
54+
TOP_N_REFINEMENTS,
4955
TOTAL_LOOPING_TIME_EFFECTIVE,
5056
)
5157
from codeflash.code_utils.deduplicate_code import normalize_code
@@ -124,19 +130,23 @@ def __init__(
124130
self,
125131
initial_candidates: list,
126132
future_line_profile_results: concurrent.futures.Future,
127-
future_all_refinements: list,
133+
all_refinements_data: list[AIServiceRefinerRequest],
134+
ai_service_client: AiServiceClient,
135+
executor: concurrent.futures.ThreadPoolExecutor,
128136
) -> None:
129137
self.candidate_queue = queue.Queue()
130138
self.line_profiler_done = False
131139
self.refinement_done = False
132140
self.candidate_len = len(initial_candidates)
141+
self.ai_service_client = ai_service_client
142+
self.executor = executor
133143

134144
# Initialize queue with initial candidates
135145
for candidate in initial_candidates:
136146
self.candidate_queue.put(candidate)
137147

138148
self.future_line_profile_results = future_line_profile_results
139-
self.future_all_refinements = future_all_refinements
149+
self.all_refinements_data = all_refinements_data
140150

141151
def get_next_candidate(self) -> OptimizedCandidate | None:
142152
"""Get the next candidate from the queue, handling async results as needed."""
@@ -168,15 +178,45 @@ def _process_line_profiler_results(self) -> OptimizedCandidate | None:
168178

169179
return self.get_next_candidate()
170180

181+
def refine_optimizations(self, request: list[AIServiceRefinerRequest]) -> concurrent.futures.Future:
182+
return self.executor.submit(self.ai_service_client.optimize_python_code_refinement, request=request)
183+
171184
def _process_refinement_results(self) -> OptimizedCandidate | None:
172-
"""Process refinement results and add to queue."""
173-
if self.future_all_refinements:
185+
"""Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined."""
186+
future_refinements: list[concurrent.futures.Future] = []
187+
188+
if len(self.all_refinements_data) <= REFINE_ALL_THRESHOLD:
189+
for data in self.all_refinements_data:
190+
future_refinements.append(self.refine_optimizations([data])) # noqa: PERF401
191+
else:
192+
diff_lens_list = []
193+
runtimes_list = []
194+
for c in self.all_refinements_data:
195+
diff_lens_list.append(diff_length(c.original_source_code, c.optimized_source_code))
196+
runtimes_list.append(c.optimized_code_runtime)
197+
198+
runtime_w, diff_w = REFINED_CANDIDATE_RANKING_WEIGHTS
199+
weights = choose_weights(runtime=runtime_w, diff=diff_w)
200+
201+
runtime_norm = normalize_by_max(runtimes_list)
202+
diffs_norm = normalize_by_max(diff_lens_list)
203+
# the lower the better
204+
score_dict = create_score_dictionary_from_metrics(weights, runtime_norm, diffs_norm)
205+
top_n_candidates = int((TOP_N_REFINEMENTS * len(runtimes_list)) + 0.5)
206+
top_indecies = sorted(score_dict, key=score_dict.get)[:top_n_candidates]
207+
208+
for idx in top_indecies:
209+
data = self.all_refinements_data[idx]
210+
future_refinements.append(self.refine_optimizations([data]))
211+
212+
if future_refinements:
174213
logger.info("loading|Refining generated code for improved quality and performance...")
175-
concurrent.futures.wait(self.future_all_refinements)
214+
215+
concurrent.futures.wait(future_refinements)
176216
refinement_response = []
177217

178-
for future_refinement in self.future_all_refinements:
179-
possible_refinement = future_refinement.result()
218+
for f in future_refinements:
219+
possible_refinement = f.result()
180220
if len(possible_refinement) > 0:
181221
refinement_response.append(possible_refinement[0])
182222

@@ -686,15 +726,14 @@ def process_single_candidate(
686726
original_helper_code: dict[Path, str],
687727
file_path_to_helper_classes: dict[Path, set[str]],
688728
eval_ctx: CandidateEvaluationContext,
689-
future_all_refinements: list[concurrent.futures.Future],
690-
ai_service_client: AiServiceClient,
729+
all_refinements_data: list[AIServiceRefinerRequest],
691730
exp_type: str,
692731
function_references: str,
693732
) -> BestOptimization | None:
694733
"""Process a single optimization candidate.
695734
696735
Returns the BestOptimization if the candidate is successful, None otherwise.
697-
Updates eval_ctx with results and may append to future_all_refinements.
736+
Updates eval_ctx with results and may append to all_refinements_data.
698737
"""
699738
# Cleanup temp files
700739
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
@@ -789,14 +828,19 @@ def process_single_candidate(
789828

790829
# Queue refinement for non-refined candidates
791830
if not candidate.optimization_id.endswith("refi"):
792-
future_all_refinements.append(
793-
self.refine_optimizations(
794-
valid_optimizations=[best_optimization],
795-
original_code_baseline=original_code_baseline,
796-
code_context=code_context,
831+
all_refinements_data.append(
832+
AIServiceRefinerRequest(
833+
optimization_id=best_optimization.candidate.optimization_id,
834+
original_source_code=code_context.read_writable_code.markdown,
835+
read_only_dependency_code=code_context.read_only_context_code,
836+
original_code_runtime=original_code_baseline.runtime,
837+
optimized_source_code=best_optimization.candidate.source_code.markdown,
838+
optimized_explanation=best_optimization.candidate.explanation,
839+
optimized_code_runtime=best_optimization.runtime,
840+
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime) * 100)}%",
797841
trace_id=self.get_trace_id(exp_type),
798-
ai_service_client=ai_service_client,
799-
executor=self.executor,
842+
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
843+
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
800844
function_references=function_references,
801845
)
802846
)
@@ -832,7 +876,7 @@ def determine_best_candidate(
832876

833877
# Initialize evaluation context and async tasks
834878
eval_ctx = CandidateEvaluationContext()
835-
future_all_refinements: list[concurrent.futures.Future] = []
879+
all_refinements_data: list[AIServiceRefinerRequest] = []
836880
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
837881
assert ai_service_client is not None, "AI service client must be set for optimization"
838882

@@ -850,7 +894,9 @@ def determine_best_candidate(
850894
else None,
851895
)
852896

853-
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
897+
processor = CandidateProcessor(
898+
candidates, future_line_profile_results, all_refinements_data, self.aiservice_client, self.executor
899+
)
854900
candidate_index = 0
855901

856902
# Process candidates using queue-based approach
@@ -871,8 +917,7 @@ def determine_best_candidate(
871917
original_helper_code=original_helper_code,
872918
file_path_to_helper_classes=file_path_to_helper_classes,
873919
eval_ctx=eval_ctx,
874-
future_all_refinements=future_all_refinements,
875-
ai_service_client=ai_service_client,
920+
all_refinements_data=all_refinements_data,
876921
exp_type=exp_type,
877922
function_references=function_references,
878923
)
@@ -905,35 +950,6 @@ def determine_best_candidate(
905950

906951
return best_optimization
907952

908-
def refine_optimizations(
909-
self,
910-
valid_optimizations: list[BestOptimization],
911-
original_code_baseline: OriginalCodeBaseline,
912-
code_context: CodeOptimizationContext,
913-
trace_id: str,
914-
ai_service_client: AiServiceClient,
915-
executor: concurrent.futures.ThreadPoolExecutor,
916-
function_references: str | None = None,
917-
) -> concurrent.futures.Future:
918-
request = [
919-
AIServiceRefinerRequest(
920-
optimization_id=opt.candidate.optimization_id,
921-
original_source_code=code_context.read_writable_code.markdown,
922-
read_only_dependency_code=code_context.read_only_context_code,
923-
original_code_runtime=humanize_runtime(original_code_baseline.runtime),
924-
optimized_source_code=opt.candidate.source_code.markdown,
925-
optimized_explanation=opt.candidate.explanation,
926-
optimized_code_runtime=humanize_runtime(opt.runtime),
927-
speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=opt.runtime) * 100)}%",
928-
trace_id=trace_id,
929-
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
930-
optimized_line_profiler_results=opt.line_profiler_test_results["str_out"],
931-
function_references=function_references,
932-
)
933-
for opt in valid_optimizations
934-
]
935-
return executor.submit(ai_service_client.optimize_python_code_refinement, request=request)
936-
937953
def log_successful_optimization(
938954
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
939955
) -> None:

0 commit comments

Comments
 (0)