Skip to content

Commit d519efd

Browse files
Introduce canonical Wave region formats and annotated pass adapters (iree-org#1120)
## Summary This is a step towards a canonical form for Wave FX graphs. This change only tackles region structure for now, not all graph structure more broadly. The canonical region form chosen here is an isolated region representation: nested regions do not directly reference outer values, and all captures are represented explicitly as lifted `Placeholder`s at the start of the region signature. This gives us one stable structural form to target for verification, printing, and roundtrips. We are introducing this gradually. Not all existing passes operate on that canonical region form yet, and this change does not require rewriting them all at once. Instead, passes can now declare which region form they expect, and the pass boundary will temporarily adapt the graph into that form before the pass runs and return to canonical form afterwards in the normal pipeline. This change introduces four region forms: - `ISOLATED`: the canonical form, with explicit lifted placeholder captures - `LEGACY_PLACEHOLDERS`: the older placeholder-based capture form expected by several existing passes - `DIRECT_OUTER_REF`: a legacy form where nested regions directly reference outer values - `SCHEDULE_SIGNATURE_PLACEHOLDERS`: a hybrid form that keeps placeholders only for schedule-signature sources ## What this PR does - introduce canonical region-capture handling and verification for nested Wave regions - let passes declare their required region form directly via annotation through a decorator - adapt the pipeline to canonicalize at pass boundaries while still supporting legacy region forms where needed - update lit tests to request raw post-pass output explicitly at the call site ## Additional Change `minimize_shared_allocs` now skips dead shared allocs instead of assuming every alloc still has a live first/last use. This became necessary once the new canonical-region migration flow exposed intermediate graphs where an alloc may survive temporarily after its uses are gone, making an old pass assumption explicit. --------- Signed-off-by: Martin Lücke <martin.luecke@amd.com>
1 parent 14a0ddd commit d519efd

61 files changed

Lines changed: 1792 additions & 233 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/wave/canonical_ir_format.md

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Canonical Region IR Format
2+
3+
This document describes the canonical region structure used by the Wave FX pipeline.
4+
5+
It is intentionally limited in scope: it only covers nested region interfaces
6+
and capture structure. It does not attempt to define a full canonical form for
7+
all FX graph details such as node ordering, `vector_shapes`, or write
8+
dependencies.
9+
10+
## Goal
11+
12+
Wave is moving toward a single canonical region form so that:
13+
14+
- nested regions have one stable structural representation in Python
15+
- FX <-> MLIR roundtrips preserve region structure in a form that later Python
16+
passes can continue to process
17+
- for migration purposes, passes may declare which temporary non-canonical region view they need
18+
without forcing the whole pipeline to stay in that view
19+
20+
The canonical form is represented by `RegionFormat.ISOLATED` in
21+
`wave_lang/kernel/wave/region_canonicalization.py`.
22+
23+
## Terms
24+
25+
- Outer source: a node defined outside a nested region but used by that region
26+
- Local capture: the region-local representative of an outer source
27+
- Capture signature: the ordered list stored on the parent `NestedRegionOp` in
28+
`implicit_captures`
29+
- Direct outer reference: a region node operand that points straight to a node
30+
in the outer graph
31+
32+
## Canonical Form: `ISOLATED`
33+
34+
`ISOLATED` is the canonical/default region form.
35+
36+
Structural invariants:
37+
38+
- nested region bodies do not directly reference outer graph nodes
39+
- every captured outer value used inside the region is represented by a region-local
40+
`Placeholder`
41+
- these capture placeholders form the leading non-`IterArg` input prefix of the
42+
subgraph
43+
- the parent `NestedRegionOp.implicit_captures` list is the authoritative
44+
ordered capture signature: it defines *which* outer values are captured and in
45+
*what order*
46+
- each local capture placeholder carries a `meta["lifted"]` link to its outer
47+
source. This per-placeholder metadata is derived from `implicit_captures` and
48+
must stay consistent with it (the verifier checks this)
49+
50+
In other words, the region interface is explicit and isolated from above.
51+
52+
Schematic shape:
53+
54+
```text
55+
root graph:
56+
%outer_a
57+
%outer_b
58+
%region = iterate(..., implicit_captures=[%outer_a, %outer_b], ...)
59+
60+
region subgraph:
61+
%iter_arg0 = placeholder(iter arg)
62+
%outer_a = placeholder(lifted from outer)
63+
%outer_b = placeholder(lifted from outer)
64+
...
65+
```
66+
67+
## Temporary Non-Canonical Forms
68+
69+
Not all existing passes operate on `ISOLATED` yet. To migrate incrementally,
70+
passes may request one of several temporary region views. The pass boundary
71+
adapts into that form before the pass runs and, in the normal pipeline,
72+
canonicalizes back to `ISOLATED` afterwards.
73+
74+
### `LEGACY_PLACEHOLDERS`
75+
76+
This is the older placeholder-based capture form still expected by some
77+
pre-existing passes.
78+
79+
Structural properties:
80+
81+
- captured outer values are represented by region-local placeholders
82+
- the mapping from a local placeholder back to its outer source may still be
83+
recovered with ad-hoc conventions that pre-existing passes relied on: name
84+
matching and positional fallback within the capture prefix (codified in
85+
`_try_resolve_legacy_capture_source` in `region_canonicalization.py`)
86+
- unlike `ISOLATED`, this form does not require `implicit_captures` plus
87+
`meta["lifted"]` to be the sole authoritative description of the capture
88+
interface
89+
90+
This is a weaker contract than `ISOLATED`. A pass marked
91+
`LEGACY_PLACEHOLDERS` may still reason about captures through placeholder
92+
layout or legacy lookup behavior instead of relying only on the canonical
93+
capture interface. An already-canonical region may also satisfy this weaker
94+
contract, so the adapter can be a no-op on some graphs.
95+
96+
This mode exists to support passes that still expect legacy placeholder
97+
structure while the pipeline as a whole moves toward explicit canonical
98+
captures.
99+
100+
### `DIRECT_OUTER_REF`
101+
102+
This is a legacy form where region bodies directly reference outer graph nodes.
103+
104+
Structural properties:
105+
106+
- operands inside the region may point directly to outer nodes
107+
- capture placeholders are removed or bypassed where possible
108+
- the parent capture signature may still track those outer values, but the body
109+
itself is not isolated from above
110+
111+
This form is convenient for passes that want to inspect or mutate the original
112+
outer values directly, especially around captured memory operands.
113+
114+
### `SCHEDULE_SIGNATURE_PLACEHOLDERS`
115+
116+
This is a hybrid legacy form used by scheduling-related passes.
117+
118+
Structural properties (schedule-signature sources are the outer values that
119+
define the region boundary from the scheduler's point of view, namely
120+
outer-graph `Placeholder`s, i.e. kernel arguments, and `NewRegister`s):
121+
122+
- placeholders are kept only for those schedule-signature sources
123+
- non-signature captures are rewritten back to direct outer references
124+
- the region mixes explicit placeholders for signature-defining values with
125+
direct outer references for everything else
126+
127+
In practice, the schedule-signature sources are the outer values that define
128+
the region boundary from the scheduler's point of view, namely values such as
129+
root placeholders and `NewRegister`s.
130+
131+
## Why There Are Multiple Forms
132+
133+
The long-term goal is for passes to converge on `ISOLATED`.
134+
135+
The intermediate forms exist because rewriting every pass at once would be too
136+
large and too risky. Instead:
137+
138+
1. the pipeline keeps one canonical form
139+
2. each pass declares the temporary form it currently expects
140+
3. pass-boundary adapters convert into that form just for the duration of the
141+
pass
142+
4. the normal pipeline returns to canonical form afterwards
143+
144+
This makes the migration incremental while still establishing one structural
145+
source of truth.
146+
147+
## Pass Contract
148+
149+
Passes declare their required region form with `@requires_region_format(...)`.
150+
151+
The important contract is:
152+
153+
- if a pass does not declare a region form, it is assumed to operate on the
154+
canonical `ISOLATED` form
155+
- in the normal pipeline, pass outputs are canonicalized back to `ISOLATED`
156+
- white-box tests that want to inspect a temporary legacy form must request
157+
`canonicalize_output=False` explicitly at the call site
158+
159+
This keeps the default pipeline principled while still allowing tests to inspect
160+
the raw intermediate structure when needed.
161+
162+
## Non-Goals
163+
164+
This document does not define:
165+
166+
- a canonical ordering for all FX nodes
167+
- a canonical form for `GetResult` materialization beyond what is required by
168+
region structure
169+
- a canonical form for `vector_shapes`, write dependencies, or other downstream
170+
analysis state
171+
172+
Those are separate concerns that may later build on top of this region-level
173+
structural baseline.

lit_tests/kernel/wave/barrier_strategies.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_gemm():
179179
set_post_expansion_indices(trace, constraints)
180180
tweak_index(graph)
181181
hoist_loop_invariant_ops(trace, constraints)
182-
add_shared_memory_barriers(trace)
182+
add_shared_memory_barriers(trace, canonicalize_output=False)
183183
print_trace(trace, False)
184184

185185
# CHECK-LABEL: test_gemm
@@ -231,8 +231,14 @@ def test_split_barriers():
231231
tweak_index(graph)
232232
hoist_loop_invariant_ops(trace, constraints)
233233
schedule_graph(trace, constraints, True, enable_scheduling)
234-
schedule_reordering(trace, constraints, enable_scheduling, use_global_to_shared)
235-
add_shared_memory_barriers(trace, target="gfx1201")
234+
schedule_reordering(
235+
trace,
236+
constraints,
237+
enable_scheduling,
238+
use_global_to_shared,
239+
canonicalize_output=False,
240+
)
241+
add_shared_memory_barriers(trace, target="gfx1201", canonicalize_output=False)
236242
print_trace(trace, False)
237243

238244
# CHECK-LABEL: test_split_barriers
@@ -363,7 +369,7 @@ def test_existing_barrier_not_duplicated():
363369

364370
# Now run barrier placement - it should detect the existing barrier
365371
# and NOT insert duplicates
366-
add_shared_memory_barriers(trace)
372+
add_shared_memory_barriers(trace, canonicalize_output=False)
367373

368374
# Count barriers after
369375
barriers_after = count_barriers(graph)
@@ -570,7 +576,7 @@ def test_memory_counter_wait_barrier_prevents_redundant_barrier():
570576

571577
# Now run barrier placement - it should detect the existing MemoryCounterWaitBarrier
572578
# and NOT insert an additional SharedMemoryBarrier (since synchronization is already provided)
573-
add_shared_memory_barriers(trace)
579+
add_shared_memory_barriers(trace, canonicalize_output=False)
574580

575581
# Count barriers after
576582
mcw_barriers_after = count_memory_counter_wait_barriers(graph)

lit_tests/kernel/wave/barriers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_read_write_equal_sizes():
114114
expand_graph(trace, constraints)
115115
set_post_expansion_indices(trace, constraints)
116116
tweak_index(graph)
117-
add_shared_memory_barriers(trace)
117+
add_shared_memory_barriers(trace, canonicalize_output=False)
118118
print_trace(trace, False)
119119
# CHECK: %a
120120
# CHECK-NEXT: %c
@@ -205,7 +205,7 @@ def test_gemm():
205205
set_post_expansion_indices(trace, constraints)
206206
tweak_index(graph)
207207
hoist_loop_invariant_ops(trace, constraints)
208-
add_shared_memory_barriers(trace)
208+
add_shared_memory_barriers(trace, canonicalize_output=False)
209209
print_trace(trace, False)
210210
# Root graph:
211211
# CHECK: %a
@@ -330,8 +330,14 @@ def test_split_barriers():
330330
tweak_index(graph)
331331
hoist_loop_invariant_ops(trace, constraints)
332332
schedule_graph(trace, constraints, True, enable_scheduling)
333-
schedule_reordering(trace, constraints, enable_scheduling, use_global_to_shared)
334-
add_shared_memory_barriers(trace, target="gfx1201")
333+
schedule_reordering(
334+
trace,
335+
constraints,
336+
enable_scheduling,
337+
use_global_to_shared,
338+
canonicalize_output=False,
339+
)
340+
add_shared_memory_barriers(trace, target="gfx1201", canonicalize_output=False)
335341
print_trace(trace, False)
336342

337343
# Note: In pipelined loops, signal/wait pairs may have operations between them

lit_tests/kernel/wave/codegen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,7 @@ def test(
900900
ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
901901
},
902902
canonicalize=True,
903+
compile_to_mlir=True,
903904
)
904905
options = set_default_compile_config(options)
905906
test = wave_compile(options, test)

0 commit comments

Comments
 (0)