Skip to content

Commit b7b82ee

Browse files
authored
Split test generation and optimizations generation into separate methods (#957)
1 parent 786d3f7 commit b7b82ee

File tree

3 files changed

+135
-106
lines changed

3 files changed

+135
-106
lines changed

codeflash/api/aiservice.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def optimize_python_code( # noqa: D417
153153

154154
if response.status_code == 200:
155155
optimizations_json = response.json()["optimizations"]
156-
logger.info(f"!lsp|Generated {len(optimizations_json)} candidate optimizations.")
157156
console.rule()
158157
end_time = time.perf_counter()
159158
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")

codeflash/lsp/features/perform_optimization.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import concurrent.futures
34
import contextlib
5+
import contextvars
46
import os
57
from typing import TYPE_CHECKING
68

7-
from codeflash.cli_cmds.console import code_print
9+
from codeflash.cli_cmds.console import code_print, logger
810
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
911
from codeflash.either import is_successful
1012

@@ -44,24 +46,48 @@ def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: thr
4446
function_optimizer.function_to_tests = function_to_tests
4547

4648
abort_if_cancelled(cancel_event)
47-
test_setup_result = function_optimizer.generate_and_instrument_tests(
48-
code_context, should_run_experiment=should_run_experiment
49-
)
49+
50+
ctx_tests = contextvars.copy_context()
51+
ctx_opts = contextvars.copy_context()
52+
53+
def run_generate_tests(): # noqa: ANN202
54+
return function_optimizer.generate_and_instrument_tests(code_context)
55+
56+
def run_generate_optimizations(): # noqa: ANN202
57+
return function_optimizer.generate_optimizations(
58+
read_writable_code=code_context.read_writable_code,
59+
read_only_context_code=code_context.read_only_context_code,
60+
run_experiment=should_run_experiment,
61+
)
62+
63+
future_tests = function_optimizer.executor.submit(ctx_tests.run, run_generate_tests)
64+
future_optimizations = function_optimizer.executor.submit(ctx_opts.run, run_generate_optimizations)
65+
66+
logger.info(f"loading|Generating new tests and optimizations for function '{params.functionName}'...")
67+
concurrent.futures.wait([future_tests, future_optimizations])
68+
69+
test_setup_result = future_tests.result()
70+
optimization_result = future_optimizations.result()
71+
5072
abort_if_cancelled(cancel_event)
5173
if not is_successful(test_setup_result):
5274
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
75+
if not is_successful(optimization_result):
76+
return {"functionName": params.functionName, "status": "error", "message": optimization_result.failure()}
77+
5378
(
5479
generated_tests,
5580
function_to_concolic_tests,
5681
concolic_test_str,
57-
optimizations_set,
5882
generated_test_paths,
5983
generated_perf_test_paths,
6084
instrumented_unittests_created_for_function,
6185
original_conftest_content,
62-
function_references,
6386
) = test_setup_result.unwrap()
6487

88+
optimizations_set, function_references = optimization_result.unwrap()
89+
90+
logger.info(f"Generated '{len(optimizations_set.control)}' candidate optimizations.")
6591
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
6692
code_context=code_context,
6793
original_helper_code=original_helper_code,

codeflash/optimization/function_optimizer.py

Lines changed: 103 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,6 @@ def __init__(
238238
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
239239
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
240240
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
241-
self.generate_and_instrument_tests_results: (
242-
tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet] | None
243-
) = None
244241
n_tests = N_TESTS_TO_GENERATE_EFFECTIVE
245242
self.executor = concurrent.futures.ThreadPoolExecutor(
246243
max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4
@@ -275,21 +272,20 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P
275272
return Success((should_run_experiment, code_context, original_helper_code))
276273

277274
def generate_and_instrument_tests(
278-
self, code_context: CodeOptimizationContext, *, should_run_experiment: bool
275+
self, code_context: CodeOptimizationContext
279276
) -> Result[
280277
tuple[
281278
GeneratedTestsList,
282279
dict[str, set[FunctionCalledInTest]],
283280
str,
284-
OptimizationSet,
285281
list[Path],
286282
list[Path],
287283
set[Path],
288284
dict | None,
289-
str,
290-
]
285+
],
286+
str,
291287
]:
292-
"""Generate and instrument tests, returning all necessary data for optimization."""
288+
"""Generate and instrument tests for the function."""
293289
n_tests = N_TESTS_TO_GENERATE_EFFECTIVE
294290
generated_test_paths = [
295291
get_test_file_path(
@@ -304,34 +300,17 @@ def generate_and_instrument_tests(
304300
for test_index in range(n_tests)
305301
]
306302

307-
with progress_bar(
308-
f"Generating new tests and optimizations for function '{self.function_to_optimize.function_name}'",
309-
transient=True,
310-
revert_to_print=bool(get_pr_number()),
311-
):
312-
generated_results = self.generate_tests_and_optimizations(
313-
testgen_context=code_context.testgen_context,
314-
read_writable_code=code_context.read_writable_code,
315-
read_only_context_code=code_context.read_only_context_code,
316-
helper_functions=code_context.helper_functions,
317-
generated_test_paths=generated_test_paths,
318-
generated_perf_test_paths=generated_perf_test_paths,
319-
run_experiment=should_run_experiment,
320-
)
303+
test_results = self.generate_tests(
304+
testgen_context=code_context.testgen_context,
305+
helper_functions=code_context.helper_functions,
306+
generated_test_paths=generated_test_paths,
307+
generated_perf_test_paths=generated_perf_test_paths,
308+
)
321309

322-
if not is_successful(generated_results):
323-
return Failure(generated_results.failure())
310+
if not is_successful(test_results):
311+
return Failure(test_results.failure())
324312

325-
generated_tests: GeneratedTestsList
326-
optimizations_set: OptimizationSet
327-
(
328-
count_tests,
329-
generated_tests,
330-
function_to_concolic_tests,
331-
concolic_test_str,
332-
optimizations_set,
333-
function_references,
334-
) = generated_results.unwrap()
313+
count_tests, generated_tests, function_to_concolic_tests, concolic_test_str = test_results.unwrap()
335314

336315
for i, generated_test in enumerate(generated_tests.generated_tests):
337316
with generated_test.behavior_file_path.open("w", encoding="utf8") as f:
@@ -372,12 +351,10 @@ def generate_and_instrument_tests(
372351
generated_tests,
373352
function_to_concolic_tests,
374353
concolic_test_str,
375-
optimizations_set,
376354
generated_test_paths,
377355
generated_perf_test_paths,
378356
instrumented_unittests_created_for_function,
379357
original_conftest_content,
380-
function_references,
381358
)
382359
)
383360

@@ -395,24 +372,45 @@ def optimize_function(self) -> Result[BestOptimization, str]:
395372
function_name=self.function_to_optimize.function_name,
396373
)
397374

398-
test_setup_result = self.generate_and_instrument_tests( # also generates optimizations
399-
code_context, should_run_experiment=should_run_experiment
400-
)
375+
with progress_bar(
376+
f"Generating new tests and optimizations for function '{self.function_to_optimize.function_name}'",
377+
transient=True,
378+
revert_to_print=bool(get_pr_number()),
379+
):
380+
console.rule()
381+
# Generate tests and optimizations in parallel
382+
future_tests = self.executor.submit(self.generate_and_instrument_tests, code_context)
383+
future_optimizations = self.executor.submit(
384+
self.generate_optimizations,
385+
read_writable_code=code_context.read_writable_code,
386+
read_only_context_code=code_context.read_only_context_code,
387+
run_experiment=should_run_experiment,
388+
)
389+
390+
concurrent.futures.wait([future_tests, future_optimizations])
391+
392+
test_setup_result = future_tests.result()
393+
optimization_result = future_optimizations.result()
394+
console.rule()
395+
401396
if not is_successful(test_setup_result):
402397
return Failure(test_setup_result.failure())
403398

399+
if not is_successful(optimization_result):
400+
return Failure(optimization_result.failure())
401+
404402
(
405403
generated_tests,
406404
function_to_concolic_tests,
407405
concolic_test_str,
408-
optimizations_set,
409406
generated_test_paths,
410407
generated_perf_test_paths,
411408
instrumented_unittests_created_for_function,
412409
original_conftest_content,
413-
function_references,
414410
) = test_setup_result.unwrap()
415411

412+
optimizations_set, function_references = optimization_result.unwrap()
413+
416414
baseline_setup_result = self.setup_and_establish_baseline(
417415
code_context=code_context,
418416
original_helper_code=original_helper_code,
@@ -1109,28 +1107,78 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio
11091107
console.rule()
11101108
return unique_instrumented_test_files
11111109

1112-
def generate_tests_and_optimizations(
1110+
def generate_tests(
11131111
self,
11141112
testgen_context: CodeStringsMarkdown,
1115-
read_writable_code: CodeStringsMarkdown,
1116-
read_only_context_code: str,
11171113
helper_functions: list[FunctionSource],
11181114
generated_test_paths: list[Path],
11191115
generated_perf_test_paths: list[Path],
1120-
run_experiment: bool = False, # noqa: FBT001, FBT002
1121-
) -> Result[tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet], str, str]:
1116+
) -> Result[tuple[int, GeneratedTestsList, dict[str, set[FunctionCalledInTest]], str], str]:
1117+
"""Generate unit tests and concolic tests for the function."""
11221118
n_tests = N_TESTS_TO_GENERATE_EFFECTIVE
11231119
assert len(generated_test_paths) == n_tests
1124-
console.rule()
1125-
# Submit the test generation task as future
1120+
1121+
# Submit test generation tasks
11261122
future_tests = self.submit_test_generation_tasks(
11271123
self.executor,
11281124
testgen_context.markdown,
11291125
[definition.fully_qualified_name for definition in helper_functions],
11301126
generated_test_paths,
11311127
generated_perf_test_paths,
11321128
)
1129+
1130+
future_concolic_tests = self.executor.submit(
1131+
generate_concolic_tests, self.test_cfg, self.args, self.function_to_optimize, self.function_to_optimize_ast
1132+
)
1133+
1134+
# Wait for test futures to complete
1135+
concurrent.futures.wait([*future_tests, future_concolic_tests])
1136+
1137+
# Process test generation results
1138+
tests: list[GeneratedTests] = []
1139+
for future in future_tests:
1140+
res = future.result()
1141+
if res:
1142+
(
1143+
generated_test_source,
1144+
instrumented_behavior_test_source,
1145+
instrumented_perf_test_source,
1146+
test_behavior_path,
1147+
test_perf_path,
1148+
) = res
1149+
tests.append(
1150+
GeneratedTests(
1151+
generated_original_test_source=generated_test_source,
1152+
instrumented_behavior_test_source=instrumented_behavior_test_source,
1153+
instrumented_perf_test_source=instrumented_perf_test_source,
1154+
behavior_file_path=test_behavior_path,
1155+
perf_file_path=test_perf_path,
1156+
)
1157+
)
1158+
1159+
if not tests:
1160+
logger.warning(f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}")
1161+
return Failure(f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}")
1162+
1163+
function_to_concolic_tests, concolic_test_str = future_concolic_tests.result()
1164+
count_tests = len(tests)
1165+
if concolic_test_str:
1166+
count_tests += 1
1167+
1168+
logger.info(f"!lsp|Generated '{count_tests}' tests for '{self.function_to_optimize.function_name}'")
1169+
1170+
generated_tests = GeneratedTestsList(generated_tests=tests)
1171+
return Success((count_tests, generated_tests, function_to_concolic_tests, concolic_test_str))
1172+
1173+
def generate_optimizations(
1174+
self,
1175+
read_writable_code: CodeStringsMarkdown,
1176+
read_only_context_code: str,
1177+
run_experiment: bool = False, # noqa: FBT001, FBT002
1178+
) -> Result[tuple[OptimizationSet, str], str]:
1179+
"""Generate optimization candidates for the function."""
11331180
n_candidates = N_CANDIDATES_EFFECTIVE
1181+
11341182
future_optimization_candidates = self.executor.submit(
11351183
self.aiservice_client.optimize_python_code,
11361184
read_writable_code.markdown,
@@ -1140,11 +1188,7 @@ def generate_tests_and_optimizations(
11401188
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
11411189
is_async=self.function_to_optimize.is_async,
11421190
)
1143-
future_candidates_exp = None
11441191

1145-
future_concolic_tests = self.executor.submit(
1146-
generate_concolic_tests, self.test_cfg, self.args, self.function_to_optimize, self.function_to_optimize_ast
1147-
)
11481192
future_references = self.executor.submit(
11491193
get_opt_review_metrics,
11501194
self.function_to_optimize_source_code,
@@ -1153,7 +1197,10 @@ def generate_tests_and_optimizations(
11531197
self.project_root,
11541198
self.test_cfg.tests_root,
11551199
)
1156-
futures = [*future_tests, future_optimization_candidates, future_concolic_tests, future_references]
1200+
1201+
futures = [future_optimization_candidates, future_references]
1202+
future_candidates_exp = None
1203+
11571204
if run_experiment:
11581205
future_candidates_exp = self.executor.submit(
11591206
self.local_aiservice_client.optimize_python_code,
@@ -1166,63 +1213,20 @@ def generate_tests_and_optimizations(
11661213
)
11671214
futures.append(future_candidates_exp)
11681215

1169-
# Wait for all futures to complete
1216+
# Wait for optimization futures to complete
11701217
concurrent.futures.wait(futures)
11711218

11721219
# Retrieve results
11731220
candidates: list[OptimizedCandidate] = future_optimization_candidates.result()
1174-
logger.info(f"lsp|Generated '{len(candidates)}' candidate optimizations.")
1175-
console.rule()
1221+
logger.info(f"!lsp|Generated '{len(candidates)}' candidate optimizations.")
11761222

11771223
if not candidates:
11781224
return Failure(f"/!\\ NO OPTIMIZATIONS GENERATED for {self.function_to_optimize.function_name}")
11791225

11801226
candidates_experiment = future_candidates_exp.result() if future_candidates_exp else None
1181-
1182-
# Process test generation results
1183-
1184-
tests: list[GeneratedTests] = []
1185-
for future in future_tests:
1186-
res = future.result()
1187-
if res:
1188-
(
1189-
generated_test_source,
1190-
instrumented_behavior_test_source,
1191-
instrumented_perf_test_source,
1192-
test_behavior_path,
1193-
test_perf_path,
1194-
) = res
1195-
tests.append(
1196-
GeneratedTests(
1197-
generated_original_test_source=generated_test_source,
1198-
instrumented_behavior_test_source=instrumented_behavior_test_source,
1199-
instrumented_perf_test_source=instrumented_perf_test_source,
1200-
behavior_file_path=test_behavior_path,
1201-
perf_file_path=test_perf_path,
1202-
)
1203-
)
1204-
if not tests:
1205-
logger.warning(f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}")
1206-
return Failure(f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}")
1207-
function_to_concolic_tests, concolic_test_str = future_concolic_tests.result()
12081227
function_references = future_references.result()
1209-
count_tests = len(tests)
1210-
if concolic_test_str:
1211-
count_tests += 1
12121228

1213-
logger.info(f"Generated '{count_tests}' tests for '{self.function_to_optimize.function_name}'")
1214-
console.rule()
1215-
generated_tests = GeneratedTestsList(generated_tests=tests)
1216-
result = (
1217-
count_tests,
1218-
generated_tests,
1219-
function_to_concolic_tests,
1220-
concolic_test_str,
1221-
OptimizationSet(control=candidates, experiment=candidates_experiment),
1222-
function_references,
1223-
)
1224-
self.generate_and_instrument_tests_results = result
1225-
return Success(result)
1229+
return Success((OptimizationSet(control=candidates, experiment=candidates_experiment), function_references))
12261230

12271231
def setup_and_establish_baseline(
12281232
self,

0 commit comments

Comments
 (0)