Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion kernels/kernels_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,48 @@
but this module is intentionally small and MLIR-dialect facing.
"""

from contextlib import contextmanager

from flydsl._mlir import ir
from flydsl.expr.typing import T
from flydsl._mlir.dialects import arith as _std_arith, builtin, gpu as _gpu, llvm as _llvm
from flydsl._mlir.dialects import arith as _std_arith, builtin, gpu as _gpu, llvm as _llvm, scf as _scf
from flydsl.expr import buffer_ops
from flydsl.runtime.device import get_rocm_arch, is_rdna_arch


@contextmanager
def _if_then(if_op, scf=None):
"""Context manager for SCF IfOp then-region across old/new Python APIs.

Ensures the then block always ends with a YieldOp.
The optional *scf* parameter is accepted for backward compatibility
but ignored — the module-level import is used.
"""
with ir.InsertionPoint(if_op.then_block):
try:
yield if_op.then_block
finally:
blk = if_op.then_block
if (not blk.operations) or not isinstance(blk.operations[-1], _scf.YieldOp):
_scf.YieldOp([])


_VALID_A_DTYPES = frozenset(("fp8", "fp16", "int8", "fp4"))
_VALID_B_DTYPES = frozenset(("fp8", "fp16", "int8", "int4", "fp4"))


def validate_moe_dtypes(a_dtype: str, b_dtype: str) -> None:
"""Validate a_dtype/b_dtype strings for mixed MoE kernels."""
if a_dtype not in _VALID_A_DTYPES:
raise ValueError(
f"a_dtype must be one of {tuple(sorted(_VALID_A_DTYPES))}, got {a_dtype!r}"
)
if b_dtype not in _VALID_B_DTYPES:
raise ValueError(
f"b_dtype must be one of {tuple(sorted(_VALID_B_DTYPES))}, got {b_dtype!r}"
)


def dtype_to_elem_type(dtype_str: str):
"""Map a dtype string to its MLIR scalar type.

Expand Down
181 changes: 181 additions & 0 deletions kernels/layout_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

"""Layout helpers for GEMM kernels.

Parses fly layout type strings (e.g. '(4,64):(64,1)') and computes
idx2crd / crd2idx with plain arith ops for static layouts.
Falls back to fly dialect ops for dynamic layouts.

Optimisation: power-of-2 strides/shapes emit ``shrui`` / ``andi`` instead of
``divui`` / ``remui``, avoiding 10-15-cycle V_DIV sequences on CDNA GPUs.
"""

import math as _math
import re
import builtins as _builtins

import flydsl.expr as fx
from flydsl._mlir import ir
from flydsl.expr import arith
from flydsl.expr.arith import ArithValue
from flydsl.expr.typing import T


def _wrap(v):
"""Wrap raw ir.Value in ArithValue for operator overloading compatibility."""
if isinstance(v, ArithValue):
return v
if isinstance(v, ir.Value):
return ArithValue(v)
return v


def _is_pow2(n):
"""Return True when *n* is a positive power of two."""
return n > 0 and (n & (n - 1)) == 0


def _div_pow2(val, divisor):
"""Unsigned divide index *val* by a **compile-time** power-of-2 *divisor*.

Emits ``arith.shrui`` (1 VALU cycle) instead of ``arith.divui``
(10-15 VALU cycles on CDNA).
"""
shift = _math.log2(divisor)
assert shift == int(shift), f"{divisor} is not a power of 2"
return arith.shrui(val, arith.index(int(shift)))


def _mod_pow2(val, modulus):
"""Unsigned remainder of index *val* by a **compile-time** power-of-2 *modulus*.

Emits ``arith.andi`` (1 VALU cycle) instead of ``arith.remui``.
"""
return arith.andi(val, arith.index(modulus - 1))


def _parse_dim(tok):
"""Parse a single dimension token: '?' -> None, otherwise int."""
tok = tok.strip()
return None if tok == "?" else int(tok)


def _parse_layout(ly):
"""Parse '(s0,s1,...):(d0,d1,...)' -> (shapes, strides) as lists (None for '?')."""
ly_str = str(ly.type) if hasattr(ly, "type") else str(ly)
m = re.search(r"\(([^)]+)\):\(([^)]+)\)", ly_str)
if not m:
return None
shapes = [_parse_dim(s) for s in m.group(1).split(",")]
strides = [_parse_dim(s) for s in m.group(2).split(",")]
return shapes, strides


