Skip to content

Commit 65cba54

Browse files
author
Codeflash Bot
committed
Merge remote-tracking branch 'origin/main' into cf-773
2 parents e642720 + 630ca8a commit 65cba54

File tree

6 files changed

+205
-53
lines changed

6 files changed

+205
-53
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,33 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
7272
return True
7373

7474

75+
def find_insertion_index_after_imports(node: cst.Module) -> int:
76+
"""Find the position of the last import statement in the top-level of the module."""
77+
insert_index = 0
78+
for i, stmt in enumerate(node.body):
79+
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
80+
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
81+
)
82+
83+
is_conditional_import = isinstance(stmt, cst.If) and all(
84+
isinstance(inner, cst.SimpleStatementLine)
85+
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
86+
for inner in stmt.body.body
87+
)
88+
89+
if is_top_level_import or is_conditional_import:
90+
insert_index = i + 1
91+
92+
# Stop scanning once we reach a class or function definition.
93+
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
94+
# Without this check, a stray import later in the file
95+
# would incorrectly shift our insertion index below actual code definitions.
96+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
97+
break
98+
99+
return insert_index
100+
101+
75102
class GlobalAssignmentTransformer(cst.CSTTransformer):
76103
"""Transforms global assignments in the original file with those from the new file."""
77104

@@ -122,32 +149,6 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
122149

123150
return updated_node
124151

125-
def _find_insertion_index(self, updated_node: cst.Module) -> int:
126-
"""Find the position of the last import statement in the top-level of the module."""
127-
insert_index = 0
128-
for i, stmt in enumerate(updated_node.body):
129-
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
130-
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
131-
)
132-
133-
is_conditional_import = isinstance(stmt, cst.If) and all(
134-
isinstance(inner, cst.SimpleStatementLine)
135-
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
136-
for inner in stmt.body.body
137-
)
138-
139-
if is_top_level_import or is_conditional_import:
140-
insert_index = i + 1
141-
142-
# Stop scanning once we reach a class or function definition.
143-
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
144-
# Without this check, a stray import later in the file
145-
# would incorrectly shift our insertion index below actual code definitions.
146-
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
147-
break
148-
149-
return insert_index
150-
151152
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
152153
# Add any new assignments that weren't in the original file
153154
new_statements = list(updated_node.body)
@@ -161,7 +162,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
161162

162163
if assignments_to_append:
163164
# after last top-level imports
164-
insert_index = self._find_insertion_index(updated_node)
165+
insert_index = find_insertion_index_after_imports(updated_node)
165166

166167
assignment_lines = [
167168
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])

codeflash/code_utils/code_replacer.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33
import ast
44
from collections import defaultdict
55
from functools import lru_cache
6+
from itertools import chain
67
from typing import TYPE_CHECKING, Optional, TypeVar
78

89
import libcst as cst
910
from libcst.metadata import PositionProvider
1011

1112
from codeflash.cli_cmds.console import logger
12-
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
13+
from codeflash.code_utils.code_extractor import (
14+
add_global_assignments,
15+
add_needed_imports_from_module,
16+
find_insertion_index_after_imports,
17+
)
1318
from codeflash.code_utils.config_parser import find_conftest_files
1419
from codeflash.code_utils.formatter import sort_imports
1520
from codeflash.code_utils.line_profile_utils import ImportAdder
@@ -249,6 +254,7 @@ def __init__(
249254
] = {} # keys are (class_name, function_name)
250255
self.new_functions: list[cst.FunctionDef] = []
251256
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
257+
self.new_classes: list[cst.ClassDef] = []
252258
self.current_class = None
253259
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
254260

@@ -271,6 +277,10 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
271277
self.current_class = node.name.value
272278

