@@ -51,10 +51,13 @@ def __init__(
5151 self .new_functions : list [cst .FunctionDef ] = []
5252 self .new_class_functions : dict [str , list [cst .FunctionDef ]] = defaultdict (list )
5353 self .current_class = None
54+ self .modified_init_functions : dict [str , cst .FunctionDef ] = {}
5455
5556 def visit_FunctionDef (self , node : cst .FunctionDef ) -> bool :
5657 if (self .current_class , node .name .value ) in self .function_names :
5758 self .modified_functions [(self .current_class , node .name .value )] = node
59+ elif self .current_class and node .name .value == "__init__" :
60+ self .modified_init_functions [self .current_class ] = node
5861 elif (
5962 self .preexisting_objects
6063 and (node .name .value , []) not in self .preexisting_objects
@@ -76,6 +79,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
7679 and (child_node .name .value , parents ) not in self .preexisting_objects
7780 ):
7881 self .new_class_functions [node .name .value ].append (child_node )
82+
7983 return True
8084
8185 def leave_ClassDef (self , node : cst .ClassDef ) -> None :
@@ -89,11 +93,15 @@ def __init__(
8993 modified_functions : dict [tuple [str | None , str ], cst .FunctionDef ] = None ,
9094 new_functions : list [cst .FunctionDef ] = None ,
9195 new_class_functions : dict [str , list [cst .FunctionDef ]] = None ,
96+ modified_init_functions : dict [str , cst .FunctionDef ] = None ,
9297 ) -> None :
9398 super ().__init__ ()
9499 self .modified_functions = modified_functions if modified_functions is not None else {}
95100 self .new_functions = new_functions if new_functions is not None else []
96101 self .new_class_functions = new_class_functions if new_class_functions is not None else defaultdict (list )
102+ self .modified_init_functions : dict [str , cst .FunctionDef ] = (
103+ modified_init_functions if modified_init_functions is not None else {}
104+ )
97105 self .current_class = None
98106
99107 def visit_FunctionDef (self , node : cst .FunctionDef ) -> bool :
@@ -103,6 +111,8 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
103111 if (self .current_class , original_node .name .value ) in self .modified_functions :
104112 node = self .modified_functions [(self .current_class , original_node .name .value )]
105113 return updated_node .with_changes (body = node .body , decorators = node .decorators )
114+ if original_node .name .value == "__init__" and self .current_class in self .modified_init_functions :
115+ return self .modified_init_functions [self .current_class ]
106116
107117 return updated_node
108118
@@ -173,6 +183,7 @@ def replace_functions_in_file(
173183 modified_functions = visitor .modified_functions ,
174184 new_functions = visitor .new_functions ,
175185 new_class_functions = visitor .new_class_functions ,
186+ modified_init_functions = visitor .modified_init_functions ,
176187 )
177188 original_module = cst .parse_module (source_code )
178189 modified_tree = original_module .visit (transformer )
@@ -183,15 +194,14 @@ def replace_functions_and_add_imports(
183194 source_code : str ,
184195 function_names : list [str ],
185196 optimized_code : str ,
186- file_path_of_module_with_function_to_optimize : Path ,
187197 module_abspath : Path ,
188198 preexisting_objects : list [tuple [str , list [FunctionParent ]]],
189199 project_root_path : Path ,
190200) -> str :
191201 return add_needed_imports_from_module (
192202 optimized_code ,
193203 replace_functions_in_file (source_code , function_names , optimized_code , preexisting_objects ),
194- file_path_of_module_with_function_to_optimize ,
204+ module_abspath ,
195205 module_abspath ,
196206 project_root_path ,
197207 )
@@ -200,20 +210,13 @@ def replace_functions_and_add_imports(
200210def replace_function_definitions_in_module (
201211 function_names : list [str ],
202212 optimized_code : str ,
203- file_path_of_module_with_function_to_optimize : Path ,
204213 module_abspath : Path ,
205214 preexisting_objects : list [tuple [str , list [FunctionParent ]]],
206215 project_root_path : Path ,
207216) -> bool :
208217 source_code : str = module_abspath .read_text (encoding = "utf8" )
209218 new_code : str = replace_functions_and_add_imports (
210- source_code ,
211- function_names ,
212- optimized_code ,
213- file_path_of_module_with_function_to_optimize ,
214- module_abspath ,
215- preexisting_objects ,
216- project_root_path ,
219+ source_code , function_names , optimized_code , module_abspath , preexisting_objects , project_root_path
217220 )
218221 if is_zero_diff (source_code , new_code ):
219222 return False
0 commit comments