Skip to content

Commit 10f559c

Browse files
committed
remove assign tracking in token order
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent 9edceaf commit 10f559c

File tree

1 file changed

+5
-24
lines changed

1 file changed

+5
-24
lines changed

src/cuda/tile/_passes/token_order.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from cuda.tile._exception import Loc, TileInternalError
1414
from cuda.tile._ir.ir import Block, IRContext, Var, Operation
1515
from cuda.tile._ir.ops import (
16-
Assign, Break, Continue, EndBranch, IfElse,
16+
Break, Continue, EndBranch, IfElse,
1717
JoinTokens, LoadMemoryOperation, Loop, MakeToken,
1818
MemoryOperation, StoreMemoryOperation, TileAtomicCAS, TileAtomicCASTokenOrdered,
1919
TileAtomicRMW, TileAtomicRMWTokenOrdered, LoadPointer, LoadPointerTokenOrdered,
@@ -106,7 +106,6 @@ class VarInfo:
106106
@dataclass(frozen=True)
107107
class TokenOrderContext:
108108
alias_result: AliasResult
109-
var_info: VarInfo
110109
block_memory_effects: Dict[Block, MemoryEffects]
111110

112111

@@ -123,8 +122,7 @@ class TokenOrderContext:
123122
def token_order_pass(root_block: Block, alias_result: AliasResult):
124123
block_memory_effects = {}
125124
_get_block_memory_effects(root_block, alias_result, block_memory_effects)
126-
var_info = _get_var_info(root_block)
127-
context = TokenOrderContext(alias_result, var_info, block_memory_effects)
125+
context = TokenOrderContext(alias_result, block_memory_effects)
128126

129127
root_tok = _make_token_var(root_block.ctx, root_block.loc)
130128
token_map = defaultdict(lambda: root_tok)
@@ -169,21 +167,6 @@ def get_memory_effects(cur_op):
169167
block_memory_effects[block] = blk_mem_effects
170168

171169

172-
# TODO: Assign ops should be gone at this point. Need to verify this and remove this logic.
173-
def _get_var_info(root_block: Block) -> VarInfo:
174-
root_var = dict()
175-
176-
def traverse(block: Block):
177-
for op in block.operations:
178-
if isinstance(op, Assign):
179-
root_var[op.result_var.name] = root_var.get(op.value.name, op.value.name)
180-
for block in op.nested_blocks:
181-
traverse(block)
182-
183-
traverse(root_block)
184-
return VarInfo(root_var)
185-
186-
187170
def _to_token_order_in_block(block: Block,
188171
context: TokenOrderContext,
189172
token_map: Dict[TokenKey, Var],
@@ -523,7 +506,7 @@ def _get_parallel_stores(
523506
tile_store_candidates.add(mem_ops[0])
524507

525508
# Filter in stores that have non-overlapping indices
526-
res = _filter_by_store_index(loop_op, tile_store_candidates, context.var_info)
509+
res = _filter_by_store_index(loop_op, tile_store_candidates)
527510
return res
528511

529512

@@ -539,13 +522,11 @@ def _get_nested_mem_effects(
539522

540523

541524
def _filter_by_store_index(loop_op: Loop,
542-
tile_store_candidates: Set[Operation],
543-
var_info: VarInfo) -> Set[Operation]:
525+
tile_store_candidates: Set[Operation]) -> Set[Operation]:
544526

545527
def is_idx_injective(idx_var: Var) -> bool:
546-
root_idx_var = var_info.root_var.get(idx_var.name, idx_var.name)
547528
# TODO: allow more complex injective check: j = i * 2 + 3
548-
return loop_op.is_for_loop and root_idx_var == loop_op.induction_var.name
529+
return loop_op.is_for_loop and idx_var.name == loop_op.induction_var.name
549530

550531
return set(store_op for store_op in tile_store_candidates
551532
if _get_input_var(store_op).get_type().elements_disjoint

0 commit comments

Comments
 (0)