273279
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
280+
281+
if (node.name.value, ()) not in self.preexisting_objects:
282+
self.new_classes.append(node)
283+
274284
for child_node in node.body.body:
275285
if (
276286
self.preexisting_objects
@@ -290,13 +300,15 @@ class OptimFunctionReplacer(cst.CSTTransformer):
290300
def __init__(
291301
self,
292302
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
303+
new_classes: Optional[list[cst.ClassDef]] = None,
293304
new_functions: Optional[list[cst.FunctionDef]] = None,
294305
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
295306
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
296307
) -> None:
297308
super().__init__()
298309
self.modified_functions = modified_functions if modified_functions is not None else {}
299310
self.new_functions = new_functions if new_functions is not None else []
311+
self.new_classes = new_classes if new_classes is not None else []
300312
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
301313
self.modified_init_functions: dict[str, cst.FunctionDef] = (
302314
modified_init_functions if modified_init_functions is not None else {}
@@ -335,19 +347,33 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
335347
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
336348
node = updated_node
337349
max_function_index = None
338-
class_index = None
350+
max_class_index = None
339351
for index, _node in enumerate(node.body):
340352
if isinstance(_node, cst.FunctionDef):
341353
max_function_index = index
342354
if isinstance(_node, cst.ClassDef):
343-
class_index = index
355+
max_class_index = index
356+
357+
if self.new_classes:
358+
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}
359+
360+
unique_classes = [
361+
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
362+
]
363+
if unique_classes:
364+
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
365+
new_body = list(
366+
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
367+
)
368+
node = node.with_changes(body=new_body)
369+
344370
if max_function_index is not None:
345371
node = node.with_changes(
346372
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
347373
)
348-
elif class_index is not None:
374+
elif max_class_index is not None:
349375
node = node.with_changes(
350-
body=(*node.body[: class_index + 1], *self.new_functions, *node.body[class_index + 1 :])
376+
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
351377
)
352378
else:
353379
node = node.with_changes(body=(*self.new_functions, *node.body))
@@ -373,18 +399,20 @@ def replace_functions_in_file(
373399
parsed_function_names.append((class_name, function_name))
374400

375401
# Collect functions we want to modify from the optimized code
376-
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
402+
optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
403+
original_module = cst.parse_module(source_code)
404+
377405
visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
378-
module.visit(visitor)
406+
optimized_module.visit(visitor)
379407

380408
# Replace these functions in the original code
381409
transformer = OptimFunctionReplacer(
382410
modified_functions=visitor.modified_functions,
411+
new_classes=visitor.new_classes,
383412
new_functions=visitor.new_functions,
384413
new_class_functions=visitor.new_class_functions,
385414
modified_init_functions=visitor.modified_init_functions,
386415
)
387-
original_module = cst.parse_module(source_code)
388416
modified_tree = original_module.visit(transformer)
389417
return modified_tree.code
390418

codeflash/code_utils/formatter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import subprocess
99
import tempfile
1010
from pathlib import Path
11-
from typing import Optional, Union
11+
from typing import Any, Optional, Union
1212

1313
import isort
1414

@@ -178,10 +178,10 @@ def format_code(
178178
return formatted_code
179179

180180

181-
def sort_imports(code: str, *, float_to_top: bool = False) -> str:
181+
def sort_imports(code: str, **kwargs: Any) -> str: # noqa : ANN401
182182
try:
183183
# Deduplicate and sort imports, modify the code in memory, not on disk
184-
sorted_code = isort.code(code=code, float_to_top=float_to_top)
184+
sorted_code = isort.code(code, **kwargs)
185185
except Exception: # this will also catch the FileSkipComment exception, use this fn everywhere
186186
logger.exception("Failed to sort imports with isort.")
187187
return code # Fall back to original code if isort fails

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pathlib import Path
77
from typing import TYPE_CHECKING
88

9-
import isort
109
import libcst as cst
1110

1211
from codeflash.cli_cmds.console import logger
@@ -741,7 +740,7 @@ def inject_async_profiling_into_existing_test(
741740
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
742741

743742
tree.body = [*new_imports, *tree.body]
744-
return True, isort.code(ast.unparse(tree), float_to_top=True)
743+
return True, sort_imports(ast.unparse(tree), float_to_top=True)
745744

746745

747746
def inject_profiling_into_existing_test(
@@ -789,7 +788,7 @@ def inject_profiling_into_existing_test(
789788
additional_functions = [create_wrapper_function(mode)]
790789

791790
tree.body = [*new_imports, *additional_functions, *tree.body]
792-
return True, isort.code(ast.unparse(tree), float_to_top=True)
791+
return True, sort_imports(ast.unparse(tree), float_to_top=True)
793792

794793

795794
def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef:

0 commit comments

Comments
 (0)