Skip to content

Commit aef82da

Browse files
committed
fixes
1 parent 8c8d598 commit aef82da

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
211211
key: self.function_to_tests.get(key, []) + function_to_concolic_tests.get(key, [])
212212
for key in set(self.function_to_tests) | set(function_to_concolic_tests)
213213
}
214-
instrumented_unittests_created_for_function = self.instrument_existing_tests()
214+
instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests)
215215

216216
# Get a dict of file_path_to_classes of fto and helpers_of_fto
217217
file_path_to_helper_classes = defaultdict(set)
@@ -623,19 +623,19 @@ def cleanup_leftover_test_return_values() -> None:
623623
get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True)
624624
get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True)
625625

626-
def instrument_existing_tests(self) -> set[Path]:
626+
def instrument_existing_tests(self, function_to_all_tests: dict[str, list[FunctionCalledInTest]]) -> set[Path]:
627627
existing_test_files_count = 0
628628
replay_test_files_count = 0
629629
concolic_coverage_test_files_count = 0
630630
unique_instrumented_test_files = set()
631631

632632
func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
633-
if func_qualname not in self.function_to_tests:
633+
if func_qualname not in function_to_all_tests:
634634
logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.")
635635
console.rule()
636636
else:
637637
test_file_invocation_positions = defaultdict(list[FunctionCalledInTest])
638-
for tests_in_file in self.function_to_tests.get(func_qualname):
638+
for tests_in_file in function_to_all_tests.get(func_qualname):
639639
test_file_invocation_positions[
640640
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
641641
].append(tests_in_file)

codeflash/optimization/optimizer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,17 @@ def create_function_optimizer(
5252
function_to_optimize: FunctionToOptimize,
5353
function_to_optimize_ast: ast.FunctionDef | None = None,
5454
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
55+
function_to_optimize_source_code: str | None = None,
5556
) -> FunctionOptimizer:
5657
return FunctionOptimizer(
57-
test_cfg=self.test_cfg,
58-
aiservice_client=self.aiservice_client,
5958
function_to_optimize=function_to_optimize,
60-
function_to_optimize_ast=function_to_optimize_ast,
59+
test_cfg=self.test_cfg,
60+
function_to_optimize_source_code=function_to_optimize_source_code,
6161
function_to_tests=function_to_tests,
62+
function_to_optimize_ast=function_to_optimize_ast,
63+
aiservice_client=self.aiservice_client,
6264
args=self.args,
65+
6366
)
6467

6568
def run(self) -> None:
@@ -161,7 +164,7 @@ def run(self) -> None:
161164
continue
162165

163166
function_optimizer = self.create_function_optimizer(
164-
function_to_optimize, function_to_optimize_ast, function_to_tests
167+
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code
165168
)
166169
best_optimization = function_optimizer.optimize_function()
167170
self.test_files = TestFiles(test_files=[])

0 commit comments

Comments
 (0)