From 8a797b1e15aa6597c92cc7b18a79701b803849b8 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Mon, 30 Mar 2026 07:29:35 +0000 Subject: [PATCH 01/31] [FLYDSL]: if_dispatch dynamic process --- python/flydsl/compiler/ast_rewriter.py | 421 ++++++++++++++++++++++--- python/flydsl/expr/numeric.py | 34 ++ tests/unit/test_if_dispatch_paths.py | 186 +++++++++++ 3 files changed, 604 insertions(+), 37 deletions(-) create mode 100644 tests/unit/test_if_dispatch_paths.py diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index 5bf62901..fc27fc8e 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -11,6 +11,7 @@ from .._mlir import ir from .._mlir.dialects import arith, scf from ..expr import const_expr +from ..expr.numeric import _unwrap_value, _wrap_like from ..utils import env, log @@ -232,32 +233,155 @@ def _is_dynamic(cond): @staticmethod def _to_i1(cond): - if hasattr(cond, "ir_value"): - return cond.ir_value() - return cond + return _unwrap_value(cond) @staticmethod - def scf_if_dispatch(cond, then_fn, else_fn=None): + def _normalize_state_values(state_names, state_values): + state_names = tuple(state_names or ()) + state_values = tuple(state_values or ()) + if len(state_names) != len(state_values): + raise ValueError( + "state_names and state_values must have the same length, " + f"got {len(state_names)} and {len(state_values)}" + ) + return state_names, state_values + + @staticmethod + def _normalize_branch_result(branch_result, state_names, state_map, branch_label): + if not state_names: + return [] + + if isinstance(branch_result, dict): + result_map = dict(branch_result) + elif branch_result is None: + result_map = {} + elif len(state_names) == 1 and not isinstance(branch_result, (list, tuple)): + result_map = {state_names[0]: branch_result} + elif isinstance(branch_result, (list, tuple)) and len(branch_result) == len(state_names): + result_map = dict(zip(state_names, branch_result)) + else: + raise TypeError( + f"{branch_label} must return dict/tuple/list for stateful dispatch; got {type(branch_result).__name__}" + ) + + values = [] + for name in state_names: + if name in result_map: + values.append(result_map[name]) + elif name in state_map: + values.append(state_map[name]) + else: + raise NameError( + f"variable '{name}' is not available before if/else and is not assigned in {branch_label}" + ) + return values + + @staticmethod + def _unwrap_mlir_values(values, state_names, branch_label): + raw_values = [] + for name, value in zip(state_names, values): + raw = _unwrap_value(value) + if not isinstance(raw, ir.Value): + raise TypeError( + f"if/else variable '{name}' in {branch_label} is {type(raw).__name__}, " + "not an MLIR Value. Only MLIR Values can be yielded from dynamic if/else branches." + ) + raw_values.append(raw) + return raw_values + + @staticmethod + def _pack_dispatch_results(results, state_values): + if not results: + return None + wrapped = [_wrap_like(v, exemplar) for v, exemplar in zip(results, state_values)] + if len(wrapped) == 1: + return wrapped[0] + return tuple(wrapped) + + @staticmethod + def scf_if_dispatch(cond, then_fn, else_fn=None, *, state_names=(), state_values=()): + state_names, state_values = ReplaceIfWithDispatch._normalize_state_values(state_names, state_values) + state_map = {name: value for name, value in zip(state_names, state_values)} + if not ReplaceIfWithDispatch._is_dynamic(cond): - # compile-time evaluation - if cond: + taken = then_fn if cond else else_fn + if taken is None: + return None + result = taken(*state_values) + if not state_names: + return None + values = ReplaceIfWithDispatch._normalize_branch_result( + result, state_names, state_map, "selected branch" + ) + if len(values) == 1: + return values[0] + return tuple(values) + + cond_i1 = ReplaceIfWithDispatch._to_i1(cond) + if not isinstance(cond_i1, ir.Value): + raise TypeError(f"dynamic if condition must lower to ir.Value, got {type(cond_i1).__name__}") + + if not state_names: + has_else = else_fn is not None + if_op = scf.IfOp(cond_i1, [], has_else=has_else, loc=ir.Location.unknown()) + with ir.InsertionPoint(if_op.regions[0].blocks[0]): then_fn() - elif else_fn is not None: - else_fn() - return + scf.YieldOp([]) + if has_else: + if len(if_op.regions[1].blocks) == 0: + if_op.regions[1].blocks.append(*[]) + with ir.InsertionPoint(if_op.regions[1].blocks[0]): + else_fn() + scf.YieldOp([]) + return None + + if else_fn is None: + else_fn = lambda *args: {} + + state_raw = [] + for name, value in zip(state_names, state_values): + raw = _unwrap_value(value) + if not isinstance(raw, ir.Value): + raise TypeError( + f"state variable '{name}' is {type(raw).__name__}, not an MLIR Value; " + "stateful dynamic if requires MLIR-backed values." + ) + state_raw.append(raw) + + result_types = [v.type for v in state_raw] + if_op = scf.IfOp(cond_i1, result_types, has_else=True, loc=ir.Location.unknown()) - has_else = else_fn is not None - loc = ir.Location.unknown() - if_op = scf.IfOp(ReplaceIfWithDispatch._to_i1(cond), [], has_else=has_else, loc=loc) with ir.InsertionPoint(if_op.regions[0].blocks[0]): - then_fn() - scf.YieldOp([]) - if has_else: - if len(if_op.regions[1].blocks) == 0: - if_op.regions[1].blocks.append(*[]) - with ir.InsertionPoint(if_op.regions[1].blocks[0]): - else_fn() - scf.YieldOp([]) + then_result = then_fn(*state_values) + then_values = ReplaceIfWithDispatch._normalize_branch_result( + then_result, state_names, state_map, "then-branch" + ) + then_raw = ReplaceIfWithDispatch._unwrap_mlir_values(then_values, state_names, "then-branch") + for name, expect_ty, got in zip(state_names, result_types, then_raw): + if got.type != expect_ty: + raise TypeError( + f"if/else variable '{name}' type mismatch in then-branch: " + f"expected {expect_ty}, got {got.type}" + ) + scf.YieldOp(then_raw) + + if len(if_op.regions[1].blocks) == 0: + if_op.regions[1].blocks.append(*[]) + with ir.InsertionPoint(if_op.regions[1].blocks[0]): + else_result = else_fn(*state_values) + else_values = ReplaceIfWithDispatch._normalize_branch_result( + else_result, state_names, state_map, "else-branch" + ) + else_raw = ReplaceIfWithDispatch._unwrap_mlir_values(else_values, state_names, "else-branch") + for name, expect_ty, got in zip(state_names, result_types, else_raw): + if got.type != expect_ty: + raise TypeError( + f"if/else variable '{name}' type mismatch in else-branch: " + f"expected {expect_ty}, got {got.type}" + ) + scf.YieldOp(else_raw) + + return ReplaceIfWithDispatch._pack_dispatch_results(list(if_op.results), state_values) @classmethod def rewrite_globals(cls): @@ -274,17 +398,171 @@ def rewrite_globals(cls): def _could_be_dynamic(test_node): """Check if an if-condition AST could produce an MLIR Value at runtime. - Calls to RewriteBoolOps helpers (dsl_not_, dsl_and_, dsl_or_) and - Python builtins are NOT considered dynamic — they just wrap Python-level - boolean logic. Only calls to user/MLIR functions can produce Values. + Layer-by-layer recursive check: + 1) classify current node if possible, + 2) otherwise recurse into direct children, + 3) use a conservative fallback for unresolved expression nodes. """ - for child in ast.walk(test_node): - if isinstance(child, ast.Call): - func = child.func + def _is_literal_expr(node): + if isinstance(node, ast.Constant): + return True + if isinstance(node, (ast.Tuple, ast.List, ast.Set)): + return all(_is_literal_expr(e) for e in node.elts) + if isinstance(node, ast.Dict): + return all( + (k is None or _is_literal_expr(k)) and _is_literal_expr(v) + for k, v in zip(node.keys, node.values) + ) + return False + + def _classify_current(node): + if _is_literal_expr(node): + return False + if isinstance(node, ast.Name): + # Plain names can be static symbols (constexpr params, local bools, etc.). + return None + if isinstance(node, ast.Call): + func = node.func if isinstance(func, ast.Name) and func.id in ReplaceIfWithDispatch._REWRITE_HELPER_NAMES: - continue + return None return True - return False + return None + + def _visit(node): + current = _classify_current(node) + if current is True: + return True + if current is False: + return False + + for child in ast.iter_child_nodes(node): + if _visit(child): + return True + + # If this expression cannot be proven static from itself or children, + # keep dynamic handling conservative. + if isinstance(node, ast.expr) and not isinstance(node, ast.Name): + return True + return False + + return _visit(test_node) + + @staticmethod + def _collect_assigned_vars(stmts): + assigned = [] + + def add_name(name): + if isinstance(name, str) and name not in assigned: + assigned.append(name) + + class AssignCollector(ast.NodeVisitor): + def _collect_target(self, target): + if isinstance(target, ast.Name): + add_name(target.id) + elif isinstance(target, (ast.Tuple, ast.List)): + for elt in target.elts: + self._collect_target(elt) + elif isinstance(target, ast.Starred): + self._collect_target(target.value) + elif isinstance(target, ast.Subscript): + self._collect_target(target.value) + elif isinstance(target, ast.Attribute): + if isinstance(target.value, ast.Name): + add_name(target.value.id) + + def visit_Assign(self, node): + for target in node.targets: + self._collect_target(target) + self.visit(node.value) + + def visit_AugAssign(self, node): + self._collect_target(node.target) + self.visit(node.value) + + def visit_AnnAssign(self, node): + self._collect_target(node.target) + if node.value is not None: + self.visit(node.value) + + def visit_For(self, node): + self._collect_target(node.target) + self.generic_visit(node) + + def visit_AsyncFor(self, node): + self._collect_target(node.target) + self.generic_visit(node) + + def visit_With(self, node): + for item in node.items: + if item.optional_vars is not None: + self._collect_target(item.optional_vars) + self.generic_visit(node) + + def visit_AsyncWith(self, node): + for item in node.items: + if item.optional_vars is not None: + self._collect_target(item.optional_vars) + self.generic_visit(node) + + def visit_ExceptHandler(self, node): + if node.name: + add_name(node.name) + self.generic_visit(node) + + def visit_NamedExpr(self, node): + self._collect_target(node.target) + self.visit(node.value) + + def visit_MatchAs(self, node): + if node.name: + add_name(node.name) + self.generic_visit(node) + + def visit_MatchStar(self, node): + if node.name: + add_name(node.name) + self.generic_visit(node) + + def visit_MatchMapping(self, node): + if node.rest: + add_name(node.rest) + self.generic_visit(node) + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + # Only treat typical mutating methods as writes. + if node.func.attr in { + "append", + "extend", + "insert", + "pop", + "remove", + "clear", + "sort", + "reverse", + "add", + "discard", + "update", + "setdefault", + }: + add_name(node.func.value.id) + self.generic_visit(node) + + collector = AssignCollector() + collector.visit(ast.Module(body=stmts, type_ignores=[])) + return assigned + + @staticmethod + def _state_value_expr(name): + return ast.Call( + func=ast.Attribute( + value=ast.Call(func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]), + attr="get", + ctx=ast.Load(), + ), + args=[ast.Constant(value=name), ast.Constant(value=None)], + keywords=[], + ) def visit_If(self, node: ast.If) -> List[ast.AST]: if _is_constexpr(node.test): @@ -299,10 +577,25 @@ def visit_If(self, node: ast.If) -> List[ast.AST]: ReplaceIfWithDispatch._counter += 1 then_name = f"__then_{uid}" + then_vars = self._collect_assigned_vars(node.body) + else_vars = self._collect_assigned_vars(node.orelse) + state_names = list(then_vars) + for var in else_vars: + if var not in state_names: + state_names.append(var) + + fn_args = [ast.arg(arg=v, annotation=None) for v in state_names] + def _state_return_node(): + return ast.Return( + value=ast.Dict( + keys=[ast.Constant(value=v) for v in state_names], + values=[ast.Name(id=v, ctx=ast.Load()) for v in state_names], + ) + ) then_func = ast.FunctionDef( name=then_name, - args=ast.arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), - body=node.body, + args=ast.arguments(posonlyargs=[], args=fn_args, kwonlyargs=[], kw_defaults=[], defaults=[]), + body=list(node.body) + ([_state_return_node()] if state_names else []), decorator_list=[], type_params=[], ) @@ -311,14 +604,58 @@ def visit_If(self, node: ast.If) -> List[ast.AST]: then_func = ast.fix_missing_locations(then_func) dispatch_args = [node.test, ast.Name(then_name, ctx=ast.Load())] + dispatch_keywords = [] + if state_names: + dispatch_keywords.extend( + [ + ast.keyword( + arg="state_names", + value=ast.Tuple(elts=[ast.Constant(value=v) for v in state_names], ctx=ast.Load()), + ), + ast.keyword( + arg="state_values", + value=ast.Tuple( + elts=[self._state_value_expr(v) for v in state_names], + ctx=ast.Load(), + ), + ), + ] + ) result = [then_func] + else_name = None if node.orelse: else_name = f"__else_{uid}" else_func = ast.FunctionDef( name=else_name, - args=ast.arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), - body=node.orelse, + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg=v, annotation=None) for v in state_names], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=list(node.orelse) + ([_state_return_node()] if state_names else []), + decorator_list=[], + type_params=[], + ) + setattr(else_func, _ASTREWRITE_MARKER, True) + else_func = ast.copy_location(else_func, node) + else_func = ast.fix_missing_locations(else_func) + dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) + result.append(else_func) + elif state_names: + else_name = f"__else_{uid}" + else_func = ast.FunctionDef( + name=else_name, + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg=v, annotation=None) for v in state_names], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[_state_return_node()], decorator_list=[], type_params=[], ) @@ -328,12 +665,22 @@ def visit_If(self, node: ast.If) -> List[ast.AST]: dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) result.append(else_func) - dispatch_call = ast.Expr( - value=ast.Call(func=ast.Name("scf_if_dispatch", ctx=ast.Load()), args=dispatch_args, keywords=[]) + dispatch_value = ast.Call( + func=ast.Name("scf_if_dispatch", ctx=ast.Load()), + args=dispatch_args, + keywords=dispatch_keywords, ) - dispatch_call = ast.copy_location(dispatch_call, node) - dispatch_call = ast.fix_missing_locations(dispatch_call) - result.append(dispatch_call) + if state_names: + if len(state_names) == 1: + target = ast.Name(id=state_names[0], ctx=ast.Store()) + else: + target = ast.Tuple(elts=[ast.Name(id=v, ctx=ast.Store()) for v in state_names], ctx=ast.Store()) + dispatch_stmt = ast.Assign(targets=[target], value=dispatch_value) + else: + dispatch_stmt = ast.Expr(value=dispatch_value) + dispatch_stmt = ast.copy_location(dispatch_stmt, node) + dispatch_stmt = ast.fix_missing_locations(dispatch_stmt) + result.append(dispatch_stmt) return result diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index e33edddf..44742217 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -213,6 +213,40 @@ def _extract_arith(val, signed): return v.with_signedness(signed) if isinstance(v, ArithValue) else v +def _unwrap_value(value): + """Convert FlyDSL wrappers to raw MLIR values when possible.""" + if isinstance(value, ir.Value): + return value + if hasattr(value, "__fly_values__"): + values = value.__fly_values__() + if len(values) == 1: + return values[0] + if hasattr(value, "ir_value"): + return value.ir_value() + return value + + +def _wrap_like(value, exemplar=None): + """Wrap an MLIR value back to a FlyDSL wrapper when possible.""" + if not isinstance(value, ir.Value): + return value + + if exemplar is not None: + if isinstance(exemplar, Numeric): + return type(exemplar)(value) + ctor = getattr(type(exemplar), "__fly_construct__", None) + if ctor is not None: + try: + return ctor([value]) + except Exception: + pass + + try: + return Numeric.from_ir_type(value.type)(value) + except Exception: + return value + + def _make_binop(op, promote=True, widen_bool=False, swap=False): """Create a binary-operator closure for Numeric subclasses.""" def _apply(lhs, rhs, *, loc=None, ip=None): diff --git a/tests/unit/test_if_dispatch_paths.py b/tests/unit/test_if_dispatch_paths.py new file mode 100644 index 00000000..4ec396ff --- /dev/null +++ b/tests/unit/test_if_dispatch_paths.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +import ast + +import pytest + +from flydsl._mlir.ir import Context, FunctionType, InsertionPoint, IntegerType, Location, Module +from flydsl._mlir.dialects import arith, func +from flydsl.compiler.ast_rewriter import ASTRewriter, ReplaceIfWithDispatch +from flydsl.expr.numeric import Int32 + + +def test_collect_assigned_vars_supports_tuple_and_augassign(): + code = """ +a, (b, c) = foo() +d += 1 +""" + stmts = ast.parse(code).body + assigned = ReplaceIfWithDispatch._collect_assigned_vars(stmts) + assert assigned == ["a", "b", "c", "d"] + + +def test_collect_assigned_vars_supports_annassign_walrus_with_except_for(): + code = """ +x: int = 1 +for i in range(4): + y = i +with ctx() as w: + z = w +try: + pass +except Exception as e: + err = e +if (n := foo()): + out = n +""" + stmts = ast.parse(code).body + assigned = ReplaceIfWithDispatch._collect_assigned_vars(stmts) + assert assigned == ["x", "i", "y", "w", "z", "e", "err", "n", "out"] + + +def test_scf_if_dispatch_static_with_states_no_ifop(): + with Context(), Location.unknown(): + module = Module.create() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + f = func.FuncOp("test_static_states", FunctionType.get([], [i32])) + entry = f.add_entry_block() + with InsertionPoint(entry): + x = Int32(1) + + def then_fn(x): + return {"x": Int32(42)} + + def else_fn(x): + return {"x": Int32(99)} + + out = ReplaceIfWithDispatch.scf_if_dispatch( + True, + then_fn, + else_fn, + state_names=("x",), + state_values=(x,), + ) + func.ReturnOp([out.ir_value()]) + + assert module.operation.verify() + assert "scf.if" not in str(module) + + +def test_scf_if_dispatch_dynamic_with_states_build_ifop(): + with Context(), Location.unknown(): + module = Module.create() + i1 = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + f = func.FuncOp("test_dynamic_states", FunctionType.get([i1], [i32])) + entry = f.add_entry_block() + with InsertionPoint(entry): + cond = entry.arguments[0] + x = Int32(arith.ConstantOp(i32, 1).result) + + def then_fn(x): + return {"x": Int32(arith.ConstantOp(i32, 42).result)} + + def else_fn(x): + return {"x": Int32(arith.ConstantOp(i32, 99).result)} + + out = ReplaceIfWithDispatch.scf_if_dispatch( + cond, + then_fn, + else_fn, + state_names=("x",), + state_values=(x,), + ) + assert isinstance(out, Int32) + func.ReturnOp([out.ir_value()]) + + assert module.operation.verify() + ir_text = str(module) + assert "scf.if" in ir_text + assert "-> (i32)" in ir_text + + +def test_scf_if_dispatch_dynamic_type_mismatch_has_clear_error(): + with Context(), Location.unknown(): + module = Module.create() + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + i1 = IntegerType.get_signless(1) + with InsertionPoint(module.body): + f = func.FuncOp("test_dynamic_type_mismatch", FunctionType.get([i1], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + cond = entry.arguments[0] + x = Int32(arith.ConstantOp(i32, 1).result) + + def then_fn(x): + return {"x": arith.ConstantOp(i32, 2).result} + + def else_fn(x): + return {"x": arith.ConstantOp(i64, 3).result} + + with pytest.raises(TypeError, match="type mismatch|mismatched types"): + ReplaceIfWithDispatch.scf_if_dispatch( + cond, + then_fn, + else_fn, + state_names=("x",), + state_values=(x,), + ) + + +def test_scf_if_dispatch_dynamic_non_mlir_value_has_clear_error(): + with Context(), Location.unknown(): + module = Module.create() + i32 = IntegerType.get_signless(32) + i1 = IntegerType.get_signless(1) + with InsertionPoint(module.body): + f = func.FuncOp("test_dynamic_non_mlir", FunctionType.get([i1], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + cond = entry.arguments[0] + x = Int32(arith.ConstantOp(i32, 1).result) + + def then_fn(x): + return {"x": 7} + + def else_fn(x): + return {"x": arith.ConstantOp(i32, 3).result} + + with pytest.raises(TypeError, match="not an MLIR Value"): + ReplaceIfWithDispatch.scf_if_dispatch( + cond, + then_fn, + else_fn, + state_names=("x",), + state_values=(x,), + ) + + +def test_ast_rewrite_keeps_semantics_for_static_bool(): + called = {"n": 0} + + def sample(flag): + x = 1 + if flag: + x = 2 + else: + x = 3 + return x + + ASTRewriter.transform(sample) + original_dispatch = sample.__globals__["scf_if_dispatch"] + + def traced_dispatch(*args, **kwargs): + called["n"] += 1 + return original_dispatch(*args, **kwargs) + + sample.__globals__["scf_if_dispatch"] = traced_dispatch + assert sample(True) == 2 + assert sample(False) == 3 + assert called["n"] in (0, 2) From 5a2d2aae257ca2b4b77f8f149a67ce59b961b6e0 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Mon, 30 Mar 2026 10:36:34 +0000 Subject: [PATCH 02/31] [FLYDSL]: Fix for local variable issues in if/else statements --- python/flydsl/compiler/ast_rewriter.py | 307 ++++++++++++++++++++----- tests/unit/test_if_dispatch_paths.py | 24 ++ 2 files changed, 272 insertions(+), 59 deletions(-) diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index fc27fc8e..b3cf9709 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -223,6 +223,10 @@ def visit_UnaryOp(self, node: ast.UnaryOp): class ReplaceIfWithDispatch(Transformer): _counter = 0 + def __init__(self, context, first_lineno): + super().__init__(context=context, first_lineno=first_lineno) + self._visible_name_stack = [] + @staticmethod def _is_dynamic(cond): if isinstance(cond, ir.Value): @@ -236,15 +240,15 @@ def _to_i1(cond): return _unwrap_value(cond) @staticmethod - def _normalize_state_values(state_names, state_values): - state_names = tuple(state_names or ()) - state_values = tuple(state_values or ()) - if len(state_names) != len(state_values): + def _normalize_named_values(names, values, names_label="names", values_label="values"): + names = tuple(names or ()) + values = tuple(values or ()) + if len(names) != len(values): raise ValueError( - "state_names and state_values must have the same length, " - f"got {len(state_names)} and {len(state_values)}" + f"{names_label} and {values_label} must have the same length, " + f"got {len(names)} and {len(values)}" ) - return state_names, state_values + return names, values @staticmethod def _normalize_branch_result(branch_result, state_names, state_map, branch_label): @@ -299,47 +303,116 @@ def _pack_dispatch_results(results, state_values): return tuple(wrapped) @staticmethod - def scf_if_dispatch(cond, then_fn, else_fn=None, *, state_names=(), state_values=()): - state_names, state_values = ReplaceIfWithDispatch._normalize_state_values(state_names, state_values) - state_map = {name: value for name, value in zip(state_names, state_values)} + def _collect_result_dict(result_names, local_vars): + return {name: local_vars[name] for name in result_names} + + @staticmethod + def _pack_named_values(names, values): + if not names: + return None + if len(names) == 1: + return values[0] + return tuple(values) + + @staticmethod + def _merge_partial_results(base_names, base_values, part_names, part_values): + merged = {name: value for name, value in zip(base_names, base_values)} + merged.update({name: value for name, value in zip(part_names, part_values)}) + return [merged[name] for name in base_names] + + @staticmethod + def _call_branch(fn, result_names, state_values): + sig = inspect.signature(fn) + params = list(sig.parameters.values()) + pos_params = [ + p + for p in params + if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + has_varargs = any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in params) + if has_varargs or len(pos_params) >= len(state_values) + 1: + return fn(result_names, *state_values) + return fn(*state_values) + + @staticmethod + def scf_if_dispatch( + cond, + then_fn, + else_fn=None, + *, + arg_names=(), + arg_values=(), + result_names=(), + result_values=(), + state_names=(), + state_values=(), + auto_else=False, + ): + # Backward compatibility: old call-sites pass state_* only. + if not arg_names and state_names: + arg_names = state_names + if not arg_values and state_values: + arg_values = state_values + if not result_names: + result_names = arg_names + if not result_values and result_names: + arg_map = {name: value for name, value in zip(arg_names, arg_values)} + result_values = tuple(arg_map.get(name, None) for name in result_names) + + arg_names, arg_values = ReplaceIfWithDispatch._normalize_named_values( + arg_names, arg_values, "arg_names", "arg_values" + ) + result_names, result_values = ReplaceIfWithDispatch._normalize_named_values( + result_names, result_values, "result_names", "result_values" + ) + # Only variables with an incoming value can be scf.if results/yields. + effective_result_pairs = [ + (name, value) + for name, value in zip(result_names, result_values) + if _unwrap_value(value) is not None + ] + effective_result_names = tuple(name for name, _ in effective_result_pairs) + effective_result_values = tuple(value for _, value in effective_result_pairs) + effective_result_map = {name: value for name, value in effective_result_pairs} if not ReplaceIfWithDispatch._is_dynamic(cond): taken = then_fn if cond else else_fn if taken is None: - return None - result = taken(*state_values) - if not state_names: - return None - values = ReplaceIfWithDispatch._normalize_branch_result( - result, state_names, state_map, "selected branch" + return ReplaceIfWithDispatch._pack_named_values(result_names, result_values) + result = ReplaceIfWithDispatch._call_branch(taken, effective_result_names, arg_values) + if not effective_result_names: + return ReplaceIfWithDispatch._pack_named_values(result_names, result_values) + partial_values = ReplaceIfWithDispatch._normalize_branch_result( + result, effective_result_names, effective_result_map, "selected branch" ) - if len(values) == 1: - return values[0] - return tuple(values) + merged_values = ReplaceIfWithDispatch._merge_partial_results( + result_names, result_values, effective_result_names, partial_values + ) + return ReplaceIfWithDispatch._pack_named_values(result_names, merged_values) cond_i1 = ReplaceIfWithDispatch._to_i1(cond) if not isinstance(cond_i1, ir.Value): raise TypeError(f"dynamic if condition must lower to ir.Value, got {type(cond_i1).__name__}") - if not state_names: + if not effective_result_names: has_else = else_fn is not None if_op = scf.IfOp(cond_i1, [], has_else=has_else, loc=ir.Location.unknown()) with ir.InsertionPoint(if_op.regions[0].blocks[0]): - then_fn() + ReplaceIfWithDispatch._call_branch(then_fn, effective_result_names, arg_values) scf.YieldOp([]) if has_else: if len(if_op.regions[1].blocks) == 0: if_op.regions[1].blocks.append(*[]) with ir.InsertionPoint(if_op.regions[1].blocks[0]): - else_fn() + ReplaceIfWithDispatch._call_branch(else_fn, effective_result_names, arg_values) scf.YieldOp([]) - return None + return ReplaceIfWithDispatch._pack_named_values(result_names, result_values) if else_fn is None: else_fn = lambda *args: {} state_raw = [] - for name, value in zip(state_names, state_values): + for name, value in zip(effective_result_names, effective_result_values): raw = _unwrap_value(value) if not isinstance(raw, ir.Value): raise TypeError( @@ -352,12 +425,12 @@ def scf_if_dispatch(cond, then_fn, else_fn=None, *, state_names=(), state_values if_op = scf.IfOp(cond_i1, result_types, has_else=True, loc=ir.Location.unknown()) with ir.InsertionPoint(if_op.regions[0].blocks[0]): - then_result = then_fn(*state_values) + then_result = ReplaceIfWithDispatch._call_branch(then_fn, effective_result_names, arg_values) then_values = ReplaceIfWithDispatch._normalize_branch_result( - then_result, state_names, state_map, "then-branch" + then_result, effective_result_names, effective_result_map, "then-branch" ) - then_raw = ReplaceIfWithDispatch._unwrap_mlir_values(then_values, state_names, "then-branch") - for name, expect_ty, got in zip(state_names, result_types, then_raw): + then_raw = ReplaceIfWithDispatch._unwrap_mlir_values(then_values, effective_result_names, "then-branch") + for name, expect_ty, got in zip(effective_result_names, result_types, then_raw): if got.type != expect_ty: raise TypeError( f"if/else variable '{name}' type mismatch in then-branch: " @@ -368,12 +441,12 @@ def scf_if_dispatch(cond, then_fn, else_fn=None, *, state_names=(), state_values if len(if_op.regions[1].blocks) == 0: if_op.regions[1].blocks.append(*[]) with ir.InsertionPoint(if_op.regions[1].blocks[0]): - else_result = else_fn(*state_values) + else_result = ReplaceIfWithDispatch._call_branch(else_fn, effective_result_names, arg_values) else_values = ReplaceIfWithDispatch._normalize_branch_result( - else_result, state_names, state_map, "else-branch" + else_result, effective_result_names, effective_result_map, "else-branch" ) - else_raw = ReplaceIfWithDispatch._unwrap_mlir_values(else_values, state_names, "else-branch") - for name, expect_ty, got in zip(state_names, result_types, else_raw): + else_raw = ReplaceIfWithDispatch._unwrap_mlir_values(else_values, effective_result_names, "else-branch") + for name, expect_ty, got in zip(effective_result_names, result_types, else_raw): if got.type != expect_ty: raise TypeError( f"if/else variable '{name}' type mismatch in else-branch: " @@ -381,13 +454,24 @@ def scf_if_dispatch(cond, then_fn, else_fn=None, *, state_names=(), state_values ) scf.YieldOp(else_raw) - return ReplaceIfWithDispatch._pack_dispatch_results(list(if_op.results), state_values) + partial_wrapped = ReplaceIfWithDispatch._pack_dispatch_results( + list(if_op.results), effective_result_values + ) + if len(effective_result_names) == 1: + partial_values = [partial_wrapped] + else: + partial_values = list(partial_wrapped) + merged_values = ReplaceIfWithDispatch._merge_partial_results( + result_names, result_values, effective_result_names, partial_values + ) + return ReplaceIfWithDispatch._pack_named_values(result_names, merged_values) @classmethod def rewrite_globals(cls): return { "const_expr": const_expr, "scf_if_dispatch": cls.scf_if_dispatch, + "scf_if_collect_results": cls._collect_result_dict, } _REWRITE_HELPER_NAMES = {"dsl_not_", "dsl_and_", "dsl_or_", @@ -401,7 +485,7 @@ def _could_be_dynamic(test_node): Layer-by-layer recursive check: 1) classify current node if possible, 2) otherwise recurse into direct children, - 3) use a conservative fallback for unresolved expression nodes. + 3) unresolved nodes default to static (no forced rewrite). """ def _is_literal_expr(node): if isinstance(node, ast.Constant): @@ -439,10 +523,8 @@ def _visit(node): if _visit(child): return True - # If this expression cannot be proven static from itself or children, - # keep dynamic handling conservative. - if isinstance(node, ast.expr) and not isinstance(node, ast.Name): - return True + # If this expression cannot be proven dynamic from itself or children, + # keep it static to avoid over-rewriting unrelated Python control flow. return False return _visit(test_node) @@ -564,6 +646,82 @@ def _state_value_expr(name): keywords=[], ) + @staticmethod + def _collect_defined_names(stmt): + defined = [] + + def add_name(name): + if isinstance(name, str) and name not in defined: + defined.append(name) + + class DefCollector(ast.NodeVisitor): + def _collect_target(self, target): + if isinstance(target, ast.Name): + add_name(target.id) + elif isinstance(target, (ast.Tuple, ast.List)): + for elt in target.elts: + self._collect_target(elt) + elif isinstance(target, ast.Starred): + self._collect_target(target.value) + + def visit_Assign(self, node): + for target in node.targets: + self._collect_target(target) + + def visit_AugAssign(self, node): + self._collect_target(node.target) + + def visit_AnnAssign(self, node): + self._collect_target(node.target) + + def visit_For(self, node): + self._collect_target(node.target) + + def visit_AsyncFor(self, node): + self._collect_target(node.target) + + def visit_With(self, node): + for item in node.items: + if item.optional_vars is not None: + self._collect_target(item.optional_vars) + + def visit_AsyncWith(self, node): + for item in node.items: + if item.optional_vars is not None: + self._collect_target(item.optional_vars) + + def visit_ExceptHandler(self, node): + if node.name: + add_name(node.name) + + def visit_NamedExpr(self, node): + self._collect_target(node.target) + + DefCollector().visit(stmt) + return defined + + def visit_FunctionDef(self, node: ast.FunctionDef): + if getattr(node, _ASTREWRITE_MARKER, False): + return node + + visible = {arg.arg for arg in node.args.posonlyargs + node.args.args + node.args.kwonlyargs} + self._visible_name_stack.append(visible) + new_body = [] + try: + for stmt in node.body: + transformed = self.visit(stmt) + if isinstance(transformed, list): + new_body.extend(transformed) + else: + new_body.append(transformed) + for name in self._collect_defined_names(stmt): + visible.add(name) + finally: + self._visible_name_stack.pop() + + node.body = new_body + return node + def visit_If(self, node: ast.If) -> List[ast.AST]: if _is_constexpr(node.test): node.test = _unwrap_constexpr(node.test) @@ -579,23 +737,33 @@ def visit_If(self, node: ast.If) -> List[ast.AST]: then_name = f"__then_{uid}" then_vars = self._collect_assigned_vars(node.body) else_vars = self._collect_assigned_vars(node.orelse) - state_names = list(then_vars) + arg_names = list(then_vars) for var in else_vars: - if var not in state_names: - state_names.append(var) + if var not in arg_names: + arg_names.append(var) + + # result_names: only carry variables visible before this if. + result_names = list(arg_names) + if self._visible_name_stack: + visible = self._visible_name_stack[-1] + result_names = [name for name in result_names if name in visible] - fn_args = [ast.arg(arg=v, annotation=None) for v in state_names] + fn_args = [ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in arg_names] def _state_return_node(): return ast.Return( - value=ast.Dict( - keys=[ast.Constant(value=v) for v in state_names], - values=[ast.Name(id=v, ctx=ast.Load()) for v in state_names], + value=ast.Call( + func=ast.Name(id="scf_if_collect_results", ctx=ast.Load()), + args=[ + ast.Name(id="__ret_names", ctx=ast.Load()), + ast.Call(func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]), + ], + keywords=[], ) ) then_func = ast.FunctionDef( name=then_name, args=ast.arguments(posonlyargs=[], args=fn_args, kwonlyargs=[], kw_defaults=[], defaults=[]), - body=list(node.body) + ([_state_return_node()] if state_names else []), + body=list(node.body) + ([_state_return_node()] if result_names else []), decorator_list=[], type_params=[], ) @@ -605,17 +773,33 @@ def _state_return_node(): dispatch_args = [node.test, ast.Name(then_name, ctx=ast.Load())] dispatch_keywords = [] - if state_names: + if arg_names: + dispatch_keywords.extend( + [ + ast.keyword( + arg="arg_names", + value=ast.Tuple(elts=[ast.Constant(value=v) for v in arg_names], ctx=ast.Load()), + ), + ast.keyword( + arg="arg_values", + value=ast.Tuple( + elts=[self._state_value_expr(v) for v in arg_names], + ctx=ast.Load(), + ), + ), + ] + ) + if result_names: dispatch_keywords.extend( [ ast.keyword( - arg="state_names", - value=ast.Tuple(elts=[ast.Constant(value=v) for v in state_names], ctx=ast.Load()), + arg="result_names", + value=ast.Tuple(elts=[ast.Constant(value=v) for v in result_names], ctx=ast.Load()), ), ast.keyword( - arg="state_values", + arg="result_values", value=ast.Tuple( - elts=[self._state_value_expr(v) for v in state_names], + elts=[self._state_value_expr(v) for v in result_names], ctx=ast.Load(), ), ), @@ -624,18 +808,19 @@ def _state_return_node(): result = [then_func] else_name = None + synthesized_else = False if node.orelse: else_name = f"__else_{uid}" else_func = ast.FunctionDef( name=else_name, args=ast.arguments( posonlyargs=[], - args=[ast.arg(arg=v, annotation=None) for v in state_names], + args=[ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in arg_names], kwonlyargs=[], kw_defaults=[], defaults=[], ), - body=list(node.orelse) + ([_state_return_node()] if state_names else []), + body=list(node.orelse) + ([_state_return_node()] if result_names else []), decorator_list=[], type_params=[], ) @@ -644,13 +829,14 @@ def _state_return_node(): else_func = ast.fix_missing_locations(else_func) dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) result.append(else_func) - elif state_names: + elif result_names: else_name = f"__else_{uid}" + synthesized_else = True else_func = ast.FunctionDef( name=else_name, args=ast.arguments( posonlyargs=[], - args=[ast.arg(arg=v, annotation=None) for v in state_names], + args=[ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in arg_names], kwonlyargs=[], kw_defaults=[], defaults=[], @@ -665,16 +851,19 @@ def _state_return_node(): dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) result.append(else_func) + if synthesized_else: + dispatch_keywords.append(ast.keyword(arg="auto_else", value=ast.Constant(value=True))) + dispatch_value = ast.Call( func=ast.Name("scf_if_dispatch", ctx=ast.Load()), args=dispatch_args, keywords=dispatch_keywords, ) - if state_names: - if len(state_names) == 1: - target = ast.Name(id=state_names[0], ctx=ast.Store()) + if result_names and else_name is not None: + if len(result_names) == 1: + target = ast.Name(id=result_names[0], ctx=ast.Store()) else: - target = ast.Tuple(elts=[ast.Name(id=v, ctx=ast.Store()) for v in state_names], ctx=ast.Store()) + target = ast.Tuple(elts=[ast.Name(id=v, ctx=ast.Store()) for v in result_names], ctx=ast.Store()) dispatch_stmt = ast.Assign(targets=[target], value=dispatch_value) else: dispatch_stmt = ast.Expr(value=dispatch_value) diff --git a/tests/unit/test_if_dispatch_paths.py b/tests/unit/test_if_dispatch_paths.py index 4ec396ff..675395f4 100644 --- a/tests/unit/test_if_dispatch_paths.py +++ b/tests/unit/test_if_dispatch_paths.py @@ -184,3 +184,27 @@ def traced_dispatch(*args, **kwargs): assert sample(True) == 2 assert sample(False) == 3 assert called["n"] in (0, 2) + + +def test_ast_rewrite_does_not_rewrite_static_string_compare(): + called = {"n": 0} + + def sample(dtype_str): + out = 0 + if dtype_str == "f32": + out = 1 + else: + out = 2 + return out + + ASTRewriter.transform(sample) + original_dispatch = sample.__globals__["scf_if_dispatch"] + + def traced_dispatch(*args, **kwargs): + called["n"] += 1 + return original_dispatch(*args, **kwargs) + + sample.__globals__["scf_if_dispatch"] = traced_dispatch + assert sample("f32") == 1 + assert sample("bf16") == 2 + assert called["n"] == 0 From 4e075dcb90d6b376aa664e315f898672b5f51168 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Thu, 2 Apr 2026 11:49:21 +0000 Subject: [PATCH 03/31] [FLYDSL]: Derivation of dynamic if/else results --- python/flydsl/compiler/ast_rewriter.py | 711 +++++++++++----------- python/flydsl/expr/numeric.py | 5 + scripts/run_tests.sh | 5 +- tests/system/test_control_flow_compile.py | 24 + tests/unit/test_if_dispatch_paths.py | 30 +- 5 files changed, 397 insertions(+), 378 deletions(-) create mode 100644 tests/system/test_control_flow_compile.py diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index b3cf9709..da072039 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -2,6 +2,7 @@ # Copyright (c) 2025 FlyDSL Project Contributors import ast +import contextlib import difflib import inspect import types @@ -72,7 +73,7 @@ def transform(cls, f): orig_code = ast.unparse(module) if env.debug.ast_diff else None func_node = module.body[0] rewriter = transformer_ctor(context=context, first_lineno=f.__code__.co_firstlineno - 1) - func_node = rewriter.generic_visit(func_node) + func_node = rewriter.visit(func_node) if env.debug.ast_diff: new_code = ast.unparse(func_node) diff = list( @@ -126,18 +127,136 @@ def transform(cls, f): _ASTREWRITE_MARKER = "_flydsl_ast_rewriter_generated_" +class SymbolScopeTracker: + def __init__(self): + self.scopes = [] + self.callables = [] + + def record_symbol(self, name: str): + if not self.scopes: + return + if name == "_": + return + self.scopes[-1].add(name) + + def record_callable(self, name: str): + if not self.callables: + return + self.callables[-1].add(name) + + def snapshot_symbol_scopes(self): + return self.scopes.copy() + + def snapshot_callable_scopes(self): + return self.callables.copy() + + @contextlib.contextmanager + def function_scope(self): + self.scopes.append(set()) + self.callables.append(set()) + try: + yield + finally: + self.scopes.pop() + self.callables.pop() + + @contextlib.contextmanager + def control_flow_scope(self): + self.scopes.append(set()) + try: + yield + finally: + self.scopes.pop() + + class Transformer(ast.NodeTransformer): def __init__(self, context, first_lineno): super().__init__() self.context = context self.first_lineno = first_lineno + self.symbol_scopes = SymbolScopeTracker() + + def _record_target_symbols(self, target): + if isinstance(target, ast.Name): + self.symbol_scopes.record_symbol(target.id) + elif isinstance(target, (ast.Tuple, ast.List)): + for t in target.elts: + self._record_target_symbols(t) + elif isinstance(target, ast.Starred): + self._record_target_symbols(target.value) + + def _visit_stmt_block(self, stmts): + new_stmts = [] + for stmt in stmts: + transformed = self.visit(stmt) + if isinstance(transformed, list): + new_stmts.extend(transformed) + else: + new_stmts.append(transformed) + return new_stmts def visit_FunctionDef(self, node: ast.FunctionDef): if getattr(node, _ASTREWRITE_MARKER, False): return node - node = self.generic_visit(node) + + with self.symbol_scopes.function_scope(): + for arg in node.args.posonlyargs: + self.symbol_scopes.record_symbol(arg.arg) + for arg in node.args.args: + self.symbol_scopes.record_symbol(arg.arg) + for arg in node.args.kwonlyargs: + self.symbol_scopes.record_symbol(arg.arg) + node = self.generic_visit(node) + return node + def visit_Assign(self, node: ast.Assign): + for target in node.targets: + self._record_target_symbols(target) + return self.generic_visit(node) + + def visit_AugAssign(self, node: ast.AugAssign): + self._record_target_symbols(node.target) + return self.generic_visit(node) + + def visit_AnnAssign(self, node: ast.AnnAssign): + self._record_target_symbols(node.target) + return self.generic_visit(node) + + def visit_For(self, node: ast.For): + self._record_target_symbols(node.target) + node.iter = self.visit(node.iter) + with self.symbol_scopes.control_flow_scope(): + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) + return node + + def visit_If(self, node: ast.If): + node.test = self.visit(node.test) + with self.symbol_scopes.control_flow_scope(): + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) + return node + + def visit_While(self, node: ast.While): + node.test = self.visit(node.test) + with self.symbol_scopes.control_flow_scope(): + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) + return node + + def visit_With(self, node: ast.With): + for item in node.items: + if item.optional_vars is not None: + self._record_target_symbols(item.optional_vars) + return self.generic_visit(node) + @ASTRewriter.register class RewriteBoolOps(Transformer): @@ -223,10 +342,6 @@ def visit_UnaryOp(self, node: ast.UnaryOp): class ReplaceIfWithDispatch(Transformer): _counter = 0 - def __init__(self, context, first_lineno): - super().__init__(context=context, first_lineno=first_lineno) - self._visible_name_stack = [] - @staticmethod def _is_dynamic(cond): if isinstance(cond, ir.Value): @@ -340,8 +455,6 @@ def scf_if_dispatch( then_fn, else_fn=None, *, - arg_names=(), - arg_values=(), result_names=(), result_values=(), state_names=(), @@ -349,19 +462,10 @@ def scf_if_dispatch( auto_else=False, ): # Backward compatibility: old call-sites pass state_* only. - if not arg_names and state_names: - arg_names = state_names - if not arg_values and state_values: - arg_values = state_values - if not result_names: - result_names = arg_names - if not result_values and result_names: - arg_map = {name: value for name, value in zip(arg_names, arg_values)} - result_values = tuple(arg_map.get(name, None) for name in result_names) - - arg_names, arg_values = ReplaceIfWithDispatch._normalize_named_values( - arg_names, arg_values, "arg_names", "arg_values" - ) + if not result_names and state_names: + result_names = state_names + if not result_values and state_values: + result_values = state_values result_names, result_values = ReplaceIfWithDispatch._normalize_named_values( result_names, result_values, "result_names", "result_values" ) @@ -379,7 +483,7 @@ def scf_if_dispatch( taken = then_fn if cond else else_fn if taken is None: return ReplaceIfWithDispatch._pack_named_values(result_names, result_values) - result = ReplaceIfWithDispatch._call_branch(taken, effective_result_names, arg_values) + result = ReplaceIfWithDispatch._call_branch(taken, effective_result_names, result_values) if not effective_result_names: return ReplaceIfWithDispatch._pack_named_values(result_names, result_values) partial_values = ReplaceIfWithDispatch._normalize_branch_result( @@ -398,13 +502,13 @@ def scf_if_dispatch( has_else = else_fn is not None if_op = scf.IfOp(cond_i1, [], has_else=has_else, loc=ir.Location.unknown()) with ir.InsertionPoint(if_op.regions[0].blocks[0]): - ReplaceIfWithDispatch._call_branch(then_fn, effective_result_names, arg_values) + ReplaceIfWithDispatch._call_branch(then_fn, effective_result_names, result_values) scf.YieldOp([]) if has_else: if len(if_op.regions[1].blocks) == 0: if_op.regions[1].blocks.append(*[]) with ir.InsertionPoint(if_op.regions[1].blocks[0]): - ReplaceIfWithDispatch._call_branch(else_fn, effective_result_names, arg_values) + ReplaceIfWithDispatch._call_branch(else_fn, effective_result_names, result_values) scf.YieldOp([]) return ReplaceIfWithDispatch._pack_named_values(result_names, result_values) @@ -425,7 +529,7 @@ def scf_if_dispatch( if_op = scf.IfOp(cond_i1, result_types, has_else=True, loc=ir.Location.unknown()) with ir.InsertionPoint(if_op.regions[0].blocks[0]): - then_result = ReplaceIfWithDispatch._call_branch(then_fn, effective_result_names, arg_values) + then_result = ReplaceIfWithDispatch._call_branch(then_fn, effective_result_names, result_values) then_values = ReplaceIfWithDispatch._normalize_branch_result( then_result, effective_result_names, effective_result_map, "then-branch" ) @@ -441,7 +545,7 @@ def scf_if_dispatch( if len(if_op.regions[1].blocks) == 0: if_op.regions[1].blocks.append(*[]) with ir.InsertionPoint(if_op.regions[1].blocks[0]): - else_result = ReplaceIfWithDispatch._call_branch(else_fn, effective_result_names, arg_values) + else_result = ReplaceIfWithDispatch._call_branch(else_fn, effective_result_names, result_values) else_values = ReplaceIfWithDispatch._normalize_branch_result( else_result, effective_result_names, effective_result_map, "else-branch" ) @@ -505,6 +609,8 @@ def _classify_current(node): if isinstance(node, ast.Name): # Plain names can be static symbols (constexpr params, local bools, etc.). return None + if isinstance(node, ast.Compare): + return True if isinstance(node, ast.Call): func = node.func if isinstance(func, ast.Name) and func.id in ReplaceIfWithDispatch._REWRITE_HELPER_NAMES: @@ -530,109 +636,74 @@ def _visit(node): return _visit(test_node) @staticmethod - def _collect_assigned_vars(stmts): - assigned = [] - - def add_name(name): - if isinstance(name, str) and name not in assigned: - assigned.append(name) - - class AssignCollector(ast.NodeVisitor): - def _collect_target(self, target): - if isinstance(target, ast.Name): - add_name(target.id) - elif isinstance(target, (ast.Tuple, ast.List)): - for elt in target.elts: - self._collect_target(elt) - elif isinstance(target, ast.Starred): - self._collect_target(target.value) - elif isinstance(target, ast.Subscript): - self._collect_target(target.value) - elif isinstance(target, ast.Attribute): - if isinstance(target.value, ast.Name): - add_name(target.value.id) + def _collect_assigned_vars(node: ast.If, active_symbols): + write_args = [] + invoked_args = [] + + def add_unique(items, name): + if isinstance(name, str) and name not in items: + items.append(name) + + def in_active_symbols(name): + return any(name in symbol_scope for symbol_scope in active_symbols) + + class RegionAnalyzer(ast.NodeVisitor): + force_store = False + + @staticmethod + def _get_call_base(func_node): + if isinstance(func_node, ast.Attribute): + if isinstance(func_node.value, ast.Attribute): + return RegionAnalyzer._get_call_base(func_node.value) + if isinstance(func_node.value, ast.Name): + return func_node.value.id + return None + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Store) or self.force_store: + add_unique(write_args, node.id) + + def visit_Subscript(self, node): + if isinstance(node.ctx, ast.Store): + self.force_store = True + self.visit(node.value) + self.force_store = False + self.visit(node.slice) + else: + self.generic_visit(node) def visit_Assign(self, node): + self.force_store = True for target in node.targets: - self._collect_target(target) + self.visit(target) + self.force_store = False self.visit(node.value) def visit_AugAssign(self, node): - self._collect_target(node.target) + self.force_store = True + self.visit(node.target) + self.force_store = False self.visit(node.value) - def visit_AnnAssign(self, node): - self._collect_target(node.target) - if node.value is not None: - self.visit(node.value) - - def visit_For(self, node): - self._collect_target(node.target) - self.generic_visit(node) - - def visit_AsyncFor(self, node): - self._collect_target(node.target) - self.generic_visit(node) - - def visit_With(self, node): - for item in node.items: - if item.optional_vars is not None: - self._collect_target(item.optional_vars) - self.generic_visit(node) - - def visit_AsyncWith(self, node): - for item in node.items: - if item.optional_vars is not None: - self._collect_target(item.optional_vars) - self.generic_visit(node) - - def visit_ExceptHandler(self, node): - if node.name: - add_name(node.name) - self.generic_visit(node) - - def visit_NamedExpr(self, node): - self._collect_target(node.target) - self.visit(node.value) - - def visit_MatchAs(self, node): - if node.name: - add_name(node.name) - self.generic_visit(node) - - def visit_MatchStar(self, node): - if node.name: - add_name(node.name) - self.generic_visit(node) + def visit_Call(self, node): + base_name = RegionAnalyzer._get_call_base(node.func) + if base_name is not None and base_name != "self": + add_unique(invoked_args, base_name) - def visit_MatchMapping(self, node): - if node.rest: - add_name(node.rest) self.generic_visit(node) - def visit_Call(self, node): - if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): - # Only treat typical mutating methods as writes. - if node.func.attr in { - "append", - "extend", - "insert", - "pop", - "remove", - "clear", - "sort", - "reverse", - "add", - "discard", - "update", - "setdefault", - }: - add_name(node.func.value.id) - self.generic_visit(node) + analyzer = RegionAnalyzer() + analyzer.visit(ast.Module(body=node.body, type_ignores=[])) + if node.orelse: + analyzer.visit(ast.Module(body=node.orelse, type_ignores=[])) - collector = AssignCollector() - collector.visit(ast.Module(body=stmts, type_ignores=[])) - return assigned + invoked_args = [name for name in invoked_args if name not in write_args] + write_args = [name for name in write_args if in_active_symbols(name)] + invoked_args = [name for name in invoked_args if in_active_symbols(name)] + print(f"write_args: {write_args}") + print(f"invoked_args: {invoked_args}") + print(f"active_symbols: {active_symbols}") + return write_args + invoked_args @staticmethod def _state_value_expr(name): @@ -646,232 +717,138 @@ def _state_value_expr(name): keywords=[], ) - @staticmethod - def _collect_defined_names(stmt): - defined = [] - - def add_name(name): - if isinstance(name, str) and name not in defined: - defined.append(name) - - class DefCollector(ast.NodeVisitor): - def _collect_target(self, target): - if isinstance(target, ast.Name): - add_name(target.id) - elif isinstance(target, (ast.Tuple, ast.List)): - for elt in target.elts: - self._collect_target(elt) - elif isinstance(target, ast.Starred): - self._collect_target(target.value) - - def visit_Assign(self, node): - for target in node.targets: - self._collect_target(target) - - def visit_AugAssign(self, node): - self._collect_target(node.target) - - def visit_AnnAssign(self, node): - self._collect_target(node.target) - - def visit_For(self, node): - self._collect_target(node.target) - - def visit_AsyncFor(self, node): - self._collect_target(node.target) - - def visit_With(self, node): - for item in node.items: - if item.optional_vars is not None: - self._collect_target(item.optional_vars) - - def visit_AsyncWith(self, node): - for item in node.items: - if item.optional_vars is not None: - self._collect_target(item.optional_vars) - - def visit_ExceptHandler(self, node): - if node.name: - add_name(node.name) - - def visit_NamedExpr(self, node): - self._collect_target(node.target) - - DefCollector().visit(stmt) - return defined - - def visit_FunctionDef(self, node: ast.FunctionDef): - if getattr(node, _ASTREWRITE_MARKER, False): - return node - - visible = {arg.arg for arg in node.args.posonlyargs + node.args.args + node.args.kwonlyargs} - self._visible_name_stack.append(visible) - new_body = [] - try: - for stmt in node.body: - transformed = self.visit(stmt) - if isinstance(transformed, list): - new_body.extend(transformed) - else: - new_body.append(transformed) - for name in self._collect_defined_names(stmt): - visible.add(name) - finally: - self._visible_name_stack.pop() - - node.body = new_body - return node - def visit_If(self, node: ast.If) -> List[ast.AST]: + active_symbols_before_if = self.symbol_scopes.snapshot_symbol_scopes() if _is_constexpr(node.test): node.test = _unwrap_constexpr(node.test) - node = self.generic_visit(node) + node = super().visit_If(node) return node if not self._could_be_dynamic(node.test): - node = self.generic_visit(node) + node = super().visit_If(node) return node - node = self.generic_visit(node) - uid = ReplaceIfWithDispatch._counter - ReplaceIfWithDispatch._counter += 1 - - then_name = f"__then_{uid}" - then_vars = self._collect_assigned_vars(node.body) - else_vars = self._collect_assigned_vars(node.orelse) - arg_names = list(then_vars) - for var in else_vars: - if var not in arg_names: - arg_names.append(var) - - # result_names: only carry variables visible before this if. - result_names = list(arg_names) - if self._visible_name_stack: - visible = self._visible_name_stack[-1] - result_names = [name for name in result_names if name in visible] - - fn_args = [ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in arg_names] - def _state_return_node(): - return ast.Return( - value=ast.Call( - func=ast.Name(id="scf_if_collect_results", ctx=ast.Load()), - args=[ - ast.Name(id="__ret_names", ctx=ast.Load()), - ast.Call(func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]), - ], - keywords=[], + with self.symbol_scopes.control_flow_scope(): + node.test = self.visit(node.test) + with self.symbol_scopes.control_flow_scope(): + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) + uid = ReplaceIfWithDispatch._counter + ReplaceIfWithDispatch._counter += 1 + + then_name = f"__then_{uid}" + result_names = self._collect_assigned_vars(node, active_symbols_before_if) + + fn_args = [ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in result_names] + + def _state_return_node(): + return ast.Return( + value=ast.Call( + func=ast.Name(id="scf_if_collect_results", ctx=ast.Load()), + args=[ + ast.Name(id="__ret_names", ctx=ast.Load()), + ast.Call(func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]), + ], + keywords=[], + ) ) - ) - then_func = ast.FunctionDef( - name=then_name, - args=ast.arguments(posonlyargs=[], args=fn_args, kwonlyargs=[], kw_defaults=[], defaults=[]), - body=list(node.body) + ([_state_return_node()] if result_names else []), - decorator_list=[], - type_params=[], - ) - setattr(then_func, _ASTREWRITE_MARKER, True) - then_func = ast.copy_location(then_func, node) - then_func = ast.fix_missing_locations(then_func) - - dispatch_args = [node.test, ast.Name(then_name, ctx=ast.Load())] - dispatch_keywords = [] - if arg_names: - dispatch_keywords.extend( - [ - ast.keyword( - arg="arg_names", - value=ast.Tuple(elts=[ast.Constant(value=v) for v in arg_names], ctx=ast.Load()), - ), - ast.keyword( - arg="arg_values", - value=ast.Tuple( - elts=[self._state_value_expr(v) for v in arg_names], - ctx=ast.Load(), - ), - ), - ] - ) - if result_names: - dispatch_keywords.extend( - [ - ast.keyword( - arg="result_names", - value=ast.Tuple(elts=[ast.Constant(value=v) for v in result_names], ctx=ast.Load()), - ), - ast.keyword( - arg="result_values", - value=ast.Tuple( - elts=[self._state_value_expr(v) for v in result_names], - ctx=ast.Load(), - ), - ), - ] - ) - result = [then_func] - else_name = None - synthesized_else = False - if node.orelse: - else_name = f"__else_{uid}" - else_func = ast.FunctionDef( - name=else_name, - args=ast.arguments( - posonlyargs=[], - args=[ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in arg_names], - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=list(node.orelse) + ([_state_return_node()] if result_names else []), + then_func = ast.FunctionDef( + name=then_name, + args=ast.arguments(posonlyargs=[], args=fn_args, kwonlyargs=[], kw_defaults=[], defaults=[]), + body=list(node.body) + ([_state_return_node()] if result_names else []), decorator_list=[], type_params=[], ) - setattr(else_func, _ASTREWRITE_MARKER, True) - else_func = ast.copy_location(else_func, node) - else_func = ast.fix_missing_locations(else_func) - dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) - result.append(else_func) - elif result_names: - else_name = f"__else_{uid}" - synthesized_else = True - else_func = ast.FunctionDef( - name=else_name, - args=ast.arguments( - posonlyargs=[], - args=[ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in arg_names], - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=[_state_return_node()], - decorator_list=[], - type_params=[], + setattr(then_func, _ASTREWRITE_MARKER, True) + then_func = ast.copy_location(then_func, node) + then_func = ast.fix_missing_locations(then_func) + + dispatch_args = [node.test, ast.Name(then_name, ctx=ast.Load())] + dispatch_keywords = [] + if result_names: + dispatch_keywords.extend( + [ + ast.keyword( + arg="result_names", + value=ast.Tuple(elts=[ast.Constant(value=v) for v in result_names], ctx=ast.Load()), + ), + ast.keyword( + arg="result_values", + value=ast.Tuple( + elts=[self._state_value_expr(v) for v in result_names], + ctx=ast.Load(), + ), + ), + ] + ) + result = [then_func] + + else_name = None + synthesized_else = False + if node.orelse: + else_name = f"__else_{uid}" + else_func = ast.FunctionDef( + name=else_name, + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in result_names], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=list(node.orelse) + ([_state_return_node()] if result_names else []), + decorator_list=[], + type_params=[], + ) + setattr(else_func, _ASTREWRITE_MARKER, True) + else_func = ast.copy_location(else_func, node) + else_func = ast.fix_missing_locations(else_func) + dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) + result.append(else_func) + elif result_names: + else_name = f"__else_{uid}" + synthesized_else = True + else_func = ast.FunctionDef( + name=else_name, + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in result_names], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[_state_return_node()], + decorator_list=[], + type_params=[], + ) + setattr(else_func, _ASTREWRITE_MARKER, True) + else_func = ast.copy_location(else_func, node) + else_func = ast.fix_missing_locations(else_func) + dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) + result.append(else_func) + + if synthesized_else: + dispatch_keywords.append(ast.keyword(arg="auto_else", value=ast.Constant(value=True))) + + dispatch_value = ast.Call( + func=ast.Name("scf_if_dispatch", ctx=ast.Load()), + args=dispatch_args, + keywords=dispatch_keywords, ) - setattr(else_func, _ASTREWRITE_MARKER, True) - else_func = ast.copy_location(else_func, node) - else_func = ast.fix_missing_locations(else_func) - dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) - result.append(else_func) - - if synthesized_else: - dispatch_keywords.append(ast.keyword(arg="auto_else", value=ast.Constant(value=True))) - - dispatch_value = ast.Call( - func=ast.Name("scf_if_dispatch", ctx=ast.Load()), - args=dispatch_args, - keywords=dispatch_keywords, - ) - if result_names and else_name is not None: - if len(result_names) == 1: - target = ast.Name(id=result_names[0], ctx=ast.Store()) + if result_names and else_name is not None: + if len(result_names) == 1: + target = ast.Name(id=result_names[0], ctx=ast.Store()) + else: + target = ast.Tuple(elts=[ast.Name(id=v, ctx=ast.Store()) for v in result_names], ctx=ast.Store()) + dispatch_stmt = ast.Assign(targets=[target], value=dispatch_value) else: - target = ast.Tuple(elts=[ast.Name(id=v, ctx=ast.Store()) for v in result_names], ctx=ast.Store()) - dispatch_stmt = ast.Assign(targets=[target], value=dispatch_value) - else: - dispatch_stmt = ast.Expr(value=dispatch_value) - dispatch_stmt = ast.copy_location(dispatch_stmt, node) - dispatch_stmt = ast.fix_missing_locations(dispatch_stmt) - result.append(dispatch_stmt) + dispatch_stmt = ast.Expr(value=dispatch_value) + dispatch_stmt = ast.copy_location(dispatch_stmt, node) + dispatch_stmt = ast.fix_missing_locations(dispatch_stmt) + result.append(dispatch_stmt) - return result + return result @ASTRewriter.register @@ -949,9 +926,16 @@ def visit_For(self, node: ast.For) -> ast.For: node.iter.func = ast.Name(id="scf_range", ctx=ast.Load()) line = ast.dump(node.iter) if "for_" in line or "scf.for_" in line or "scf_range" in line: - node = self.generic_visit(node) + node.iter = self.visit(node.iter) + with self.symbol_scopes.control_flow_scope(): + if isinstance(node.target, ast.Name): + self.symbol_scopes.record_symbol(node.target.id) + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) new_yield = ast.Expr(ast.Yield(value=None)) - if not self._is_yield(node.body[-1]): + if node.body and not self._is_yield(node.body[-1]): last_statement = node.body[-1] assert last_statement.end_lineno is not None, ( f"last_statement {ast.unparse(last_statement)} must have end_lineno" @@ -1045,37 +1029,38 @@ def rewrite_globals(cls): def visit_While(self, node: ast.While) -> List[ast.AST]: if _is_constexpr(node.test): node.test = _unwrap_constexpr(node.test) - node = self.generic_visit(node) + node = super().visit_While(node) return node - node = self.generic_visit(node) - if isinstance(node.test, ast.NamedExpr): - test = node.test.value - else: - test = node.test - w = ast.Call(func=ast.Name("scf_while_gen", ctx=ast.Load()), args=[test], keywords=[]) - w = ast.copy_location(w, node) - assign = ast.Assign( - targets=[ast.Name(f"w_{node.lineno}", ctx=ast.Store())], - value=w, - ) - assign = ast.fix_missing_locations(ast.copy_location(assign, node)) - - next_ = ast.Call( - func=ast.Name("next", ctx=ast.Load()), - args=[ - ast.Name(f"w_{node.lineno}", ctx=ast.Load()), - ast.Constant(False, kind="bool"), - ], - keywords=[], - ) - next_ = ast.fix_missing_locations(ast.copy_location(next_, node)) - if isinstance(node.test, ast.NamedExpr): - node.test.value = next_ - else: - new_test = ast.NamedExpr(target=ast.Name(f"__init__{node.lineno}", ctx=ast.Store()), value=next_) - new_test = ast.copy_location(new_test, node) - node.test = new_test + with self.symbol_scopes.control_flow_scope(): + node = super().visit_While(node) + if isinstance(node.test, ast.NamedExpr): + test = node.test.value + else: + test = node.test + w = ast.Call(func=ast.Name("scf_while_gen", ctx=ast.Load()), args=[test], keywords=[]) + w = ast.copy_location(w, node) + assign = ast.Assign( + targets=[ast.Name(f"w_{node.lineno}", ctx=ast.Store())], + value=w, + ) + assign = ast.fix_missing_locations(ast.copy_location(assign, node)) + + next_ = ast.Call( + func=ast.Name("next", ctx=ast.Load()), + args=[ + ast.Name(f"w_{node.lineno}", ctx=ast.Load()), + ast.Constant(False, kind="bool"), + ], + keywords=[], + ) + next_ = ast.fix_missing_locations(ast.copy_location(next_, node)) + if isinstance(node.test, ast.NamedExpr): + node.test.value = next_ + else: + new_test = ast.NamedExpr(target=ast.Name(f"__init__{node.lineno}", ctx=ast.Store()), value=next_) + new_test = ast.copy_location(new_test, node) + node.test = new_test - node = ast.fix_missing_locations(node) - assign = ast.fix_missing_locations(assign) - return [assign, node] + node = ast.fix_missing_locations(node) + assign = ast.fix_missing_locations(assign) + return [assign, node] diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index 44742217..e3754054 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -217,6 +217,11 @@ def _unwrap_value(value): """Convert FlyDSL wrappers to raw MLIR values when possible.""" if isinstance(value, ir.Value): return value + if isinstance(value, (bool, int, float)): + try: + return as_numeric(value).ir_value() + except Exception: + return value if hasattr(value, "__fly_values__"): values = value.__fly_values__() if len(values) == 1: diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index cdaf4bb1..897ff34d 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -42,15 +42,16 @@ if [ "${RUN_TESTS_FULL:-0}" != "1" ]; then fi # --------------------------------------------------------------------------- -# 1. All pytest-based tests (kernels + unit + examples) +# 1. All pytest-based tests (kernels + unit + system + examples) # --------------------------------------------------------------------------- echo "========================================================================" -echo "Pytest: kernels + unit + examples" +echo "Pytest: kernels + unit + system + examples" echo "========================================================================" python3 -m pytest \ tests/kernels/ \ tests/unit/ \ + tests/system/ \ tests/python/examples/ \ "${pytest_args[@]}" diff --git a/tests/system/test_control_flow_compile.py b/tests/system/test_control_flow_compile.py new file mode 100644 index 00000000..db2bf084 --- /dev/null +++ b/tests/system/test_control_flow_compile.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +import flydsl.compiler as flyc +import flydsl.expr as fx + + +def test_control_flow_kernel_snippet_compiles_without_error(): + @flyc.kernel + def vecAbsKernel( + A: fx.Tensor, + C: fx.Tensor, + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], + print_debug: fx.Constexpr[bool] = True, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + if print_debug and bid == 0 and tid <= 2: + fx.printf("[kernel] bid={}, tid={}", bid, tid) + + assert vecAbsKernel is not None diff --git a/tests/unit/test_if_dispatch_paths.py b/tests/unit/test_if_dispatch_paths.py index 675395f4..fbc1aceb 100644 --- a/tests/unit/test_if_dispatch_paths.py +++ b/tests/unit/test_if_dispatch_paths.py @@ -19,7 +19,9 @@ def test_collect_assigned_vars_supports_tuple_and_augassign(): d += 1 """ stmts = ast.parse(code).body - assigned = ReplaceIfWithDispatch._collect_assigned_vars(stmts) + node = ast.If(test=ast.Constant(value=True), body=stmts, orelse=[]) + active_symbols = [{"a", "b", "c", "d"}] + assigned = ReplaceIfWithDispatch._collect_assigned_vars(node, active_symbols) assert assigned == ["a", "b", "c", "d"] @@ -38,8 +40,10 @@ def test_collect_assigned_vars_supports_annassign_walrus_with_except_for(): out = n """ stmts = ast.parse(code).body - assigned = ReplaceIfWithDispatch._collect_assigned_vars(stmts) - assert assigned == ["x", "i", "y", "w", "z", "e", "err", "n", "out"] + node = ast.If(test=ast.Constant(value=True), body=stmts, orelse=[]) + active_symbols = [{"x", "i", "y", "w", "z", "e", "err", "n", "out"}] + assigned = ReplaceIfWithDispatch._collect_assigned_vars(node, active_symbols) + assert assigned == ["x", "i", "y", "w", "z", "err", "n", "out"] def test_scf_if_dispatch_static_with_states_no_ifop(): @@ -134,7 +138,7 @@ def else_fn(x): ) -def test_scf_if_dispatch_dynamic_non_mlir_value_has_clear_error(): +def test_scf_if_dispatch_dynamic_non_mlir_value_is_promoted(): with Context(), Location.unknown(): module = Module.create() i32 = IntegerType.get_signless(32) @@ -152,14 +156,14 @@ def then_fn(x): def else_fn(x): return {"x": arith.ConstantOp(i32, 3).result} - with pytest.raises(TypeError, match="not an MLIR Value"): - ReplaceIfWithDispatch.scf_if_dispatch( - cond, - then_fn, - else_fn, - state_names=("x",), - state_values=(x,), - ) + out = ReplaceIfWithDispatch.scf_if_dispatch( + cond, + then_fn, + else_fn, + state_names=("x",), + state_values=(x,), + ) + assert isinstance(out, Int32) def test_ast_rewrite_keeps_semantics_for_static_bool(): @@ -207,4 +211,4 @@ def traced_dispatch(*args, **kwargs): sample.__globals__["scf_if_dispatch"] = traced_dispatch assert sample("f32") == 1 assert sample("bf16") == 2 - assert called["n"] == 0 + assert called["n"] == 2 From aab6097e4e77979076e3cf87fb19edda7e4056bb Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 3 Apr 2026 10:21:05 +0000 Subject: [PATCH 04/31] [FLYDSL]: _could_be_dynamic Reconstruction --- python/flydsl/compiler/ast_rewriter.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index da072039..87809928 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -603,27 +603,17 @@ def _is_literal_expr(node): ) return False - def _classify_current(node): + def _visit(node): if _is_literal_expr(node): return False - if isinstance(node, ast.Name): - # Plain names can be static symbols (constexpr params, local bools, etc.). - return None if isinstance(node, ast.Compare): return True if isinstance(node, ast.Call): func = node.func - if isinstance(func, ast.Name) and func.id in ReplaceIfWithDispatch._REWRITE_HELPER_NAMES: - return None - return True - return None - - def _visit(node): - current = _classify_current(node) - if current is True: - return True - if current is False: - return False + if not (isinstance(func, ast.Name) and func.id in ReplaceIfWithDispatch._REWRITE_HELPER_NAMES): + return True + # Plain names can be static symbols (constexpr params, local bools, etc.), + # and unknown nodes keep recursing into children. for child in ast.iter_child_nodes(node): if _visit(child): From 938992eaba8e73697864efc6163bf34409862223 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 3 Apr 2026 10:52:44 +0000 Subject: [PATCH 05/31] [FLYDSL]: Only test and verification --- kernels/blockscale_preshuffle_gemm.py | 8 ++++---- kernels/layernorm_kernel.py | 22 +++++++++++----------- kernels/rmsnorm_kernel.py | 16 +++++++++------- kernels/softmax_kernel.py | 20 +++++++++++--------- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index 2c9e5ca0..bcadf24b 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -553,6 +553,9 @@ def compute_tile_blockscale( for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) + a0, a1 = lds_load_packs_k64( + curr_row_a_lds, col_base, lds_buffer + ) if ( a0_prefetch is not None @@ -561,10 +564,6 @@ def compute_tile_blockscale( and mi == 0 ): a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64( - curr_row_a_lds, col_base, lds_buffer - ) for ni in range_constexpr(num_acc_n): acc_idx = mi * num_acc_n + ni @@ -732,6 +731,7 @@ def _load_a_to_lds(base_k, lds_buffer): a0_prefetch_pong = prefetch_a0_pack(lds_a_pong) num_tiles = K // tile_k + final_accs = global_accs if (num_tiles % 2) == 1: for k_iv in range_constexpr(0, K - tile_k, tile_k * 2): diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index 8850d9f3..b08a12e6 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -112,16 +112,16 @@ def block_reduce_add2(val0, val1): if lane == fx.Int32(0): wave_idx = arith.index_cast(T.index, wave) - s_sum.store(w0, [wave_idx]) - s_sumsq.store(w1, [wave_idx]) + SmemPtr.store(s_sum, w0, [wave_idx]) + SmemPtr.store(s_sumsq, w1, [wave_idx]) gpu.barrier() if wave == fx.Int32(0): in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, fx.Int32(0)) lane_safe_idx = arith.index_cast(T.index, lane_safe) - v0 = s_sum.load([lane_safe_idx]) - v1 = s_sumsq.load([lane_safe_idx]) + v0 = SmemPtr.load(s_sum, [lane_safe_idx]) + v1 = SmemPtr.load(s_sumsq, [lane_safe_idx]) z = fx.Float32(0.0) ww0 = in_range.select(v0, z) ww1 = in_range.select(v1, z) @@ -130,12 +130,12 @@ def block_reduce_add2(val0, val1): if lane == fx.Int32(0): c0_idx = fx.Index(0) - s_sum.store(ww0, [c0_idx]) - s_sumsq.store(ww1, [c0_idx]) + SmemPtr.store(s_sum, ww0, [c0_idx]) + SmemPtr.store(s_sumsq, ww1, [c0_idx]) gpu.barrier() c0_idx = fx.Index(0) - return s_sum.load([c0_idx]), s_sumsq.load([c0_idx]) + return SmemPtr.load(s_sum, [c0_idx]), SmemPtr.load(s_sumsq, [c0_idx]) def compute_mean_rstd(sum_val, sumsq_val): from flydsl.expr.arith import ArithValue @@ -250,6 +250,8 @@ def _store_vec_buf(data, rsrc, col_byte_off, soff=None): # ── Pass 2: normalize + affine + store ─────────────────────── for tile_i in range_constexpr(num_tiles_py): + g_next = g_cur + b_next = b_cur if tile_i + 1 < num_tiles_py: next_col_bytes = ArithValue(thr_col_bytes) + ((tile_i + 1) * tile_cols * elem_bytes) g_e_next = _load_vec_buf(gamma_rsrc, next_col_bytes) @@ -264,10 +266,6 @@ def _store_vec_buf(data, rsrc, col_byte_off, soff=None): if dtype_str == "f32" else b_e_next.extf(vec_type_c) ) - else: - g_next = g_cur - b_next = b_cur - x = in_local[tile_i] if cache_as_elem: x = x.extf(vec_type_c) @@ -278,6 +276,7 @@ def _store_vec_buf(data, rsrc, col_byte_off, soff=None): y = (x_av - mean_splat_av) * rstd_splat_av y = (y * g_av) + b_av y_val = y + out_e = y_val if dtype_str == "bf16": if USE_HW_CVT_PK_BF16_F32: @@ -403,6 +402,7 @@ def _store_scalar(divided_tensor, index, val): norm = diff * ArithValue(rstd) scaled = norm * ArithValue(g) y = scaled + ArithValue(b) + y_e = y if dtype_str == "bf16": y_e = y.truncf(elem_type) elif dtype_str == "f32": diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index d354dbb7..08fce27c 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -109,16 +109,16 @@ def block_reduce_add2(val0, val1): if lane == fx.Int32(0): wave_idx = arith.index_cast(T.index, wave) - s_red.store(w0, [wave_idx]) - s_red2.store(w1, [wave_idx]) + SmemPtr.store(s_red, w0, [wave_idx]) + SmemPtr.store(s_red2, w1, [wave_idx]) gpu.barrier() if wave == fx.Int32(0): in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, fx.Int32(0)) lane_safe_idx = arith.index_cast(T.index, lane_safe) - v0 = s_red.load([lane_safe_idx]) - v1 = s_red2.load([lane_safe_idx]) + v0 = SmemPtr.load(s_red, [lane_safe_idx]) + v1 = SmemPtr.load(s_red2, [lane_safe_idx]) z = fx.Float32(0.0) ww0 = in_range.select(v0, z) ww1 = in_range.select(v1, z) @@ -127,12 +127,12 @@ def block_reduce_add2(val0, val1): if lane == fx.Int32(0): c0_idx = fx.Index(0) - s_red.store(ww0, [c0_idx]) - s_red2.store(ww1, [c0_idx]) + SmemPtr.store(s_red, ww0, [c0_idx]) + SmemPtr.store(s_red2, ww1, [c0_idx]) gpu.barrier() c0_idx = fx.Index(0) - return s_red.load([c0_idx]), s_red2.load([c0_idx]) + return SmemPtr.load(s_red, [c0_idx]), SmemPtr.load(s_red2, [c0_idx]) # ================================================================== # Fast path: N is a multiple of tile_cols @@ -210,6 +210,7 @@ def _store_vec(data, rsrc, col_byte_off, soff=None): g_av = ArithValue(g) y = (x_av * rrms_splat_av) * g_av y_val = y + out_e = y_val if dtype_str == "bf16": if USE_HW_CVT_PK_BF16_F32: @@ -306,6 +307,7 @@ def _store_scalar(divided_tensor, index, val): g = g_e if dtype_str == "f32" else g_e.extf(compute_type) norm = ArithValue(x) * ArithValue(rrms) y = norm * ArithValue(g) + y_e = y if dtype_str == "f32": y_e = y elif dtype_str == "bf16": diff --git a/kernels/softmax_kernel.py b/kernels/softmax_kernel.py index 832fcffa..272c64b1 100644 --- a/kernels/softmax_kernel.py +++ b/kernels/softmax_kernel.py @@ -96,7 +96,7 @@ def wave_reduce(x, mode): w = w.addf(peer, fastmath=fm_fast) return w - def block_reduce(val, mode): + def block_reduce(val, mode, s_red_buffer): if RED_SLOTS == 1: return wave_reduce(val, mode) @@ -108,25 +108,25 @@ def block_reduce(val, mode): if lane == fx.Int32(0): wave_idx = arith.index_cast(T.index, wave) - s_red.store(w, [wave_idx]) + SmemPtr.store(s_red_buffer, w, [wave_idx]) gpu.barrier() if wave == fx.Int32(0): in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, fx.Int32(0)) lane_safe_idx = arith.index_cast(T.index, lane_safe) - v = s_red.load([lane_safe_idx]) + v = SmemPtr.load(s_red_buffer, [lane_safe_idx]) z = neutral ww = in_range.select(v, z) ww = wave_reduce(ww, mode) if lane == fx.Int32(0): c0_idx = fx.Index(0) - s_red.store(ww, [c0_idx]) + SmemPtr.store(s_red_buffer, ww, [c0_idx]) gpu.barrier() c0_idx = fx.Index(0) - return s_red.load([c0_idx]) + return SmemPtr.load(s_red_buffer, [c0_idx]) # ================================================================== # Fast path: N is a multiple of tile_cols @@ -170,7 +170,7 @@ def _store_vec(data, rsrc, col_byte_off, soff=None): red_max = vector.reduction(compute_type, vector.CombiningKind.MAXNUMF, x) thread_max = thread_max.maximumf(red_max) - global_max = block_reduce(thread_max, "max") + global_max = block_reduce(thread_max, "max", s_red) # 2. Exp + local sum g_max_splat = vector.broadcast(vec_type_c, global_max) @@ -186,7 +186,7 @@ def _store_vec(data, rsrc, col_byte_off, soff=None): red_sum = vector.reduction(compute_type, vector.CombiningKind.ADD, exp_val, fastmath=fm_fast) thread_sum = thread_sum + red_sum - global_sum = block_reduce(thread_sum, "sum") + global_sum = block_reduce(thread_sum, "sum", s_red) # 3. Normalize + store c_one = arith.constant(1.0, type=compute_type) @@ -196,6 +196,7 @@ def _store_vec(data, rsrc, col_byte_off, soff=None): for tile_i in range_constexpr(num_tiles): exp_vec = row_buffer[tile_i] norm_vec = ArithValue(exp_vec) * ArithValue(inv_sum_splat) + out_e = norm_vec if dtype_str == "f32": out_e = norm_vec @@ -253,7 +254,7 @@ def _store_scalar(divided, index, val): row_buffer.append((safe_val, is_valid)) thread_max = thread_max.maximumf(safe_val) - global_max = block_reduce(thread_max, "max") + global_max = block_reduce(thread_max, "max", s_red) # 2. Exp + sum thread_sum = c_zero_f @@ -266,7 +267,7 @@ def _store_scalar(divided, index, val): thread_sum = thread_sum + safe_exp new_buffer.append((exp_val, is_valid)) - global_sum = block_reduce(thread_sum, "sum") + global_sum = block_reduce(thread_sum, "sum", s_red) c_one = arith.constant(1.0, type=compute_type) inv_sum = c_one / ArithValue(global_sum) @@ -278,6 +279,7 @@ def _store_scalar(divided, index, val): buf_idx += 1 if arith.cmpi(arith.CmpIPredicate.ult, idx, Int32(N)): norm_val = ArithValue(exp_val) * inv_sum + out_e = norm_val if dtype_str == "f32": out_e = norm_val else: From bc653911585123a1af17d257651393fa4bff5384 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Tue, 7 Apr 2026 04:23:07 +0000 Subject: [PATCH 06/31] [FLYDSL]: Standardized use cases --- kernels/moe_blockscale_2stage.py | 20 ++++++++++++++++++-- kernels/moe_gemm_2stage.py | 27 +++++++++++++++------------ kernels/preshuffle_gemm.py | 3 +++ tests/kernels/test_quant.py | 8 ++++---- 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index 23cb9948..67c670f5 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -336,6 +336,7 @@ def silu(x): ) # fp16 path ignores scales completely (implicit scale=1.0). + x_load_bytes = 16 if is_f16: sx_rsrc = None sw_rsrc = None @@ -364,6 +365,7 @@ def silu(x): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. + x_load_bytes = 16 if is_f16: if bytes_per_thread_x % 16 != 0: raise ValueError( @@ -695,6 +697,8 @@ def load_scales_s1(k_base): _sw_shared_n = (n_per_wave <= 128) s_w_gate_vals = [] s_w_up_vals = [] + s_w_gate = fx.Float32(1.0) + s_w_up = fx.Float32(1.0) for ni in range_constexpr(num_acc_n): if ni == 0 or not _sw_shared_n: sw_gate_idx = _pre_n_block_gate[ni] * c_nblk_k_w1 + kb @@ -811,10 +815,14 @@ def mfma_k64(acc_in, a0, a1, b0, b1): b_up_packs0, b_up_packs1 = b_up_tile_in[ku] ki64 = arith.index(ku * 64) col_base = col_offset_base_bytes + ki64 + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (sb == 0) and (ku_local == 0) and (mi == 0): a0, a1 = a0_prefetch else: - a0, a1 = lds_load_packs_k64(row_a_lds + arith.index(mi * 16), col_base, lds_base) + a0, a1 = lds_load_packs_k64( + row_a_lds + arith.index(mi * 16), col_base, lds_base + ) blk_g = mfma_k64(blk_g, a0, a1, b_gate_packs0[ni], b_gate_packs1[ni]) blk_u = mfma_k64(blk_u, a0, a1, b_up_packs0[ni], b_up_packs1[ni]) s_wg_bc = vector.broadcast(T.f32x4, s_w_gate_vals[ni]) @@ -894,6 +902,8 @@ def mfma_k64(acc_in, a0, a1, b0, b1): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch else: @@ -1956,10 +1966,14 @@ def mfma_k64(acc0, a0, a1, b0, b1): b_packs0, b_packs1 = b_tile_in[ku] ki64 = arith.index(ku * 64) col_base = col_offset_base_bytes + ki64 + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (sb == 0) and (ku_local == 0) and (mi == 0): a0, a1 = a0_prefetch else: - a0, a1 = lds_load_packs_k64(row_a_lds + arith.index(mi * 16), col_base, lds_base) + a0, a1 = lds_load_packs_k64( + row_a_lds + arith.index(mi * 16), col_base, lds_base + ) blk = mfma_k64(blk, a0, a1, b_packs0[ni], b_packs1[ni]) s_w_bc = vector.broadcast(T.f32x4, s_w_vals[ni]) scale = ArithValue(s_a_vec4_list[mi]) * ArithValue(s_w_bc) @@ -2031,6 +2045,8 @@ def mfma_k64(acc0, a0, a1, b0, b1): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch else: diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 9eefa4ed..390a33b6 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -372,6 +372,7 @@ def silu(x): ) # scale_x: fp16/bf16 path ignores (implicit scale=1.0); int4_bf16 also uses 1.0. + x_load_bytes = 16 if is_f16_or_bf16: sx_rsrc = None else: @@ -796,10 +797,9 @@ def mfma_k64(acc_in, a0, a1, b0, b1): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) for ni in range_constexpr(num_acc_n): acc_idx = mi * num_acc_n + ni @@ -974,12 +974,12 @@ def hot_loop_scheduler(): topk_i32_v = topk_i32 inter_i32_v = fx.Int32(inter_dim) mask24_i32 = fx.Int32(0xFFFFFF) + sw_gate_vals = [] + sw_up_vals = [] if epilogue_pf is not None: sw_gate_vals, sw_up_vals = epilogue_pf else: - sw_gate_vals = [] - sw_up_vals = [] for ni in range_constexpr(num_acc_n): col_g = col_g_list[ni] row_gate_idx = expert_off + col_g @@ -1618,6 +1618,7 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16/bf16 we require 16B. + x_load_bytes = 16 if is_f16_or_bf16: if bytes_per_thread_x % 16 != 0: raise ValueError( @@ -1975,10 +1976,9 @@ def mfma_k64(acc0, a0, a1, b0, b1): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) for ni in range_constexpr(num_acc_n): acc_idx = mi * num_acc_n + ni @@ -2603,15 +2603,17 @@ def moe_reduction_kernel( ) else: v = buffer_ops.buffer_load(x_rsrc, x_idx_i32, vec_width=1, dtype=elem_type()) + v_cast = v if dtype_str in ("f16", "bf16"): - v = arith.extf(compute_type(), v) - a = a + v + v_cast = arith.extf(compute_type(), v) + a = a + v_cast v = a + out_v = v if dtype_str in ("f16", "bf16"): - v = arith.trunc_f(elem_type(), v) + out_v = arith.trunc_f(elem_type(), v) y_idx = token_idx * c_model_dim + col y_idx_i32 = arith.index_cast(i32_type(), y_idx) - buffer_ops.buffer_store(v, y_rsrc, y_idx_i32) + buffer_ops.buffer_store(out_v, y_rsrc, y_idx_i32) with _if_else(_if_full): # Tail path: scalar load/store per lane. @@ -2646,9 +2648,10 @@ def moe_reduction_kernel( v = buffer_ops.buffer_load( x_rsrc, x_idx_i32, vec_width=1, dtype=elem_type() ) + v_cast = v if dtype_str in ("f16", "bf16"): - v = arith.extf(compute_type(), v) - a = a + v + v_cast = arith.extf(compute_type(), v) + a = a + v_cast out = a if dtype_str in ("f16", "bf16"): diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 1e6d38ed..c740e5ed 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -359,6 +359,9 @@ def kernel_gemm( base_ptr_pong = allocator_pong.get_base() base_ptr_ping = allocator_ping.get_base() + lds_a_pong = None + lds_a_ping = None + lds_out = None if lds_stage == 2: lds_a_pong = SmemPtr( base_ptr_pong, lds_pong_offset, _elem_type(), shape=(tile_m * tile_k,) diff --git a/tests/kernels/test_quant.py b/tests/kernels/test_quant.py index 6f8d1adb..c3f4750e 100644 --- a/tests/kernels/test_quant.py +++ b/tests/kernels/test_quant.py @@ -131,24 +131,24 @@ def block_reduce_max(val): if arith.cmpi(arith.CmpIPredicate.eq, lane, Int32(0)): wave_idx = arith.index_cast(T.index, wave) - s_red.store(w, [wave_idx]) + SmemPtr.store(s_red, w, [wave_idx]) gpu.barrier() if arith.cmpi(arith.CmpIPredicate.eq, wave, Int32(0)): in_range = lane < RED_SLOTS lane_safe = arith.select(in_range, lane, Int32(0)) lane_safe_idx = arith.index_cast(T.index, lane_safe) - v = s_red.load([lane_safe_idx]) + v = SmemPtr.load(s_red, [lane_safe_idx]) ww = arith.select(in_range, v, c_zero_f) ww = wave_reduce_max(ww) if arith.cmpi(arith.CmpIPredicate.eq, lane, Int32(0)): c0_idx = arith.constant(0, index=True) - s_red.store(ww, [c0_idx]) + SmemPtr.store(s_red, ww, [c0_idx]) gpu.barrier() c0_idx = arith.constant(0, index=True) - return s_red.load([c0_idx]) + return SmemPtr.load(s_red, [c0_idx]) # ── Buffer resources ───────────────────────────────────────────── in_rsrc = buffer_ops.create_buffer_resource(Input, max_size=True) From 652fcc302390bb81fa1047ef7d5f6790fd3ef382 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Tue, 7 Apr 2026 16:08:44 +0000 Subject: [PATCH 07/31] [FLYDSL]: Supplement the static condition of ast.Compare --- kernels/blockscale_preshuffle_gemm.py | 54 ++++++++------- kernels/moe_blockscale_2stage.py | 64 +++++++++--------- kernels/moe_gemm_2stage.py | 91 +++++++++++++++----------- kernels/preshuffle_gemm.py | 80 ++++++++++++++-------- python/flydsl/compiler/ast_rewriter.py | 56 ++++++++++++++-- 5 files changed, 217 insertions(+), 128 deletions(-) diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index bcadf24b..c6223fc0 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -321,33 +321,33 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): c_chunk_a = fx.Index(chunk_i32_a) tx_i32_base = tx * c_chunk_a - def load_a(idx_i32): - if a_load_bytes == 16: + def load_a(idx_i32, a_load_bytes_v): + if a_load_bytes_v == 16: return buffer_copy_gmem16_dwordx4( buffer_ops, vector, elem_type=T.f8, idx_i32=idx_i32, rsrc=a_rsrc, vec_elems=16, elem_bytes=elem_bytes, ) - if a_load_bytes == 8: + if a_load_bytes_v == 8: return buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=1, dtype=T.i32) - def a_tile_chunk_coord_i32(i: int): + def a_tile_chunk_coord_i32(i: int, tx_i32_base_v, chunk_i32_a_v): return tile_chunk_coord_i32( - arith, tx_i32_base=tx_i32_base, i=i, + arith, tx_i32_base=tx_i32_base_v, i=i, total_threads=total_threads, layout_tile_div4=layout_a_tile_div4, - chunk_i32=chunk_i32_a, + chunk_i32=chunk_i32_a_v, ) - def load_a_tile(base_k_div4): + def load_a_tile(base_k_div4, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): parts = [] for i in range_constexpr(num_a_loads): - row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i) + row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i, tx_i32_base_v, chunk_i32_a_v) row_a_global = bx_m + row_a_local idx_i32 = row_a_global * _k_div4_factor + (base_k_div4 + col_a_local_i32) - a_vec = load_a(idx_i32) - if a_load_bytes == 16: + a_vec = load_a(idx_i32, a_load_bytes_v) + if a_load_bytes_v == 16: parts.append(vector.bitcast(T.i32x4, a_vec)) else: parts.append(a_vec) @@ -355,10 +355,10 @@ def load_a_tile(base_k_div4): c4_bytes = fx.Index(4) # bytes per dword (always 4, used for LDS byte addressing) - def store_a_tile_to_lds(vec_a_parts, lds_buffer): + def store_a_tile_to_lds(vec_a_parts, lds_buffer, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): for i in range_constexpr(num_a_loads): - row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i) - if a_load_bytes == 16: + row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i, tx_i32_base_v, chunk_i32_a_v) + if a_load_bytes_v == 16: lds_store_16b_xor16( arith, vector, lds_memref=lds_buffer, vec16_ty=T.f8x16, @@ -368,7 +368,7 @@ def store_a_tile_to_lds(vec_a_parts, lds_buffer): lds_base=fx.Index(0), vec_part_i32x4=vec_a_parts[i], elem_bytes=elem_bytes, ) - elif a_load_bytes == 8: + elif a_load_bytes_v == 8: lds_store_8b_xor16( arith, vector, lds_memref=lds_buffer, vec8_ty=T.f8x8, @@ -431,9 +431,9 @@ def prefetch_a_to_lds(base_k, lds_buffer): base_k_div4 = base_k // 4 dma_a_tile_to_lds(base_k_div4, lds_buffer) - def prefetch_a_tile(base_k): + def prefetch_a_tile(base_k, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): base_k_div4 = base_k // 4 - return load_a_tile(base_k_div4) + return load_a_tile(base_k_div4, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v) def prefetch_b_tile(base_k): return load_b_tile(base_k) @@ -714,17 +714,23 @@ def hot_loop_scheduler(): def prefetch_a0_pack(lds_buffer): return lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_buffer) - def _load_a_to_lds(base_k, lds_buffer): + def _load_a_to_lds(base_k, lds_buffer, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): if use_async_copy: prefetch_a_to_lds(base_k, lds_buffer) else: - store_a_tile_to_lds(prefetch_a_tile(base_k), lds_buffer) + store_a_tile_to_lds( + prefetch_a_tile(base_k, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v), + lds_buffer, + a_load_bytes_v, + tx_i32_base_v, + chunk_i32_a_v, + ) # ── Main pipeline: prologue ─────────────────────────────────────── k0 = fx.Index(0) b_tile_pong = prefetch_b_tile(k0) scales_pong = load_scales_for_tile(k0) - _load_a_to_lds(k0, lds_a_pong) + _load_a_to_lds(k0, lds_a_pong, a_load_bytes, tx_i32_base, chunk_i32_a) gpu.barrier() global_accs = [acc_init] * (num_acc_n * m_repeat) @@ -737,7 +743,7 @@ def _load_a_to_lds(base_k, lds_buffer): for k_iv in range_constexpr(0, K - tile_k, tile_k * 2): _k = fx.Index(k_iv) next_k1 = _k + tile_k - _load_a_to_lds(next_k1, lds_a_ping) + _load_a_to_lds(next_k1, lds_a_ping, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_ping = prefetch_b_tile(next_k1) scales_ping = load_scales_for_tile(next_k1) @@ -754,7 +760,7 @@ def _load_a_to_lds(base_k, lds_buffer): a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) next_k2 = _k + tile_k * 2 - _load_a_to_lds(next_k2, lds_a_pong) + _load_a_to_lds(next_k2, lds_a_pong, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_pong = prefetch_b_tile(next_k2) scales_pong = load_scales_for_tile(next_k2) @@ -779,7 +785,7 @@ def _load_a_to_lds(base_k, lds_buffer): for k_iv in range_constexpr(0, K - tile_k * 3, tile_k * 2): _k = fx.Index(k_iv) next_k1 = _k + tile_k - _load_a_to_lds(next_k1, lds_a_ping) + _load_a_to_lds(next_k1, lds_a_ping, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_ping = prefetch_b_tile(next_k1) scales_ping = load_scales_for_tile(next_k1) @@ -796,7 +802,7 @@ def _load_a_to_lds(base_k, lds_buffer): a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) next_k2 = _k + tile_k * 2 - _load_a_to_lds(next_k2, lds_a_pong) + _load_a_to_lds(next_k2, lds_a_pong, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_pong = prefetch_b_tile(next_k2) scales_pong = load_scales_for_tile(next_k2) @@ -815,7 +821,7 @@ def _load_a_to_lds(base_k, lds_buffer): last_k = arith.index(K - tile_k) second_last_k = arith.index(K - tile_k * 2) - _load_a_to_lds(last_k, lds_a_ping) + _load_a_to_lds(last_k, lds_a_ping, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_ping = prefetch_b_tile(last_k) scales_ping = load_scales_for_tile(last_k) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index 67c670f5..98a5567f 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -444,12 +444,12 @@ def x_tile_chunk_coord_i32(i: int): vec2_i32 = T.vec(2, T.i32) vec4_x = T.vec(4, x_elem) - def load_x(idx_i32): + def load_x(idx_i32, x_load_bytes_v): """Load `x_load_bytes` bytes from X (gmem) into regs. For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ - if x_load_bytes == 16: + if x_load_bytes_v == 16: idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -460,20 +460,20 @@ def load_x(idx_i32): vec_elems=vec16_elems, elem_bytes=elem_bytes, ) - if x_load_bytes == 8: + if x_load_bytes_v == 8: return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) - def load_x_tile(base_k): + def load_x_tile(base_k, x_load_bytes_v): """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" base_k_div4 = (base_k * arith.index(int(elem_bytes))) // fx.Index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - if x_load_bytes == 16: + x_vec = load_x(idx_i32, x_load_bytes_v) + if x_load_bytes_v == 16: parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes == 8: + elif x_load_bytes_v == 8: parts.append(x_vec) else: parts.append(x_vec) @@ -583,11 +583,11 @@ def load_b_tile(base_k, blk_list, intra_list): acc_up = [arith.constant_vector(0.0, T.f32x4)] * (num_acc_n * m_repeat) # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if x_load_bytes_v == 16: lds_store_16b_xor16( arith, vector, @@ -602,7 +602,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif x_load_bytes_v == 8: lds_store_8b_xor16( arith, vector, @@ -951,7 +951,7 @@ def do_one_stage(acc_gate_in, acc_up_in, k_compute, k_next, """One pipeline stage: load next tile data, compute current tile, store X to LDS.""" scale_fn = load_scales_s1 pre_scales = scale_fn(k_compute) - x_regs_next = load_x_tile(k_next) + x_regs_next = load_x_tile(k_next, x_load_bytes) b_gate_cur = load_b_tile(k_compute, n_blk_gate, n_intra_gate) b_up_cur = load_b_tile(k_compute, n_blk_up, n_intra_up) @@ -959,15 +959,15 @@ def do_one_stage(acc_gate_in, acc_up_in, k_compute, k_next, acc_gate_in, acc_up_in, b_gate_cur, b_up_cur, lds_compute, pre_scales) - store_x_tile_to_lds(x_regs_next, lds_store) + store_x_tile_to_lds(x_regs_next, lds_store, x_load_bytes) hot_loop_scheduler() gpu.barrier() return ag, au # Prologue: prefetch tile0 X into LDS, sync. k0 = fx.Index(0) - x_regs0 = load_x_tile(k0) - store_x_tile_to_lds(x_regs0, lds_base_cur) + x_regs0 = load_x_tile(k0, x_load_bytes) + store_x_tile_to_lds(x_regs0, lds_base_cur, x_load_bytes) gpu.barrier() lds_base_pong = lds_base_cur @@ -1612,8 +1612,8 @@ def x_tile_chunk_coord_i32(i: int): vec2_i32 = T.vec(2, T.i32) vec4_x = T.vec(4, x_elem) - def load_x(idx_i32): - if x_load_bytes == 16: + def load_x(idx_i32, x_load_bytes_v): + if x_load_bytes_v == 16: idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -1624,7 +1624,7 @@ def load_x(idx_i32): vec_elems=vec16_elems, elem_bytes=elem_bytes, ) - if x_load_bytes == 8: + if x_load_bytes_v == 8: return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) @@ -1653,15 +1653,15 @@ def load_x(idx_i32): # Base row offset in dword units: row_ts_idx * (k_in/4) x_row_base_div4.append(row_ts_idx * c_k_div4) - def load_x_tile(base_k): + def load_x_tile(base_k, x_load_bytes_v): base_k_div4 = (base_k * arith.index(int(elem_bytes))) // fx.Index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - if x_load_bytes == 16: + x_vec = load_x(idx_i32, x_load_bytes_v) + if x_load_bytes_v == 16: parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes == 8: + elif x_load_bytes_v == 8: parts.append(x_vec) else: parts.append(x_vec) @@ -1754,11 +1754,11 @@ def load_b_tile(base_k): return b_tile # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if x_load_bytes_v == 16: lds_store_16b_xor16( arith, vector, @@ -1773,7 +1773,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif x_load_bytes_v == 8: lds_store_8b_xor16( arith, vector, @@ -2117,9 +2117,9 @@ def hot_loop_scheduler(): # Prologue. k0 = fx.Index(0) - x_regs0 = load_x_tile(k0) + x_regs0 = load_x_tile(k0, x_load_bytes) b_cur = load_b_tile(k0) - store_x_tile_to_lds(x_regs0, lds_base_cur) + store_x_tile_to_lds(x_regs0, lds_base_cur, x_load_bytes) gpu.barrier() acc = [arith.constant_vector(0.0, T.f32x4)] * (num_acc_n * m_repeat) @@ -2149,12 +2149,12 @@ def hot_loop_scheduler(): # Issue scale loads FIRST so their latency hides behind heavy tile VMEM. pre_scales_pong = load_scales_s2(k_iv) next_k1 = k_iv + tile_k - x_regs_ping = load_x_tile(next_k1) + x_regs_ping = load_x_tile(next_k1, x_load_bytes) b_ping = load_b_tile(next_k1) acc = compute_tile_bs_s2(acc, b_cur, lds_base_pong, pre_scales_pong, a0_prefetch=a0_prefetch_pong) a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) + store_x_tile_to_lds(x_regs_ping, lds_base_ping, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -2164,12 +2164,12 @@ def hot_loop_scheduler(): # Issue scale loads FIRST so their latency hides behind heavy tile VMEM. pre_scales_ping = load_scales_s2(next_k1) next_k2 = k_iv + c2_tile_k - x_regs_pong = load_x_tile(next_k2) + x_regs_pong = load_x_tile(next_k2, x_load_bytes) b_next = load_b_tile(next_k2) acc = compute_tile_bs_s2(acc, b_ping, lds_base_ping, pre_scales_ping, a0_prefetch=a0_prefetch_ping) a0_prefetch_ping = None - store_x_tile_to_lds(x_regs_pong, lds_base_pong) + store_x_tile_to_lds(x_regs_pong, lds_base_pong, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -2195,12 +2195,12 @@ def hot_loop_scheduler(): k_tail1 = k_in - tile_k # Issue scale loads FIRST so their latency hides behind heavy tile VMEM. pre_scales_tail0 = load_scales_s2(k_tail0) - x_regs_ping = load_x_tile(k_tail1) + x_regs_ping = load_x_tile(k_tail1, x_load_bytes) b_ping = load_b_tile(k_tail1) acc = compute_tile_bs_s2(acc, b_cur, lds_base_pong, pre_scales_tail0, a0_prefetch=a0_prefetch_pong) a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) + store_x_tile_to_lds(x_regs_ping, lds_base_ping, x_load_bytes) hot_loop_scheduler() gpu.barrier() diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 390a33b6..106e2b2a 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -481,13 +481,13 @@ def x_tile_chunk_coord_i32(i: int): vec4_x = T.vec(4, x_elem) - def load_x(idx_i32): - """Load `x_load_bytes` bytes from X (gmem) into regs. + def load_x(idx_i32, x_load_bytes_v): + """Load `x_load_bytes_v` bytes from X (gmem) into regs. For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. idx_i32 is in dword units; convert to element index for _buffer_load_vec. """ - if x_load_bytes == 16: + if x_load_bytes_v == 16: idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -499,20 +499,21 @@ def load_x(idx_i32): elem_bytes=elem_bytes, ) # For 8B/4B, load raw i32 dwords directly. - if x_load_bytes == 8: + if x_load_bytes_v == 8: return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) - def load_x_tile(base_k): + def load_x_tile(base_k, x_load_bytes_v): """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" base_k_div4 = (base_k * arith.index(int(elem_bytes))) // fx.Index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - if x_load_bytes == 16: + x_vec = load_x(idx_i32, x_load_bytes_v) + print(f"x_load_bytes_v: {x_load_bytes_v}") + if x_load_bytes_v == 16: parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes == 8: + elif x_load_bytes_v == 8: parts.append(x_vec) else: parts.append(x_vec) @@ -651,11 +652,11 @@ def load_b_tile(base_k, blk_list, intra_list): acc_up = [acc_init] * (num_acc_n * m_repeat) # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if x_load_bytes_v == 16: lds_store_16b_xor16( arith, vector, @@ -670,7 +671,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif x_load_bytes_v == 8: lds_store_8b_xor16( arith, vector, @@ -797,9 +798,12 @@ def mfma_k64(acc_in, a0, a1, b0, b1): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) for ni in range_constexpr(num_acc_n): acc_idx = mi * num_acc_n + ni @@ -857,10 +861,10 @@ def hot_loop_scheduler(): # Prologue: prefetch tile0, store to LDS(cur), sync. k0 = fx.Index(0) - x_regs0 = load_x_tile(k0) + x_regs0 = load_x_tile(k0, x_load_bytes) b_gate_cur = load_b_tile(k0, n_blk_gate, n_intra_gate) b_up_cur = load_b_tile(k0, n_blk_up, n_intra_up) - store_x_tile_to_lds(x_regs0, lds_base_cur) + store_x_tile_to_lds(x_regs0, lds_base_cur, x_load_bytes) gpu.barrier() # Loop-carried ping/pong state. @@ -882,7 +886,7 @@ def hot_loop_scheduler(): k_iv = arith.index(pair_i * (tile_k * 2)) # ---- stage 0: prefetch+store ping, compute pong ---- next_k1 = k_iv + tile_k - x_regs_ping = load_x_tile(next_k1) + x_regs_ping = load_x_tile(next_k1, x_load_bytes) b_gate_ping = load_b_tile(next_k1, n_blk_gate, n_intra_gate) b_up_ping = load_b_tile(next_k1, n_blk_up, n_intra_up) @@ -895,7 +899,7 @@ def hot_loop_scheduler(): a0_prefetch=a0_prefetch_pong, ) a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) + store_x_tile_to_lds(x_regs_ping, lds_base_ping, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -904,7 +908,7 @@ def hot_loop_scheduler(): # ---- stage 1: prefetch+store pong, compute ping ---- next_k2 = k_iv + c2_tile_k - x_regs_pong = load_x_tile(next_k2) + x_regs_pong = load_x_tile(next_k2, x_load_bytes) b_gate_next = load_b_tile(next_k2, n_blk_gate, n_intra_gate) b_up_next = load_b_tile(next_k2, n_blk_up, n_intra_up) @@ -917,7 +921,7 @@ def hot_loop_scheduler(): a0_prefetch=a0_prefetch_ping, ) a0_prefetch_ping = None - store_x_tile_to_lds(x_regs_pong, lds_base_pong) + store_x_tile_to_lds(x_regs_pong, lds_base_pong, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -936,7 +940,7 @@ def hot_loop_scheduler(): b_up_cur = load_b_tile(k_tail0, n_blk_up, n_intra_up) a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) k_tail1 = k_in - tile_k - x_regs_ping = load_x_tile(k_tail1) + x_regs_ping = load_x_tile(k_tail1, x_load_bytes) b_gate_ping = load_b_tile(k_tail1, n_blk_gate, n_intra_gate) b_up_ping = load_b_tile(k_tail1, n_blk_up, n_intra_up) @@ -949,7 +953,7 @@ def hot_loop_scheduler(): a0_prefetch=a0_prefetch_pong, ) a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) + store_x_tile_to_lds(x_regs_ping, lds_base_ping, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -1052,6 +1056,7 @@ def write_row_to_lds( ) # Sorted weight aligned with `row` (matches aiter moe_sorting output). + tw = fx.Float32(1.0) if doweight_stage1: tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) @@ -1168,6 +1173,7 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): idx0 = (t2 * topk_i32_v + s2) * inter_i32_local # Sorted weight aligned with `row` (matches aiter moe_sorting output). + tw = fx.Float32(1.0) if doweight_stage1: tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) @@ -1664,8 +1670,8 @@ def x_tile_chunk_coord_i32(i: int): vec4_x = T.vec(4, x_elem) - def load_x(idx_i32): - if x_load_bytes == 16: + def load_x(idx_i32, x_load_bytes_v): + if x_load_bytes_v == 16: idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -1676,7 +1682,7 @@ def load_x(idx_i32): vec_elems=vec16_elems, elem_bytes=elem_bytes, ) - if x_load_bytes == 8: + if x_load_bytes_v == 8: return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) @@ -1705,15 +1711,15 @@ def load_x(idx_i32): # Base row offset in dword units: row_ts_idx * (k_in/4) x_row_base_div4.append(row_ts_idx * c_k_div4) - def load_x_tile(base_k): + def load_x_tile(base_k, x_load_bytes_v): base_k_div4 = (base_k * arith.index(int(elem_bytes))) // fx.Index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - if x_load_bytes == 16: + x_vec = load_x(idx_i32, x_load_bytes_v) + if x_load_bytes_v == 16: parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes == 8: + elif x_load_bytes_v == 8: parts.append(vector.bitcast(T.vec(2, T.i32), x_vec)) else: parts.append(vector.bitcast(T.vec(1, T.i32), x_vec)) @@ -1833,11 +1839,11 @@ def load_b_tile(base_k): return b_tile # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if x_load_bytes_v == 16: lds_store_16b_xor16( arith, vector, @@ -1852,7 +1858,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif x_load_bytes_v == 8: lds_store_8b_xor16( arith, vector, @@ -1976,9 +1982,12 @@ def mfma_k64(acc0, a0, a1, b0, b1): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) for ni in range_constexpr(num_acc_n): acc_idx = mi * num_acc_n + ni @@ -2078,9 +2087,9 @@ def hot_loop_scheduler(): rocdl.sched_barrier(0) # Prologue. k0 = fx.Index(0) - x_regs0 = load_x_tile(k0) + x_regs0 = load_x_tile(k0, x_load_bytes) b_cur = load_b_tile(k0) - store_x_tile_to_lds(x_regs0, lds_base_cur) + store_x_tile_to_lds(x_regs0, lds_base_cur, x_load_bytes) gpu.barrier() acc = [acc_init] * (num_acc_n * m_repeat) @@ -2108,12 +2117,12 @@ def hot_loop_scheduler(): for pair_i in range_constexpr(pair_iters): k_iv = arith.index(pair_i * (tile_k * 2)) next_k1 = k_iv + tile_k - x_regs_ping = load_x_tile(next_k1) + x_regs_ping = load_x_tile(next_k1, x_load_bytes) b_ping = load_b_tile(next_k1) acc, _ = compute_tile(acc, b_cur, lds_base_pong, a0_prefetch=a0_prefetch_pong) a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) + store_x_tile_to_lds(x_regs_ping, lds_base_ping, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -2121,12 +2130,12 @@ def hot_loop_scheduler(): a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) next_k2 = k_iv + c2_tile_k - x_regs_pong = load_x_tile(next_k2) + x_regs_pong = load_x_tile(next_k2, x_load_bytes) b_next = load_b_tile(next_k2) acc, _ = compute_tile(acc, b_ping, lds_base_ping, a0_prefetch=a0_prefetch_ping) a0_prefetch_ping = None - store_x_tile_to_lds(x_regs_pong, lds_base_pong) + store_x_tile_to_lds(x_regs_pong, lds_base_pong, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -2147,12 +2156,12 @@ def hot_loop_scheduler(): else: # Tail: 2 remaining tiles. k_tail1 = k_in - tile_k - x_regs_ping = load_x_tile(k_tail1) + x_regs_ping = load_x_tile(k_tail1, x_load_bytes) b_ping = load_b_tile(k_tail1) acc, _ = compute_tile(acc, b_cur, lds_base_pong, a0_prefetch=a0_prefetch_pong) a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) + store_x_tile_to_lds(x_regs_ping, lds_base_ping, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -2190,10 +2199,10 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): sw_pf, tw_pf = epilogue_pf # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). + sw_vals = [] if sw_pf is not None: sw_vals = sw_pf else: - sw_vals = [] for ni in range_constexpr(num_acc_n): col_g = col_g_list[ni] row_w_idx = expert_off + col_g @@ -2239,6 +2248,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): ) ) + tw = fx.Float32(1.0) if doweight_stage2: tw_idx = (mi * 4) + ii if tw_pf is not None: @@ -2319,6 +2329,7 @@ def write_row_to_lds( ) ) + tw = fx.Float32(1.0) if doweight_stage2: tw_idx = (mi * 4) + ii if tw_pf is not None: diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index c740e5ed..611215a0 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -288,7 +288,10 @@ def _out_elem(): lds_tile_bytes = int(tile_m) * int(lds_stride_bytes) // a_elem_vec_pack lds_out_bytes = 2 * int(tile_m) * int(tile_n) if use_cshuffle_epilog else 0 - if int(lds_stage) == 2: + is_2stage = int(lds_stage) == 2 + is_num_tiles_odd = ((int(K) // int(tile_k)) % 2) == 1 + fp4_tilek128_mode = bool(is_fp4 and int(tile_k) == 128) + if is_2stage: assert lds_out_bytes % 2 == 0, "lds_out_bytes should be multiple of 2" buffer_size_bytes = max(lds_tile_bytes, lds_out_bytes // lds_stage) buffer_size_elems = buffer_size_bytes if elem_bytes == 1 else (buffer_size_bytes // 2) @@ -298,12 +301,16 @@ def _out_elem(): lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) allocator_ping.ptr = lds_ping_offset + buffer_size_elems * elem_bytes + lds_default_offset = lds_pong_offset + lds_default_elems = buffer_size_elems else: lds_total_bytes = max(lds_tile_bytes, lds_out_bytes) lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) lds_alloc_offset = allocator_pong._align(allocator_pong.ptr, 16) allocator_pong.ptr = lds_alloc_offset + lds_total_elems * elem_bytes + lds_default_offset = lds_alloc_offset + lds_default_elems = lds_total_elems # ── Kernel function ──────────────────────────────────────────────────── @flyc.kernel @@ -359,10 +366,18 @@ def kernel_gemm( base_ptr_pong = allocator_pong.get_base() base_ptr_ping = allocator_ping.get_base() - lds_a_pong = None - lds_a_ping = None - lds_out = None - if lds_stage == 2: + # Initialize with valid memrefs to avoid None flowing into rewritten branches. + lds_a_ptr = SmemPtr( + base_ptr_pong, lds_default_offset, _elem_type(), shape=(lds_default_elems,) + ) + lds_a_pong = lds_a_ptr.get() + lds_a_ping = lds_a_pong + lds_out = ( + SmemPtr(base_ptr_pong, lds_default_offset, _out_elem(), shape=(tile_m * tile_n,)).get() + if use_cshuffle_epilog + else None + ) + if is_2stage: lds_a_pong = SmemPtr( base_ptr_pong, lds_pong_offset, _elem_type(), shape=(tile_m * tile_k,) ).get() @@ -377,16 +392,7 @@ def kernel_gemm( else: lds_out = None else: - lds_a_ptr = SmemPtr( - base_ptr_pong, lds_alloc_offset, _elem_type(), shape=(lds_total_elems,) - ) - lds_a_pong = lds_a_ptr.get() - lds_a_ping = lds_a_pong - lds_out = ( - SmemPtr(base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(tile_m * tile_n,)).get() - if use_cshuffle_epilog - else None - ) + pass # ---- Buffer resources (runtime byte sizes for OOB protection) ---- _a_nrec = arith.index_cast(T.i64, c_m * (K * elem_bytes // a_elem_vec_pack)) @@ -637,6 +643,12 @@ def dma_a_tile_to_lds(base_k_div4, lds_buffer): ), ) + lds_base = memref_dialect.extract_aligned_pointer_as_index(lds_buffer) + lds_ptr_base = buffer_ops.create_llvm_ptr( + arith.index_cast(T.i64, lds_base), address_space=3 + ) + lds_ptr = buffer_ops.get_element_ptr(lds_ptr_base, wave_offset) + for i in range_constexpr(num_a_async_loads): row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32_async(i) col_a_local_sw = swizzle_xor16(row_a_local, col_a_local_i32 * c4, k_blocks16) @@ -645,8 +657,6 @@ def dma_a_tile_to_lds(base_k_div4, lds_buffer): global_offset = arith.index_cast(T.i32, global_byte_idx) if i == 0: - lds_base = memref_dialect.extract_aligned_pointer_as_index(lds_buffer) - lds_ptr_base = buffer_ops.create_llvm_ptr(arith.index_cast(T.i64, lds_base), address_space=3) lds_ptr = buffer_ops.get_element_ptr(lds_ptr_base, wave_offset) else: lds_ptr = buffer_ops.get_element_ptr( @@ -749,9 +759,17 @@ def load_fp4_scales(base_k_scale_idx): def load_fp4_scale_chunk(base_k): return load_fp4_scales(base_k // fx.Index(_fp4_scale_chunk_k)) + def make_unit_scales(): + one_f32 = arith.constant(1.0, type=T.f32) + one_vec4 = vector.from_elements(T.f32x4, [one_f32, one_f32, one_f32, one_f32]) + return { + "s_b_vals": [one_f32 for _ in range_constexpr(num_acc_n)], + "s_a_vecs": [one_vec4 for _ in range_constexpr(m_repeat)], + } + # ── Compute tile (MFMA) ─────────────────────────────────────────── def compute_tile(accs_in, b_tile_in, lds_buffer, *, is_last_tile=False, a0_prefetch=None, fp4_scales=None, fp4_scale_half=0): - scales_pf = {} + scales_pf = make_unit_scales() if is_last_tile and (not is_f16_or_bf16): s_b_vals = [] for ni in range_constexpr(num_acc_n): @@ -820,6 +838,8 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for imxdl in range_constexpr(_fp4_pack_M): mi_idx = mi_p * _fp4_pack_M + imxdl curr_row_a_lds = row_a_lds + (mi_idx * 16) + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): a0, a1 = a0_prefetch else: @@ -851,6 +871,8 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku0 == 0) and (mi == 0): a0, a1 = a0_prefetch else: @@ -894,6 +916,8 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): col_base = col_offset_base_bytes + ki64 for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch else: @@ -910,12 +934,8 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): vec1_out = T.vec(1, _out_elem()) def store_output(final_accs, scales): - if is_f16_or_bf16 or is_fp4: - s_b_vals = None - s_a_vecs = None - else: - s_b_vals = scales["s_b_vals"] - s_a_vecs = scales["s_a_vecs"] + s_b_vals = scales["s_b_vals"] + s_a_vecs = scales["s_a_vecs"] if use_cshuffle_epilog: if lds_out is None: @@ -934,6 +954,7 @@ def write_row_to_lds(*, mi, ii, row_in_tile, row, row_base_lds, val = vector.extract(acc, static_position=[ii], dynamic_position=[]) if is_int8: val = arith.sitofp(T.f32, val) + val_s = val if is_f16_or_bf16 or is_fp4: val_s = val elif _needs_per_token_scale: @@ -982,6 +1003,7 @@ def body_row(*, mi, ii, row_in_tile, row): val = vector.extract(acc, static_position=[ii], dynamic_position=[]) if is_int8: val = arith.sitofp(T.f32, val) + val_s = val if is_f16_or_bf16 or is_fp4: val_s = val elif _needs_per_token_scale: @@ -1243,7 +1265,7 @@ def _build_pingpong_body(k_iv, inner_state): return _pack_state(accs_in, _flatten_b_tile(b_tile_pong_new), a0_prefetch_pong_new, _sc_pong) - if lds_stage == 2: + if is_2stage: def prefetch_a0_pack(lds_buffer): return lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_buffer) @@ -1259,8 +1281,10 @@ def prefetch_a0_pack(lds_buffer): fp4_scales0 = load_fp4_scale_chunk(fx.Index(0)) if is_fp4 else None num_tiles = K // tile_k - if _fp4_tilek128: - if (num_tiles % 2) == 1: + final_accs = accs + scales = make_unit_scales() + if fp4_tilek128_mode: + if is_num_tiles_odd: c_k_main = K - tile_k init_state = _pack_state(accs, _flatten_b_tile(b_tile0), a0_prefetch_pong, fp4_scales0) @@ -1304,7 +1328,7 @@ def prefetch_a0_pack(lds_buffer): is_last_tile=not is_fp4, a0_prefetch=a0_prefetch_ping, fp4_scales=fp4_scales_ep, fp4_scale_half=1, ) - elif (num_tiles % 2) == 1: + elif is_num_tiles_odd: c_k_main = K - tile_k init_state = _pack_state(accs, _flatten_b_tile(b_tile0), a0_prefetch_pong, fp4_scales0) diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index 87809928..d88c97e8 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -603,11 +603,59 @@ def _is_literal_expr(node): ) return False + def _try_static_value(node): + if not _is_literal_expr(node): + return False, None + if isinstance(node, ast.Constant): + return True, node.value + try: + return True, ast.literal_eval(node) + except Exception: + return False, None + + def _eval_static_compare_pair(lhs, op, rhs): + try: + if isinstance(op, ast.Eq): + return lhs == rhs + if isinstance(op, ast.NotEq): + return lhs != rhs + if isinstance(op, ast.Lt): + return lhs < rhs + if isinstance(op, ast.LtE): + return lhs <= rhs + if isinstance(op, ast.Gt): + return lhs > rhs + if isinstance(op, ast.GtE): + return lhs >= rhs + if isinstance(op, ast.Is): + return lhs is rhs + if isinstance(op, ast.IsNot): + return lhs is not rhs + if isinstance(op, ast.In): + return lhs in rhs + if isinstance(op, ast.NotIn): + return lhs not in rhs + except Exception: + return None + return None + def _visit(node): if _is_literal_expr(node): return False if isinstance(node, ast.Compare): - return True + compare_parts = [node.left, *node.comparators] + # Early static-false short-circuit for chain compare pairs like: + # left < c0 < c1 ... where any known static pair is False. + for i, op in enumerate(node.ops): + lhs_node = compare_parts[i] + rhs_node = compare_parts[i + 1] + lhs_ok, lhs_val = _try_static_value(lhs_node) + rhs_ok, rhs_val = _try_static_value(rhs_node) + if lhs_ok and rhs_ok: + pair_result = _eval_static_compare_pair(lhs_val, op, rhs_val) + if pair_result is False: + return False + return any(_visit(part) for part in compare_parts) if isinstance(node, ast.Call): func = node.func if not (isinstance(func, ast.Name) and func.id in ReplaceIfWithDispatch._REWRITE_HELPER_NAMES): @@ -690,9 +738,9 @@ def visit_Call(self, node): invoked_args = [name for name in invoked_args if name not in write_args] write_args = [name for name in write_args if in_active_symbols(name)] invoked_args = [name for name in invoked_args if in_active_symbols(name)] - print(f"write_args: {write_args}") - print(f"invoked_args: {invoked_args}") - print(f"active_symbols: {active_symbols}") + # print(f"write_args: {write_args}") + # print(f"invoked_args: {invoked_args}") + # print(f"active_symbols: {active_symbols}") return write_args + invoked_args @staticmethod From 7f1cfa11ef3732f2089f25473cc786a522c9e636 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Wed, 8 Apr 2026 13:56:43 +0000 Subject: [PATCH 08/31] [FLYDSL]: dsl_not_/dsl_and_/dsl_or_ is dynamic. --- python/flydsl/compiler/ast_rewriter.py | 17 +++++++----- tests/system/test_control_flow_compile.py | 32 +++++++++++++++++++++-- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index d88c97e8..bb37076e 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -578,9 +578,13 @@ def rewrite_globals(cls): "scf_if_collect_results": cls._collect_result_dict, } - _REWRITE_HELPER_NAMES = {"dsl_not_", "dsl_and_", "dsl_or_", - "scf_if_dispatch", "const_expr", "type", - "bool", "isinstance", "hasattr"} + _REWRITE_HELPER_NAMES = { + "const_expr", + "type", + "bool", + "isinstance", + "hasattr", + } @staticmethod def _could_be_dynamic(test_node): @@ -644,8 +648,6 @@ def _visit(node): return False if isinstance(node, ast.Compare): compare_parts = [node.left, *node.comparators] - # Early static-false short-circuit for chain compare pairs like: - # left < c0 < c1 ... where any known static pair is False. for i, op in enumerate(node.ops): lhs_node = compare_parts[i] rhs_node = compare_parts[i + 1] @@ -657,11 +659,12 @@ def _visit(node): return False return any(_visit(part) for part in compare_parts) if isinstance(node, ast.Call): + print(f"_visit Call node: {ast.unparse(node)}") func = node.func if not (isinstance(func, ast.Name) and func.id in ReplaceIfWithDispatch._REWRITE_HELPER_NAMES): return True - # Plain names can be static symbols (constexpr params, local bools, etc.), - # and unknown nodes keep recursing into children. + if isinstance(node, ast.Name): + return True for child in ast.iter_child_nodes(node): if _visit(child): diff --git a/tests/system/test_control_flow_compile.py b/tests/system/test_control_flow_compile.py index db2bf084..ce240e24 100644 --- a/tests/system/test_control_flow_compile.py +++ b/tests/system/test_control_flow_compile.py @@ -5,9 +5,14 @@ import flydsl.compiler as flyc import flydsl.expr as fx +import pytest +import torch -def test_control_flow_kernel_snippet_compiles_without_error(): +def test_control_flow_kernel_snippet_compiles_without_error(monkeypatch): + if not torch.cuda.is_available(): + pytest.skip("CUDA device is required for control-flow compile coverage test") + @flyc.kernel def vecAbsKernel( A: fx.Tensor, @@ -21,4 +26,27 @@ def vecAbsKernel( if print_debug and bid == 0 and tid <= 2: fx.printf("[kernel] bid={}, tid={}", bid, tid) - assert vecAbsKernel is not None + @flyc.jit + def vecAbs( + A: fx.Tensor, + C, + n: fx.Int32, + const_n: fx.Constexpr[int], + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], + stream: fx.Stream = fx.Stream(None), + ): + tile_elems = block_dim * vec_width + grid_x = (n + tile_elems - 1) // tile_elems + vecAbsKernel(A, C, block_dim, vec_width).launch( + grid=(grid_x, 1, 1), block=(block_dim, 1, 1), stream=stream + ) + + monkeypatch.setenv("FLYDSL_COMPILE_ONLY", "1") + threads = 64 + vec = 4 + size = threads * vec + a = torch.randn(size, device="cuda", dtype=torch.float32) + c = torch.empty_like(a) + t_a = flyc.from_dlpack(a).mark_layout_dynamic(leading_dim=0, divisibility=vec) + vecAbs(t_a, c, size, size, threads, vec) From 7faf35f2f270472b155407ba7532a929fec3c0f4 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Wed, 8 Apr 2026 16:12:25 +0000 Subject: [PATCH 09/31] [FLYDSL]: Fixed the issue of type merging in Python's multi-type return values --- kernels/blockscale_preshuffle_gemm.py | 22 ++++++++++++++-------- python/flydsl/compiler/ast_rewriter.py | 1 - 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index c6223fc0..700a5e33 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -322,13 +322,13 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): tx_i32_base = tx * c_chunk_a def load_a(idx_i32, a_load_bytes_v): - if a_load_bytes_v == 16: + if const_expr(a_load_bytes_v == 16): return buffer_copy_gmem16_dwordx4( buffer_ops, vector, elem_type=T.f8, idx_i32=idx_i32, rsrc=a_rsrc, vec_elems=16, elem_bytes=elem_bytes, ) - if a_load_bytes_v == 8: + if const_expr(a_load_bytes_v == 8): return buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=1, dtype=T.i32) @@ -347,7 +347,7 @@ def load_a_tile(base_k_div4, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): row_a_global = bx_m + row_a_local idx_i32 = row_a_global * _k_div4_factor + (base_k_div4 + col_a_local_i32) a_vec = load_a(idx_i32, a_load_bytes_v) - if a_load_bytes_v == 16: + if const_expr(a_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, a_vec)) else: parts.append(a_vec) @@ -440,6 +440,11 @@ def prefetch_b_tile(base_k): # ── MFMA ────────────────────────────────────────────────────────── mfma_res_ty = T.f32x4 + + def _mfma_fn_placeholder(*args, **kwargs): + raise RuntimeError("mfma_fn placeholder should be overwritten before use") + + mfma_fn = _mfma_fn_placeholder if _is_gfx950: c0_i64 = arith.constant(0, type=T.i64) @@ -450,12 +455,12 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): else: mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8 - def mfma_step(acc_in, a, b): - return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) + def mfma_step(acc_in, a, b): + return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) - def mfma_k64_bytes(acc_in, a0, a1, b0, b1): - acc_mid = mfma_step(acc_in, a0, b0) - return mfma_step(acc_mid, a1, b1) + def mfma_k64_bytes(acc_in, a0, a1, b0, b1): + acc_mid = mfma_step(acc_in, a0, b0) + return mfma_step(acc_mid, a1, b1) # ── Blockscale compute tile ─────────────────────────────────────── from flydsl._mlir.dialects import math as math_dialect @@ -672,6 +677,7 @@ def body_row(*, mi, ii, row_in_tile, row): def hot_loop_scheduler(): mfma_group = num_acc_n + mfma_total = -1 if _is_gfx950: mfma_total = sb_per_tile * m_repeat * mfma_group else: diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index bb37076e..eb329940 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -659,7 +659,6 @@ def _visit(node): return False return any(_visit(part) for part in compare_parts) if isinstance(node, ast.Call): - print(f"_visit Call node: {ast.unparse(node)}") func = node.func if not (isinstance(func, ast.Name) and func.id in ReplaceIfWithDispatch._REWRITE_HELPER_NAMES): return True From c8fbb434e69c79398451c5421c4499fbd674b19e Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Thu, 9 Apr 2026 03:46:05 +0000 Subject: [PATCH 10/31] [FLYDSL]: Refactor the bitcast code --- kernels/blockscale_preshuffle_gemm.py | 4 +- kernels/mixed_moe_gemm_2stage.py | 10 ++--- kernels/moe_blockscale_2stage.py | 18 ++++---- kernels/moe_gemm_2stage.py | 63 +++++++++++++-------------- 4 files changed, 47 insertions(+), 48 deletions(-) diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index 700a5e33..cd3cef07 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -358,7 +358,7 @@ def load_a_tile(base_k_div4, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): def store_a_tile_to_lds(vec_a_parts, lds_buffer, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): for i in range_constexpr(num_a_loads): row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i, tx_i32_base_v, chunk_i32_a_v) - if a_load_bytes_v == 16: + if const_expr(a_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, lds_memref=lds_buffer, vec16_ty=T.f8x16, @@ -368,7 +368,7 @@ def store_a_tile_to_lds(vec_a_parts, lds_buffer, a_load_bytes_v, tx_i32_base_v, lds_base=fx.Index(0), vec_part_i32x4=vec_a_parts[i], elem_bytes=elem_bytes, ) - elif a_load_bytes_v == 8: + elif const_expr(a_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, lds_memref=lds_buffer, vec8_ty=T.f8x8, diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index eb23631a..53d28b3f 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -1971,7 +1971,7 @@ def load_x(idx_i32): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): idx_elem = ( idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) ) @@ -2034,9 +2034,9 @@ def load_x_tile(base_k): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32) - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): parts.append(vector.bitcast(vec4_i32, x_vec)) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes == 8): parts.append(vector.bitcast(vec2_i32, x_vec)) else: parts.append(vector.bitcast(vec1_i32, x_vec)) @@ -2185,7 +2185,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): lds_store_16b_xor16( arith, vector, @@ -2200,7 +2200,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes == 8): lds_store_8b_xor16( arith, vector, diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index 7fc739e3..eabe540d 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -449,7 +449,7 @@ def load_x(idx_i32, x_load_bytes_v): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -460,7 +460,7 @@ def load_x(idx_i32, x_load_bytes_v): vec_elems=vec16_elems, elem_bytes=elem_bytes, ) - if x_load_bytes_v == 8: + if const_expr(x_load_bytes_v == 8): return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) @@ -471,11 +471,11 @@ def load_x_tile(base_k, x_load_bytes_v): for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32, x_load_bytes_v) - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes_v == 8: + elif const_expr(x_load_bytes_v == 8): parts.append(x_vec) - else: + if const_expr(x_load_bytes_v == 4): parts.append(x_vec) return parts @@ -1617,7 +1617,7 @@ def x_tile_chunk_coord_i32(i: int): vec4_x = T.vec(4, x_elem) def load_x(idx_i32, x_load_bytes_v): - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -1628,7 +1628,7 @@ def load_x(idx_i32, x_load_bytes_v): vec_elems=vec16_elems, elem_bytes=elem_bytes, ) - if x_load_bytes_v == 8: + if const_expr(x_load_bytes_v == 8): return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) @@ -1663,9 +1663,9 @@ def load_x_tile(base_k, x_load_bytes_v): for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32, x_load_bytes_v) - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes_v == 8: + elif const_expr(x_load_bytes_v == 8): parts.append(x_vec) else: parts.append(x_vec) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 2f015e9d..bd9c6e6d 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -373,18 +373,16 @@ def silu(x): # scale_x: fp16/bf16 path ignores (implicit scale=1.0); int4_bf16 also uses 1.0. x_load_bytes = 16 - if is_f16_or_bf16: - sx_rsrc = None - else: + sx_rsrc = -1 + if not is_f16_or_bf16: sx_rows = tokens_in * (c_topk if x_is_token_slot else fx.Index(1)) sx_nbytes_idx = sx_rows * fx.Index(4) sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_idx ) # scale_w: fp16/bf16 (non-int4) path ignores; int4_bf16 needs dequant scale. - if not needs_scale_w: - sw_rsrc = None - else: + sw_rsrc = -1 + if needs_scale_w: sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False) @@ -487,7 +485,7 @@ def load_x(idx_i32, x_load_bytes_v): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. idx_i32 is in dword units; convert to element index for _buffer_load_vec. """ - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -499,9 +497,11 @@ def load_x(idx_i32, x_load_bytes_v): elem_bytes=elem_bytes, ) # For 8B/4B, load raw i32 dwords directly. - if x_load_bytes_v == 8: + if const_expr(x_load_bytes_v == 8): return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) - return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) + if const_expr(x_load_bytes_v == 4): + return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) + raise ValueError(f"Invalid x_load_bytes_v: {x_load_bytes_v}") def load_x_tile(base_k, x_load_bytes_v): """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" @@ -510,12 +510,11 @@ def load_x_tile(base_k, x_load_bytes_v): for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32, x_load_bytes_v) - print(f"x_load_bytes_v: {x_load_bytes_v}") - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes_v == 8: + if const_expr(x_load_bytes_v == 8): parts.append(x_vec) - else: + if const_expr(x_load_bytes_v == 4): parts.append(x_vec) return parts @@ -656,7 +655,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, @@ -671,7 +670,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes_v == 8: + elif const_expr(x_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, @@ -741,7 +740,7 @@ def compute_tile( # Optional: prefetch epilogue scales while we are about to run the last MFMA tile, # matching the preshuffle GEMM pattern of overlapping scale loads with MFMA. - epilogue_pf = None + epilogue_pf = [] if prefetch_epilogue: expert_off_pf = expert_off_idx sw_gate_pf = [] @@ -1144,6 +1143,7 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v) # Do NOT rely on buffer OOB semantics for scale loads; explicitly mask. + sx0 = fx.Float32(1.0) if x_is_token_slot: # slot-major: slot*tokens + token ts2 = s2 * tokens_i32_v + t2 @@ -1570,18 +1570,16 @@ def moe_gemm2( arg_out, max_size=False, num_records_bytes=out_nbytes_idx ) # scale_x: fp16/bf16 path ignores (implicit scale=1.0); int4_bf16 also uses 1.0. - if is_f16_or_bf16: - sx_rsrc = None - else: + sx_rsrc = -1 + if not is_f16_or_bf16: # scale_x (A2 scale): [tokens*topk] f32 -> bytes = tokens*topk*4 sx_nbytes_idx = (tokens_in * c_topk) * fx.Index(4) sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_idx ) # scale_w: fp16/bf16 (non-int4) path ignores; int4_bf16 needs dequant scale. - if not needs_scale_w: - sw_rsrc = None - else: + sw_rsrc = -1 + if needs_scale_w: # scale_w: [experts*model_dim] f32 (static shape in practice) sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) @@ -1675,7 +1673,7 @@ def x_tile_chunk_coord_i32(i: int): vec4_x = T.vec(4, x_elem) def load_x(idx_i32, x_load_bytes_v): - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -1686,7 +1684,7 @@ def load_x(idx_i32, x_load_bytes_v): vec_elems=vec16_elems, elem_bytes=elem_bytes, ) - if x_load_bytes_v == 8: + if const_expr(x_load_bytes_v == 8): return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) @@ -1721,11 +1719,11 @@ def load_x_tile(base_k, x_load_bytes_v): for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32, x_load_bytes_v) - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes_v == 8: + if const_expr(x_load_bytes_v == 8): parts.append(vector.bitcast(T.vec(2, T.i32), x_vec)) - else: + if const_expr(x_load_bytes_v == 4): parts.append(vector.bitcast(T.vec(1, T.i32), x_vec)) return parts @@ -1847,7 +1845,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, @@ -1862,7 +1860,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes_v == 8: + if const_expr(x_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, @@ -1876,7 +1874,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): lds_base=lds_base, vec_part_i32x2=vec_x_in_parts[i], ) - else: + if const_expr(x_load_bytes_v == 4): lds_store_4b_xor16( arith, vector, @@ -1907,6 +1905,8 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) return a0, a1 + epilogue_pf = [] + def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False, a0_prefetch=None): acc_list = list(acc_in) mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 @@ -1933,9 +1933,8 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False else buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=T.f32) ) # Also prefetch per-row routed/topk weights (sorted_weights) when enabled. - tw_pf = None + tw_pf = [] if doweight_stage2: - tw_pf = [] lane_div_16_mul4_pf = lane_div_16 * fx.Index(4) ii_idx_list_pf = [fx.Index(ii) for ii in range(4)] for mi in range_constexpr(m_repeat): From acc4b9899fab2c60017a36806ae9b4dbd343724c Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Thu, 9 Apr 2026 11:11:29 +0000 Subject: [PATCH 11/31] [FLYDSL]: moe_blockscale_2stage.py variable definition --- kernels/moe_blockscale_2stage.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index eabe540d..2fa5aab5 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -337,10 +337,10 @@ def silu(x): # fp16 path ignores scales completely (implicit scale=1.0). x_load_bytes = 16 - if is_f16: - sx_rsrc = None - sw_rsrc = None - else: + + sx_rsrc = -1 + sw_rsrc = -1 + if not is_f16: # scale_x: [nblk_k_w1, tokens] f32 transposed -> total = nblk_k_w1 * tokens sx_nbytes_idx = arith.index(nblk_k_w1) * tokens_in * fx.Index(4) sx_rsrc = buffer_ops.create_buffer_resource( @@ -1515,10 +1515,9 @@ def moe_blockscale_gemm2( arg_out, max_size=False, num_records_bytes=arith.index_cast(T.i64, out_nbytes_idx) ) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16: - sx_rsrc = None - sw_rsrc = None - else: + sx_rsrc = -1 + sw_rsrc = -1 + if not is_f16: # scale_x (A2 scale): [nblk_k_w2, tokens*topk] f32 transposed -> total = nblk_k_w2 * tokens * topk sx_nbytes_idx = arith.index(nblk_k_w2) * (tokens_in * c_topk) * fx.Index(4) sx_rsrc = buffer_ops.create_buffer_resource( @@ -1569,6 +1568,7 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. + x_load_bytes = 0 if is_f16: if bytes_per_thread_x % 16 != 0: raise ValueError( @@ -1877,6 +1877,7 @@ def load_scales_s2(k_base): _sw_shared_n_s2 = (n_per_wave <= 128) s_w_vals = [] + s_w = arith.constant(1.0, type=T.f32) for ni in range_constexpr(num_acc_n): if ni == 0 or not _sw_shared_n_s2: sw_idx = _pre_n_block_s2[ni] * c_nblk_k_w2 + kb @@ -2306,6 +2307,7 @@ def write_row_to_lds( lds_out, ): # Blockscale: dequant already done in compute_tile_bs_s2. + tw = arith.constant(1.0, type=T.f32) if doweight_stage2: tw = buffer_ops.buffer_load( sorted_w_rsrc, row, vec_width=1, dtype=T.f32 From 0e034a5987367e010bedc1ce11c24ccdb34bacb3 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Thu, 9 Apr 2026 16:56:10 +0000 Subject: [PATCH 12/31] [FLYDSL]: kernels/moe_gemm_2stage.py refactor --- kernels/moe_gemm_2stage.py | 8 ++-- kernels/preshuffle_gemm.py | 79 +++++++++++++------------------------- 2 files changed, 30 insertions(+), 57 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index bd9c6e6d..1c675384 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -2611,6 +2611,7 @@ def moe_reduction_kernel( x_div = fx.logical_divide(x_tiled[None, tile_i32], fx.make_layout(VEC_WIDTH, 1)) x_thread = x_div[None, tid_i32] + mv_ok = True if use_mask: m_idx_i32 = fx.Int32(token_idx * c_topk + fx.Index(k)) mv = buffer_ops.buffer_load(mask_rsrc, m_idx_i32, vec_width=1, dtype=i8_type()) @@ -2628,10 +2629,7 @@ def moe_reduction_kernel( zero_e = vector.broadcast(vec_type_e, arith.constant(0.0, type=elem_type())) vec_e = mv_ok.select(vec_e, zero_e) - if elem_bits < 32: - vec_c = vec_e.extf(vec_type_c) - else: - vec_c = vec_e + vec_c = vec_e.extf(vec_type_c) if elem_bits < 32 else vec_e acc_vecs[si] = acc_vecs[si] + vec_c # ── Store results ── @@ -2646,6 +2644,8 @@ def moe_reduction_kernel( if elem_bits < 32: out_vec = out_vec.truncf(vec_type_e) + # Placeholder init to avoid unbound name before branch assignment. + dst = fx.make_layout(1, 1) if n_sub > 1: dst = y_inner[None, fx.Int32(si)] else: diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 611215a0..61ee3377 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -288,10 +288,7 @@ def _out_elem(): lds_tile_bytes = int(tile_m) * int(lds_stride_bytes) // a_elem_vec_pack lds_out_bytes = 2 * int(tile_m) * int(tile_n) if use_cshuffle_epilog else 0 - is_2stage = int(lds_stage) == 2 - is_num_tiles_odd = ((int(K) // int(tile_k)) % 2) == 1 - fp4_tilek128_mode = bool(is_fp4 and int(tile_k) == 128) - if is_2stage: + if int(lds_stage) == 2: assert lds_out_bytes % 2 == 0, "lds_out_bytes should be multiple of 2" buffer_size_bytes = max(lds_tile_bytes, lds_out_bytes // lds_stage) buffer_size_elems = buffer_size_bytes if elem_bytes == 1 else (buffer_size_bytes // 2) @@ -301,16 +298,12 @@ def _out_elem(): lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) allocator_ping.ptr = lds_ping_offset + buffer_size_elems * elem_bytes - lds_default_offset = lds_pong_offset - lds_default_elems = buffer_size_elems else: lds_total_bytes = max(lds_tile_bytes, lds_out_bytes) lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) lds_alloc_offset = allocator_pong._align(allocator_pong.ptr, 16) allocator_pong.ptr = lds_alloc_offset + lds_total_elems * elem_bytes - lds_default_offset = lds_alloc_offset - lds_default_elems = lds_total_elems # ── Kernel function ──────────────────────────────────────────────────── @flyc.kernel @@ -366,18 +359,7 @@ def kernel_gemm( base_ptr_pong = allocator_pong.get_base() base_ptr_ping = allocator_ping.get_base() - # Initialize with valid memrefs to avoid None flowing into rewritten branches. - lds_a_ptr = SmemPtr( - base_ptr_pong, lds_default_offset, _elem_type(), shape=(lds_default_elems,) - ) - lds_a_pong = lds_a_ptr.get() - lds_a_ping = lds_a_pong - lds_out = ( - SmemPtr(base_ptr_pong, lds_default_offset, _out_elem(), shape=(tile_m * tile_n,)).get() - if use_cshuffle_epilog - else None - ) - if is_2stage: + if lds_stage == 2: lds_a_pong = SmemPtr( base_ptr_pong, lds_pong_offset, _elem_type(), shape=(tile_m * tile_k,) ).get() @@ -392,7 +374,16 @@ def kernel_gemm( else: lds_out = None else: - pass + lds_a_ptr = SmemPtr( + base_ptr_pong, lds_alloc_offset, _elem_type(), shape=(lds_total_elems,) + ) + lds_a_pong = lds_a_ptr.get() + lds_a_ping = lds_a_pong + lds_out = ( + SmemPtr(base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(tile_m * tile_n,)).get() + if use_cshuffle_epilog + else None + ) # ---- Buffer resources (runtime byte sizes for OOB protection) ---- _a_nrec = arith.index_cast(T.i64, c_m * (K * elem_bytes // a_elem_vec_pack)) @@ -643,12 +634,6 @@ def dma_a_tile_to_lds(base_k_div4, lds_buffer): ), ) - lds_base = memref_dialect.extract_aligned_pointer_as_index(lds_buffer) - lds_ptr_base = buffer_ops.create_llvm_ptr( - arith.index_cast(T.i64, lds_base), address_space=3 - ) - lds_ptr = buffer_ops.get_element_ptr(lds_ptr_base, wave_offset) - for i in range_constexpr(num_a_async_loads): row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32_async(i) col_a_local_sw = swizzle_xor16(row_a_local, col_a_local_i32 * c4, k_blocks16) @@ -657,6 +642,8 @@ def dma_a_tile_to_lds(base_k_div4, lds_buffer): global_offset = arith.index_cast(T.i32, global_byte_idx) if i == 0: + lds_base = memref_dialect.extract_aligned_pointer_as_index(lds_buffer) + lds_ptr_base = buffer_ops.create_llvm_ptr(arith.index_cast(T.i64, lds_base), address_space=3) lds_ptr = buffer_ops.get_element_ptr(lds_ptr_base, wave_offset) else: lds_ptr = buffer_ops.get_element_ptr( @@ -759,17 +746,9 @@ def load_fp4_scales(base_k_scale_idx): def load_fp4_scale_chunk(base_k): return load_fp4_scales(base_k // fx.Index(_fp4_scale_chunk_k)) - def make_unit_scales(): - one_f32 = arith.constant(1.0, type=T.f32) - one_vec4 = vector.from_elements(T.f32x4, [one_f32, one_f32, one_f32, one_f32]) - return { - "s_b_vals": [one_f32 for _ in range_constexpr(num_acc_n)], - "s_a_vecs": [one_vec4 for _ in range_constexpr(m_repeat)], - } - # ── Compute tile (MFMA) ─────────────────────────────────────────── def compute_tile(accs_in, b_tile_in, lds_buffer, *, is_last_tile=False, a0_prefetch=None, fp4_scales=None, fp4_scale_half=0): - scales_pf = make_unit_scales() + scales_pf = {} if is_last_tile and (not is_f16_or_bf16): s_b_vals = [] for ni in range_constexpr(num_acc_n): @@ -838,8 +817,6 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for imxdl in range_constexpr(_fp4_pack_M): mi_idx = mi_p * _fp4_pack_M + imxdl curr_row_a_lds = row_a_lds + (mi_idx * 16) - a0 = arith.constant(-1, type=T.i64) - a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): a0, a1 = a0_prefetch else: @@ -871,8 +848,6 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) - a0 = arith.constant(-1, type=T.i64) - a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku0 == 0) and (mi == 0): a0, a1 = a0_prefetch else: @@ -916,8 +891,6 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): col_base = col_offset_base_bytes + ki64 for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) - a0 = arith.constant(-1, type=T.i64) - a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch else: @@ -934,8 +907,12 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): vec1_out = T.vec(1, _out_elem()) def store_output(final_accs, scales): - s_b_vals = scales["s_b_vals"] - s_a_vecs = scales["s_a_vecs"] + if is_f16_or_bf16 or is_fp4: + s_b_vals = None + s_a_vecs = None + else: + s_b_vals = scales["s_b_vals"] + s_a_vecs = scales["s_a_vecs"] if use_cshuffle_epilog: if lds_out is None: @@ -954,7 +931,6 @@ def write_row_to_lds(*, mi, ii, row_in_tile, row, row_base_lds, val = vector.extract(acc, static_position=[ii], dynamic_position=[]) if is_int8: val = arith.sitofp(T.f32, val) - val_s = val if is_f16_or_bf16 or is_fp4: val_s = val elif _needs_per_token_scale: @@ -1003,7 +979,6 @@ def body_row(*, mi, ii, row_in_tile, row): val = vector.extract(acc, static_position=[ii], dynamic_position=[]) if is_int8: val = arith.sitofp(T.f32, val) - val_s = val if is_f16_or_bf16 or is_fp4: val_s = val elif _needs_per_token_scale: @@ -1265,7 +1240,7 @@ def _build_pingpong_body(k_iv, inner_state): return _pack_state(accs_in, _flatten_b_tile(b_tile_pong_new), a0_prefetch_pong_new, _sc_pong) - if is_2stage: + if lds_stage == 2: def prefetch_a0_pack(lds_buffer): return lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_buffer) @@ -1281,10 +1256,8 @@ def prefetch_a0_pack(lds_buffer): fp4_scales0 = load_fp4_scale_chunk(fx.Index(0)) if is_fp4 else None num_tiles = K // tile_k - final_accs = accs - scales = make_unit_scales() - if fp4_tilek128_mode: - if is_num_tiles_odd: + if _fp4_tilek128: + if (num_tiles % 2) == 1: c_k_main = K - tile_k init_state = _pack_state(accs, _flatten_b_tile(b_tile0), a0_prefetch_pong, fp4_scales0) @@ -1328,7 +1301,7 @@ def prefetch_a0_pack(lds_buffer): is_last_tile=not is_fp4, a0_prefetch=a0_prefetch_ping, fp4_scales=fp4_scales_ep, fp4_scale_half=1, ) - elif is_num_tiles_odd: + elif (num_tiles % 2) == 1: c_k_main = K - tile_k init_state = _pack_state(accs, _flatten_b_tile(b_tile0), a0_prefetch_pong, fp4_scales0) @@ -1480,4 +1453,4 @@ def compile_preshuffle_gemm_w4( return inner -__all__ = ["compile_preshuffle_gemm_a8", "compile_preshuffle_gemm_w4"] +__all__ = ["compile_preshuffle_gemm_a8", "compile_preshuffle_gemm_w4"] \ No newline at end of file From 30dc632d44734a97f7e3d0aa7a4332f7d4f6ff30 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 04:59:46 +0000 Subject: [PATCH 13/31] [FLYDSL]: kernels/preshuffle_gemm.py refactor --- kernels/preshuffle_gemm.py | 596 +++++++++++++++++++++++++++++-------- 1 file changed, 471 insertions(+), 125 deletions(-) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 61ee3377..c2df5998 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -258,7 +258,10 @@ def _elem_type(): return T.f16 if is_bf16: return T.bf16 - if is_fp4: + def load_fp4_scale_chunk(_base_k): + raise RuntimeError("load_fp4_scale_chunk called when is_fp4=False") + + if const_expr(is_fp4): return T.i8 return T.i8 if is_int8 else T.f8 @@ -288,6 +291,9 @@ def _out_elem(): lds_tile_bytes = int(tile_m) * int(lds_stride_bytes) // a_elem_vec_pack lds_out_bytes = 2 * int(tile_m) * int(tile_n) if use_cshuffle_epilog else 0 + lds_pong_offset = 0 + lds_ping_offset = 0 + lds_alloc_offset = 0 if int(lds_stage) == 2: assert lds_out_bytes % 2 == 0, "lds_out_bytes should be multiple of 2" buffer_size_bytes = max(lds_tile_bytes, lds_out_bytes // lds_stage) @@ -327,6 +333,7 @@ def kernel_gemm( ) # ---- Layouts ---- + _k_div4_factor = (K * elem_bytes) // 4 // a_elem_vec_pack kpack_bytes = 8 if is_int4 else 16 @@ -359,31 +366,39 @@ def kernel_gemm( base_ptr_pong = allocator_pong.get_base() base_ptr_ping = allocator_ping.get_base() - if lds_stage == 2: - lds_a_pong = SmemPtr( + lds_a_pong_ptr = SmemPtr(base_ptr_pong, lds_alloc_offset, _elem_type(), shape=(1,)) + lds_a_ping_ptr = lds_a_pong_ptr + lds_out_ptr = SmemPtr(base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(1,)) + + if const_expr(lds_stage == 2): + lds_a_pong_ptr = SmemPtr( base_ptr_pong, lds_pong_offset, _elem_type(), shape=(tile_m * tile_k,) - ).get() - lds_a_ping = SmemPtr( + ) + lds_a_ping_ptr = SmemPtr( base_ptr_ping, lds_ping_offset, _elem_type(), shape=(tile_m * tile_k,) - ).get() + ) if use_cshuffle_epilog: - lds_out = SmemPtr( + lds_out_ptr = SmemPtr( base_ptr_pong, lds_pong_offset, _out_elem(), shape=(tile_m * tile_n,) - ).get() + ) else: - lds_out = None + lds_out_ptr = SmemPtr(base_ptr_pong, lds_pong_offset, _out_elem(), shape=(1,)) else: - lds_a_ptr = SmemPtr( + lds_a_pong_ptr = SmemPtr( base_ptr_pong, lds_alloc_offset, _elem_type(), shape=(lds_total_elems,) ) - lds_a_pong = lds_a_ptr.get() - lds_a_ping = lds_a_pong - lds_out = ( - SmemPtr(base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(tile_m * tile_n,)).get() - if use_cshuffle_epilog - else None - ) + lds_a_ping_ptr = lds_a_pong_ptr + if use_cshuffle_epilog: + lds_out_ptr = SmemPtr( + base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(tile_m * tile_n,) + ) + else: + lds_out_ptr = SmemPtr(base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(1,)) + + lds_a_pong = lds_a_pong_ptr.get() + lds_a_ping = lds_a_ping_ptr.get() + lds_out = lds_out_ptr.get() # ---- Buffer resources (runtime byte sizes for OOB protection) ---- _a_nrec = arith.index_cast(T.i64, c_m * (K * elem_bytes // a_elem_vec_pack)) @@ -476,11 +491,11 @@ def _extract_b_packs(b16): b_i64x2 = vector.bitcast(T.i64x2, b16) b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) - if not is_f16_or_bf16: + if const_expr(not is_f16_or_bf16): return b0_i64, b1_i64 b0_v1 = vector.from_elements(T.vec(1, T.i64), [b0_i64]) b1_v1 = vector.from_elements(T.vec(1, T.i64), [b1_i64]) - if is_f16: + if const_expr(is_f16): return vector.bitcast(T.f16x4, b0_v1), vector.bitcast(T.f16x4, b1_v1) return vector.bitcast(T.i16x4, b0_v1), vector.bitcast(T.i16x4, b1_v1) @@ -494,7 +509,7 @@ def _load_b_single(k_dword_offset, ni): return _extract_b_packs(b16) def load_b_packs_k64(base_k, ku: int, ni: int): - if is_int4: + if const_expr(is_int4): ki0 = (ku * 2) + 0 ki1 = (ku * 2) + 1 return load_b_pack(base_k, ki0, ni), load_b_pack(base_k, ki1, ni) @@ -511,7 +526,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): return _extract_b_packs(b16) def load_b_tile(base_k): - if not is_int4 and not is_f16_or_bf16: + if const_expr((not is_int4) and (not is_f16_or_bf16)): base_k_bytes = base_k * elem_bytes k0_base = base_k_bytes // c64_b k_dwords = [] @@ -558,12 +573,12 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): a0_i64 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) a1_i64 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) - if not is_f16_or_bf16: + if const_expr(not is_f16_or_bf16): return a0_i64, a1_i64 a0_v1 = vector.from_elements(T.vec(1, T.i64), [a0_i64]) a1_v1 = vector.from_elements(T.vec(1, T.i64), [a1_i64]) - if is_f16: + if const_expr(is_f16): return vector.bitcast(T.f16x4, a0_v1), vector.bitcast(T.f16x4, a1_v1) return vector.bitcast(T.i16x4, a0_v1), vector.bitcast(T.i16x4, a1_v1) @@ -623,41 +638,54 @@ def a_tile_chunk_coord_i32_async(i: int): chunk_i32=a_async_load_dword, ) - def dma_a_tile_to_lds(base_k_div4, lds_buffer): + def dma_a_tile_to_lds( + base_k_div4, + lds_buffer, + *, + wave_id_v, + wave_size_v, + dma_bytes_v, + num_a_async_loads_v, + a_tile_chunk_coord_i32_async_fn, + c4_v, + k_blocks16_v, + bx_m_v, + k_bytes_factor_v, + total_threads_v, + a_rsrc_v, + ): from flydsl._mlir.dialects import memref as memref_dialect - dma_bytes = a_async_load_bytes wave_offset = rocdl.readfirstlane( T.i64, arith.index_cast( - T.i64, wave_id * arith.constant(wave_size * dma_bytes, index=True) + T.i64, wave_id_v * arith.constant(wave_size_v * dma_bytes_v, index=True) ), ) - - for i in range_constexpr(num_a_async_loads): - row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32_async(i) - col_a_local_sw = swizzle_xor16(row_a_local, col_a_local_i32 * c4, k_blocks16) - row_a_global = bx_m + row_a_local - global_byte_idx = row_a_global * k_bytes_factor + (base_k_div4 * c4 + col_a_local_sw) + lds_base = memref_dialect.extract_aligned_pointer_as_index(lds_buffer) + lds_ptr_base = buffer_ops.create_llvm_ptr(arith.index_cast(T.i64, lds_base), address_space=3) + lds_ptr = buffer_ops.get_element_ptr(lds_ptr_base, wave_offset) + + for i in range_constexpr(num_a_async_loads_v): + row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32_async_fn(i) + col_a_local_sw = swizzle_xor16(row_a_local, col_a_local_i32 * c4_v, k_blocks16_v) + row_a_global = bx_m_v + row_a_local + global_byte_idx = row_a_global * k_bytes_factor_v + (base_k_div4 * c4_v + col_a_local_sw) global_offset = arith.index_cast(T.i32, global_byte_idx) - if i == 0: - lds_base = memref_dialect.extract_aligned_pointer_as_index(lds_buffer) - lds_ptr_base = buffer_ops.create_llvm_ptr(arith.index_cast(T.i64, lds_base), address_space=3) - lds_ptr = buffer_ops.get_element_ptr(lds_ptr_base, wave_offset) - else: + if const_expr(i > 0): lds_ptr = buffer_ops.get_element_ptr( lds_ptr, - static_byte_offset=total_threads * dma_bytes, + static_byte_offset=total_threads_v * dma_bytes_v, ) - size_i32 = arith.constant(dma_bytes, type=T.i32) + size_i32 = arith.constant(dma_bytes_v, type=T.i32) soffset = arith.constant(0, type=T.i32) offset_imm = arith.constant(0, type=T.i32) aux = arith.constant(1, type=T.i32) rocdl.raw_ptr_buffer_load_lds( - a_rsrc, + a_rsrc_v, lds_ptr, size_i32, global_offset, @@ -666,9 +694,23 @@ def dma_a_tile_to_lds(base_k_div4, lds_buffer): aux, ) - def prefetch_a_to_lds(base_k, lds_buffer): - base_k_div4 = base_k // 4 // a_elem_vec_pack - dma_a_tile_to_lds(base_k_div4, lds_buffer) + def prefetch_a_to_lds(base_k, lds_buffer, *, a_elem_vec_pack_v, dma_a_tile_to_lds_fn): + base_k_div4 = base_k // 4 // a_elem_vec_pack_v + dma_a_tile_to_lds_fn( + base_k_div4, + lds_buffer, + wave_id_v=wave_id, + wave_size_v=wave_size, + dma_bytes_v=a_async_load_bytes, + num_a_async_loads_v=num_a_async_loads, + a_tile_chunk_coord_i32_async_fn=a_tile_chunk_coord_i32_async, + c4_v=c4, + k_blocks16_v=k_blocks16, + bx_m_v=bx_m, + k_bytes_factor_v=k_bytes_factor, + total_threads_v=total_threads, + a_rsrc_v=a_rsrc, + ) def prefetch_a_tile(base_k): base_k_bytes = base_k * elem_bytes // a_elem_vec_pack @@ -686,7 +728,11 @@ def prefetch_ab_tile(base_k): # ── FP4 scale pre-fetch (outside compute_tile for latency hiding) ── _fp4_tilek128 = False - if is_fp4: + + def load_fp4_scale_chunk(_base_k): + raise RuntimeError("load_fp4_scale_chunk called when is_fp4=False") + + if const_expr(is_fp4): _fp4_pack_M_outer = 2 _fp4_pack_N_outer = 2 _fp4_pack_K_outer = 2 @@ -869,11 +915,16 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): return current_accs_list, scales_pf mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 - if is_int8: + + def _mfma_fn_placeholder(_res_ty, _ops): + raise RuntimeError("mfma_fn is not selected for current dtype path") + + mfma_fn = _mfma_fn_placeholder + if const_expr(is_int8): mfma_fn = mfma_i32_k32 - elif is_f16: + elif const_expr(is_f16): mfma_fn = rocdl.mfma_f32_16x16x16f16 - elif is_bf16: + elif const_expr(is_bf16): mfma_fn = rocdl.mfma_f32_16x16x16bf16_1k else: mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8 @@ -891,7 +942,7 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): col_base = col_offset_base_bytes + ki64 for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)): a0, a1 = a0_prefetch else: a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer) @@ -907,10 +958,9 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): vec1_out = T.vec(1, _out_elem()) def store_output(final_accs, scales): - if is_f16_or_bf16 or is_fp4: - s_b_vals = None - s_a_vecs = None - else: + s_b_vals = [] + s_a_vecs = [] + if const_expr(not (is_f16_or_bf16 or is_fp4)): s_b_vals = scales["s_b_vals"] s_a_vecs = scales["s_a_vecs"] @@ -921,7 +971,8 @@ def store_output(final_accs, scales): def write_row_to_lds(*, mi, ii, row_in_tile, row, row_base_lds, col_base_local, num_acc_n, lds_out): - if _needs_per_token_scale: + s_a = arith.constant(1.0, type=T.f32) + if const_expr(_needs_per_token_scale): s_a_vec4 = s_a_vecs[mi] s_a = vector.extract(s_a_vec4, static_position=[ii], dynamic_position=[]) for ni in range_constexpr(num_acc_n): @@ -929,11 +980,11 @@ def write_row_to_lds(*, mi, ii, row_in_tile, row, row_base_lds, acc_idx = mi * num_acc_n + ni acc = final_accs[acc_idx] val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - if is_int8: + if const_expr(is_int8): val = arith.sitofp(T.f32, val) - if is_f16_or_bf16 or is_fp4: + if const_expr(is_f16_or_bf16 or is_fp4): val_s = val - elif _needs_per_token_scale: + elif const_expr(_needs_per_token_scale): val_s = (val * s_a) * s_b_vals[ni] else: val_s = val @@ -968,7 +1019,8 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): return def body_row(*, mi, ii, row_in_tile, row): - if _needs_per_token_scale: + s_a = arith.constant(1.0, type=T.f32) + if const_expr(_needs_per_token_scale): s_a_vec4 = s_a_vecs[mi] s_a = vector.extract(s_a_vec4, static_position=[ii], dynamic_position=[]) col_base = by_n + n_tile_base + lane_mod_16 @@ -977,11 +1029,11 @@ def body_row(*, mi, ii, row_in_tile, row): acc_idx = mi * num_acc_n + ni acc = final_accs[acc_idx] val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - if is_int8: + if const_expr(is_int8): val = arith.sitofp(T.f32, val) - if is_f16_or_bf16 or is_fp4: + if const_expr(is_f16_or_bf16 or is_fp4): val_s = val - elif _needs_per_token_scale: + elif const_expr(_needs_per_token_scale): val_s = (val * s_a) * s_b_vals[ni] else: val_s = val @@ -1137,38 +1189,85 @@ def _unflatten_b_tile(flat): n_accs = num_acc_n * m_repeat n_btile = k_unroll * 2 * num_acc_n n_a0pf = 2 + n_fp4_asc = 0 + n_fp4_bsc = 0 if is_fp4: n_fp4_asc = _k_unroll_packed_outer * _m_repeat_packed_outer n_fp4_bsc = _k_unroll_packed_outer * _num_acc_n_packed_outer - def _pack_state(accs_l, bt_flat, a0pf, fp4_scales=None): + def _pack_state(accs_l, bt_flat, a0pf, fp4_scales=None, *, is_fp4_v): state = list(accs_l) + list(bt_flat) + [a0pf[0], a0pf[1]] - if is_fp4: + if is_fp4_v: a_scales, b_scales = fp4_scales state.extend(a_scales) state.extend(b_scales) return state - def _unpack_state(vals): - accs_l = list(vals[:n_accs]) - bt_flat = list(vals[n_accs:n_accs + n_btile]) - a0pf = (vals[n_accs + n_btile], vals[n_accs + n_btile + 1]) - if not is_fp4: + def _unpack_state(vals, *, n_accs_v, n_btile_v, n_a0pf_v, is_fp4_v, n_fp4_asc_v, n_fp4_bsc_v): + accs_l = list(vals[:n_accs_v]) + bt_flat = list(vals[n_accs_v:n_accs_v + n_btile_v]) + a0pf = (vals[n_accs_v + n_btile_v], vals[n_accs_v + n_btile_v + 1]) + if not is_fp4_v: return accs_l, bt_flat, a0pf, None - sc_base = n_accs + n_btile + n_a0pf - a_scales = list(vals[sc_base:sc_base + n_fp4_asc]) - b_scales = list(vals[sc_base + n_fp4_asc:sc_base + n_fp4_asc + n_fp4_bsc]) + sc_base = n_accs_v + n_btile_v + n_a0pf_v + a_scales = list(vals[sc_base:sc_base + n_fp4_asc_v]) + b_scales = list(vals[sc_base + n_fp4_asc_v:sc_base + n_fp4_asc_v + n_fp4_bsc_v]) return accs_l, bt_flat, a0pf, (a_scales, b_scales) - def _build_pingpong_body(k_iv, inner_state): - accs_in, bt_flat_in, a0pf_in, fp4_scales_pong_in = _unpack_state(inner_state) + def _build_pingpong_body( + k_iv, + inner_state, + *, + _unpack_state, + _unflatten_b_tile, + _fp4_tilek128, + tile_k, + use_async_copy, + prefetch_a_to_lds, + a_elem_vec_pack, + dma_a_tile_to_lds, + prefetch_a_tile, + prefetch_b_tile, + compute_tile, + lds_a_pong, + lds_a_ping, + store_a_tile_to_lds, + hot_loop_scheduler, + num_b_loads, + gpu, + prefetch_a0_pack, + load_fp4_scale_chunk, + is_fp4, + rocdl, + _pack_state, + _flatten_b_tile, + lds_load_packs_k64, + row_a_lds, + col_offset_base_bytes, + n_accs, + n_btile, + n_a0pf, + n_fp4_asc, + n_fp4_bsc, + ): + accs_in, bt_flat_in, a0pf_in, fp4_scales_pong_in = _unpack_state( + inner_state, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_in = _unflatten_b_tile(bt_flat_in) if _fp4_tilek128: next_k1 = k_iv + tile_k - if use_async_copy: - prefetch_a_to_lds(next_k1, lds_a_ping) + if const_expr(use_async_copy): + prefetch_a_to_lds( + next_k1, lds_a_ping, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_tile_ping = prefetch_a_tile(next_k1) b_tile_ping = prefetch_b_tile(next_k1) @@ -1176,18 +1275,25 @@ def _build_pingpong_body(k_iv, inner_state): accs_in, b_tile_pong_in, lds_a_pong, a0_prefetch=a0pf_in, fp4_scales=fp4_scales_pong_in, fp4_scale_half=0, ) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_tile_ping, lds_a_ping) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) + a0_prefetch_ping = prefetch_a0_pack( + lds_a_ping, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) next_k2 = k_iv + (tile_k * 2) _sc_ping = load_fp4_scale_chunk(next_k2) if is_fp4 else None rocdl.sched_barrier(0) - if use_async_copy: - prefetch_a_to_lds(next_k2, lds_a_pong) + if const_expr(use_async_copy): + prefetch_a_to_lds( + next_k2, lds_a_pong, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_tile_pong = prefetch_a_tile(next_k2) b_tile_pong_new = prefetch_b_tile(next_k2) @@ -1195,76 +1301,161 @@ def _build_pingpong_body(k_iv, inner_state): accs_in, b_tile_ping, lds_a_ping, a0_prefetch=a0_prefetch_ping, fp4_scales=fp4_scales_pong_in, fp4_scale_half=1, ) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_tile_pong, lds_a_pong) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_pong_new = prefetch_a0_pack(lds_a_pong) + a0_prefetch_pong_new = prefetch_a0_pack( + lds_a_pong, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) - return _pack_state(accs_in, _flatten_b_tile(b_tile_pong_new), - a0_prefetch_pong_new, _sc_ping) + return _pack_state( + accs_in, + _flatten_b_tile(b_tile_pong_new), + a0_prefetch_pong_new, + _sc_ping, + is_fp4_v=is_fp4, + ) next_k1 = k_iv + tile_k - if use_async_copy: - prefetch_a_to_lds(next_k1, lds_a_ping) + if const_expr(use_async_copy): + prefetch_a_to_lds( + next_k1, lds_a_ping, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_tile = prefetch_a_tile(next_k1) _sc_ping = load_fp4_scale_chunk(k_iv + fx.Index(tile_k)) if is_fp4 else None b_tile_ping = prefetch_b_tile(next_k1) accs_in, _ = compute_tile(accs_in, b_tile_pong_in, lds_a_pong, a0_prefetch=a0pf_in, fp4_scales=fp4_scales_pong_in) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_tile, lds_a_ping) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) + a0_prefetch_ping = prefetch_a0_pack( + lds_a_ping, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) next_k2 = k_iv + (tile_k * 2) - if use_async_copy: - prefetch_a_to_lds(next_k2, lds_a_pong) + if const_expr(use_async_copy): + prefetch_a_to_lds( + next_k2, lds_a_pong, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_tile = prefetch_a_tile(next_k2) _sc_pong = load_fp4_scale_chunk(k_iv + (tile_k * 2)) if is_fp4 else None b_tile_pong_new = prefetch_b_tile(next_k2) accs_in, _ = compute_tile(accs_in, b_tile_ping, lds_a_ping, a0_prefetch=a0_prefetch_ping, fp4_scales=_sc_ping) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_tile, lds_a_pong) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_pong_new = prefetch_a0_pack(lds_a_pong) + a0_prefetch_pong_new = prefetch_a0_pack( + lds_a_pong, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) - return _pack_state(accs_in, _flatten_b_tile(b_tile_pong_new), - a0_prefetch_pong_new, _sc_pong) + return _pack_state( + accs_in, + _flatten_b_tile(b_tile_pong_new), + a0_prefetch_pong_new, + _sc_pong, + is_fp4_v=is_fp4, + ) - if lds_stage == 2: - def prefetch_a0_pack(lds_buffer): - return lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_buffer) + if const_expr(lds_stage == 2): + def prefetch_a0_pack(lds_buffer, *, lds_load_packs_k64_fn, row_a_lds_v, col_offset_base_bytes_v): + return lds_load_packs_k64_fn(row_a_lds_v, col_offset_base_bytes_v, lds_buffer) k0 = fx.Index(0) b_tile0 = prefetch_b_tile(k0) - if use_async_copy: - prefetch_a_to_lds(k0, lds_a_pong) + if const_expr(use_async_copy): + prefetch_a_to_lds( + k0, lds_a_pong, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: store_a_tile_to_lds(prefetch_a_tile(k0), lds_a_pong) gpu.barrier() accs = [acc_init] * n_accs - a0_prefetch_pong = prefetch_a0_pack(lds_a_pong) + a0_prefetch_pong = prefetch_a0_pack( + lds_a_pong, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) fp4_scales0 = load_fp4_scale_chunk(fx.Index(0)) if is_fp4 else None + final_accs = 1 + scales = 1 num_tiles = K // tile_k - if _fp4_tilek128: - if (num_tiles % 2) == 1: + if const_expr(_fp4_tilek128): + if const_expr((num_tiles % 2) == 1): c_k_main = K - tile_k - init_state = _pack_state(accs, _flatten_b_tile(b_tile0), - a0_prefetch_pong, fp4_scales0) + init_state = _pack_state( + accs, + _flatten_b_tile(b_tile0), + a0_prefetch_pong, + fp4_scales0, + is_fp4_v=is_fp4, + ) results = init_state for iv, inner in range(0, c_k_main, tile_k * 2, init=init_state): - results = yield _build_pingpong_body(iv, inner) - accs, bt_flat, a0pf, fp4_scales_final = _unpack_state(results) + results = yield _build_pingpong_body( + iv, + inner, + _unpack_state=_unpack_state, + _unflatten_b_tile=_unflatten_b_tile, + _fp4_tilek128=_fp4_tilek128, + tile_k=tile_k, + use_async_copy=use_async_copy, + prefetch_a_to_lds=prefetch_a_to_lds, + a_elem_vec_pack=a_elem_vec_pack, + dma_a_tile_to_lds=dma_a_tile_to_lds, + prefetch_a_tile=prefetch_a_tile, + prefetch_b_tile=prefetch_b_tile, + compute_tile=compute_tile, + lds_a_pong=lds_a_pong, + lds_a_ping=lds_a_ping, + store_a_tile_to_lds=store_a_tile_to_lds, + hot_loop_scheduler=hot_loop_scheduler, + num_b_loads=num_b_loads, + gpu=gpu, + prefetch_a0_pack=prefetch_a0_pack, + load_fp4_scale_chunk=load_fp4_scale_chunk, + is_fp4=is_fp4, + rocdl=rocdl, + _pack_state=_pack_state, + _flatten_b_tile=_flatten_b_tile, + lds_load_packs_k64=lds_load_packs_k64, + row_a_lds=row_a_lds, + col_offset_base_bytes=col_offset_base_bytes, + n_accs=n_accs, + n_btile=n_btile, + n_a0pf=n_a0pf, + n_fp4_asc=n_fp4_asc, + n_fp4_bsc=n_fp4_bsc, + ) + accs, bt_flat, a0pf, fp4_scales_final = _unpack_state( + results, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_final = _unflatten_b_tile(bt_flat) final_accs, scales = compute_tile( accs, b_tile_pong_final, lds_a_pong, @@ -1273,42 +1464,143 @@ def prefetch_a0_pack(lds_buffer): ) else: c_k_stop = K - (tile_k * 3) - init_state = _pack_state(accs, _flatten_b_tile(b_tile0), - a0_prefetch_pong, fp4_scales0) + init_state = _pack_state( + accs, + _flatten_b_tile(b_tile0), + a0_prefetch_pong, + fp4_scales0, + is_fp4_v=is_fp4, + ) results = init_state for iv, inner in range(0, c_k_stop, tile_k * 2, init=init_state): - results = yield _build_pingpong_body(iv, inner) - accs, bt_flat, a0pf, fp4_scales_ep = _unpack_state(results) + results = yield _build_pingpong_body( + iv, + inner, + _unpack_state=_unpack_state, + _unflatten_b_tile=_unflatten_b_tile, + _fp4_tilek128=_fp4_tilek128, + tile_k=tile_k, + use_async_copy=use_async_copy, + prefetch_a_to_lds=prefetch_a_to_lds, + a_elem_vec_pack=a_elem_vec_pack, + dma_a_tile_to_lds=dma_a_tile_to_lds, + prefetch_a_tile=prefetch_a_tile, + prefetch_b_tile=prefetch_b_tile, + compute_tile=compute_tile, + lds_a_pong=lds_a_pong, + lds_a_ping=lds_a_ping, + store_a_tile_to_lds=store_a_tile_to_lds, + hot_loop_scheduler=hot_loop_scheduler, + num_b_loads=num_b_loads, + gpu=gpu, + prefetch_a0_pack=prefetch_a0_pack, + load_fp4_scale_chunk=load_fp4_scale_chunk, + is_fp4=is_fp4, + rocdl=rocdl, + _pack_state=_pack_state, + _flatten_b_tile=_flatten_b_tile, + lds_load_packs_k64=lds_load_packs_k64, + row_a_lds=row_a_lds, + col_offset_base_bytes=col_offset_base_bytes, + n_accs=n_accs, + n_btile=n_btile, + n_a0pf=n_a0pf, + n_fp4_asc=n_fp4_asc, + n_fp4_bsc=n_fp4_bsc, + ) + accs, bt_flat, a0pf, fp4_scales_ep = _unpack_state( + results, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_ep = _unflatten_b_tile(bt_flat) last_k = arith.index(K - tile_k) b_tile_ping = prefetch_b_tile(last_k) - if use_async_copy: - prefetch_a_to_lds(last_k, lds_a_ping) + if const_expr(use_async_copy): + prefetch_a_to_lds( + last_k, lds_a_ping, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_regs_ping = prefetch_a_tile(last_k) accs, _ = compute_tile( accs, b_tile_pong_ep, lds_a_pong, a0_prefetch=a0pf, fp4_scales=fp4_scales_ep, fp4_scale_half=0, ) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_regs_ping, lds_a_ping) rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) + a0_prefetch_ping = prefetch_a0_pack( + lds_a_ping, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) final_accs, scales = compute_tile( accs, b_tile_ping, lds_a_ping, is_last_tile=not is_fp4, a0_prefetch=a0_prefetch_ping, fp4_scales=fp4_scales_ep, fp4_scale_half=1, ) - elif (num_tiles % 2) == 1: + elif const_expr((num_tiles % 2) == 1): c_k_main = K - tile_k - init_state = _pack_state(accs, _flatten_b_tile(b_tile0), - a0_prefetch_pong, fp4_scales0) + init_state = _pack_state( + accs, + _flatten_b_tile(b_tile0), + a0_prefetch_pong, + fp4_scales0, + is_fp4_v=is_fp4, + ) results = init_state for iv, inner in range(0, c_k_main, tile_k * 2, init=init_state): - results = yield _build_pingpong_body(iv, inner) - accs, bt_flat, a0pf, fp4_scales_final = _unpack_state(results) + results = yield _build_pingpong_body( + iv, + inner, + _unpack_state=_unpack_state, + _unflatten_b_tile=_unflatten_b_tile, + _fp4_tilek128=_fp4_tilek128, + tile_k=tile_k, + use_async_copy=use_async_copy, + prefetch_a_to_lds=prefetch_a_to_lds, + a_elem_vec_pack=a_elem_vec_pack, + dma_a_tile_to_lds=dma_a_tile_to_lds, + prefetch_a_tile=prefetch_a_tile, + prefetch_b_tile=prefetch_b_tile, + compute_tile=compute_tile, + lds_a_pong=lds_a_pong, + lds_a_ping=lds_a_ping, + store_a_tile_to_lds=store_a_tile_to_lds, + hot_loop_scheduler=hot_loop_scheduler, + num_b_loads=num_b_loads, + gpu=gpu, + prefetch_a0_pack=prefetch_a0_pack, + load_fp4_scale_chunk=load_fp4_scale_chunk, + is_fp4=is_fp4, + rocdl=rocdl, + _pack_state=_pack_state, + _flatten_b_tile=_flatten_b_tile, + lds_load_packs_k64=lds_load_packs_k64, + row_a_lds=row_a_lds, + col_offset_base_bytes=col_offset_base_bytes, + n_accs=n_accs, + n_btile=n_btile, + n_a0pf=n_a0pf, + n_fp4_asc=n_fp4_asc, + n_fp4_bsc=n_fp4_bsc, + ) + accs, bt_flat, a0pf, fp4_scales_final = _unpack_state( + results, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_final = _unflatten_b_tile(bt_flat) final_accs, scales = compute_tile( accs, b_tile_pong_final, lds_a_pong, @@ -1316,29 +1608,83 @@ def prefetch_a0_pack(lds_buffer): ) else: c_k_stop = K - (tile_k * 3) - init_state = _pack_state(accs, _flatten_b_tile(b_tile0), - a0_prefetch_pong, fp4_scales0) + init_state = _pack_state( + accs, + _flatten_b_tile(b_tile0), + a0_prefetch_pong, + fp4_scales0, + is_fp4_v=is_fp4, + ) results = init_state for iv, inner in range(0, c_k_stop, tile_k * 2, init=init_state): - results = yield _build_pingpong_body(iv, inner) - accs, bt_flat, a0pf, fp4_scales_ep = _unpack_state(results) + results = yield _build_pingpong_body( + iv, + inner, + _unpack_state=_unpack_state, + _unflatten_b_tile=_unflatten_b_tile, + _fp4_tilek128=_fp4_tilek128, + tile_k=tile_k, + use_async_copy=use_async_copy, + prefetch_a_to_lds=prefetch_a_to_lds, + a_elem_vec_pack=a_elem_vec_pack, + dma_a_tile_to_lds=dma_a_tile_to_lds, + prefetch_a_tile=prefetch_a_tile, + prefetch_b_tile=prefetch_b_tile, + compute_tile=compute_tile, + lds_a_pong=lds_a_pong, + lds_a_ping=lds_a_ping, + store_a_tile_to_lds=store_a_tile_to_lds, + hot_loop_scheduler=hot_loop_scheduler, + num_b_loads=num_b_loads, + gpu=gpu, + prefetch_a0_pack=prefetch_a0_pack, + load_fp4_scale_chunk=load_fp4_scale_chunk, + is_fp4=is_fp4, + rocdl=rocdl, + _pack_state=_pack_state, + _flatten_b_tile=_flatten_b_tile, + lds_load_packs_k64=lds_load_packs_k64, + row_a_lds=row_a_lds, + col_offset_base_bytes=col_offset_base_bytes, + n_accs=n_accs, + n_btile=n_btile, + n_a0pf=n_a0pf, + n_fp4_asc=n_fp4_asc, + n_fp4_bsc=n_fp4_bsc, + ) + accs, bt_flat, a0pf, fp4_scales_ep = _unpack_state( + results, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_ep = _unflatten_b_tile(bt_flat) last_k = arith.index(K - tile_k) b_tile_ping = prefetch_b_tile(last_k) - if use_async_copy: - prefetch_a_to_lds(last_k, lds_a_ping) + if const_expr(use_async_copy): + prefetch_a_to_lds( + last_k, lds_a_ping, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_regs_ping = prefetch_a_tile(last_k) _sc_last = load_fp4_scale_chunk(last_k) if is_fp4 else None accs, _ = compute_tile(accs, b_tile_pong_ep, lds_a_pong, a0_prefetch=a0pf, fp4_scales=fp4_scales_ep) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_regs_ping, lds_a_ping) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) + a0_prefetch_ping = prefetch_a0_pack( + lds_a_ping, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) final_accs, scales = compute_tile( accs, b_tile_ping, lds_a_ping, is_last_tile=not is_fp4, a0_prefetch=a0_prefetch_ping, fp4_scales=_sc_last, From d2dcc3fbef645ebefcd0f0dcb069e1127c892a90 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 06:38:59 +0000 Subject: [PATCH 14/31] [FLYDSL]: if const_expr(ast.Name) --- kernels/hgemm_splitk.py | 14 +++++++------- kernels/layernorm_kernel.py | 2 +- kernels/moe_gemm_2stage.py | 16 ++++++++-------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/kernels/hgemm_splitk.py b/kernels/hgemm_splitk.py index be8c80d0..c8952eb8 100644 --- a/kernels/hgemm_splitk.py +++ b/kernels/hgemm_splitk.py @@ -234,11 +234,11 @@ def hgemm_kernel( bs_ = STensor(smem_b_ptr, dtype_, shape=(STAGES, BLOCK_N, BLOCK_K)) smem_c_ptr = SmemPtr(base_ptr, smem_a_offset, dtype_, shape=(BLOCK_M * BLOCK_N,)) cs_ = STensor(smem_c_ptr, dtype_, shape=(BLOCK_M, BLOCK_N)) - if B_PRE_SHUFFLE: + if const_expr(B_PRE_SHUFFLE): # origin: n // WARP_ATOM_N, WARP_ATOM_N, k // WARP_ATOM_K, WARP_ATOM_K // LDG_VEC_SIZE, LDG_VEC_SIZE SHUFFLED_B_ = GTensor(B, dtype=dtype_, shape=( n // WARP_ATOM_N, k // WARP_ATOM_K, WARP_ATOM_K // LDG_VEC_SIZE, WARP_ATOM_N, LDG_VEC_SIZE)) - if IS_SPLIT_K: + if const_expr(IS_SPLIT_K): COUNTER_ = GTensor(COUNTER, dtype=T.i32, shape=(-1,)) tid = fx.Int32(fx.thread_idx.x) @@ -504,7 +504,7 @@ def ldg_matrix_b(k_offset): b_k0 = b_k0_base + kk for ii in range_constexpr(WARP_N_STEPS): b_n0 = b_n0_base + ii - if not B_PRE_SHUFFLE: + if const_expr(not B_PRE_SHUFFLE): warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N warp_atom_k_idx = kk * WARP_ATOM_K n_idx = n_offset + warp_atom_n_idx + ldmatrix_b_n_idx @@ -551,7 +551,7 @@ def block_mma_sync(a_frags, b_frags, c_frags): if IS_SPLIT_K: zero_c() - if B_TO_LDS: + if const_expr(B_TO_LDS): sts_a(ldg_a(ks_begin), 0) sts_b(ldg_b(ks_begin), 0) @@ -623,7 +623,7 @@ def hot_loop_scheduler(): # for i in range_constexpr(WARP_K_STEPS * WARP_M_STEPS * WARP_N_STEPS * MFMA_PER_WARP_K): # rocdl.sched_mfma(1) # ================ Reordered ================ - if ASYNC_COPY: + if const_expr(ASYNC_COPY): AVG_MFMA_COUNT = (MFMA_TOTAL + LDG_TOTAL - 1) // LDG_TOTAL for i in range_constexpr(LDG_TOTAL): rocdl.sched_vmem(ldg_.consume(1)) @@ -646,13 +646,13 @@ def hot_loop_scheduler(): c_frags = state[2 : 2 + C_FRAGS_LEN] a_frags = state[2 + C_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN] b_frags = state[2 + C_FRAGS_LEN + A_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN + B_FRAGS_LEN] - if ASYNC_COPY: + if const_expr(ASYNC_COPY): ldg_sts_a_async(k_offset + BLOCK_K, next_stage) else: a_regs_next = ldg_a(k_offset + BLOCK_K) b_frags_next = ldg_matrix_b(k_offset + BLOCK_K) block_mma_sync(a_frags, b_frags, c_frags) - if not ASYNC_COPY: + if const_expr(not ASYNC_COPY): sts_a(a_regs_next, next_stage) hot_loop_scheduler() gpu.barrier() diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index 05e3a310..68064aa9 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -198,7 +198,7 @@ def _store_vec(val, div_tensor, idx): idx = tid + tile_i * BLOCK_THREADS vec_e = _load_vec(in_div, idx) - if cache_as_elem: + if const_expr(cache_as_elem): in_local.append(vec_e) x = vec_e.extf(vec_type_c) else: diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 1c675384..68950d5b 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -596,7 +596,7 @@ def load_b_pack(base_k, ki_step, ni, blk_list, intra_list): elem_type=w_elem, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes, - unpack_int4=is_int4, + unpack_int4=(is_int4 or is_int4_bf16), ) def load_b_tile(base_k, blk_list, intra_list): @@ -605,7 +605,7 @@ def load_b_tile(base_k, blk_list, intra_list): Returns a list of length `k_unroll`, where each entry is a tuple: (packs_half0[ni], packs_half1[ni]) for the K64 micro-step. """ - if is_int4_bf16: + if const_expr(is_int4_bf16): # W4A16: 2-phase load+unpack for VMEM latency hiding # Phase 1: Issue ALL buffer_loads first. raw_data = [] @@ -770,14 +770,14 @@ def _i64_to_v4i16(x_i64): return vector.bitcast(T.i16x4, v1) def mfma_k64(acc_in, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) b1v = _i64_to_v4f16(b1) acc_mid = mfma_fn(mfma_res_ty, [a0v, b0v, acc_in, 0, 0, 0]) return mfma_fn(mfma_res_ty, [a1v, b1v, acc_mid, 0, 0, 0]) - if is_bf16: + if const_expr(is_bf16): a0v = _i64_to_v4i16(a0) a1v = _i64_to_v4i16(a1) b0v = _i64_to_v4i16(b0) @@ -1791,7 +1791,7 @@ def load_b_pack(base_k, ki_step, ni): elem_type=w_elem, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes, - unpack_int4=is_int4, + unpack_int4=(is_int4 or is_int4_bf16), ) def load_b_tile(base_k): @@ -1800,7 +1800,7 @@ def load_b_tile(base_k): Returns a list of length `k_unroll`, where each entry is a tuple: (packs_half0[ni], packs_half1[ni]) for the K64 micro-step. """ - if is_int4_bf16: + if const_expr(is_int4_bf16): # W4A16: 2-phase load+unpack for VMEM latency hiding raw_data = [] for ku in range_constexpr(k_unroll): @@ -1959,14 +1959,14 @@ def _i64_to_v4i16(x_i64): return vector.bitcast(T.i16x4, v1) def mfma_k64(acc0, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) b1v = _i64_to_v4f16(b1) acc1 = mfma_fn(mfma_res_ty, [a0v, b0v, acc0, 0, 0, 0]) return mfma_fn(mfma_res_ty, [a1v, b1v, acc1, 0, 0, 0]) - if is_bf16: + if const_expr(is_bf16): a0v = _i64_to_v4i16(a0) a1v = _i64_to_v4i16(a1) b0v = _i64_to_v4i16(b0) From 3f82c18ece5dad226080eb7c6c20515a80858c7e Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 07:13:14 +0000 Subject: [PATCH 15/31] [FLYDSL]: if const_expr(ast.compare(ast.name)) --- kernels/rdna_fp8_preshuffle_gemm.py | 4 ++-- kernels/rmsnorm_kernel.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kernels/rdna_fp8_preshuffle_gemm.py b/kernels/rdna_fp8_preshuffle_gemm.py index 09fc56bd..d29ff9cf 100644 --- a/kernels/rdna_fp8_preshuffle_gemm.py +++ b/kernels/rdna_fp8_preshuffle_gemm.py @@ -347,7 +347,7 @@ def _unflatten_b(flat): init_state = _flatten_tile(a_cur) + list(accs) + _flatten_tile(b_cur) # Main K-loop: SCF outer with constexpr inner unroll - if full_outer_iters > 0: + if const_expr(full_outer_iters > 0): for iv, state in range(0, full_outer_iters * k_unroll, k_unroll, init=init_state): s_a = _unflatten_a(list(state[:n_a])) s_accs = list(state[n_a : n_a + n_acc]) @@ -369,7 +369,7 @@ def _unflatten_b(flat): b_cur = _unflatten_b(list(results[n_a + n_acc :])) # Handle remainder tiles - if remainder > 0: + if const_expr(remainder > 0): for j in range_constexpr(remainder): next_kt = fx.Index(full_outer_iters * k_unroll + j + 1) a_next = _load_a_tile(next_kt) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index 784b2e61..4a2a2b6d 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -174,7 +174,7 @@ def _store_vec(val, div_tensor, idx): idx = tid + tile_i * BLOCK_THREADS vec_e = _load_vec(in_div, idx) - if cache_as_elem: + if const_expr(cache_as_elem): in_local.append(vec_e) x = vec_e.extf(vec_type_c) else: From 0415671ba486d9dd444973ffaff8de749fe9cc00 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 07:50:13 +0000 Subject: [PATCH 16/31] [FLYDSL]: Non-None initialization + if dynamic(cond) --- kernels/moe_gemm_2stage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 68950d5b..14866166 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -2296,8 +2296,8 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): # For bf16 global atomics (gfx942 only), precompute the output base address. # gfx950+ has buffer_atomic_pk_add_bf16, so bf16 uses buffer atomics there. - out_base_idx = None - if _needs_global_atomic_bf16: + out_base_idx = fx.Index(0) + if const_expr(_needs_global_atomic_bf16): out_base_idx = buffer_ops.extract_base_index(arg_out) def write_row_to_lds( From b982dbd0a3422efd762dfae2b6351a1970d91a6c Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 09:14:49 +0000 Subject: [PATCH 17/31] [FLYDSL]: Example parameters --- examples/04-preshuffle_gemm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/04-preshuffle_gemm.py b/examples/04-preshuffle_gemm.py index 9cb3335e..a125a30a 100644 --- a/examples/04-preshuffle_gemm.py +++ b/examples/04-preshuffle_gemm.py @@ -73,18 +73,18 @@ def gemm_kernel( mma_frag_C_f16 = fx.make_fragment_like(mma_frag_C, fx.Float16.ir_type) mma_frag_C_retile = thr_copy_r2g_C.retile(mma_frag_C_f16) - def run_pipeline_stage(read_stage, next_k, read_next=True): + def run_pipeline_stage(read_stage, next_k, read_next=True, buffer_copy_128b_v=None): write_stage = read_stage ^ 1 if read_next: next_k = fx.Int32(next_k) fx.copy( - buffer_copy_128b.set_value("soffset", next_k * BLOCK_K), + buffer_copy_128b_v.set_value("soffset", next_k * BLOCK_K), thr_gA_k[None, None, None, 0], # global offset is added on the soffset of buffer_copy_atom copy_frag_A, ) fx.copy( - buffer_copy_128b, + buffer_copy_128b_v, thr_gB_k[None, None, None, next_k], mma_frag_B_retile[None, None, None, write_stage], ) @@ -142,11 +142,11 @@ def sched_main_iter(with_vmem=False, with_dswr=False): fx.gpu.barrier() for k_iter in range(0, K // BLOCK_K - 2, 2): - run_pipeline_stage(read_stage=0, next_k=k_iter + 1) - run_pipeline_stage(read_stage=1, next_k=k_iter + 2) + run_pipeline_stage(read_stage=0, next_k=k_iter + 1, buffer_copy_128b_v=buffer_copy_128b) + run_pipeline_stage(read_stage=1, next_k=k_iter + 2, buffer_copy_128b_v=buffer_copy_128b) - run_pipeline_stage(read_stage=0, next_k=K // BLOCK_K - 1) - run_pipeline_stage(read_stage=1, next_k=None, read_next=False) + run_pipeline_stage(read_stage=0, next_k=K // BLOCK_K - 1, buffer_copy_128b_v=buffer_copy_128b) + run_pipeline_stage(read_stage=1, next_k=None, read_next=False, buffer_copy_128b_v=buffer_copy_128b) mma_frag_C_f16.store(fx.arith.trunc_f(fx.T.VectorType.get([64], fx.T.f16()), mma_frag_C.load())) fx.copy(buffer_copy_16b, mma_frag_C_retile, thr_gC) From 2c740a0713acc8bae40510cca4129901766b39b4 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 09:56:07 +0000 Subject: [PATCH 18/31] [FLYDSL]: gfx950 const_expr --- kernels/blockscale_preshuffle_gemm.py | 4 +- kernels/mixed_moe_gemm_2stage.py | 74 ++++++++++++++------------- kernels/moe_blockscale_2stage.py | 68 ++++++++++++------------ 3 files changed, 77 insertions(+), 69 deletions(-) diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index cd3cef07..a3552b83 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -446,7 +446,7 @@ def _mfma_fn_placeholder(*args, **kwargs): mfma_fn = _mfma_fn_placeholder - if _is_gfx950: + if const_expr(_is_gfx950): c0_i64 = arith.constant(0, type=T.i64) def pack_i64x4_to_i32x8(x0, x1, x2, x3): @@ -521,7 +521,7 @@ def compute_tile_blockscale( combined_scales = pre_scales[sb] block_accs = [acc_init] * (num_acc_n * m_repeat) - if _is_gfx950: + if const_expr(_is_gfx950): ku0 = sb * ku_per_sb ku1 = ku0 + 1 b0_packs0, b0_packs1 = b_tile_in[ku0] diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index 53d28b3f..d4fb99b6 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -17,6 +17,7 @@ from flydsl.expr.typing import T from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +from test.FLYIR2.FlyDSL.python.flydsl.expr.primitive import const_expr try: from flydsl.runtime.device import supports_bf16_global_atomics @@ -369,7 +370,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): arg_out, max_size=False, num_records_bytes=out_nbytes_i32 ) - if is_f16_a: + if const_expr(is_f16_a): sx_rsrc = None else: # A1 microscale: [sorted_rows, K/32] e8m0 bytes, packed as i32. @@ -384,7 +385,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) - if is_f16_b: + if const_expr(is_f16_b): sw_rsrc = None else: # W1 microscale: [experts * 2 * inter_dim, K/32] e8m0 bytes. @@ -413,7 +414,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): max_size=False, num_records_bytes=sorted_nbytes_i32, ) - if doweight_stage1 + if const_expr(doweight_stage1) else None ) @@ -766,7 +767,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): lds_store_16b_xor16( arith, vector, @@ -1046,7 +1047,7 @@ def hot_loop_scheduler(): os.environ.get("FLYDSL_STAGE1_SKIP_COMPUTE", "0") == "1" ) - if k_main2_py > 0: + if const_expr(k_main2_py > 0): for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): k_iv = k_iv_py next_k1 = k_iv + tile_k @@ -1056,7 +1057,7 @@ def hot_loop_scheduler(): prefetch_ab_scale_tile(next_k1 // pack_K // 128) ) - if _skip_compute: + if const_expr(_skip_compute): store_x_tile_to_lds(x_regs_ping, lds_base_ping) gpu.barrier() a0_prefetch_ping = None @@ -1112,7 +1113,7 @@ def hot_loop_scheduler(): a0_prefetch_pong = None - if odd_k_tiles: + if const_expr(odd_k_tiles): acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( acc_gate, acc_up, @@ -1184,7 +1185,7 @@ def hot_loop_scheduler(): _mask_even_i32 = fx.Int32(0xFFFFFFFE) - if _use_cshuffle_epilog: + if const_expr(_use_cshuffle_epilog): if lds_out is None: raise RuntimeError( "CShuffle epilogue enabled but lds_out is not allocated/aliased." @@ -1215,7 +1216,7 @@ def write_row_to_lds( _t2 = fused2 & mask24_i32 # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load( sorted_w_rsrc, row_safe, vec_width=1, dtype=T.f32 ) @@ -1240,12 +1241,12 @@ def write_row_to_lds( vg = vg + gate_bias_list[ni] vu = vu + up_bias_list[ni] - if act == "swiglu": + if const_expr(act == "swiglu"): y = swiglu(vg, vu) else: y = silu(vg) * vu - if doweight_stage1: + if const_expr(doweight_stage1): y = y * tw lds_idx = row_base_lds + col_local @@ -1279,7 +1280,7 @@ def store_pair( idx0 = row_ctx col_i32 = arith.index_cast(T.i32, col_g0) idx_out = idx0 + col_i32 - if out_dtype == "fp8": + if const_expr(out_dtype == "fp8"): frag = vector.bitcast(vec4_f32, frag) frag0 = vector.extract( frag, static_position=[0], dynamic_position=[] @@ -1373,7 +1374,7 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): idx0 = (t2_safe * topk_i32_v + s2_safe) * inter_i32_local # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load( sorted_w_rsrc, row_safe, vec_width=1, dtype=T.f32 ) @@ -1398,12 +1399,12 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): vg = vg + gate_bias_list[ni] vu = vu + up_bias_list[ni] - if act == "swiglu": + if const_expr(act == "swiglu"): y = swiglu(vg, vu) else: y = silu(vg) * vu - if doweight_stage1: + if const_expr(doweight_stage1): y = y * tw y = arith.trunc_f(_out_elem_type(), y) @@ -1816,10 +1817,10 @@ def moe_gemm2( num_valid_idx = arith.index_cast(T.index, num_valid_i32) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16_a: + if const_expr(is_f16_a): sx_rsrc = None else: - if is_f4_a: + if const_expr(is_f4_a): # A2 microscale: packed i32 holding e8m0 bytes for [sorted_size, K/32]. c32 = fx.Index(32) kblk = k_in // c32 @@ -1837,7 +1838,7 @@ def moe_gemm2( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) - if is_f16_b: + if const_expr(is_f16_b): sw_rsrc = None else: # Weight microscale buffer (packed i32 holding e8m0 bytes). @@ -1911,7 +1912,7 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. - if is_f16_a: + if const_expr(is_f16_a): if bytes_per_thread_x % 16 != 0: raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" @@ -2273,8 +2274,8 @@ def compute_tile( epilogue_pf = None bias = None - if prefetch_epilogue: - if enable_bias: + if const_expr(prefetch_epilogue): + if const_expr(enable_bias): bias = [] for ni in range_constexpr(num_acc_n): global_n = by_n + n_tile_base + ni * 16 + lane_mod_16 @@ -2285,7 +2286,7 @@ def compute_tile( ) ) tw_pf = None - if doweight_stage2: + if const_expr(doweight_stage2): tw_pf = [] lane_div_16_mul4_pf = lane_div_16 * arith.index(4) ii_idx_list_pf = [ @@ -2347,6 +2348,8 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): mi_val = fx.Index(mi_idx * 16) curr_row_a_lds = row_a_lds + mi_val + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if ( (a0_prefetch is not None) and (k_idx == 0) @@ -2358,7 +2361,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): curr_row_a_lds, col_base0, lds_base ) - if is_f8_a: + if const_expr(is_f8_a): col_base1 = col_base + 64 a2, a3 = lds_load_packs_k64( curr_row_a_lds, col_base1, lds_base @@ -2419,19 +2422,19 @@ def hot_loop_scheduler(): rocdl.sched_dsrd(2) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) if num_acc_n < 4: rocdl.sched_dsrd(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_dsrd(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_mfma(1) @@ -2498,7 +2501,7 @@ def hot_loop_scheduler(): # When k_main2_py == 0 the loop body is empty; emitting an scf.for # would create a region whose internal SSA values cannot be used # by the post-loop tail code. - if k_main2_py > 0: + if const_expr(k_main2_py > 0): for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): k_iv = k_iv_py next_k1 = k_iv + tile_k @@ -2549,7 +2552,7 @@ def hot_loop_scheduler(): row_a_lds, col_offset_base, lds_base_pong ) - if odd_k_tiles: + if const_expr(odd_k_tiles): # Tail: single remaining tile (already in `b_cur` / `lds_base_pong`). acc, epilogue_pf = compute_tile( acc, @@ -2644,7 +2647,8 @@ def write_row_to_lds( num_acc_n: int, lds_out, ): - if doweight_stage2: + tw = arith.constant(1.0, type=T.f32) + if const_expr(doweight_stage2): tw_idx = (mi * 4) + ii if tw_pf is not None: tw = tw_pf[tw_idx] @@ -2659,10 +2663,10 @@ def write_row_to_lds( v = vector.extract( acc[acc_idx], static_position=[ii], dynamic_position=[] ) - if enable_bias: + if const_expr(enable_bias): v = v + bias_pf[ni] - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw v_out = arith.trunc_f(out_elem(), v) @@ -2697,8 +2701,8 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): col_i32 = arith.index_cast(T.i32, col_g0) idx_elem = idx0 + col_i32 idx_elem_even = idx_elem & mask_even_i32 - if _needs_global_atomic_bf16: - if bool(accumulate): + if const_expr(_needs_global_atomic_bf16): + if const_expr(bool(accumulate)): byte_off = idx_elem_even * c2_i32 byte_off_idx = arith.index_cast(T.index, byte_off) ptr_addr_idx = out_base_idx + byte_off_idx @@ -2718,7 +2722,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) else: byte_off = idx_elem_even * c2_i32 - if bool(accumulate): + if const_expr(bool(accumulate)): atomic_add_f16x2(frag, byte_off) else: buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index 2fa5aab5..84e08313 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -720,7 +720,7 @@ def compute_tile_bs_s1(acc_gate_in, acc_up_in, b_gate_tile_in, b_up_tile_in, current_up = list(acc_up_in) mfma_res_ty = T.f32x4 - if _is_gfx950: + if const_expr(_is_gfx950): def _pack128(x0, x1, x2, x3): v4 = vector.from_elements(T.vec(4, T.i64), [x0, x1, x2, x3]) return vector.bitcast(T.vec(8, T.i32), v4) @@ -737,6 +737,8 @@ def _pack128(x0, x1, x2, x3): col1 = col_offset_base_bytes + arith.index(ku1 * 64) for mi in range_constexpr(m_repeat): curr_row = row_a_lds + arith.index(mi * 16) + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if a0_prefetch is not None and sb == 0 and mi == 0: a0, a1 = a0_prefetch else: @@ -782,7 +784,7 @@ def _pack128(x0, x1, x2, x3): else: mfma_fn = ( mfma_i32_k32 - if is_int8 + if const_expr(is_int8) else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) ) @@ -791,7 +793,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(T.f16x4, v1) def mfma_k64(acc_in, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -850,7 +852,7 @@ def compute_tile( mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 mfma_fn = ( mfma_i32_k32 - if is_int8 + if const_expr(is_int8) else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) ) @@ -867,12 +869,12 @@ def compute_tile( row_up_idx = row_gate_idx + inter_idx sw_gate_pf.append( fx.Float32(1.0) - if is_f16 + if const_expr(is_f16) else buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=T.f32) ) sw_up_pf.append( fx.Float32(1.0) - if is_f16 + if const_expr(is_f16) else buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=T.f32) ) epilogue_pf = (sw_gate_pf, sw_up_pf) @@ -882,7 +884,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(T.f16x4, v1) def mfma_k64(acc_in, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -1041,7 +1043,7 @@ def do_one_stage(acc_gate_in, acc_up_in, k_compute, k_next, lane_div_16_mul4 = lane_div_16 * fx.Index(4) inter_i32_local = inter_i32_v - if _use_cshuffle_epilog: + if const_expr(use_cshuffle_epilog): if lds_out is None: raise RuntimeError("CShuffle epilogue enabled but lds_out is not allocated/aliased.") @@ -1058,7 +1060,7 @@ def write_row_to_lds( ): # Blockscale: dequant already done in compute_tile_bs_s1. # Just apply silu + optional sorted weight. - if doweight_stage1: + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) for ni in range_constexpr(num_acc_n): @@ -1136,7 +1138,7 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): idx0 = (t2 * topk_i32_v + s2) * inter_i32_local # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) _if_valid = scf.IfOp(t_valid) @@ -1894,7 +1896,7 @@ def compute_tile_bs_s2(acc_in, b_tile_in, lds_base, pre_scales, *, a0_prefetch=N current_acc = list(acc_in) mfma_res_ty = T.f32x4 - if _is_gfx950: + if const_expr(_is_gfx950): def _pack128(x0, x1, x2, x3): v4 = vector.from_elements(T.vec(4, T.i64), [x0, x1, x2, x3]) return vector.bitcast(T.vec(8, T.i32), v4) @@ -1909,6 +1911,8 @@ def _pack128(x0, x1, x2, x3): col1 = col_offset_base_bytes + arith.index(ku1 * 64) for mi in range_constexpr(m_repeat): curr_row = row_a_lds + arith.index(mi * 16) + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if a0_prefetch is not None and sb == 0 and mi == 0: a0, a1 = a0_prefetch else: @@ -1941,7 +1945,7 @@ def _pack128(x0, x1, x2, x3): else: mfma_fn = ( mfma_i32_k32 - if is_int8 + if const_expr(is_int8) else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) ) @@ -1950,7 +1954,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(T.f16x4, v1) def mfma_k64(acc0, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -2009,7 +2013,7 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False ) # Also prefetch per-row routed/topk weights (sorted_weights) when enabled. tw_pf = None - if doweight_stage2: + if const_expr(doweight_stage2): tw_pf = [] lane_div_16_mul4_pf = lane_div_16 * fx.Index(4) ii_idx_list_pf = [fx.Index(ii) for ii in range(4)] @@ -2031,7 +2035,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(T.f16x4, v1) def mfma_k64(acc0, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -2240,7 +2244,7 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): # Blockscale: dequant already done in compute_tile_bs_s2, no sw/sx needed here. - if out_is_f32: + if const_expr(out_is_f32): # origin/dev_a16w4: f32 output uses scalar f32 atomics and skips CShuffle/LDS. c4_i32 = fx.Int32(4) @@ -2259,7 +2263,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): t2 = fused2 & mask24_i32 s2 = fused2 >> 24 - if doweight_stage2: + if const_expr(doweight_stage2): tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) idx0 = t2 * model_i32 # i32 element index base @@ -2268,7 +2272,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): col_g = col_g_list[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw col_i32 = arith.index_cast(T.i32, col_g) idx_elem = idx0 + col_i32 @@ -2292,7 +2296,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): # For bf16 global atomics (gfx942 only), precompute the output base address. # gfx950+ has buffer_atomic_pk_add_bf16, so bf16 uses buffer atomics there. out_base_idx = None - if _needs_global_atomic_bf16: + if const_expr(_needs_global_atomic_bf16): out_base_idx = buffer_ops.extract_base_index(arg_out) def write_row_to_lds( @@ -2308,7 +2312,7 @@ def write_row_to_lds( ): # Blockscale: dequant already done in compute_tile_bs_s2. tw = arith.constant(1.0, type=T.f32) - if doweight_stage2: + if const_expr(doweight_stage2): tw = buffer_ops.buffer_load( sorted_w_rsrc, row, vec_width=1, dtype=T.f32 ) @@ -2317,7 +2321,7 @@ def write_row_to_lds( col_local = col_base_local + (ni * 16) acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw v_out = arith.trunc_f(out_elem(), v) @@ -2351,7 +2355,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): col_i32 = arith.index_cast(T.i32, col_g0) idx_elem = idx0 + col_i32 idx_elem_even = idx_elem & mask_even_i32 - if _needs_global_atomic_bf16: + if const_expr(_needs_global_atomic_bf16): # gfx942: no buffer_atomic_pk_add_bf16, use global atomicrmw fadd if bool(accumulate): byte_off = idx_elem_even * c2_i32 @@ -2584,12 +2588,12 @@ def moe_reduction_kernel( x_div = fx.logical_divide(x_tiled[None, tile_i32], fx.make_layout(VEC_WIDTH, 1)) x_thread = x_div[None, tid_i32] - if use_mask: + if const_expr(use_mask): m_idx_i32 = fx.Int32(token_idx * c_topk + fx.Index(k)) mv = buffer_ops.buffer_load(mask_rsrc, m_idx_i32, vec_width=1, dtype=i8_type()) mv_ok = mv != fx.Int8(0) - if n_sub > 1: + if const_expr(n_sub > 1): x_inner = fx.logical_divide(x_thread, fx.make_layout(copy_vec_width, 1)) for si in range_constexpr(n_sub): src = x_inner[None, fx.Int32(si)] if n_sub > 1 else x_thread @@ -2597,18 +2601,18 @@ def moe_reduction_kernel( fx.copy_atom_call(copy_atom, src, r) vec_e = fx.memref_load_vec(r) - if use_mask: + if const_expr(use_mask): zero_e = vector.broadcast(vec_type_e, arith.constant(0.0, type=elem_type())) vec_e = mv_ok.select(vec_e, zero_e) - if elem_bits < 32: + if const_expr(elem_bits < 32): vec_c = vec_e.extf(vec_type_c) else: vec_c = vec_e acc_vecs[si] = acc_vecs[si] + vec_c # ── Store results ── - if n_sub > 1: + if const_expr(n_sub > 1): y_row = Y_buf[tok_i32, None] y_tiled = fx.logical_divide(y_row, fx.make_layout(tile_cols, 1)) y_div = fx.logical_divide(y_tiled[None, tile_i32], fx.make_layout(VEC_WIDTH, 1)) @@ -2616,10 +2620,10 @@ def moe_reduction_kernel( for si in range_constexpr(n_sub): out_vec = acc_vecs[si] - if elem_bits < 32: + if const_expr(elem_bits < 32): out_vec = out_vec.truncf(vec_type_e) - if n_sub > 1: + if const_expr(n_sub > 1): dst = y_inner[None, fx.Int32(si)] else: y_row = Y_buf[tok_i32, None] @@ -2642,7 +2646,7 @@ def moe_reduction_kernel( for k in range_constexpr(topk): k_idx = fx.Index(k) x_idx_i32 = fx.Int32((token_base + k_idx) * c_model_dim + col) - if use_mask: + if const_expr(use_mask): m_idx_i32 = fx.Int32(token_base + k_idx) mv = buffer_ops.buffer_load( mask_rsrc, m_idx_i32, vec_width=1, dtype=i8_type() @@ -2655,12 +2659,12 @@ def moe_reduction_kernel( v = buffer_ops.buffer_load( x_rsrc, x_idx_i32, vec_width=1, dtype=elem_type() ) - if dtype_str in ("f16", "bf16"): + if const_expr(dtype_str in ("f16", "bf16")): v = v.extf(compute_type()) a = a + v out = a - if dtype_str in ("f16", "bf16"): + if const_expr(dtype_str in ("f16", "bf16")): out = out.truncf(elem_type()) y_idx_i32 = fx.Int32(token_idx * c_model_dim + col) buffer_ops.buffer_store(out, y_rsrc, y_idx_i32) From 0d45a29c00fc75d4ec8e4f64015df61bb7efc160 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 13:41:56 +0000 Subject: [PATCH 19/31] [FLYDSL]: rm import --- kernels/mixed_moe_gemm_2stage.py | 1 - 1 file changed, 1 deletion(-) diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index d4fb99b6..804d6faa 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -17,7 +17,6 @@ from flydsl.expr.typing import T from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr -from test.FLYIR2.FlyDSL.python.flydsl.expr.primitive import const_expr try: from flydsl.runtime.device import supports_bf16_global_atomics From efef3d6d2de9766423e2754f85b7c2bb69b7df81 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 14:46:19 +0000 Subject: [PATCH 20/31] [FLYDSL]: Add initialization --- kernels/fused_rope_cache_kernel.py | 2 +- kernels/layernorm_kernel.py | 1 + kernels/rmsnorm_kernel.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/kernels/fused_rope_cache_kernel.py b/kernels/fused_rope_cache_kernel.py index d9258487..48d4c69c 100644 --- a/kernels/fused_rope_cache_kernel.py +++ b/kernels/fused_rope_cache_kernel.py @@ -243,7 +243,7 @@ def k_cache_kernel( i32_reg_ty = fx.MemRefType.get(T.i32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) i32_reg_lay = fx.make_layout(1, 1) - if not flash_layout: + if const_expr(not flash_layout): copy_atom_elem = fx.make_copy_atom(fx.rocdl.BufferCopy16b(), elem_bits) elem_reg_ty = fx.MemRefType.get( elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index 9e7ae64a..642ebc79 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -223,6 +223,7 @@ def _store_vec(val, div_tensor, idx): y = (x - mean) * rstd y = y * g_cur + b_cur + out_e = y.to(elem_dtype) if dtype_str == "bf16": if USE_HW_CVT_PK_BF16_F32: out_e = y.to(elem_dtype) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index 5772544c..09bd36f2 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -191,6 +191,7 @@ def _store_vec(val, div_tensor, idx): y = (x * rrms) * g + out_e = y.to(elem_dtype) if dtype_str == "bf16": if USE_HW_CVT_PK_BF16_F32: out_e = y.to(elem_dtype) From c50606be7ab500a3cd9b6f68a5c6254eecdf8418 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 15:04:13 +0000 Subject: [PATCH 21/31] [FLYDSL]: MI355 const_expr cases --- kernels/blockscale_preshuffle_gemm.py | 2 ++ kernels/mixed_moe_gemm_2stage.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index a3552b83..777a215b 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -531,6 +531,8 @@ def compute_tile_blockscale( for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) + a0 = ArithValue(arith.constant(-1, type=T.i64)) + a1 = ArithValue(arith.constant(-1, type=T.i64)) if a0_prefetch is not None and sb == 0 and mi == 0: a0, a1 = a0_prefetch else: diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index 804d6faa..a5dfc859 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -909,7 +909,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): curr_row_a_lds, col_base0, lds_base ) - if is_f8_a: + if const_expr(is_f8_a): col_base1 = col_base + 64 a2, a3 = lds_load_packs_k64( curr_row_a_lds, col_base1, lds_base From 723056d6216dca88e747769f8882b43108b2360a Mon Sep 17 00:00:00 2001 From: root Date: Fri, 10 Apr 2026 18:29:19 +0000 Subject: [PATCH 22/31] [FLYDSL]: MI355 add initialization --- kernels/mixed_moe_gemm_2stage.py | 1 + kernels/preshuffle_gemm.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index a5dfc859..f2770b30 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -1911,6 +1911,7 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. + x_load_bytes = 16 if const_expr(is_f16_a): if bytes_per_thread_x % 16 != 0: raise ValueError( diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index c2df5998..16f18500 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -1065,7 +1065,7 @@ def _build_scheduler(numer: int, denom: int): prev = cur return out - if _is_gfx942: + if const_expr(_is_gfx942): mfma_group = num_acc_n mfma_total = (k_unroll * 2) * m_repeat * mfma_group mfma_per_iter = 2 * mfma_group @@ -1120,6 +1120,7 @@ def _build_scheduler(numer: int, denom: int): dvmem_preload_eff = min(int(dvmem_preload), num_gmem_loads) vmem_remaining = num_gmem_loads - dvmem_preload_eff dsrd_remaining = num_ds_load - dsrd_preload_eff + vmem_schedule = [] if vmem_remaining > 0 and vmem_remaining < mfma_total: vmem_schedule = (_build_scheduler(vmem_remaining, vmem_remaining) + [0] * (mfma_total - vmem_remaining)) From d2e49c2ee40bf720587a587dc8608bce0a959d72 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Fri, 10 Apr 2026 19:29:49 +0000 Subject: [PATCH 23/31] [FLYDSL]: a0 a1 init --- kernels/preshuffle_gemm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 16f18500..be1a1565 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -863,6 +863,8 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for imxdl in range_constexpr(_fp4_pack_M): mi_idx = mi_p * _fp4_pack_M + imxdl curr_row_a_lds = row_a_lds + (mi_idx * 16) + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): a0, a1 = a0_prefetch else: @@ -894,6 +896,8 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if (a0_prefetch is not None) and (ku0 == 0) and (mi == 0): a0, a1 = a0_prefetch else: @@ -942,6 +946,8 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): col_base = col_offset_base_bytes + ki64 for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)): a0, a1 = a0_prefetch else: From 3c7c39d4788639403699c74bc7f90347f841f4dd Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Sat, 11 Apr 2026 02:33:30 +0000 Subject: [PATCH 24/31] [FLYDSL]: MI355 const_expr --- kernels/preshuffle_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index be1a1565..49929c33 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -819,7 +819,7 @@ def compute_tile(accs_in, b_tile_in, lds_buffer, *, is_last_tile=False, a0_prefe str(gpu_arch).startswith("gfx95") and (not is_int8) and (not is_int4) and (not is_f16_or_bf16) ) - if use_mfma_scale_128: + if const_expr(use_mfma_scale_128): if (int(tile_k) % 128) != 0: raise ValueError( f"tile_k must be divisible by 128 for mfma_scale_x128, got tile_k={tile_k}" @@ -842,7 +842,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(T.vec(4, T.i64), [x0, x1, x2, x3]) return vector.bitcast(T.vec(8, T.i32), v4) - if is_fp4: + if const_expr(is_fp4): _fp4_a_sc, _fp4_b_sc = fp4_scales if fp4_scales else ([], []) ku128_iters = 1 if _fp4_tilek128 else _k_unroll_packed ikxdl_iters = 1 if _fp4_tilek128 else _fp4_pack_K From 48d8e5bd131b625b53c915b75dee3b27f3ec9df4 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Sat, 11 Apr 2026 02:47:20 +0000 Subject: [PATCH 25/31] [FLYDSL]: const_expr --- kernels/preshuffle_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 49929c33..1daeb916 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -795,7 +795,7 @@ def load_fp4_scale_chunk(base_k): # ── Compute tile (MFMA) ─────────────────────────────────────────── def compute_tile(accs_in, b_tile_in, lds_buffer, *, is_last_tile=False, a0_prefetch=None, fp4_scales=None, fp4_scale_half=0): scales_pf = {} - if is_last_tile and (not is_f16_or_bf16): + if const_expr(is_last_tile and (not is_f16_or_bf16)): s_b_vals = [] for ni in range_constexpr(num_acc_n): col_g = by_n + n_tile_base + (ni * 16) + lane_mod_16 From 461c9cdeaaa5dc746b2b599b0f4213bcf5401c30 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Sat, 11 Apr 2026 07:11:20 +0000 Subject: [PATCH 26/31] [FLYDSL]: const_expr(_fp4_tilek128) --- kernels/preshuffle_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 1daeb916..73d49f9a 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -1269,7 +1269,7 @@ def _build_pingpong_body( ) b_tile_pong_in = _unflatten_b_tile(bt_flat_in) - if _fp4_tilek128: + if const_expr(_fp4_tilek128): next_k1 = k_iv + tile_k if const_expr(use_async_copy): prefetch_a_to_lds( From c4deaaf0ff3798826851d0c1bf66a44cc271f949 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Sun, 12 Apr 2026 08:46:01 +0000 Subject: [PATCH 27/31] [FLYDSL]: Cases Normalization --- kernels/moe_blockscale_2stage.py | 10 +++++----- kernels/moe_gemm_2stage.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index 84e08313..66549cb7 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -475,7 +475,7 @@ def load_x_tile(base_k, x_load_bytes_v): parts.append(vector.bitcast(T.i32x4, x_vec)) elif const_expr(x_load_bytes_v == 8): parts.append(x_vec) - if const_expr(x_load_bytes_v == 4): + else: parts.append(x_vec) return parts @@ -587,7 +587,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, @@ -602,7 +602,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes_v == 8: + elif const_expr(x_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, @@ -1764,7 +1764,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes_v == 16: + if const_expr(x_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, @@ -1779,7 +1779,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes_v == 8: + elif const_expr(x_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 14866166..bc060c61 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -512,9 +512,9 @@ def load_x_tile(base_k, x_load_bytes_v): x_vec = load_x(idx_i32, x_load_bytes_v) if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - if const_expr(x_load_bytes_v == 8): + elif const_expr(x_load_bytes_v == 8): parts.append(x_vec) - if const_expr(x_load_bytes_v == 4): + else: parts.append(x_vec) return parts @@ -1721,9 +1721,9 @@ def load_x_tile(base_k, x_load_bytes_v): x_vec = load_x(idx_i32, x_load_bytes_v) if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - if const_expr(x_load_bytes_v == 8): + elif const_expr(x_load_bytes_v == 8): parts.append(vector.bitcast(T.vec(2, T.i32), x_vec)) - if const_expr(x_load_bytes_v == 4): + else: parts.append(vector.bitcast(T.vec(1, T.i32), x_vec)) return parts @@ -1860,7 +1860,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - if const_expr(x_load_bytes_v == 8): + elif const_expr(x_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, @@ -1874,7 +1874,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): lds_base=lds_base, vec_part_i32x2=vec_x_in_parts[i], ) - if const_expr(x_load_bytes_v == 4): + else: lds_store_4b_xor16( arith, vector, From 82dc381e13ea9f5a545b8d7a08869f7965ad6746 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Sun, 12 Apr 2026 08:51:51 +0000 Subject: [PATCH 28/31] [FLYDSL]: rm notes --- kernels/preshuffle_gemm.py | 1 - python/flydsl/compiler/ast_rewriter.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 73d49f9a..31e814cb 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -1121,7 +1121,6 @@ def _build_scheduler(numer: int, denom: int): num_a_scale_loads = num_fp4_scale_k_groups * (m_repeat // 2) num_b_scale_loads = num_fp4_scale_k_groups * (num_acc_n // 2) num_gmem_loads += num_a_scale_loads + num_b_scale_loads - # print("mfma_total, dswr_tail, dstr_advance", mfma_total, dswr_tail, dstr_advance) dsrd_preload_eff = min(int(dsrd_preload), num_ds_load) dvmem_preload_eff = min(int(dvmem_preload), num_gmem_loads) vmem_remaining = num_gmem_loads - dvmem_preload_eff diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index eb329940..03b4f953 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -740,9 +740,6 @@ def visit_Call(self, node): invoked_args = [name for name in invoked_args if name not in write_args] write_args = [name for name in write_args if in_active_symbols(name)] invoked_args = [name for name in invoked_args if in_active_symbols(name)] - # print(f"write_args: {write_args}") - # print(f"invoked_args: {invoked_args}") - # print(f"active_symbols: {active_symbols}") return write_args + invoked_args @staticmethod From a98680f87b4504752ebb8941b15b257eaa2e4dc7 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Sun, 12 Apr 2026 09:23:47 +0000 Subject: [PATCH 29/31] [FLYDSL]: add if/else ST --- tests/system/test_control_flow_compile.py | 106 ++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/tests/system/test_control_flow_compile.py b/tests/system/test_control_flow_compile.py index ce240e24..5d7a6759 100644 --- a/tests/system/test_control_flow_compile.py +++ b/tests/system/test_control_flow_compile.py @@ -50,3 +50,109 @@ def vecAbs( c = torch.empty_like(a) t_a = flyc.from_dlpack(a).mark_layout_dynamic(leading_dim=0, divisibility=vec) vecAbs(t_a, c, size, size, threads, vec) + + +def test_control_flow_dynamic_if_end_to_end_numeric(monkeypatch): + if not torch.cuda.is_available(): + pytest.skip("CUDA device is required for dynamic if end-to-end test") + # Avoid compile-cache hits so dynamic dispatch is exercised in this test process. + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + + @flyc.kernel + def dynamicIfKernel( + A: fx.Tensor, + B: fx.Tensor, + C: fx.Tensor, + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + tile_elems = block_dim * vec_width + + A = fx.rocdl.make_buffer_tensor(A) + B = fx.rocdl.make_buffer_tensor(B) + C = fx.rocdl.make_buffer_tensor(C) + + tA = fx.logical_divide(A, fx.make_layout(tile_elems, 1)) + tB = fx.logical_divide(B, fx.make_layout(tile_elems, 1)) + tC = fx.logical_divide(C, fx.make_layout(tile_elems, 1)) + + tA = fx.slice(tA, (None, bid)) + tB = fx.slice(tB, (None, bid)) + tC = fx.slice(tC, (None, bid)) + + tA = fx.logical_divide(tA, fx.make_layout(vec_width, 1)) + tB = fx.logical_divide(tB, fx.make_layout(vec_width, 1)) + tC = fx.logical_divide(tC, fx.make_layout(vec_width, 1)) + + reg_ty = fx.MemRefType.get( + fx.T.f32(), fx.LayoutType.get(vec_width, 1), fx.AddressSpace.Register + ) + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32) + + rA = fx.memref_alloca(reg_ty, fx.make_layout(vec_width, 1)) + rB = fx.memref_alloca(reg_ty, fx.make_layout(vec_width, 1)) + rC = fx.memref_alloca(reg_ty, fx.make_layout(vec_width, 1)) + + fx.copy_atom_call(copy_atom, fx.slice(tA, (None, tid)), rA) + fx.copy_atom_call(copy_atom, fx.slice(tB, (None, tid)), rB) + + vA = fx.memref_load_vec(rA) + vB = fx.memref_load_vec(rB) + vOut = fx.arith.addf(vA, vB) + + # Runtime branch (tid/bid come from GPU execution), so this should lower to dynamic scf.if. + if (tid % 2) == 0: + vOut = fx.arith.addf(vOut, vA) + else: + vOut = fx.arith.subf(vOut, vB) + + if (bid % 2) == 0: + vOut = fx.arith.addf(vOut, vB) + else: + vOut = fx.arith.subf(vOut, vA) + + fx.memref_store_vec(vOut, rC) + fx.copy_atom_call(copy_atom, rC, fx.slice(tC, (None, tid))) + + @flyc.jit + def dynamicIfVec( + A: fx.Tensor, + B: fx.Tensor, + C, + n: fx.Int32, + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], + stream: fx.Stream = fx.Stream(None), + ): + tile_elems = block_dim * vec_width + grid_x = (n + tile_elems - 1) // tile_elems + dynamicIfKernel(A, B, C, block_dim, vec_width).launch( + grid=(grid_x, 1, 1), block=(block_dim, 1, 1), stream=stream + ) + + block_dim = 64 + vec_width = 4 + num_blocks = 5 + size = block_dim * vec_width * num_blocks + + a = torch.randn(size, device="cuda", dtype=torch.float32) + b = torch.randn(size, device="cuda", dtype=torch.float32) + c = torch.empty_like(a) + + t_a = flyc.from_dlpack(a).mark_layout_dynamic(leading_dim=0, divisibility=vec_width) + dynamicIfVec(t_a, b, c, size, block_dim, vec_width) + torch.cuda.synchronize() + + a3 = a.view(num_blocks, block_dim, vec_width) + b3 = b.view(num_blocks, block_dim, vec_width) + tid = torch.arange(block_dim, device="cuda").view(1, block_dim, 1) + bid = torch.arange(num_blocks, device="cuda").view(num_blocks, 1, 1) + + ref = a3 + b3 + ref = torch.where((tid % 2) == 0, ref + a3, ref - b3) + ref = torch.where((bid % 2) == 0, ref + b3, ref - a3) + ref = ref.reshape(-1) + + torch.testing.assert_close(c, ref, rtol=1e-5, atol=1e-5) From 600b09d403f90af6f75ce17f80f82bf30aa7c958 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 13 Apr 2026 09:15:06 +0000 Subject: [PATCH 30/31] Support kwargs to set atom_state --- examples/04-preshuffle_gemm.py | 23 ++++++++++++++--------- python/flydsl/expr/primitive.py | 16 ++++++++++++---- python/flydsl/expr/typing.py | 24 ++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 15 deletions(-) diff --git a/examples/04-preshuffle_gemm.py b/examples/04-preshuffle_gemm.py index a125a30a..869dc322 100644 --- a/examples/04-preshuffle_gemm.py +++ b/examples/04-preshuffle_gemm.py @@ -73,20 +73,25 @@ def gemm_kernel( mma_frag_C_f16 = fx.make_fragment_like(mma_frag_C, fx.Float16.ir_type) mma_frag_C_retile = thr_copy_r2g_C.retile(mma_frag_C_f16) - def run_pipeline_stage(read_stage, next_k, read_next=True, buffer_copy_128b_v=None): + gA_k_stride = fx.get_scalar(gA_k.stride[2]) + gB_k_stride = fx.get_scalar(gB_k.stride[2]) + + def run_pipeline_stage(read_stage, next_k, read_next=True): write_stage = read_stage ^ 1 if read_next: next_k = fx.Int32(next_k) fx.copy( - buffer_copy_128b_v.set_value("soffset", next_k * BLOCK_K), - thr_gA_k[None, None, None, 0], # global offset is added on the soffset of buffer_copy_atom + buffer_copy_128b, + thr_gA_k[None, None, None, 0], # global offset is added on the soffset of buffer_copy_atom copy_frag_A, + soffset=next_k * gA_k_stride, ) fx.copy( - buffer_copy_128b_v, - thr_gB_k[None, None, None, next_k], + buffer_copy_128b, + thr_gB_k[None, None, None, 0], mma_frag_B_retile[None, None, None, write_stage], + soffset=next_k * gB_k_stride, ) for block_k_iter in fx.range_constexpr(BLOCK_K // 32): @@ -142,11 +147,11 @@ def sched_main_iter(with_vmem=False, with_dswr=False): fx.gpu.barrier() for k_iter in range(0, K // BLOCK_K - 2, 2): - run_pipeline_stage(read_stage=0, next_k=k_iter + 1, buffer_copy_128b_v=buffer_copy_128b) - run_pipeline_stage(read_stage=1, next_k=k_iter + 2, buffer_copy_128b_v=buffer_copy_128b) + run_pipeline_stage(read_stage=0, next_k=k_iter + 1) + run_pipeline_stage(read_stage=1, next_k=k_iter + 2) - run_pipeline_stage(read_stage=0, next_k=K // BLOCK_K - 1, buffer_copy_128b_v=buffer_copy_128b) - run_pipeline_stage(read_stage=1, next_k=None, read_next=False, buffer_copy_128b_v=buffer_copy_128b) + run_pipeline_stage(read_stage=0, next_k=K // BLOCK_K - 1) + run_pipeline_stage(read_stage=1, next_k=None, read_next=False) mma_frag_C_f16.store(fx.arith.trunc_f(fx.T.VectorType.get([64], fx.T.f16()), mma_frag_C.load())) fx.copy(buffer_copy_16b, mma_frag_C_retile, thr_gC) diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 21aa55ed..f2eba9fe 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -748,16 +748,24 @@ def mma_make_fragment(operand_id, tiled_mma, input, loc=None, ip=None): @traced_op -def copy(copy_atom, src, dst, *, pred=None, loc=None, ip=None): - return fly.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip) +def copy(copy_atom, src, dst, *, pred=None, loc=None, ip=None, **kwargs): + return fly.copy(copy_atom.set_value(kwargs), src, dst, pred=pred, loc=loc, ip=ip) @traced_op -def gemm(mma_atom, d, a, b, c, *, traversal_order=None, traversal_layout=None, loc=None, ip=None): +def gemm(mma_atom, d, a, b, c, *, traversal_order=None, traversal_layout=None, loc=None, ip=None, **kwargs): if traversal_order is not None and traversal_layout is not None: raise ValueError("Only one of 'traversal_order' or 'traversal_layout' can be specified, not both") return fly.gemm( - mma_atom, d, a, b, c, traversal_order=traversal_order, traversal_layout=traversal_layout, loc=loc, ip=ip + mma_atom.set_value(kwargs), + d, + a, + b, + c, + traversal_order=traversal_order, + traversal_layout=traversal_layout, + loc=loc, + ip=ip, ) diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index 852d2071..425dd093 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -638,8 +638,18 @@ def layout_dst_tv(self): def layout_ref_tv(self): return static(self.type.tv_layout_ref) + @overload + def set_value(self, field: str, value, loc=None, ip=None): ... + @overload + def set_value(self, field: dict, loc=None, ip=None): ... + @traced_op - def set_value(self, field, value, loc=None, ip=None): + def set_value(self, field, value=None, loc=None, ip=None): + if isinstance(field, dict): + result = self + for k, v in field.items(): + result = atom_set_value(result, k, v, loc=loc, ip=ip) + return result return atom_set_value(self, field, value, loc=loc, ip=ip) @@ -669,8 +679,18 @@ def layout_B_tv(self): def layout_C_tv(self): return static(self.type.tv_layout_c) + @overload + def set_value(self, field: str, value, loc=None, ip=None): ... + @overload + def set_value(self, field: dict, loc=None, ip=None): ... + @traced_op - def set_value(self, field, value, loc=None, ip=None): + def set_value(self, field, value=None, loc=None, ip=None): + if isinstance(field, dict): + result = self + for k, v in field.items(): + result = atom_set_value(result, k, v, loc=loc, ip=ip) + return result return atom_set_value(self, field, value, loc=loc, ip=ip) From 38297fde7321d7e6bdedc66c88cfb64bc92c3d61 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Tue, 14 Apr 2026 03:28:50 +0000 Subject: [PATCH 31/31] [FLYDSL]: import overload --- python/flydsl/expr/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index deafb0d9..42200454 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -4,7 +4,7 @@ import ctypes import enum from inspect import isclass -from typing import Generic, Type, TypeVar +from typing import Generic, Type, TypeVar, overload from flydsl.runtime.device import get_rocm_arch