33import ast
44from collections import defaultdict
55from functools import lru_cache
6+ from itertools import chain
67from typing import TYPE_CHECKING , Optional , TypeVar
78
89import libcst as cst
910from libcst .metadata import PositionProvider
1011
1112from codeflash .cli_cmds .console import logger
12- from codeflash .code_utils .code_extractor import add_global_assignments , add_needed_imports_from_module
13+ from codeflash .code_utils .code_extractor import (
14+ add_global_assignments ,
15+ add_needed_imports_from_module ,
16+ find_insertion_index_after_imports ,
17+ )
1318from codeflash .code_utils .config_parser import find_conftest_files
1419from codeflash .code_utils .formatter import sort_imports
1520from codeflash .code_utils .line_profile_utils import ImportAdder
@@ -249,6 +254,7 @@ def __init__(
249254 ] = {} # keys are (class_name, function_name)
250255 self .new_functions : list [cst .FunctionDef ] = []
251256 self .new_class_functions : dict [str , list [cst .FunctionDef ]] = defaultdict (list )
257+ self .new_classes : list [cst .ClassDef ] = []
252258 self .current_class = None
253259 self .modified_init_functions : dict [str , cst .FunctionDef ] = {}
254260
@@ -271,6 +277,10 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
271277 self .current_class = node .name .value
272278
273279 parents = (FunctionParent (name = node .name .value , type = "ClassDef" ),)
280+
281+ if (node .name .value , ()) not in self .preexisting_objects :
282+ self .new_classes .append (node )
283+
274284 for child_node in node .body .body :
275285 if (
276286 self .preexisting_objects
@@ -290,13 +300,15 @@ class OptimFunctionReplacer(cst.CSTTransformer):
290300 def __init__ (
291301 self ,
292302 modified_functions : Optional [dict [tuple [str | None , str ], cst .FunctionDef ]] = None ,
303+ new_classes : Optional [list [cst .ClassDef ]] = None ,
293304 new_functions : Optional [list [cst .FunctionDef ]] = None ,
294305 new_class_functions : Optional [dict [str , list [cst .FunctionDef ]]] = None ,
295306 modified_init_functions : Optional [dict [str , cst .FunctionDef ]] = None ,
296307 ) -> None :
297308 super ().__init__ ()
298309 self .modified_functions = modified_functions if modified_functions is not None else {}
299310 self .new_functions = new_functions if new_functions is not None else []
311+ self .new_classes = new_classes if new_classes is not None else []
300312 self .new_class_functions = new_class_functions if new_class_functions is not None else defaultdict (list )
301313 self .modified_init_functions : dict [str , cst .FunctionDef ] = (
302314 modified_init_functions if modified_init_functions is not None else {}
@@ -335,19 +347,33 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
335347 def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module : # noqa: ARG002
336348 node = updated_node
337349 max_function_index = None
338- class_index = None
350+ max_class_index = None
339351 for index , _node in enumerate (node .body ):
340352 if isinstance (_node , cst .FunctionDef ):
341353 max_function_index = index
342354 if isinstance (_node , cst .ClassDef ):
343- class_index = index
355+ max_class_index = index
356+
357+ if self .new_classes :
358+ existing_class_names = {_node .name .value for _node in node .body if isinstance (_node , cst .ClassDef )}
359+
360+ unique_classes = [
361+ new_class for new_class in self .new_classes if new_class .name .value not in existing_class_names
362+ ]
363+ if unique_classes :
364+ new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports (node )
365+ new_body = list (
366+ chain (node .body [:new_classes_insertion_idx ], unique_classes , node .body [new_classes_insertion_idx :])
367+ )
368+ node = node .with_changes (body = new_body )
369+
344370 if max_function_index is not None :
345371 node = node .with_changes (
346372 body = (* node .body [: max_function_index + 1 ], * self .new_functions , * node .body [max_function_index + 1 :])
347373 )
348- elif class_index is not None :
374+ elif max_class_index is not None :
349375 node = node .with_changes (
350- body = (* node .body [: class_index + 1 ], * self .new_functions , * node .body [class_index + 1 :])
376+ body = (* node .body [: max_class_index + 1 ], * self .new_functions , * node .body [max_class_index + 1 :])
351377 )
352378 else :
353379 node = node .with_changes (body = (* self .new_functions , * node .body ))
@@ -373,18 +399,20 @@ def replace_functions_in_file(
373399 parsed_function_names .append ((class_name , function_name ))
374400
375401 # Collect functions we want to modify from the optimized code
376- module = cst .metadata .MetadataWrapper (cst .parse_module (optimized_code ))
402+ optimized_module = cst .metadata .MetadataWrapper (cst .parse_module (optimized_code ))
403+ original_module = cst .parse_module (source_code )
404+
377405 visitor = OptimFunctionCollector (preexisting_objects , set (parsed_function_names ))
378- module .visit (visitor )
406+ optimized_module .visit (visitor )
379407
380408 # Replace these functions in the original code
381409 transformer = OptimFunctionReplacer (
382410 modified_functions = visitor .modified_functions ,
411+ new_classes = visitor .new_classes ,
383412 new_functions = visitor .new_functions ,
384413 new_class_functions = visitor .new_class_functions ,
385414 modified_init_functions = visitor .modified_init_functions ,
386415 )
387- original_module = cst .parse_module (source_code )
388416 modified_tree = original_module .visit (transformer )
389417 return modified_tree .code
390418
0 commit comments