55from dataclasses import dataclass , field
66from itertools import chain
77from pathlib import Path
8- from typing import TYPE_CHECKING , Optional
8+ from typing import TYPE_CHECKING , Optional , Union
99
1010import libcst as cst
1111
@@ -122,6 +122,8 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
122122class DependencyCollector (cst .CSTVisitor ):
123123 """Collects dependencies between definitions using the visitor pattern with depth tracking."""
124124
125+ METADATA_DEPENDENCIES = (cst .metadata .ParentNodeProvider ,)
126+
125127 def __init__ (self , definitions : dict [str , UsageInfo ]) -> None :
126128 super ().__init__ ()
127129 self .definitions = definitions
@@ -259,8 +261,12 @@ def visit_Name(self, node: cst.Name) -> None:
259261 if self .processing_variable and name in self .current_variable_names :
260262 return
261263
262- # Check if name is a top-level definition we're tracking
263264 if name in self .definitions and name != self .current_top_level_name :
265+ # skip if we are refrencing a class attribute and not a top-level definition
266+ if self .class_depth > 0 :
267+ parent = self .get_metadata (cst .metadata .ParentNodeProvider , node )
268+ if parent is not None and isinstance (parent , cst .Attribute ):
269+ return
264270 self .definitions [self .current_top_level_name ].dependencies .add (name )
265271
266272
@@ -293,13 +299,19 @@ def _expand_qualified_functions(self) -> set[str]:
293299
294300 def mark_used_definitions (self ) -> None :
295301 """Find all qualified functions and mark them and their dependencies as used."""
296- # First identify all specified functions (including expanded ones)
297- functions_to_mark = [name for name in self .expanded_qualified_functions if name in self .definitions ]
302+ # Avoid list comprehension for set intersection
303+ expanded_names = self .expanded_qualified_functions
304+ defs = self .definitions
305+ functions_to_mark = (
306+ expanded_names & defs .keys ()
307+ if isinstance (expanded_names , set )
308+ else [name for name in expanded_names if name in defs ]
309+ )
298310
299311 # For each specified function, mark it and all its dependencies as used
300312 for func_name in functions_to_mark :
301- self . definitions [func_name ].used_by_qualified_function = True
302- for dep in self . definitions [func_name ].dependencies :
313+ defs [func_name ].used_by_qualified_function = True
314+ for dep in defs [func_name ].dependencies :
303315 self .mark_as_used_recursively (dep )
304316
305317 def mark_as_used_recursively (self , name : str ) -> None :
@@ -457,7 +469,28 @@ def remove_unused_definitions_recursively( # noqa: PLR0911
457469 return node , False
458470
459471
460- def remove_unused_definitions_by_function_names (code : str , qualified_function_names : set [str ]) -> str :
472+ def collect_top_level_defs_with_usages (
473+ code : Union [str , cst .Module ], qualified_function_names : set [str ]
474+ ) -> dict [str , UsageInfo ]:
475+ """Collect all top level definitions (classes, variables or functions) and their usages."""
476+ module = code if isinstance (code , cst .Module ) else cst .parse_module (code )
477+ # Collect all definitions (top level classes, variables or function)
478+ definitions = collect_top_level_definitions (module )
479+
480+ # Collect dependencies between definitions using the visitor pattern
481+ wrapper = cst .MetadataWrapper (module )
482+ dependency_collector = DependencyCollector (definitions )
483+ wrapper .visit (dependency_collector )
484+
485+ # Mark definitions used by specified functions, and their dependencies recursively
486+ usage_marker = QualifiedFunctionUsageMarker (definitions , qualified_function_names )
487+ usage_marker .mark_used_definitions ()
488+ return definitions
489+
490+
491+ def remove_unused_definitions_by_function_names (
492+ code : str , qualified_function_names : set [str ]
493+ ) -> tuple [str , dict [str , UsageInfo ]]:
461494 """Analyze a file and remove top level definitions not used by specified functions.
462495
463496 Top level definitions, in this context, are only classes, variables or functions.
@@ -476,19 +509,10 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
476509 return code
477510
478511 try :
479- # Collect all definitions (top level classes, variables or function)
480- definitions = collect_top_level_definitions (module )
481-
482- # Collect dependencies between definitions using the visitor pattern
483- dependency_collector = DependencyCollector (definitions )
484- module .visit (dependency_collector )
485-
486- # Mark definitions used by specified functions, and their dependencies recursively
487- usage_marker = QualifiedFunctionUsageMarker (definitions , qualified_function_names )
488- usage_marker .mark_used_definitions ()
512+ defs_with_usages = collect_top_level_defs_with_usages (module , qualified_function_names )
489513
490514 # Apply the recursive removal transformation
491- modified_module , _ = remove_unused_definitions_recursively (module , definitions )
515+ modified_module , _ = remove_unused_definitions_recursively (module , defs_with_usages )
492516
493517 return modified_module .code if modified_module else "" # noqa: TRY300
494518 except Exception as e :
0 commit comments