Skip to content

Commit b866d33

Browse files
authored
Fix: track shadowed jinja variable assignments correctly (#5503)
1 parent e1510ce commit b866d33

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

sqlmesh/utils/jinja.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[C
133133
vars_in_scope = vars_in_scope.copy()
134134
for child_node in node.iter_child_nodes():
135135
if "target" in child_node.fields:
136+
# For nodes with assignment targets (Assign, AssignBlock, For, Import),
137+
# the target name could shadow a reference in the right hand side.
138+
# So we need to process the RHS before adding the target to scope.
139+
# For example: {% set model = model.path %} should track model.path.
140+
yield from find_call_names(child_node, vars_in_scope)
141+
136142
target = getattr(child_node, "target")
137143
if isinstance(target, nodes.Name):
138144
vars_in_scope.add(target.name)
@@ -149,7 +155,9 @@ def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[C
149155
name = call_name(child_node)
150156
if name[0][0] != "'" and name[0] not in vars_in_scope:
151157
yield (name, child_node)
152-
yield from find_call_names(child_node, vars_in_scope)
158+
159+
if "target" not in child_node.fields:
160+
yield from find_call_names(child_node, vars_in_scope)
153161

154162

155163
def extract_call_names(

tests/dbt/test_manifest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,46 @@ def test_macro_depenency_none_str():
324324

325325
# "None" macro shouldn't raise a KeyError
326326
_macro_references(helper._manifest, node)
327+
328+
329+
@pytest.mark.xdist_group("dbt_manifest")
330+
def test_macro_assignment_shadowing(create_empty_project):
331+
project_name = "local"
332+
project_path, models_path = create_empty_project(project_name=project_name)
333+
334+
macros_path = project_path / "macros"
335+
macros_path.mkdir()
336+
337+
(macros_path / "model_path_macro.sql").write_text("""
338+
{% macro model_path_macro() %}
339+
{% if execute %}
340+
{% set model = model.path.split('/')[-1].replace('.sql', '') %}
341+
SELECT '{{ model }}' as model_name
342+
{% else %}
343+
SELECT 'placeholder' as placeholder
344+
{% endif %}
345+
{% endmacro %}
346+
""")
347+
348+
(models_path / "model_using_path_macro.sql").write_text("""
349+
{{ model_path_macro() }}
350+
""")
351+
352+
context = DbtContext(project_path)
353+
profile = Profile.load(context)
354+
355+
helper = ManifestHelper(
356+
project_path,
357+
project_path,
358+
project_name,
359+
profile.target,
360+
model_defaults=ModelDefaultsConfig(start="2020-01-01"),
361+
)
362+
363+
macros = helper.macros(project_name)
364+
assert "model_path_macro" in macros
365+
assert "path" in macros["model_path_macro"].dependencies.model_attrs.attrs
366+
367+
models = helper.models()
368+
assert "model_using_path_macro" in models
369+
assert "path" in models["model_using_path_macro"].dependencies.model_attrs.attrs

0 commit comments

Comments
 (0)