Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8a797b1
[FLYDSL]: if_dispatch dynamic process
xudoyuan Mar 30, 2026
5a2d2aa
[FLYDSL]: Fix for local variable issues in if/else statements
xudoyuan Mar 30, 2026
ab074ea
Merge branch 'main' into yxd/if_dispatch_dynamic
xudoyuan Mar 30, 2026
4e075dc
[FLYDSL]: Derivation of dynamic if/else results
xudoyuan Apr 2, 2026
aab6097
[FLYDSL]: _could_be_dynamic Reconstruction
xudoyuan Apr 3, 2026
938992e
[FLYDSL]: Only test and verification
xudoyuan Apr 3, 2026
bc65391
[FLYDSL]: Standardized use cases
xudoyuan Apr 7, 2026
652fcc3
[FLYDSL]: Supplement the static condition of ast.Compare
xudoyuan Apr 7, 2026
7f1cfa1
[FLYDSL]: dsl_not_/dsl_and_/dsl_or_ is dynamic.
xudoyuan Apr 8, 2026
7faf35f
[FLYDSL]: Fixed the issue of type merging in Python's multi-type retu…
xudoyuan Apr 8, 2026
95ea586
Merge branch 'main' into yxd/if_dispatch_dynamic_tests_refactor
xudoyuan Apr 8, 2026
c8fbb43
[FLYDSL]: Refactor the bitcast code
xudoyuan Apr 9, 2026
acc4b98
[FLYDSL]: moe_blockscale_2stage.py variable definition
xudoyuan Apr 9, 2026
427b66c
Merge branch 'main' into yxd/if_dispatch_dynamic_tests_refactor
xudoyuan Apr 9, 2026
e73fa4f
Merge branch 'main' into yxd/if_dispatch_dynamic_tests_refactor
xudoyuan Apr 9, 2026
0e034a5
[FLYDSL]: kernels/moe_gemm_2stage.py refactor
xudoyuan Apr 9, 2026
30dc632
[FLYDSL]: kernels/preshuffle_gemm.py refactor
xudoyuan Apr 10, 2026
d2dcc3f
[FLYDSL]: if const_expr(ast.Name)
xudoyuan Apr 10, 2026
3f82c18
[FLYDSL]: if const_expr(ast.compare(ast.name))
xudoyuan Apr 10, 2026
0415671
[FLYDSL]: Non-None initialization + if dynamic(cond)
xudoyuan Apr 10, 2026
b982dbd
[FLYDSL]: Example parameters
xudoyuan Apr 10, 2026
2c740a0
[FLYDSL]: gfx950 const_expr
xudoyuan Apr 10, 2026
0d45a29
[FLYDSL]: rm import
xudoyuan Apr 10, 2026
0f1f913
Merge branch 'main' into yxd/if_dispatch_dynamic_tests_refactor
xudoyuan Apr 10, 2026
efef3d6
[FLYDSL]: Add initialization
xudoyuan Apr 10, 2026
c50606b
[FLYDSL]: MI355 const_expr cases
xudoyuan Apr 10, 2026
723056d
[FLYDSL]: MI355 add initialization
Apr 10, 2026
d2e49c2
[FLYDSL]: a0 a1 init
xudoyuan Apr 10, 2026
3c7c39d
[FLYDSL]: MI355 const_expr
xudoyuan Apr 11, 2026
48d8e5b
[FLYDSL]: const_expr
xudoyuan Apr 11, 2026
461c9cd
[FLYDSL]: const_expr(_fp4_tilek128)
xudoyuan Apr 11, 2026
c4deaaf
[FLYDSL]: Cases Normalization
xudoyuan Apr 12, 2026
82dc381
[FLYDSL]: rm notes
xudoyuan Apr 12, 2026
a98680f
[FLYDSL]: add if/else ST
xudoyuan Apr 12, 2026
600b09d
Support kwargs to set atom_state
sjfeng1999 Apr 13, 2026
c82a757
Merge branch 'main' into yxd/if_dispatch_dynamic_tests_refactor
xudoyuan Apr 14, 2026
38297fd
[FLYDSL]: import overload
xudoyuan Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions examples/04-preshuffle_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,25 @@ def gemm_kernel(
mma_frag_C_f16 = fx.make_fragment_like(mma_frag_C, fx.Float16.ir_type)
mma_frag_C_retile = thr_copy_r2g_C.retile(mma_frag_C_f16)

gA_k_stride = fx.get_scalar(gA_k.stride[2])
gB_k_stride = fx.get_scalar(gB_k.stride[2])

def run_pipeline_stage(read_stage, next_k, read_next=True):
write_stage = read_stage ^ 1

if read_next:
next_k = fx.Int32(next_k)
fx.copy(
buffer_copy_128b.set_value("soffset", next_k * BLOCK_K),
thr_gA_k[None, None, None, 0], # global offset is added on the soffset of buffer_copy_atom
buffer_copy_128b,
thr_gA_k[None, None, None, 0], # global offset is added on the soffset of buffer_copy_atom
copy_frag_A,
soffset=next_k * gA_k_stride,
)
fx.copy(
buffer_copy_128b,
thr_gB_k[None, None, None, next_k],
thr_gB_k[None, None, None, 0],
mma_frag_B_retile[None, None, None, write_stage],
soffset=next_k * gB_k_stride,
)

for block_k_iter in fx.range_constexpr(BLOCK_K // 32):
Expand Down
84 changes: 49 additions & 35 deletions kernels/blockscale_preshuffle_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,44 +321,44 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer):
c_chunk_a = fx.Index(chunk_i32_a)
tx_i32_base = tx * c_chunk_a

def load_a(idx_i32):
if a_load_bytes == 16:
def load_a(idx_i32, a_load_bytes_v):
if const_expr(a_load_bytes_v == 16):
return buffer_copy_gmem16_dwordx4(
buffer_ops, vector,
elem_type=T.f8, idx_i32=idx_i32,
rsrc=a_rsrc, vec_elems=16, elem_bytes=elem_bytes,
)
if a_load_bytes == 8:
if const_expr(a_load_bytes_v == 8):
return buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=2, dtype=T.i32)
return buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=1, dtype=T.i32)

