Skip to content

Commit d77458c

Browse files
Generalize AST import detection through shared symbol helper
Co-authored-by: Shri Sukhani <shrisukhani@users.noreply.github.com>
1 parent f5ab921 commit d77458c

File tree

2 files changed

+49
-12
lines changed

2 files changed

+49
-12
lines changed

tests/ast_import_utils.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
import ast
22

33

4-
def imports_collect_function_sources(module_text: str) -> bool:
4+
def imports_symbol_from_module(module_text: str, module: str, symbol: str) -> bool:
55
module_ast = ast.parse(module_text)
66
for node in module_ast.body:
77
if not isinstance(node, ast.ImportFrom):
88
continue
9-
if node.module != "tests.ast_function_source_utils":
9+
if node.module != module:
1010
continue
11-
if any(alias.name == "collect_function_sources" for alias in node.names):
11+
if any(alias.name == symbol for alias in node.names):
1212
return True
1313
return False
1414

1515

16+
def imports_collect_function_sources(module_text: str) -> bool:
17+
return imports_symbol_from_module(
18+
module_text,
19+
module="tests.ast_function_source_utils",
20+
symbol="collect_function_sources",
21+
)
22+
23+
1624
def imports_imports_collect_function_sources(module_text: str) -> bool:
17-
module_ast = ast.parse(module_text)
18-
for node in module_ast.body:
19-
if not isinstance(node, ast.ImportFrom):
20-
continue
21-
if node.module != "tests.ast_import_utils":
22-
continue
23-
if any(alias.name == "imports_collect_function_sources" for alias in node.names):
24-
return True
25-
return False
25+
return imports_symbol_from_module(
26+
module_text,
27+
module="tests.ast_import_utils",
28+
symbol="imports_collect_function_sources",
29+
)

tests/test_ast_import_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
from tests.ast_import_utils import (
4+
imports_symbol_from_module,
45
imports_collect_function_sources,
56
imports_imports_collect_function_sources,
67
)
@@ -26,6 +27,38 @@ def test_imports_collect_function_sources_ignores_non_matching_imports():
2627
assert imports_collect_function_sources(module_text) is False
2728

2829

30+
def test_imports_symbol_from_module_detects_expected_symbol():
31+
module_text = (
32+
"from tests.ast_import_utils import imports_collect_function_sources\n"
33+
"imports_collect_function_sources('dummy')\n"
34+
)
35+
36+
assert (
37+
imports_symbol_from_module(
38+
module_text,
39+
module="tests.ast_import_utils",
40+
symbol="imports_collect_function_sources",
41+
)
42+
is True
43+
)
44+
45+
46+
def test_imports_symbol_from_module_ignores_unrelated_symbols():
47+
module_text = (
48+
"from tests.ast_import_utils import imports_collect_function_sources\n"
49+
"from tests.ast_import_utils import imports_imports_collect_function_sources\n"
50+
)
51+
52+
assert (
53+
imports_symbol_from_module(
54+
module_text,
55+
module="tests.ast_import_utils",
56+
symbol="missing_symbol",
57+
)
58+
is False
59+
)
60+
61+
2962
def test_imports_collect_function_sources_supports_aliased_import():
3063
module_text = (
3164
"from tests.ast_function_source_utils import collect_function_sources as cfs\n"

0 commit comments

Comments
 (0)