Skip to content

Commit a6945cb

Browse files
authored
Fix: Preserve the DAG evaluation order even when a transitive dependency is not included (#5335)
1 parent 0583f96 commit a6945cb

File tree

3 files changed

+226
-15
lines changed

3 files changed

+226
-15
lines changed

sqlmesh/core/scheduler.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,13 @@ def run_merged_intervals(
446446
if not selected_snapshots:
447447
selected_snapshots = list(merged_intervals)
448448

449-
snapshot_dag = snapshots_to_dag(selected_snapshots)
449+
# Build the full DAG from all snapshots to preserve transitive dependencies
450+
full_dag = snapshots_to_dag(self.snapshots.values())
451+
452+
# Create a subdag that includes the selected snapshots and all their upstream dependencies
453+
# This ensures that transitive dependencies are preserved even when intermediate nodes are not selected
454+
selected_snapshot_ids_set = {s.snapshot_id for s in selected_snapshots}
455+
snapshot_dag = full_dag.subdag(*selected_snapshot_ids_set)
450456

451457
batched_intervals = self.batch_intervals(
452458
merged_intervals, deployability_index, environment_naming_info, dag=snapshot_dag
@@ -646,20 +652,11 @@ def _dag(
646652
upstream_dependencies: t.List[SchedulingUnit] = []
647653

648654
for p_sid in snapshot.parents:
649-
if p_sid in self.snapshots:
650-
p_intervals = intervals_per_snapshot.get(p_sid.name, [])
651-
652-
if not p_intervals and p_sid in original_snapshots_to_create:
653-
upstream_dependencies.append(CreateNode(snapshot_name=p_sid.name))
654-
elif len(p_intervals) > 1:
655-
upstream_dependencies.append(DummyNode(snapshot_name=p_sid.name))
656-
else:
657-
for i, interval in enumerate(p_intervals):
658-
upstream_dependencies.append(
659-
EvaluateNode(
660-
snapshot_name=p_sid.name, interval=interval, batch_index=i
661-
)
662-
)
655+
upstream_dependencies.extend(
656+
self._find_upstream_dependencies(
657+
p_sid, intervals_per_snapshot, original_snapshots_to_create
658+
)
659+
)
663660

664661
batch_concurrency = snapshot.node.batch_concurrency
665662
batch_size = snapshot.node.batch_size
@@ -703,6 +700,36 @@ def _dag(
703700
)
704701
return dag
705702

703+
def _find_upstream_dependencies(
704+
self,
705+
parent_sid: SnapshotId,
706+
intervals_per_snapshot: t.Dict[str, Intervals],
707+
snapshots_to_create: t.Set[SnapshotId],
708+
) -> t.List[SchedulingUnit]:
709+
if parent_sid not in self.snapshots:
710+
return []
711+
712+
p_intervals = intervals_per_snapshot.get(parent_sid.name, [])
713+
714+
if p_intervals:
715+
if len(p_intervals) > 1:
716+
return [DummyNode(snapshot_name=parent_sid.name)]
717+
interval = p_intervals[0]
718+
return [EvaluateNode(snapshot_name=parent_sid.name, interval=interval, batch_index=0)]
719+
if parent_sid in snapshots_to_create:
720+
return [CreateNode(snapshot_name=parent_sid.name)]
721+
# This snapshot has no intervals and doesn't need creation which means
722+
# that it can be a transitive dependency
723+
transitive_deps: t.List[SchedulingUnit] = []
724+
parent_snapshot = self.snapshots[parent_sid]
725+
for grandparent_sid in parent_snapshot.parents:
726+
transitive_deps.extend(
727+
self._find_upstream_dependencies(
728+
grandparent_sid, intervals_per_snapshot, snapshots_to_create
729+
)
730+
)
731+
return transitive_deps
732+
706733
def _run_or_audit(
707734
self,
708735
environment: str | EnvironmentNamingInfo,

tests/core/test_integration.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,83 @@ def test_plan_ignore_cron(
16721672
)
16731673

16741674

1675+
@time_machine.travel("2023-01-08 15:00:00 UTC")
1676+
def test_run_respects_excluded_transitive_dependencies(init_and_plan_context: t.Callable):
1677+
context, _ = init_and_plan_context("examples/sushi")
1678+
1679+
# Graph: C <- B <- A
1680+
# B is a transitive dependency linking A and C
1681+
# Note that the alphabetical ordering of the model names is intentional and helps
1682+
# surface the problem
1683+
expressions_a = d.parse(
1684+
f"""
1685+
MODEL (
1686+
name memory.sushi.test_model_c,
1687+
kind FULL,
1688+
allow_partials true,
1689+
cron '@hourly',
1690+
);
1691+
1692+
SELECT @execution_ts AS execution_ts
1693+
"""
1694+
)
1695+
model_c = load_sql_based_model(expressions_a)
1696+
context.upsert_model(model_c)
1697+
1698+
# A VIEW model with no partials allowed and a daily cron instead of hourly.
1699+
expressions_b = d.parse(
1700+
f"""
1701+
MODEL (
1702+
name memory.sushi.test_model_b,
1703+
kind VIEW,
1704+
allow_partials false,
1705+
cron '@daily',
1706+
);
1707+
1708+
SELECT * FROM memory.sushi.test_model_c
1709+
"""
1710+
)
1711+
model_b = load_sql_based_model(expressions_b)
1712+
context.upsert_model(model_b)
1713+
1714+
expressions_a = d.parse(
1715+
f"""
1716+
MODEL (
1717+
name memory.sushi.test_model_a,
1718+
kind FULL,
1719+
allow_partials true,
1720+
cron '@hourly',
1721+
);
1722+
1723+
SELECT * FROM memory.sushi.test_model_b
1724+
"""
1725+
)
1726+
model_a = load_sql_based_model(expressions_a)
1727+
context.upsert_model(model_a)
1728+
1729+
context.plan("prod", skip_tests=True, auto_apply=True, no_prompts=True)
1730+
assert (
1731+
context.fetchdf("SELECT execution_ts FROM memory.sushi.test_model_c")["execution_ts"].iloc[
1732+
0
1733+
]
1734+
== "2023-01-08 15:00:00"
1735+
)
1736+
1737+
with time_machine.travel("2023-01-08 17:00:00 UTC", tick=False):
1738+
context.run(
1739+
"prod",
1740+
select_models=["*test_model_c", "*test_model_a"],
1741+
no_auto_upstream=True,
1742+
ignore_cron=True,
1743+
)
1744+
assert (
1745+
context.fetchdf("SELECT execution_ts FROM memory.sushi.test_model_a")[
1746+
"execution_ts"
1747+
].iloc[0]
1748+
== "2023-01-08 17:00:00"
1749+
)
1750+
1751+
16751752
@time_machine.travel("2023-01-08 15:00:00 UTC")
16761753
def test_run_with_select_models_no_auto_upstream(
16771754
init_and_plan_context: t.Callable,

tests/core/test_scheduler.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
SnapshotEvaluator,
3333
SnapshotChangeCategory,
3434
DeployabilityIndex,
35+
snapshots_to_dag,
3536
)
3637
from sqlmesh.utils.date import to_datetime, to_timestamp, DatetimeRanges, TimeLike
3738
from sqlmesh.utils.errors import CircuitBreakerError, NodeAuditsErrors
@@ -1019,3 +1020,109 @@ def record_execute_environment_statements(*args, **kwargs):
10191020
execute_env_idx = call_order.index("execute_environment_statements")
10201021
snapshots_to_create_idx = call_order.index("get_snapshots_to_create")
10211022
assert env_statements_idx < execute_env_idx < snapshots_to_create_idx
1023+
1024+
1025+
def test_dag_transitive_deps(mocker: MockerFixture, make_snapshot):
1026+
# Create a simple dependency chain: A <- B <- C
1027+
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id")))
1028+
snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT * FROM a")))
1029+
snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("SELECT * FROM b")))
1030+
1031+
snapshot_b = snapshot_b.model_copy(update={"parents": (snapshot_a.snapshot_id,)})
1032+
snapshot_c = snapshot_c.model_copy(update={"parents": (snapshot_b.snapshot_id,)})
1033+
1034+
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
1035+
snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING)
1036+
snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING)
1037+
1038+
scheduler = Scheduler(
1039+
snapshots=[snapshot_a, snapshot_b, snapshot_c],
1040+
snapshot_evaluator=mocker.Mock(),
1041+
state_sync=mocker.Mock(),
1042+
default_catalog=None,
1043+
)
1044+
1045+
# Test scenario: select only A and C (skip B)
1046+
merged_intervals = {
1047+
snapshot_a: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1048+
snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1049+
}
1050+
1051+
deployability_index = DeployabilityIndex.create([snapshot_a, snapshot_b, snapshot_c])
1052+
1053+
full_dag = snapshots_to_dag([snapshot_a, snapshot_b, snapshot_c])
1054+
1055+
dag = scheduler._dag(merged_intervals, snapshot_dag=full_dag)
1056+
assert dag.graph == {
1057+
EvaluateNode(
1058+
snapshot_name='"a"',
1059+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1060+
batch_index=0,
1061+
): set(),
1062+
EvaluateNode(
1063+
snapshot_name='"c"',
1064+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1065+
batch_index=0,
1066+
): {
1067+
EvaluateNode(
1068+
snapshot_name='"a"',
1069+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1070+
batch_index=0,
1071+
)
1072+
},
1073+
}
1074+
1075+
1076+
def test_dag_multiple_chain_transitive_deps(mocker: MockerFixture, make_snapshot):
1077+
# Create a more complex dependency graph:
1078+
# A <- B <- D <- E
1079+
# A <- C <- D <- E
1080+
# Select A and E only
1081+
snapshots = {}
1082+
for name in ["a", "b", "c", "d", "e"]:
1083+
snapshots[name] = make_snapshot(SqlModel(name=name, query=parse_one("SELECT 1 as id")))
1084+
snapshots[name].categorize_as(SnapshotChangeCategory.BREAKING)
1085+
1086+
# Set up dependencies
1087+
snapshots["b"] = snapshots["b"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})
1088+
snapshots["c"] = snapshots["c"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})
1089+
snapshots["d"] = snapshots["d"].model_copy(
1090+
update={"parents": (snapshots["b"].snapshot_id, snapshots["c"].snapshot_id)}
1091+
)
1092+
snapshots["e"] = snapshots["e"].model_copy(update={"parents": (snapshots["d"].snapshot_id,)})
1093+
1094+
scheduler = Scheduler(
1095+
snapshots=list(snapshots.values()),
1096+
snapshot_evaluator=mocker.Mock(),
1097+
state_sync=mocker.Mock(),
1098+
default_catalog=None,
1099+
)
1100+
1101+
# Only provide intervals for A and E
1102+
batched_intervals = {
1103+
snapshots["a"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1104+
snapshots["e"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
1105+
}
1106+
1107+
# Create subdag including transitive dependencies
1108+
full_dag = snapshots_to_dag(snapshots.values())
1109+
1110+
dag = scheduler._dag(batched_intervals, snapshot_dag=full_dag)
1111+
assert dag.graph == {
1112+
EvaluateNode(
1113+
snapshot_name='"a"',
1114+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1115+
batch_index=0,
1116+
): set(),
1117+
EvaluateNode(
1118+
snapshot_name='"e"',
1119+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1120+
batch_index=0,
1121+
): {
1122+
EvaluateNode(
1123+
snapshot_name='"a"',
1124+
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
1125+
batch_index=0,
1126+
)
1127+
},
1128+
}

0 commit comments

Comments
 (0)