def a_tile_chunk_coord_i32(i: int):
def a_tile_chunk_coord_i32(i: int, tx_i32_base_v, chunk_i32_a_v):
return tile_chunk_coord_i32(
arith, tx_i32_base=tx_i32_base, i=i,
arith, tx_i32_base=tx_i32_base_v, i=i,
total_threads=total_threads,
layout_tile_div4=layout_a_tile_div4,
chunk_i32=chunk_i32_a,
chunk_i32=chunk_i32_a_v,
)

def load_a_tile(base_k_div4):
def load_a_tile(base_k_div4, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v):
parts = []
for i in range_constexpr(num_a_loads):
row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i)
row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i, tx_i32_base_v, chunk_i32_a_v)
row_a_global = bx_m + row_a_local
idx_i32 = row_a_global * _k_div4_factor + (base_k_div4 + col_a_local_i32)
a_vec = load_a(idx_i32)
if a_load_bytes == 16:
a_vec = load_a(idx_i32, a_load_bytes_v)
if const_expr(a_load_bytes_v == 16):
parts.append(vector.bitcast(T.i32x4, a_vec))
else:
parts.append(a_vec)
return parts

c4_bytes = fx.Index(4) # bytes per dword (always 4, used for LDS byte addressing)

def store_a_tile_to_lds(vec_a_parts, lds_buffer):
def store_a_tile_to_lds(vec_a_parts, lds_buffer, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v):
for i in range_constexpr(num_a_loads):
row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i)
if a_load_bytes == 16:
row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i, tx_i32_base_v, chunk_i32_a_v)
if const_expr(a_load_bytes_v == 16):
lds_store_16b_xor16(
arith, vector,
lds_memref=lds_buffer, vec16_ty=T.f8x16,
Expand All @@ -368,7 +368,7 @@ def store_a_tile_to_lds(vec_a_parts, lds_buffer):
lds_base=fx.Index(0),
vec_part_i32x4=vec_a_parts[i], elem_bytes=elem_bytes,
)
elif a_load_bytes == 8:
elif const_expr(a_load_bytes_v == 8):
lds_store_8b_xor16(
arith, vector,
lds_memref=lds_buffer, vec8_ty=T.f8x8,
Expand Down Expand Up @@ -431,17 +431,22 @@ def prefetch_a_to_lds(base_k, lds_buffer):
base_k_div4 = base_k // 4
dma_a_tile_to_lds(base_k_div4, lds_buffer)

def prefetch_a_tile(base_k):
def prefetch_a_tile(base_k, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v):
base_k_div4 = base_k // 4
return load_a_tile(base_k_div4)
return load_a_tile(base_k_div4, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v)

def prefetch_b_tile(base_k):
return load_b_tile(base_k)

# ── MFMA ──────────────────────────────────────────────────────────
mfma_res_ty = T.f32x4

def _mfma_fn_placeholder(*args, **kwargs):
raise RuntimeError("mfma_fn placeholder should be overwritten before use")

if _is_gfx950:
mfma_fn = _mfma_fn_placeholder

if const_expr(_is_gfx950):
c0_i64 = arith.constant(0, type=T.i64)

def pack_i64x4_to_i32x8(x0, x1, x2, x3):
Expand All @@ -450,12 +455,12 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3):
else:
mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8