def _has_dynamic_strides(strides):
"""Check if any stride is dynamic (None)."""
return any(s is None for s in strides)


def idx2crd(idx, layout):
"""Decompose flat index into a list of coordinate values.

For static layouts, computes coordinates with plain arith ops.
Power-of-2 strides/shapes use shift/mask instead of div/rem.
For dynamic layouts, falls back to fx.idx2crd + fx.get.
"""
parsed = _parse_layout(layout)

if parsed is None or _has_dynamic_strides(parsed[1]):
result = fx.idx2crd(idx, layout)
ndims = len(parsed[1]) if parsed else 1
return [_wrap(fx.get(result, i)) for i in range(ndims)]

if hasattr(idx, "type") and str(idx.type) != "index":
idx = arith.index_cast(T.index, idx)
shapes, strides = parsed
ndims = len(strides)

ordered = sorted(
[
(i, s, sz)
for i, s, sz in _builtins.zip(range(ndims), strides, shapes)
if s != 0
],
key=lambda x: x[1],
reverse=True,
)
coords = [None] * ndims
remaining = idx
for i, stride_val, size_val in ordered:
if stride_val == 1:
c = remaining
elif _is_pow2(stride_val):
c = _div_pow2(remaining, stride_val)
else:
c = remaining / arith.index(stride_val)
if size_val is not None:
if _is_pow2(size_val):
c = _mod_pow2(c, size_val)
else:
c = c % arith.index(size_val)
coords[i] = c
for i in range(ndims):
if coords[i] is None:
coords[i] = remaining
return coords


def crd2idx(crd, layout):
"""Compute flat index from a coordinate tuple/list.

For static layouts, computes with plain arith ops.
For dynamic layouts, falls back to fx.crd2idx with fx.make_coord.
"""
if not isinstance(crd, (list, tuple)):
crd = [crd]
parsed = _parse_layout(layout)

if parsed is None or _has_dynamic_strides(parsed[1]):
crd_i32 = []
for c in crd:
cv = c
if isinstance(cv, int):
cv = arith.constant(cv, T.i32)
crd_i32.append(cv)
continue
if isinstance(cv, ArithValue):
raw = cv.ir_value() if hasattr(cv, "ir_value") else cv
if isinstance(raw, ir.Value) and isinstance(raw.type, ir.IndexType):
cv = arith.index_cast(T.i32, raw)
else:
cv = raw
elif isinstance(cv, ir.Value) and isinstance(cv.type, ir.IndexType):
cv = arith.index_cast(T.i32, cv)
elif hasattr(cv, "ir_value"):
raw = cv.ir_value()
if isinstance(raw, ir.Value) and isinstance(raw.type, ir.IndexType):
cv = arith.index_cast(T.i32, raw)
else:
cv = raw
crd_i32.append(cv)
coord_val = fx.make_coord(*crd_i32)
result = fx.crd2idx(coord_val, layout)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = arith.index_cast(T.index, scalar)
return _wrap(scalar)

_, strides = parsed
result = None
for coord_v, stride_v in _builtins.zip(crd, strides):
if stride_v == 0:
continue
term = coord_v if stride_v == 1 else coord_v * arith.index(stride_v)
result = term if result is None else result + term
return result if result is not None else arith.index(0)


def get(int_tuple, mode):
"""Extract element at `mode` from a Python list/tuple."""
return int_tuple[mode]
13 changes: 1 addition & 12 deletions kernels/mfma_epilogues.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,13 @@

from __future__ import annotations

from contextlib import contextmanager
from typing import Callable

from flydsl._mlir import ir
import flydsl.expr as fx
from flydsl.expr.typing import T


@contextmanager
def _if_then(if_op, scf):
"""Compat helper for SCF IfOp then-region across old/new Python APIs."""
with ir.InsertionPoint(if_op.then_block):
try:
yield if_op.then_block
finally:
blk = if_op.then_block
if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp):
scf.YieldOp([])
from kernels.kernels_common import _if_then


def default_epilog(
Expand Down
Loading
Loading