Skip to content

Commit 15d2027

Browse files
keep the refrenced global definitions
1 parent ad09525 commit 15d2027

File tree

3 files changed

+218
-54
lines changed

3 files changed

+218
-54
lines changed

codeflash/context/code_context_extractor.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from codeflash.cli_cmds.console import logger
1313
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
1414
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
15-
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
15+
from codeflash.context.unused_definition_remover import (
16+
collect_top_level_defs_with_usages,
17+
extract_names_from_targets,
18+
remove_unused_definitions_by_function_names,
19+
)
1620
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
1721
from codeflash.models.models import (
1822
CodeContextType,
@@ -29,6 +33,8 @@
2933
from jedi.api.classes import Name
3034
from libcst import CSTNode
3135

36+
from codeflash.context.unused_definition_remover import UsageInfo
37+
3238

3339
def get_code_optimization_context(
3440
function_to_optimize: FunctionToOptimize,
@@ -498,8 +504,10 @@ def parse_code_and_prune_cst(
498504
) -> str:
499505
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
500506
module = cst.parse_module(code)
507+
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
508+
501509
if code_context_type == CodeContextType.READ_WRITABLE:
502-
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
510+
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages)
503511
elif code_context_type == CodeContextType.READ_ONLY:
504512
filtered_node, found_target = prune_cst_for_read_only_code(
505513
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
@@ -524,7 +532,7 @@ def parse_code_and_prune_cst(
524532

525533

526534
def prune_cst_for_read_writable_code( # noqa: PLR0911
527-
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
535+
node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = ""
528536
) -> tuple[cst.CSTNode | None, bool]:
529537
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
530538
@@ -569,6 +577,21 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
569577

570578
return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target
571579

580+
if isinstance(node, cst.Assign):
581+
for target in node.targets:
582+
names = extract_names_from_targets(target.target)
583+
for name in names:
584+
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
585+
return node, True
586+
return None, False
587+
588+
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
589+
names = extract_names_from_targets(node.target)
590+
for name in names:
591+
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
592+
return node, True
593+
return None, False
594+
572595
# For other nodes, we preserve them only if they contain target functions in their children.
573596
section_names = get_section_names(node)
574597
if not section_names:
@@ -583,7 +606,9 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
583606
new_children = []
584607
section_found_target = False
585608
for child in original_content:
586-
filtered, found_target = prune_cst_for_read_writable_code(child, target_functions, prefix)
609+
filtered, found_target = prune_cst_for_read_writable_code(
610+
child, target_functions, defs_with_usages, prefix
611+
)
587612
if filtered:
588613
new_children.append(filtered)
589614
section_found_target |= found_target
@@ -592,15 +617,16 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
592617
found_any_target = True
593618
updates[section] = new_children
594619
elif original_content is not None:
595-
filtered, found_target = prune_cst_for_read_writable_code(original_content, target_functions, prefix)
620+
filtered, found_target = prune_cst_for_read_writable_code(
621+
original_content, target_functions, defs_with_usages, prefix
622+
)
596623
if found_target:
597624
found_any_target = True
598625
if filtered:
599626
updates[section] = filtered
600627

601628
if not found_any_target:
602629
return None, False
603-
604630
return (node.with_changes(**updates) if updates else node), True
605631

606632

codeflash/context/unused_definition_remover.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66
from itertools import chain
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Optional
8+
from typing import TYPE_CHECKING, Optional, Union
99

1010
import libcst as cst
1111

@@ -122,6 +122,8 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
122122
class DependencyCollector(cst.CSTVisitor):
123123
"""Collects dependencies between definitions using the visitor pattern with depth tracking."""
124124

125+
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
126+
125127
def __init__(self, definitions: dict[str, UsageInfo]) -> None:
126128
super().__init__()
127129
self.definitions = definitions
@@ -259,8 +261,12 @@ def visit_Name(self, node: cst.Name) -> None:
259261
if self.processing_variable and name in self.current_variable_names:
260262
return
261263

262-
# Check if name is a top-level definition we're tracking
263264
if name in self.definitions and name != self.current_top_level_name:
265+
# skip if we are refrencing a class attribute and not a top-level definition
266+
if self.class_depth > 0:
267+
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
268+
if parent is not None and isinstance(parent, cst.Attribute):
269+
return
264270
self.definitions[self.current_top_level_name].dependencies.add(name)
265271

266272

@@ -293,13 +299,19 @@ def _expand_qualified_functions(self) -> set[str]:
293299

294300
def mark_used_definitions(self) -> None:
295301
"""Find all qualified functions and mark them and their dependencies as used."""
296-
# First identify all specified functions (including expanded ones)
297-
functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions]
302+
# Avoid list comprehension for set intersection
303+
expanded_names = self.expanded_qualified_functions
304+
defs = self.definitions
305+
functions_to_mark = (
306+
expanded_names & defs.keys()
307+
if isinstance(expanded_names, set)
308+
else [name for name in expanded_names if name in defs]
309+
)
298310

299311
# For each specified function, mark it and all its dependencies as used
300312
for func_name in functions_to_mark:
301-
self.definitions[func_name].used_by_qualified_function = True
302-
for dep in self.definitions[func_name].dependencies:
313+
defs[func_name].used_by_qualified_function = True
314+
for dep in defs[func_name].dependencies:
303315
self.mark_as_used_recursively(dep)
304316

305317
def mark_as_used_recursively(self, name: str) -> None:
@@ -457,7 +469,28 @@ def remove_unused_definitions_recursively( # noqa: PLR0911
457469
return node, False
458470

459471

460-
def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str:
472+
def collect_top_level_defs_with_usages(
473+
code: Union[str, cst.Module], qualified_function_names: set[str]
474+
) -> dict[str, UsageInfo]:
475+
"""Collect all top level definitions (classes, variables or functions) and their usages."""
476+
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
477+
# Collect all definitions (top level classes, variables or function)
478+
definitions = collect_top_level_definitions(module)
479+
480+
# Collect dependencies between definitions using the visitor pattern
481+
wrapper = cst.MetadataWrapper(module)
482+
dependency_collector = DependencyCollector(definitions)
483+
wrapper.visit(dependency_collector)
484+
485+
# Mark definitions used by specified functions, and their dependencies recursively
486+
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
487+
usage_marker.mark_used_definitions()
488+
return definitions
489+
490+
491+
def remove_unused_definitions_by_function_names(
492+
code: str, qualified_function_names: set[str]
493+
) -> tuple[str, dict[str, UsageInfo]]:
461494
"""Analyze a file and remove top level definitions not used by specified functions.
462495
463496
Top level definitions, in this context, are only classes, variables or functions.
@@ -476,19 +509,10 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
476509
return code
477510

478511
try:
479-
# Collect all definitions (top level classes, variables or function)
480-
definitions = collect_top_level_definitions(module)
481-
482-
# Collect dependencies between definitions using the visitor pattern
483-
dependency_collector = DependencyCollector(definitions)
484-
module.visit(dependency_collector)
485-
486-
# Mark definitions used by specified functions, and their dependencies recursively
487-
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
488-
usage_marker.mark_used_definitions()
512+
defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names)
489513

490514
# Apply the recursive removal transformation
491-
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
515+
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)
492516

493517
return modified_module.code if modified_module else "" # noqa: TRY300
494518
except Exception as e:

0 commit comments

Comments
 (0)