def mfma_step(acc_in, a, b):
return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0])
def mfma_step(acc_in, a, b):
return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0])

def mfma_k64_bytes(acc_in, a0, a1, b0, b1):
acc_mid = mfma_step(acc_in, a0, b0)
return mfma_step(acc_mid, a1, b1)
def mfma_k64_bytes(acc_in, a0, a1, b0, b1):
acc_mid = mfma_step(acc_in, a0, b0)
return mfma_step(acc_mid, a1, b1)

# ── Blockscale compute tile ───────────────────────────────────────
from flydsl._mlir.dialects import math as math_dialect
Expand Down Expand Up @@ -516,7 +521,7 @@ def compute_tile_blockscale(
combined_scales = pre_scales[sb]
block_accs = [acc_init] * (num_acc_n * m_repeat)

if _is_gfx950:
if const_expr(_is_gfx950):
ku0 = sb * ku_per_sb
ku1 = ku0 + 1
b0_packs0, b0_packs1 = b_tile_in[ku0]
Expand All @@ -526,6 +531,8 @@ def compute_tile_blockscale(

for mi in range_constexpr(m_repeat):
curr_row_a_lds = row_a_lds + (mi * 16)
a0 = ArithValue(arith.constant(-1, type=T.i64))
a1 = ArithValue(arith.constant(-1, type=T.i64))
if a0_prefetch is not None and sb == 0 and mi == 0:
a0, a1 = a0_prefetch
else:
Expand Down Expand Up @@ -553,6 +560,9 @@ def compute_tile_blockscale(

for mi in range_constexpr(m_repeat):
curr_row_a_lds = row_a_lds + (mi * 16)
a0, a1 = lds_load_packs_k64(
curr_row_a_lds, col_base, lds_buffer
)

if (
a0_prefetch is not None
Expand All @@ -561,10 +571,6 @@ def compute_tile_blockscale(
and mi == 0
):
a0, a1 = a0_prefetch
else:
a0, a1 = lds_load_packs_k64(
curr_row_a_lds, col_base, lds_buffer
)

for ni in range_constexpr(num_acc_n):
acc_idx = mi * num_acc_n + ni
Expand Down Expand Up @@ -673,6 +679,7 @@ def body_row(*, mi, ii, row_in_tile, row):

def hot_loop_scheduler():
mfma_group = num_acc_n
mfma_total = -1
if _is_gfx950:
mfma_total = sb_per_tile * m_repeat * mfma_group
else:
Expand Down Expand Up @@ -715,29 +722,36 @@ def hot_loop_scheduler():
def prefetch_a0_pack(lds_buffer):
return lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_buffer)

def _load_a_to_lds(base_k, lds_buffer):
def _load_a_to_lds(base_k, lds_buffer, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v):
if use_async_copy:
prefetch_a_to_lds(base_k, lds_buffer)
else:
store_a_tile_to_lds(prefetch_a_tile(base_k), lds_buffer)
store_a_tile_to_lds(
prefetch_a_tile(base_k, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v),
lds_buffer,
a_load_bytes_v,
tx_i32_base_v,
chunk_i32_a_v,
)

# ── Main pipeline: prologue ───────────────────────────────────────
k0 = fx.Index(0)
b_tile_pong = prefetch_b_tile(k0)
scales_pong = load_scales_for_tile(k0)
_load_a_to_lds(k0, lds_a_pong)
_load_a_to_lds(k0, lds_a_pong, a_load_bytes, tx_i32_base, chunk_i32_a)
gpu.barrier()
global_accs = [acc_init] * (num_acc_n * m_repeat)

a0_prefetch_pong = prefetch_a0_pack(lds_a_pong)

num_tiles = K // tile_k
final_accs = global_accs

if (num_tiles % 2) == 1:
for k_iv in range_constexpr(0, K - tile_k, tile_k * 2):
_k = fx.Index(k_iv)
next_k1 = _k + tile_k
_load_a_to_lds(next_k1, lds_a_ping)
_load_a_to_lds(next_k1, lds_a_ping, a_load_bytes, tx_i32_base, chunk_i32_a)
b_tile_ping = prefetch_b_tile(next_k1)
scales_ping = load_scales_for_tile(next_k1)

Expand All @@ -754,7 +768,7 @@ def _load_a_to_lds(base_k, lds_buffer):
a0_prefetch_ping = prefetch_a0_pack(lds_a_ping)

next_k2 = _k + tile_k * 2
_load_a_to_lds(next_k2, lds_a_pong)
_load_a_to_lds(next_k2, lds_a_pong, a_load_bytes, tx_i32_base, chunk_i32_a)
b_tile_pong = prefetch_b_tile(next_k2)
scales_pong = load_scales_for_tile(next_k2)

Expand All @@ -779,7 +793,7 @@ def _load_a_to_lds(base_k, lds_buffer):
for k_iv in range_constexpr(0, K - tile_k * 3, tile_k * 2):
_k = fx.Index(k_iv)
next_k1 = _k + tile_k
_load_a_to_lds(next_k1, lds_a_ping)
_load_a_to_lds(next_k1, lds_a_ping, a_load_bytes, tx_i32_base, chunk_i32_a)
b_tile_ping = prefetch_b_tile(next_k1)
scales_ping = load_scales_for_tile(next_k1)

Expand All @@ -796,7 +810,7 @@ def _load_a_to_lds(base_k, lds_buffer):
a0_prefetch_ping = prefetch_a0_pack(lds_a_ping)

next_k2 = _k + tile_k * 2
_load_a_to_lds(next_k2, lds_a_pong)
_load_a_to_lds(next_k2, lds_a_pong, a_load_bytes, tx_i32_base, chunk_i32_a)
b_tile_pong = prefetch_b_tile(next_k2)
scales_pong = load_scales_for_tile(next_k2)

Expand All @@ -815,7 +829,7 @@ def _load_a_to_lds(base_k, lds_buffer):
last_k = arith.index(K - tile_k)
second_last_k = arith.index(K - tile_k * 2)

_load_a_to_lds(last_k, lds_a_ping)
_load_a_to_lds(last_k, lds_a_ping, a_load_bytes, tx_i32_base, chunk_i32_a)
b_tile_ping = prefetch_b_tile(last_k)
scales_ping = load_scales_for_tile(last_k)

Expand Down
2 changes: 1 addition & 1 deletion kernels/fused_rope_cache_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def k_cache_kernel(
i32_reg_ty = fx.MemRefType.get(T.i32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register)
i32_reg_lay = fx.make_layout(1, 1)

if not flash_layout:
if const_expr(not flash_layout):
copy_atom_elem = fx.make_copy_atom(fx.rocdl.BufferCopy16b(), elem_bits)
elem_reg_ty = fx.MemRefType.get(
elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register
Expand Down
14 changes: 7 additions & 7 deletions kernels/hgemm_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ def hgemm_kernel(
bs_ = STensor(smem_b_ptr, dtype_, shape=(STAGES, BLOCK_N, BLOCK_K))
smem_c_ptr = SmemPtr(base_ptr, smem_a_offset, dtype_, shape=(BLOCK_M * BLOCK_N,))
cs_ = STensor(smem_c_ptr, dtype_, shape=(BLOCK_M, BLOCK_N))
if B_PRE_SHUFFLE:
if const_expr(B_PRE_SHUFFLE):
# origin: n // WARP_ATOM_N, WARP_ATOM_N, k // WARP_ATOM_K, WARP_ATOM_K // LDG_VEC_SIZE, LDG_VEC_SIZE
SHUFFLED_B_ = GTensor(B, dtype=dtype_, shape=(
n // WARP_ATOM_N, k // WARP_ATOM_K, WARP_ATOM_K // LDG_VEC_SIZE, WARP_ATOM_N, LDG_VEC_SIZE))
if IS_SPLIT_K:
if const_expr(IS_SPLIT_K):
COUNTER_ = GTensor(COUNTER, dtype=T.i32, shape=(-1,))

tid = fx.Int32(fx.thread_idx.x)
Expand Down Expand Up @@ -504,7 +504,7 @@ def ldg_matrix_b(k_offset):
b_k0 = b_k0_base + kk
for ii in range_constexpr(WARP_N_STEPS):
b_n0 = b_n0_base + ii
if not B_PRE_SHUFFLE:
if const_expr(not B_PRE_SHUFFLE):
warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N
warp_atom_k_idx = kk * WARP_ATOM_K
n_idx = n_offset + warp_atom_n_idx + ldmatrix_b_n_idx
Expand Down Expand Up @@ -551,7 +551,7 @@ def block_mma_sync(a_frags, b_frags, c_frags):
if IS_SPLIT_K:
zero_c()

if B_TO_LDS:
if const_expr(B_TO_LDS):

sts_a(ldg_a(ks_begin), 0)
sts_b(ldg_b(ks_begin), 0)
Expand Down Expand Up @@ -623,7 +623,7 @@ def hot_loop_scheduler():
# for i in range_constexpr(WARP_K_STEPS * WARP_M_STEPS * WARP_N_STEPS * MFMA_PER_WARP_K):
# rocdl.sched_mfma(1)
# ================ Reordered ================
if ASYNC_COPY:
if const_expr(ASYNC_COPY):
AVG_MFMA_COUNT = (MFMA_TOTAL + LDG_TOTAL - 1) // LDG_TOTAL
for i in range_constexpr(LDG_TOTAL):
rocdl.sched_vmem(ldg_.consume(1))
Expand All @@ -646,13 +646,13 @@ def hot_loop_scheduler():
c_frags = state[2 : 2 + C_FRAGS_LEN]
a_frags = state[2 + C_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN]
b_frags = state[2 + C_FRAGS_LEN + A_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN + B_FRAGS_LEN]
if ASYNC_COPY:
if const_expr(ASYNC_COPY):
ldg_sts_a_async(k_offset + BLOCK_K, next_stage)
else:
a_regs_next = ldg_a(k_offset + BLOCK_K)
b_frags_next = ldg_matrix_b(k_offset + BLOCK_K)
block_mma_sync(a_frags, b_frags, c_frags)
if not ASYNC_COPY:
if const_expr(not ASYNC_COPY):
sts_a(a_regs_next, next_stage)
hot_loop_scheduler()
gpu.barrier()
Expand Down
18 changes: 11 additions & 7 deletions kernels/layernorm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ def block_reduce_add2(val0, val1):

if lane == fx.Int32(0):
wave_idx = ArithValue(wave).index_cast(T.index)
s_sum.store(w0, [wave_idx])
s_sumsq.store(w1, [wave_idx])
SmemPtr.store(s_sum, w0, [wave_idx])
SmemPtr.store(s_sumsq, w1, [wave_idx])
gpu.barrier()

if wave == fx.Int32(0):
in_range = lane < RED_SLOTS
lane_safe = in_range.select(lane, fx.Int32(0))
lane_safe_idx = ArithValue(lane_safe).index_cast(T.index)
v0 = s_sum.load([lane_safe_idx])
v1 = s_sumsq.load([lane_safe_idx])
v0 = SmemPtr.load(s_sum, [lane_safe_idx])
v1 = SmemPtr.load(s_sumsq, [lane_safe_idx])
z = fx.Float32(0.0)
ww0 = in_range.select(v0, z)
ww1 = in_range.select(v1, z)
Expand All @@ -122,12 +122,12 @@ def block_reduce_add2(val0, val1):

if lane == fx.Int32(0):
c0_idx = fx.Index(0)
s_sum.store(ww0, [c0_idx])
s_sumsq.store(ww1, [c0_idx])
SmemPtr.store(s_sum, ww0, [c0_idx])
SmemPtr.store(s_sumsq, ww1, [c0_idx])
gpu.barrier()

c0_idx = fx.Index(0)
return s_sum.load([c0_idx]), s_sumsq.load([c0_idx])
return SmemPtr.load(s_sum, [c0_idx]), SmemPtr.load(s_sumsq, [c0_idx])

def compute_mean_rstd(sum_val, sumsq_val):
inv_n = arith.constant(1.0 / float(N), type=compute_type)
Expand Down Expand Up @@ -209,6 +209,8 @@ def _store_vec(val, div_tensor, idx):

# ── Pass 2: normalize + affine + store ───────────────────────
for tile_i in range_constexpr(num_tiles_py):
g_next = g_cur
b_next = b_cur
if tile_i + 1 < num_tiles_py:
next_idx = tid + (tile_i + 1) * BLOCK_THREADS
g_next = _load_vec(gamma_div, next_idx).to(Float32)
Expand All @@ -221,6 +223,7 @@ def _store_vec(val, div_tensor, idx):
y = (x - mean) * rstd
y = y * g_cur + b_cur

out_e = y.to(elem_dtype)
if dtype_str == "bf16":
if USE_HW_CVT_PK_BF16_F32:
out_e = y.to(elem_dtype)
Expand Down Expand Up @@ -342,6 +345,7 @@ def _store_scalar(divided_tensor, index, val):
norm = diff * rstd
scaled = norm * g
y = scaled + b
y_e = y
if dtype_str == "bf16":
y_e = y.truncf(elem_type)
elif dtype_str == "f32":
Expand Down
Loading
Loading