diff --git a/examples/04-preshuffle_gemm.py b/examples/04-preshuffle_gemm.py index d47b6956..cc049af8 100644 --- a/examples/04-preshuffle_gemm.py +++ b/examples/04-preshuffle_gemm.py @@ -75,20 +75,25 @@ def gemm_kernel( mma_frag_C_f16 = fx.make_fragment_like(mma_frag_C, fx.Float16.ir_type) mma_frag_C_retile = thr_copy_r2g_C.retile(mma_frag_C_f16) + gA_k_stride = fx.get_scalar(gA_k.stride[2]) + gB_k_stride = fx.get_scalar(gB_k.stride[2]) + def run_pipeline_stage(read_stage, next_k, read_next=True): write_stage = read_stage ^ 1 if read_next: next_k = fx.Int32(next_k) fx.copy( - buffer_copy_128b.set_value("soffset", next_k * BLOCK_K), - thr_gA_k[None, None, None, 0], # global offset is added on the soffset of buffer_copy_atom + buffer_copy_128b, + thr_gA_k[None, None, None, 0], # global offset is added on the soffset of buffer_copy_atom copy_frag_A, + soffset=next_k * gA_k_stride, ) fx.copy( buffer_copy_128b, - thr_gB_k[None, None, None, next_k], + thr_gB_k[None, None, None, 0], mma_frag_B_retile[None, None, None, write_stage], + soffset=next_k * gB_k_stride, ) for block_k_iter in fx.range_constexpr(BLOCK_K // 32): diff --git a/kernels/blockscale_preshuffle_gemm.py b/kernels/blockscale_preshuffle_gemm.py index 2c9e5ca0..777a215b 100644 --- a/kernels/blockscale_preshuffle_gemm.py +++ b/kernels/blockscale_preshuffle_gemm.py @@ -321,33 +321,33 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): c_chunk_a = fx.Index(chunk_i32_a) tx_i32_base = tx * c_chunk_a - def load_a(idx_i32): - if a_load_bytes == 16: + def load_a(idx_i32, a_load_bytes_v): + if const_expr(a_load_bytes_v == 16): return buffer_copy_gmem16_dwordx4( buffer_ops, vector, elem_type=T.f8, idx_i32=idx_i32, rsrc=a_rsrc, vec_elems=16, elem_bytes=elem_bytes, ) - if a_load_bytes == 8: + if const_expr(a_load_bytes_v == 8): return buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=1, dtype=T.i32) - def a_tile_chunk_coord_i32(i: int): + def a_tile_chunk_coord_i32(i: int, tx_i32_base_v, chunk_i32_a_v): return tile_chunk_coord_i32( - arith, tx_i32_base=tx_i32_base, i=i, + arith, tx_i32_base=tx_i32_base_v, i=i, total_threads=total_threads, layout_tile_div4=layout_a_tile_div4, - chunk_i32=chunk_i32_a, + chunk_i32=chunk_i32_a_v, ) - def load_a_tile(base_k_div4): + def load_a_tile(base_k_div4, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): parts = [] for i in range_constexpr(num_a_loads): - row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i) + row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i, tx_i32_base_v, chunk_i32_a_v) row_a_global = bx_m + row_a_local idx_i32 = row_a_global * _k_div4_factor + (base_k_div4 + col_a_local_i32) - a_vec = load_a(idx_i32) - if a_load_bytes == 16: + a_vec = load_a(idx_i32, a_load_bytes_v) + if const_expr(a_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, a_vec)) else: parts.append(a_vec) @@ -355,10 +355,10 @@ def load_a_tile(base_k_div4): c4_bytes = fx.Index(4) # bytes per dword (always 4, used for LDS byte addressing) - def store_a_tile_to_lds(vec_a_parts, lds_buffer): + def store_a_tile_to_lds(vec_a_parts, lds_buffer, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): for i in range_constexpr(num_a_loads): - row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i) - if a_load_bytes == 16: + row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32(i, tx_i32_base_v, chunk_i32_a_v) + if const_expr(a_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, lds_memref=lds_buffer, vec16_ty=T.f8x16, @@ -368,7 +368,7 @@ def store_a_tile_to_lds(vec_a_parts, lds_buffer): lds_base=fx.Index(0), vec_part_i32x4=vec_a_parts[i], elem_bytes=elem_bytes, ) - elif a_load_bytes == 8: + elif const_expr(a_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, lds_memref=lds_buffer, vec8_ty=T.f8x8, @@ -431,17 +431,22 @@ def prefetch_a_to_lds(base_k, lds_buffer): base_k_div4 = base_k // 4 dma_a_tile_to_lds(base_k_div4, lds_buffer) - def prefetch_a_tile(base_k): + def prefetch_a_tile(base_k, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): base_k_div4 = base_k // 4 - return load_a_tile(base_k_div4) + return load_a_tile(base_k_div4, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v) def prefetch_b_tile(base_k): return load_b_tile(base_k) # ── MFMA ────────────────────────────────────────────────────────── mfma_res_ty = T.f32x4 + + def _mfma_fn_placeholder(*args, **kwargs): + raise RuntimeError("mfma_fn placeholder should be overwritten before use") - if _is_gfx950: + mfma_fn = _mfma_fn_placeholder + + if const_expr(_is_gfx950): c0_i64 = arith.constant(0, type=T.i64) def pack_i64x4_to_i32x8(x0, x1, x2, x3): @@ -450,12 +455,12 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): else: mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8 - def mfma_step(acc_in, a, b): - return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) + def mfma_step(acc_in, a, b): + return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) - def mfma_k64_bytes(acc_in, a0, a1, b0, b1): - acc_mid = mfma_step(acc_in, a0, b0) - return mfma_step(acc_mid, a1, b1) + def mfma_k64_bytes(acc_in, a0, a1, b0, b1): + acc_mid = mfma_step(acc_in, a0, b0) + return mfma_step(acc_mid, a1, b1) # ── Blockscale compute tile ─────────────────────────────────────── from flydsl._mlir.dialects import math as math_dialect @@ -516,7 +521,7 @@ def compute_tile_blockscale( combined_scales = pre_scales[sb] block_accs = [acc_init] * (num_acc_n * m_repeat) - if _is_gfx950: + if const_expr(_is_gfx950): ku0 = sb * ku_per_sb ku1 = ku0 + 1 b0_packs0, b0_packs1 = b_tile_in[ku0] @@ -526,6 +531,8 @@ def compute_tile_blockscale( for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) + a0 = ArithValue(arith.constant(-1, type=T.i64)) + a1 = ArithValue(arith.constant(-1, type=T.i64)) if a0_prefetch is not None and sb == 0 and mi == 0: a0, a1 = a0_prefetch else: @@ -553,6 +560,9 @@ def compute_tile_blockscale( for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) + a0, a1 = lds_load_packs_k64( + curr_row_a_lds, col_base, lds_buffer + ) if ( a0_prefetch is not None @@ -561,10 +571,6 @@ def compute_tile_blockscale( and mi == 0 ): a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64( - curr_row_a_lds, col_base, lds_buffer - ) for ni in range_constexpr(num_acc_n): acc_idx = mi * num_acc_n + ni @@ -673,6 +679,7 @@ def body_row(*, mi, ii, row_in_tile, row): def hot_loop_scheduler(): mfma_group = num_acc_n + mfma_total = -1 if _is_gfx950: mfma_total = sb_per_tile * m_repeat * mfma_group else: @@ -715,29 +722,36 @@ def hot_loop_scheduler(): def prefetch_a0_pack(lds_buffer): return lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_buffer) - def _load_a_to_lds(base_k, lds_buffer): + def _load_a_to_lds(base_k, lds_buffer, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v): if use_async_copy: prefetch_a_to_lds(base_k, lds_buffer) else: - store_a_tile_to_lds(prefetch_a_tile(base_k), lds_buffer) + store_a_tile_to_lds( + prefetch_a_tile(base_k, a_load_bytes_v, tx_i32_base_v, chunk_i32_a_v), + lds_buffer, + a_load_bytes_v, + tx_i32_base_v, + chunk_i32_a_v, + ) # ── Main pipeline: prologue ─────────────────────────────────────── k0 = fx.Index(0) b_tile_pong = prefetch_b_tile(k0) scales_pong = load_scales_for_tile(k0) - _load_a_to_lds(k0, lds_a_pong) + _load_a_to_lds(k0, lds_a_pong, a_load_bytes, tx_i32_base, chunk_i32_a) gpu.barrier() global_accs = [acc_init] * (num_acc_n * m_repeat) a0_prefetch_pong = prefetch_a0_pack(lds_a_pong) num_tiles = K // tile_k + final_accs = global_accs if (num_tiles % 2) == 1: for k_iv in range_constexpr(0, K - tile_k, tile_k * 2): _k = fx.Index(k_iv) next_k1 = _k + tile_k - _load_a_to_lds(next_k1, lds_a_ping) + _load_a_to_lds(next_k1, lds_a_ping, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_ping = prefetch_b_tile(next_k1) scales_ping = load_scales_for_tile(next_k1) @@ -754,7 +768,7 @@ def _load_a_to_lds(base_k, lds_buffer): a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) next_k2 = _k + tile_k * 2 - _load_a_to_lds(next_k2, lds_a_pong) + _load_a_to_lds(next_k2, lds_a_pong, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_pong = prefetch_b_tile(next_k2) scales_pong = load_scales_for_tile(next_k2) @@ -779,7 +793,7 @@ def _load_a_to_lds(base_k, lds_buffer): for k_iv in range_constexpr(0, K - tile_k * 3, tile_k * 2): _k = fx.Index(k_iv) next_k1 = _k + tile_k - _load_a_to_lds(next_k1, lds_a_ping) + _load_a_to_lds(next_k1, lds_a_ping, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_ping = prefetch_b_tile(next_k1) scales_ping = load_scales_for_tile(next_k1) @@ -796,7 +810,7 @@ def _load_a_to_lds(base_k, lds_buffer): a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) next_k2 = _k + tile_k * 2 - _load_a_to_lds(next_k2, lds_a_pong) + _load_a_to_lds(next_k2, lds_a_pong, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_pong = prefetch_b_tile(next_k2) scales_pong = load_scales_for_tile(next_k2) @@ -815,7 +829,7 @@ def _load_a_to_lds(base_k, lds_buffer): last_k = arith.index(K - tile_k) second_last_k = arith.index(K - tile_k * 2) - _load_a_to_lds(last_k, lds_a_ping) + _load_a_to_lds(last_k, lds_a_ping, a_load_bytes, tx_i32_base, chunk_i32_a) b_tile_ping = prefetch_b_tile(last_k) scales_ping = load_scales_for_tile(last_k) diff --git a/kernels/fused_rope_cache_kernel.py b/kernels/fused_rope_cache_kernel.py index d9258487..48d4c69c 100644 --- a/kernels/fused_rope_cache_kernel.py +++ b/kernels/fused_rope_cache_kernel.py @@ -243,7 +243,7 @@ def k_cache_kernel( i32_reg_ty = fx.MemRefType.get(T.i32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) i32_reg_lay = fx.make_layout(1, 1) - if not flash_layout: + if const_expr(not flash_layout): copy_atom_elem = fx.make_copy_atom(fx.rocdl.BufferCopy16b(), elem_bits) elem_reg_ty = fx.MemRefType.get( elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register diff --git a/kernels/hgemm_splitk.py b/kernels/hgemm_splitk.py index 723b4ab5..e91506b7 100644 --- a/kernels/hgemm_splitk.py +++ b/kernels/hgemm_splitk.py @@ -234,11 +234,11 @@ def hgemm_kernel( bs_ = STensor(smem_b_ptr, dtype_, shape=(STAGES, BLOCK_N, BLOCK_K)) smem_c_ptr = SmemPtr(base_ptr, smem_a_offset, dtype_, shape=(BLOCK_M * BLOCK_N,)) cs_ = STensor(smem_c_ptr, dtype_, shape=(BLOCK_M, BLOCK_N)) - if B_PRE_SHUFFLE: + if const_expr(B_PRE_SHUFFLE): # origin: n // WARP_ATOM_N, WARP_ATOM_N, k // WARP_ATOM_K, WARP_ATOM_K // LDG_VEC_SIZE, LDG_VEC_SIZE SHUFFLED_B_ = GTensor(B, dtype=dtype_, shape=( n // WARP_ATOM_N, k // WARP_ATOM_K, WARP_ATOM_K // LDG_VEC_SIZE, WARP_ATOM_N, LDG_VEC_SIZE)) - if IS_SPLIT_K: + if const_expr(IS_SPLIT_K): COUNTER_ = GTensor(COUNTER, dtype=T.i32, shape=(-1,)) tid = fx.Int32(fx.thread_idx.x) @@ -504,7 +504,7 @@ def ldg_matrix_b(k_offset): b_k0 = b_k0_base + kk for ii in range_constexpr(WARP_N_STEPS): b_n0 = b_n0_base + ii - if not B_PRE_SHUFFLE: + if const_expr(not B_PRE_SHUFFLE): warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N warp_atom_k_idx = kk * WARP_ATOM_K n_idx = n_offset + warp_atom_n_idx + ldmatrix_b_n_idx @@ -551,7 +551,7 @@ def block_mma_sync(a_frags, b_frags, c_frags): if IS_SPLIT_K: zero_c() - if B_TO_LDS: + if const_expr(B_TO_LDS): sts_a(ldg_a(ks_begin), 0) sts_b(ldg_b(ks_begin), 0) @@ -623,7 +623,7 @@ def hot_loop_scheduler(): # for i in range_constexpr(WARP_K_STEPS * WARP_M_STEPS * WARP_N_STEPS * MFMA_PER_WARP_K): # rocdl.sched_mfma(1) # ================ Reordered ================ - if ASYNC_COPY: + if const_expr(ASYNC_COPY): AVG_MFMA_COUNT = (MFMA_TOTAL + LDG_TOTAL - 1) // LDG_TOTAL for i in range_constexpr(LDG_TOTAL): rocdl.sched_vmem(ldg_.consume(1)) @@ -646,13 +646,13 @@ def hot_loop_scheduler(): c_frags = state[2 : 2 + C_FRAGS_LEN] a_frags = state[2 + C_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN] b_frags = state[2 + C_FRAGS_LEN + A_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN + B_FRAGS_LEN] - if ASYNC_COPY: + if const_expr(ASYNC_COPY): ldg_sts_a_async(k_offset + BLOCK_K, next_stage) else: a_regs_next = ldg_a(k_offset + BLOCK_K) b_frags_next = ldg_matrix_b(k_offset + BLOCK_K) block_mma_sync(a_frags, b_frags, c_frags) - if not ASYNC_COPY: + if const_expr(not ASYNC_COPY): sts_a(a_regs_next, next_stage) hot_loop_scheduler() gpu.barrier() diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index 6f1441ea..642ebc79 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -104,16 +104,16 @@ def block_reduce_add2(val0, val1): if lane == fx.Int32(0): wave_idx = ArithValue(wave).index_cast(T.index) - s_sum.store(w0, [wave_idx]) - s_sumsq.store(w1, [wave_idx]) + SmemPtr.store(s_sum, w0, [wave_idx]) + SmemPtr.store(s_sumsq, w1, [wave_idx]) gpu.barrier() if wave == fx.Int32(0): in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, fx.Int32(0)) lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) - v0 = s_sum.load([lane_safe_idx]) - v1 = s_sumsq.load([lane_safe_idx]) + v0 = SmemPtr.load(s_sum, [lane_safe_idx]) + v1 = SmemPtr.load(s_sumsq, [lane_safe_idx]) z = fx.Float32(0.0) ww0 = in_range.select(v0, z) ww1 = in_range.select(v1, z) @@ -122,12 +122,12 @@ def block_reduce_add2(val0, val1): if lane == fx.Int32(0): c0_idx = fx.Index(0) - s_sum.store(ww0, [c0_idx]) - s_sumsq.store(ww1, [c0_idx]) + SmemPtr.store(s_sum, ww0, [c0_idx]) + SmemPtr.store(s_sumsq, ww1, [c0_idx]) gpu.barrier() c0_idx = fx.Index(0) - return s_sum.load([c0_idx]), s_sumsq.load([c0_idx]) + return SmemPtr.load(s_sum, [c0_idx]), SmemPtr.load(s_sumsq, [c0_idx]) def compute_mean_rstd(sum_val, sumsq_val): inv_n = arith.constant(1.0 / float(N), type=compute_type) @@ -209,6 +209,8 @@ def _store_vec(val, div_tensor, idx): # ── Pass 2: normalize + affine + store ─────────────────────── for tile_i in range_constexpr(num_tiles_py): + g_next = g_cur + b_next = b_cur if tile_i + 1 < num_tiles_py: next_idx = tid + (tile_i + 1) * BLOCK_THREADS g_next = _load_vec(gamma_div, next_idx).to(Float32) @@ -221,6 +223,7 @@ def _store_vec(val, div_tensor, idx): y = (x - mean) * rstd y = y * g_cur + b_cur + out_e = y.to(elem_dtype) if dtype_str == "bf16": if USE_HW_CVT_PK_BF16_F32: out_e = y.to(elem_dtype) @@ -342,6 +345,7 @@ def _store_scalar(divided_tensor, index, val): norm = diff * rstd scaled = norm * g y = scaled + b + y_e = y if dtype_str == "bf16": y_e = y.truncf(elem_type) elif dtype_str == "f32": diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index eb23631a..f2770b30 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -369,7 +369,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): arg_out, max_size=False, num_records_bytes=out_nbytes_i32 ) - if is_f16_a: + if const_expr(is_f16_a): sx_rsrc = None else: # A1 microscale: [sorted_rows, K/32] e8m0 bytes, packed as i32. @@ -384,7 +384,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) - if is_f16_b: + if const_expr(is_f16_b): sw_rsrc = None else: # W1 microscale: [experts * 2 * inter_dim, K/32] e8m0 bytes. @@ -413,7 +413,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): max_size=False, num_records_bytes=sorted_nbytes_i32, ) - if doweight_stage1 + if const_expr(doweight_stage1) else None ) @@ -766,7 +766,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): lds_store_16b_xor16( arith, vector, @@ -909,7 +909,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): curr_row_a_lds, col_base0, lds_base ) - if is_f8_a: + if const_expr(is_f8_a): col_base1 = col_base + 64 a2, a3 = lds_load_packs_k64( curr_row_a_lds, col_base1, lds_base @@ -1046,7 +1046,7 @@ def hot_loop_scheduler(): os.environ.get("FLYDSL_STAGE1_SKIP_COMPUTE", "0") == "1" ) - if k_main2_py > 0: + if const_expr(k_main2_py > 0): for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): k_iv = k_iv_py next_k1 = k_iv + tile_k @@ -1056,7 +1056,7 @@ def hot_loop_scheduler(): prefetch_ab_scale_tile(next_k1 // pack_K // 128) ) - if _skip_compute: + if const_expr(_skip_compute): store_x_tile_to_lds(x_regs_ping, lds_base_ping) gpu.barrier() a0_prefetch_ping = None @@ -1112,7 +1112,7 @@ def hot_loop_scheduler(): a0_prefetch_pong = None - if odd_k_tiles: + if const_expr(odd_k_tiles): acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( acc_gate, acc_up, @@ -1184,7 +1184,7 @@ def hot_loop_scheduler(): _mask_even_i32 = fx.Int32(0xFFFFFFFE) - if _use_cshuffle_epilog: + if const_expr(_use_cshuffle_epilog): if lds_out is None: raise RuntimeError( "CShuffle epilogue enabled but lds_out is not allocated/aliased." @@ -1215,7 +1215,7 @@ def write_row_to_lds( _t2 = fused2 & mask24_i32 # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load( sorted_w_rsrc, row_safe, vec_width=1, dtype=T.f32 ) @@ -1240,12 +1240,12 @@ def write_row_to_lds( vg = vg + gate_bias_list[ni] vu = vu + up_bias_list[ni] - if act == "swiglu": + if const_expr(act == "swiglu"): y = swiglu(vg, vu) else: y = silu(vg) * vu - if doweight_stage1: + if const_expr(doweight_stage1): y = y * tw lds_idx = row_base_lds + col_local @@ -1279,7 +1279,7 @@ def store_pair( idx0 = row_ctx col_i32 = arith.index_cast(T.i32, col_g0) idx_out = idx0 + col_i32 - if out_dtype == "fp8": + if const_expr(out_dtype == "fp8"): frag = vector.bitcast(vec4_f32, frag) frag0 = vector.extract( frag, static_position=[0], dynamic_position=[] @@ -1373,7 +1373,7 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): idx0 = (t2_safe * topk_i32_v + s2_safe) * inter_i32_local # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load( sorted_w_rsrc, row_safe, vec_width=1, dtype=T.f32 ) @@ -1398,12 +1398,12 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): vg = vg + gate_bias_list[ni] vu = vu + up_bias_list[ni] - if act == "swiglu": + if const_expr(act == "swiglu"): y = swiglu(vg, vu) else: y = silu(vg) * vu - if doweight_stage1: + if const_expr(doweight_stage1): y = y * tw y = arith.trunc_f(_out_elem_type(), y) @@ -1816,10 +1816,10 @@ def moe_gemm2( num_valid_idx = arith.index_cast(T.index, num_valid_i32) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16_a: + if const_expr(is_f16_a): sx_rsrc = None else: - if is_f4_a: + if const_expr(is_f4_a): # A2 microscale: packed i32 holding e8m0 bytes for [sorted_size, K/32]. c32 = fx.Index(32) kblk = k_in // c32 @@ -1837,7 +1837,7 @@ def moe_gemm2( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) - if is_f16_b: + if const_expr(is_f16_b): sw_rsrc = None else: # Weight microscale buffer (packed i32 holding e8m0 bytes). @@ -1911,7 +1911,8 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. - if is_f16_a: + x_load_bytes = 16 + if const_expr(is_f16_a): if bytes_per_thread_x % 16 != 0: raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" @@ -1971,7 +1972,7 @@ def load_x(idx_i32): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): idx_elem = ( idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) ) @@ -2034,9 +2035,9 @@ def load_x_tile(base_k): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32) - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): parts.append(vector.bitcast(vec4_i32, x_vec)) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes == 8): parts.append(vector.bitcast(vec2_i32, x_vec)) else: parts.append(vector.bitcast(vec1_i32, x_vec)) @@ -2185,7 +2186,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): lds_store_16b_xor16( arith, vector, @@ -2200,7 +2201,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes == 8): lds_store_8b_xor16( arith, vector, @@ -2273,8 +2274,8 @@ def compute_tile( epilogue_pf = None bias = None - if prefetch_epilogue: - if enable_bias: + if const_expr(prefetch_epilogue): + if const_expr(enable_bias): bias = [] for ni in range_constexpr(num_acc_n): global_n = by_n + n_tile_base + ni * 16 + lane_mod_16 @@ -2285,7 +2286,7 @@ def compute_tile( ) ) tw_pf = None - if doweight_stage2: + if const_expr(doweight_stage2): tw_pf = [] lane_div_16_mul4_pf = lane_div_16 * arith.index(4) ii_idx_list_pf = [ @@ -2347,6 +2348,8 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): mi_val = fx.Index(mi_idx * 16) curr_row_a_lds = row_a_lds + mi_val + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if ( (a0_prefetch is not None) and (k_idx == 0) @@ -2358,7 +2361,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): curr_row_a_lds, col_base0, lds_base ) - if is_f8_a: + if const_expr(is_f8_a): col_base1 = col_base + 64 a2, a3 = lds_load_packs_k64( curr_row_a_lds, col_base1, lds_base @@ -2419,19 +2422,19 @@ def hot_loop_scheduler(): rocdl.sched_dsrd(2) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) if num_acc_n < 4: rocdl.sched_dsrd(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_dsrd(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_mfma(1) @@ -2498,7 +2501,7 @@ def hot_loop_scheduler(): # When k_main2_py == 0 the loop body is empty; emitting an scf.for # would create a region whose internal SSA values cannot be used # by the post-loop tail code. - if k_main2_py > 0: + if const_expr(k_main2_py > 0): for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): k_iv = k_iv_py next_k1 = k_iv + tile_k @@ -2549,7 +2552,7 @@ def hot_loop_scheduler(): row_a_lds, col_offset_base, lds_base_pong ) - if odd_k_tiles: + if const_expr(odd_k_tiles): # Tail: single remaining tile (already in `b_cur` / `lds_base_pong`). acc, epilogue_pf = compute_tile( acc, @@ -2644,7 +2647,8 @@ def write_row_to_lds( num_acc_n: int, lds_out, ): - if doweight_stage2: + tw = arith.constant(1.0, type=T.f32) + if const_expr(doweight_stage2): tw_idx = (mi * 4) + ii if tw_pf is not None: tw = tw_pf[tw_idx] @@ -2659,10 +2663,10 @@ def write_row_to_lds( v = vector.extract( acc[acc_idx], static_position=[ii], dynamic_position=[] ) - if enable_bias: + if const_expr(enable_bias): v = v + bias_pf[ni] - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw v_out = arith.trunc_f(out_elem(), v) @@ -2697,8 +2701,8 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): col_i32 = arith.index_cast(T.i32, col_g0) idx_elem = idx0 + col_i32 idx_elem_even = idx_elem & mask_even_i32 - if _needs_global_atomic_bf16: - if bool(accumulate): + if const_expr(_needs_global_atomic_bf16): + if const_expr(bool(accumulate)): byte_off = idx_elem_even * c2_i32 byte_off_idx = arith.index_cast(T.index, byte_off) ptr_addr_idx = out_base_idx + byte_off_idx @@ -2718,7 +2722,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) else: byte_off = idx_elem_even * c2_i32 - if bool(accumulate): + if const_expr(bool(accumulate)): atomic_add_f16x2(frag, byte_off) else: buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) diff --git a/kernels/moe_blockscale_2stage.py b/kernels/moe_blockscale_2stage.py index a4fbc5d9..66549cb7 100644 --- a/kernels/moe_blockscale_2stage.py +++ b/kernels/moe_blockscale_2stage.py @@ -336,10 +336,11 @@ def silu(x): ) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16: - sx_rsrc = None - sw_rsrc = None - else: + x_load_bytes = 16 + + sx_rsrc = -1 + sw_rsrc = -1 + if not is_f16: # scale_x: [nblk_k_w1, tokens] f32 transposed -> total = nblk_k_w1 * tokens sx_nbytes_idx = arith.index(nblk_k_w1) * tokens_in * fx.Index(4) sx_rsrc = buffer_ops.create_buffer_resource( @@ -364,6 +365,7 @@ def silu(x): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. + x_load_bytes = 16 if is_f16: if bytes_per_thread_x % 16 != 0: raise ValueError( @@ -442,12 +444,12 @@ def x_tile_chunk_coord_i32(i: int): vec2_i32 = T.vec(2, T.i32) vec4_x = T.vec(4, x_elem) - def load_x(idx_i32): + def load_x(idx_i32, x_load_bytes_v): """Load `x_load_bytes` bytes from X (gmem) into regs. For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ - if x_load_bytes == 16: + if const_expr(x_load_bytes_v == 16): idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -458,20 +460,20 @@ def load_x(idx_i32): vec_elems=vec16_elems, elem_bytes=elem_bytes, ) - if x_load_bytes == 8: + if const_expr(x_load_bytes_v == 8): return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) - def load_x_tile(base_k): + def load_x_tile(base_k, x_load_bytes_v): """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" base_k_div4 = (base_k * arith.index(int(elem_bytes))) // fx.Index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - if x_load_bytes == 16: + x_vec = load_x(idx_i32, x_load_bytes_v) + if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes_v == 8): parts.append(x_vec) else: parts.append(x_vec) @@ -581,11 +583,11 @@ def load_b_tile(base_k, blk_list, intra_list): acc_up = [arith.constant_vector(0.0, T.f32x4)] * (num_acc_n * m_repeat) # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, @@ -600,7 +602,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, @@ -695,6 +697,8 @@ def load_scales_s1(k_base): _sw_shared_n = (n_per_wave <= 128) s_w_gate_vals = [] s_w_up_vals = [] + s_w_gate = fx.Float32(1.0) + s_w_up = fx.Float32(1.0) for ni in range_constexpr(num_acc_n): if ni == 0 or not _sw_shared_n: sw_gate_idx = _pre_n_block_gate[ni] * c_nblk_k_w1 + kb @@ -716,7 +720,7 @@ def compute_tile_bs_s1(acc_gate_in, acc_up_in, b_gate_tile_in, b_up_tile_in, current_up = list(acc_up_in) mfma_res_ty = T.f32x4 - if _is_gfx950: + if const_expr(_is_gfx950): def _pack128(x0, x1, x2, x3): v4 = vector.from_elements(T.vec(4, T.i64), [x0, x1, x2, x3]) return vector.bitcast(T.vec(8, T.i32), v4) @@ -733,6 +737,8 @@ def _pack128(x0, x1, x2, x3): col1 = col_offset_base_bytes + arith.index(ku1 * 64) for mi in range_constexpr(m_repeat): curr_row = row_a_lds + arith.index(mi * 16) + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if a0_prefetch is not None and sb == 0 and mi == 0: a0, a1 = a0_prefetch else: @@ -778,7 +784,7 @@ def _pack128(x0, x1, x2, x3): else: mfma_fn = ( mfma_i32_k32 - if is_int8 + if const_expr(is_int8) else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) ) @@ -787,7 +793,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(T.f16x4, v1) def mfma_k64(acc_in, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -811,10 +817,14 @@ def mfma_k64(acc_in, a0, a1, b0, b1): b_up_packs0, b_up_packs1 = b_up_tile_in[ku] ki64 = arith.index(ku * 64) col_base = col_offset_base_bytes + ki64 + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (sb == 0) and (ku_local == 0) and (mi == 0): a0, a1 = a0_prefetch else: - a0, a1 = lds_load_packs_k64(row_a_lds + arith.index(mi * 16), col_base, lds_base) + a0, a1 = lds_load_packs_k64( + row_a_lds + arith.index(mi * 16), col_base, lds_base + ) blk_g = mfma_k64(blk_g, a0, a1, b_gate_packs0[ni], b_gate_packs1[ni]) blk_u = mfma_k64(blk_u, a0, a1, b_up_packs0[ni], b_up_packs1[ni]) s_wg_bc = vector.broadcast(T.f32x4, s_w_gate_vals[ni]) @@ -842,7 +852,7 @@ def compute_tile( mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 mfma_fn = ( mfma_i32_k32 - if is_int8 + if const_expr(is_int8) else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) ) @@ -859,12 +869,12 @@ def compute_tile( row_up_idx = row_gate_idx + inter_idx sw_gate_pf.append( fx.Float32(1.0) - if is_f16 + if const_expr(is_f16) else buffer_ops.buffer_load(sw_rsrc, row_gate_idx, vec_width=1, dtype=T.f32) ) sw_up_pf.append( fx.Float32(1.0) - if is_f16 + if const_expr(is_f16) else buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=T.f32) ) epilogue_pf = (sw_gate_pf, sw_up_pf) @@ -874,7 +884,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(T.f16x4, v1) def mfma_k64(acc_in, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -894,6 +904,8 @@ def mfma_k64(acc_in, a0, a1, b0, b1): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch else: @@ -941,7 +953,7 @@ def do_one_stage(acc_gate_in, acc_up_in, k_compute, k_next, """One pipeline stage: load next tile data, compute current tile, store X to LDS.""" scale_fn = load_scales_s1 pre_scales = scale_fn(k_compute) - x_regs_next = load_x_tile(k_next) + x_regs_next = load_x_tile(k_next, x_load_bytes) b_gate_cur = load_b_tile(k_compute, n_blk_gate, n_intra_gate) b_up_cur = load_b_tile(k_compute, n_blk_up, n_intra_up) @@ -949,15 +961,15 @@ def do_one_stage(acc_gate_in, acc_up_in, k_compute, k_next, acc_gate_in, acc_up_in, b_gate_cur, b_up_cur, lds_compute, pre_scales) - store_x_tile_to_lds(x_regs_next, lds_store) + store_x_tile_to_lds(x_regs_next, lds_store, x_load_bytes) hot_loop_scheduler() gpu.barrier() return ag, au # Prologue: prefetch tile0 X into LDS, sync. k0 = fx.Index(0) - x_regs0 = load_x_tile(k0) - store_x_tile_to_lds(x_regs0, lds_base_cur) + x_regs0 = load_x_tile(k0, x_load_bytes) + store_x_tile_to_lds(x_regs0, lds_base_cur, x_load_bytes) gpu.barrier() lds_base_pong = lds_base_cur @@ -1031,7 +1043,7 @@ def do_one_stage(acc_gate_in, acc_up_in, k_compute, k_next, lane_div_16_mul4 = lane_div_16 * fx.Index(4) inter_i32_local = inter_i32_v - if _use_cshuffle_epilog: + if const_expr(use_cshuffle_epilog): if lds_out is None: raise RuntimeError("CShuffle epilogue enabled but lds_out is not allocated/aliased.") @@ -1048,7 +1060,7 @@ def write_row_to_lds( ): # Blockscale: dequant already done in compute_tile_bs_s1. # Just apply silu + optional sorted weight. - if doweight_stage1: + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) for ni in range_constexpr(num_acc_n): @@ -1126,7 +1138,7 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): idx0 = (t2 * topk_i32_v + s2) * inter_i32_local # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) _if_valid = scf.IfOp(t_valid) @@ -1505,10 +1517,9 @@ def moe_blockscale_gemm2( arg_out, max_size=False, num_records_bytes=arith.index_cast(T.i64, out_nbytes_idx) ) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16: - sx_rsrc = None - sw_rsrc = None - else: + sx_rsrc = -1 + sw_rsrc = -1 + if not is_f16: # scale_x (A2 scale): [nblk_k_w2, tokens*topk] f32 transposed -> total = nblk_k_w2 * tokens * topk sx_nbytes_idx = arith.index(nblk_k_w2) * (tokens_in * c_topk) * fx.Index(4) sx_rsrc = buffer_ops.create_buffer_resource( @@ -1559,6 +1570,7 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. + x_load_bytes = 0 if is_f16: if bytes_per_thread_x % 16 != 0: raise ValueError( @@ -1606,8 +1618,8 @@ def x_tile_chunk_coord_i32(i: int): vec2_i32 = T.vec(2, T.i32) vec4_x = T.vec(4, x_elem) - def load_x(idx_i32): - if x_load_bytes == 16: + def load_x(idx_i32, x_load_bytes_v): + if const_expr(x_load_bytes_v == 16): idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -1618,7 +1630,7 @@ def load_x(idx_i32): vec_elems=vec16_elems, elem_bytes=elem_bytes, ) - if x_load_bytes == 8: + if const_expr(x_load_bytes_v == 8): return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) @@ -1647,15 +1659,15 @@ def load_x(idx_i32): # Base row offset in dword units: row_ts_idx * (k_in/4) x_row_base_div4.append(row_ts_idx * c_k_div4) - def load_x_tile(base_k): + def load_x_tile(base_k, x_load_bytes_v): base_k_div4 = (base_k * arith.index(int(elem_bytes))) // fx.Index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - if x_load_bytes == 16: + x_vec = load_x(idx_i32, x_load_bytes_v) + if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes_v == 8): parts.append(x_vec) else: parts.append(x_vec) @@ -1748,11 +1760,11 @@ def load_b_tile(base_k): return b_tile # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, @@ -1767,7 +1779,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, @@ -1867,6 +1879,7 @@ def load_scales_s2(k_base): _sw_shared_n_s2 = (n_per_wave <= 128) s_w_vals = [] + s_w = arith.constant(1.0, type=T.f32) for ni in range_constexpr(num_acc_n): if ni == 0 or not _sw_shared_n_s2: sw_idx = _pre_n_block_s2[ni] * c_nblk_k_w2 + kb @@ -1883,7 +1896,7 @@ def compute_tile_bs_s2(acc_in, b_tile_in, lds_base, pre_scales, *, a0_prefetch=N current_acc = list(acc_in) mfma_res_ty = T.f32x4 - if _is_gfx950: + if const_expr(_is_gfx950): def _pack128(x0, x1, x2, x3): v4 = vector.from_elements(T.vec(4, T.i64), [x0, x1, x2, x3]) return vector.bitcast(T.vec(8, T.i32), v4) @@ -1898,6 +1911,8 @@ def _pack128(x0, x1, x2, x3): col1 = col_offset_base_bytes + arith.index(ku1 * 64) for mi in range_constexpr(m_repeat): curr_row = row_a_lds + arith.index(mi * 16) + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if a0_prefetch is not None and sb == 0 and mi == 0: a0, a1 = a0_prefetch else: @@ -1930,7 +1945,7 @@ def _pack128(x0, x1, x2, x3): else: mfma_fn = ( mfma_i32_k32 - if is_int8 + if const_expr(is_int8) else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) ) @@ -1939,7 +1954,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(T.f16x4, v1) def mfma_k64(acc0, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -1960,10 +1975,14 @@ def mfma_k64(acc0, a0, a1, b0, b1): b_packs0, b_packs1 = b_tile_in[ku] ki64 = arith.index(ku * 64) col_base = col_offset_base_bytes + ki64 + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (sb == 0) and (ku_local == 0) and (mi == 0): a0, a1 = a0_prefetch else: - a0, a1 = lds_load_packs_k64(row_a_lds + arith.index(mi * 16), col_base, lds_base) + a0, a1 = lds_load_packs_k64( + row_a_lds + arith.index(mi * 16), col_base, lds_base + ) blk = mfma_k64(blk, a0, a1, b_packs0[ni], b_packs1[ni]) s_w_bc = vector.broadcast(T.f32x4, s_w_vals[ni]) scale = ArithValue(s_a_vec4_list[mi]) * ArithValue(s_w_bc) @@ -1994,7 +2013,7 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False ) # Also prefetch per-row routed/topk weights (sorted_weights) when enabled. tw_pf = None - if doweight_stage2: + if const_expr(doweight_stage2): tw_pf = [] lane_div_16_mul4_pf = lane_div_16 * fx.Index(4) ii_idx_list_pf = [fx.Index(ii) for ii in range(4)] @@ -2016,7 +2035,7 @@ def _i64_to_v4f16(x_i64): return vector.bitcast(T.f16x4, v1) def mfma_k64(acc0, a0, a1, b0, b1): - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) @@ -2035,6 +2054,8 @@ def mfma_k64(acc0, a0, a1, b0, b1): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) if (a0_prefetch is not None) and (ku == 0) and (mi == 0): a0, a1 = a0_prefetch else: @@ -2105,9 +2126,9 @@ def hot_loop_scheduler(): # Prologue. k0 = fx.Index(0) - x_regs0 = load_x_tile(k0) + x_regs0 = load_x_tile(k0, x_load_bytes) b_cur = load_b_tile(k0) - store_x_tile_to_lds(x_regs0, lds_base_cur) + store_x_tile_to_lds(x_regs0, lds_base_cur, x_load_bytes) gpu.barrier() acc = [arith.constant_vector(0.0, T.f32x4)] * (num_acc_n * m_repeat) @@ -2137,12 +2158,12 @@ def hot_loop_scheduler(): # Issue scale loads FIRST so their latency hides behind heavy tile VMEM. pre_scales_pong = load_scales_s2(k_iv) next_k1 = k_iv + tile_k - x_regs_ping = load_x_tile(next_k1) + x_regs_ping = load_x_tile(next_k1, x_load_bytes) b_ping = load_b_tile(next_k1) acc = compute_tile_bs_s2(acc, b_cur, lds_base_pong, pre_scales_pong, a0_prefetch=a0_prefetch_pong) a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) + store_x_tile_to_lds(x_regs_ping, lds_base_ping, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -2152,12 +2173,12 @@ def hot_loop_scheduler(): # Issue scale loads FIRST so their latency hides behind heavy tile VMEM. pre_scales_ping = load_scales_s2(next_k1) next_k2 = k_iv + c2_tile_k - x_regs_pong = load_x_tile(next_k2) + x_regs_pong = load_x_tile(next_k2, x_load_bytes) b_next = load_b_tile(next_k2) acc = compute_tile_bs_s2(acc, b_ping, lds_base_ping, pre_scales_ping, a0_prefetch=a0_prefetch_ping) a0_prefetch_ping = None - store_x_tile_to_lds(x_regs_pong, lds_base_pong) + store_x_tile_to_lds(x_regs_pong, lds_base_pong, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -2183,12 +2204,12 @@ def hot_loop_scheduler(): k_tail1 = k_in - tile_k # Issue scale loads FIRST so their latency hides behind heavy tile VMEM. pre_scales_tail0 = load_scales_s2(k_tail0) - x_regs_ping = load_x_tile(k_tail1) + x_regs_ping = load_x_tile(k_tail1, x_load_bytes) b_ping = load_b_tile(k_tail1) acc = compute_tile_bs_s2(acc, b_cur, lds_base_pong, pre_scales_tail0, a0_prefetch=a0_prefetch_pong) a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) + store_x_tile_to_lds(x_regs_ping, lds_base_ping, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -2223,7 +2244,7 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): # Blockscale: dequant already done in compute_tile_bs_s2, no sw/sx needed here. - if out_is_f32: + if const_expr(out_is_f32): # origin/dev_a16w4: f32 output uses scalar f32 atomics and skips CShuffle/LDS. c4_i32 = fx.Int32(4) @@ -2242,7 +2263,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): t2 = fused2 & mask24_i32 s2 = fused2 >> 24 - if doweight_stage2: + if const_expr(doweight_stage2): tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) idx0 = t2 * model_i32 # i32 element index base @@ -2251,7 +2272,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): col_g = col_g_list[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw col_i32 = arith.index_cast(T.i32, col_g) idx_elem = idx0 + col_i32 @@ -2275,7 +2296,7 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): # For bf16 global atomics (gfx942 only), precompute the output base address. # gfx950+ has buffer_atomic_pk_add_bf16, so bf16 uses buffer atomics there. out_base_idx = None - if _needs_global_atomic_bf16: + if const_expr(_needs_global_atomic_bf16): out_base_idx = buffer_ops.extract_base_index(arg_out) def write_row_to_lds( @@ -2290,7 +2311,8 @@ def write_row_to_lds( lds_out, ): # Blockscale: dequant already done in compute_tile_bs_s2. - if doweight_stage2: + tw = arith.constant(1.0, type=T.f32) + if const_expr(doweight_stage2): tw = buffer_ops.buffer_load( sorted_w_rsrc, row, vec_width=1, dtype=T.f32 ) @@ -2299,7 +2321,7 @@ def write_row_to_lds( col_local = col_base_local + (ni * 16) acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw v_out = arith.trunc_f(out_elem(), v) @@ -2333,7 +2355,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): col_i32 = arith.index_cast(T.i32, col_g0) idx_elem = idx0 + col_i32 idx_elem_even = idx_elem & mask_even_i32 - if _needs_global_atomic_bf16: + if const_expr(_needs_global_atomic_bf16): # gfx942: no buffer_atomic_pk_add_bf16, use global atomicrmw fadd if bool(accumulate): byte_off = idx_elem_even * c2_i32 @@ -2566,12 +2588,12 @@ def moe_reduction_kernel( x_div = fx.logical_divide(x_tiled[None, tile_i32], fx.make_layout(VEC_WIDTH, 1)) x_thread = x_div[None, tid_i32] - if use_mask: + if const_expr(use_mask): m_idx_i32 = fx.Int32(token_idx * c_topk + fx.Index(k)) mv = buffer_ops.buffer_load(mask_rsrc, m_idx_i32, vec_width=1, dtype=i8_type()) mv_ok = mv != fx.Int8(0) - if n_sub > 1: + if const_expr(n_sub > 1): x_inner = fx.logical_divide(x_thread, fx.make_layout(copy_vec_width, 1)) for si in range_constexpr(n_sub): src = x_inner[None, fx.Int32(si)] if n_sub > 1 else x_thread @@ -2579,18 +2601,18 @@ def moe_reduction_kernel( fx.copy_atom_call(copy_atom, src, r) vec_e = fx.memref_load_vec(r) - if use_mask: + if const_expr(use_mask): zero_e = vector.broadcast(vec_type_e, arith.constant(0.0, type=elem_type())) vec_e = mv_ok.select(vec_e, zero_e) - if elem_bits < 32: + if const_expr(elem_bits < 32): vec_c = vec_e.extf(vec_type_c) else: vec_c = vec_e acc_vecs[si] = acc_vecs[si] + vec_c # ── Store results ── - if n_sub > 1: + if const_expr(n_sub > 1): y_row = Y_buf[tok_i32, None] y_tiled = fx.logical_divide(y_row, fx.make_layout(tile_cols, 1)) y_div = fx.logical_divide(y_tiled[None, tile_i32], fx.make_layout(VEC_WIDTH, 1)) @@ -2598,10 +2620,10 @@ def moe_reduction_kernel( for si in range_constexpr(n_sub): out_vec = acc_vecs[si] - if elem_bits < 32: + if const_expr(elem_bits < 32): out_vec = out_vec.truncf(vec_type_e) - if n_sub > 1: + if const_expr(n_sub > 1): dst = y_inner[None, fx.Int32(si)] else: y_row = Y_buf[tok_i32, None] @@ -2624,7 +2646,7 @@ def moe_reduction_kernel( for k in range_constexpr(topk): k_idx = fx.Index(k) x_idx_i32 = fx.Int32((token_base + k_idx) * c_model_dim + col) - if use_mask: + if const_expr(use_mask): m_idx_i32 = fx.Int32(token_base + k_idx) mv = buffer_ops.buffer_load( mask_rsrc, m_idx_i32, vec_width=1, dtype=i8_type() @@ -2637,12 +2659,12 @@ def moe_reduction_kernel( v = buffer_ops.buffer_load( x_rsrc, x_idx_i32, vec_width=1, dtype=elem_type() ) - if dtype_str in ("f16", "bf16"): + if const_expr(dtype_str in ("f16", "bf16")): v = v.extf(compute_type()) a = a + v out = a - if dtype_str in ("f16", "bf16"): + if const_expr(dtype_str in ("f16", "bf16")): out = out.truncf(elem_type()) y_idx_i32 = fx.Int32(token_idx * c_model_dim + col) buffer_ops.buffer_store(out, y_rsrc, y_idx_i32) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 06eddb24..d0a12e4e 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -22,7 +22,7 @@ from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith from flydsl.expr import gpu, buffer_ops, vector, rocdl -from flydsl.expr import range_constexpr +from flydsl.expr import range_constexpr, const_expr from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr @@ -374,7 +374,7 @@ def silu(x): by = gpu.block_id("x") # tile along inter_dim bx = gpu.block_id("y") # tile along sorted M - if _is_splitk: + if const_expr(_is_splitk): bz = gpu.block_id("z") # K-batch id k_base_idx = bz * arith.index(_k_per_batch) else: @@ -435,7 +435,7 @@ def silu(x): # OUT: normal=[tokens, topk, inter] f16/bf16, split-K=[tokens*topk, 2*inter] f32 out_elem_bytes = 4 if _is_splitk else 2 - if _is_splitk: + if const_expr(_is_splitk): out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(2 * out_elem_bytes) else: out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(out_elem_bytes) @@ -444,18 +444,17 @@ def silu(x): ) # scale_x: fp16/bf16 path ignores (implicit scale=1.0); int4_bf16 also uses 1.0. - if is_f16_or_bf16: - sx_rsrc = None - else: + x_load_bytes = 16 + sx_rsrc = -1 + if const_expr(not is_f16_or_bf16): sx_rows = tokens_in * (c_topk if x_is_token_slot else fx.Index(1)) sx_nbytes_idx = sx_rows * fx.Index(4) sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_idx ) # scale_w: fp16/bf16 (non-int4) path ignores; int4_bf16 needs dequant scale. - if not needs_scale_w: - sw_rsrc = None - else: + sw_rsrc = -1 + if const_expr(needs_scale_w): sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False) @@ -477,18 +476,18 @@ def silu(x): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16/bf16 we require 16B. - if is_f16_or_bf16: - if bytes_per_thread_x % 16 != 0: + if const_expr(is_f16_or_bf16): + if const_expr(bytes_per_thread_x % 16 != 0): raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" ) x_load_bytes = 16 else: - if bytes_per_thread_x % 16 == 0: + if const_expr(bytes_per_thread_x % 16 == 0): x_load_bytes = 16 - elif bytes_per_thread_x % 8 == 0: + elif const_expr(bytes_per_thread_x % 8 == 0): x_load_bytes = 8 - elif bytes_per_thread_x % 4 == 0: + elif const_expr(bytes_per_thread_x % 4 == 0): x_load_bytes = 4 else: raise ValueError( @@ -535,7 +534,7 @@ def x_tile_chunk_coord_i32(i: int): # NOTE: aiter moe_sorting uses sentinel token_id == tokens for padding. # Do NOT rely on buffer OOB semantics for X loads; explicitly mask to a safe row. t_valid_i32 = arith.cmpi(arith.CmpIPredicate.ult, t_raw, tokens_i32) - if x_is_token_slot: + if const_expr(x_is_token_slot): s_raw = fused_i >> 24 # X is indexed by token-slot in **slot-major** order: # row_ts = slot * tokens + token @@ -552,13 +551,13 @@ def x_tile_chunk_coord_i32(i: int): vec4_x = T.vec(4, x_elem) - def load_x(idx_i32): - """Load `x_load_bytes` bytes from X (gmem) into regs. + def load_x(idx_i32, x_load_bytes_v): + """Load `x_load_bytes_v` bytes from X (gmem) into regs. For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. idx_i32 is in dword units; convert to element index for _buffer_load_vec. """ - if x_load_bytes == 16: + if const_expr(x_load_bytes_v == 16): idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -570,20 +569,22 @@ def load_x(idx_i32): elem_bytes=elem_bytes, ) # For 8B/4B, load raw i32 dwords directly. - if x_load_bytes == 8: + if const_expr(x_load_bytes_v == 8): return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) - return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) + if const_expr(x_load_bytes_v == 4): + return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) + raise ValueError(f"Invalid x_load_bytes_v: {x_load_bytes_v}") - def load_x_tile(base_k): + def load_x_tile(base_k, x_load_bytes_v): """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" base_k_div4 = (base_k * arith.index(int(elem_bytes))) // fx.Index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - if x_load_bytes == 16: + x_vec = load_x(idx_i32, x_load_bytes_v) + if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes_v == 8): parts.append(x_vec) else: parts.append(x_vec) @@ -667,7 +668,7 @@ def load_b_pack(base_k, ki_step, ni, blk_list, intra_list): elem_type=w_elem, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes, - unpack_int4=is_int4, + unpack_int4=(is_int4 or is_int4_bf16), ) def load_b_tile(base_k, blk_list, intra_list): @@ -678,7 +679,7 @@ def load_b_tile(base_k, blk_list, intra_list): For groupwise variants, each entry also includes per-group scales: (packs0[ni], packs1[ni], scales0[ni], scales1[ni]) """ - if is_int4_bf16_groupwise: + if const_expr(is_int4_bf16_groupwise): # W4A16 groupwise: load raw packed32 + scale; defer dequant to compute_tile. raw_data = [] for ku in range_constexpr(k_unroll): @@ -701,7 +702,7 @@ def load_b_tile(base_k, blk_list, intra_list): raw_ku.append((packed32, scale_val)) raw_data.append(raw_ku) return raw_data - elif is_int4_bf16: + elif const_expr(is_int4_bf16): # W4A16 per-row: load raw packed32; defer dequant to compute_tile. raw_data = [] for ku in range_constexpr(k_unroll): @@ -738,11 +739,11 @@ def load_b_tile(base_k, blk_list, intra_list): acc_up = [acc_init] * (num_acc_n * m_repeat) # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, @@ -757,7 +758,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, @@ -815,7 +816,7 @@ def compute_tile( gate_list = list(acc_gate_in) up_list = list(acc_up_in) mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 - if _use_mfma_k32: + if const_expr(_use_mfma_k32): mfma_fn = rocdl.mfma_f32_16x16x32_f16 if is_f16 else rocdl.mfma_f32_16x16x32_bf16 else: mfma_fn = ( @@ -830,8 +831,8 @@ def compute_tile( # Optional: prefetch epilogue scales while we are about to run the last MFMA tile, # matching the preshuffle GEMM pattern of overlapping scale loads with MFMA. - epilogue_pf = None - if prefetch_epilogue and not use_groupwise_scale: + epilogue_pf = [] + if const_expr(prefetch_epilogue and not use_groupwise_scale): expert_off_pf = expert_off_idx sw_gate_pf = [] sw_up_pf = [] @@ -868,23 +869,23 @@ def _i64x2_to_v8bf16(lo, hi): return vector.bitcast(T.bf16x8, v2) def mfma_k64(acc_in, a0, a1, b0, b1): - if _use_mfma_k32: + if const_expr(_use_mfma_k32): # gfx950: single 16x16x32 MFMA consuming all 128 bits (K=32 f16/bf16) - if is_f16: + if const_expr(is_f16): av = _i64x2_to_v8f16(a0, a1) bv = _i64x2_to_v8f16(b0, b1) else: av = _i64x2_to_v8bf16(a0, a1) bv = _i64x2_to_v8bf16(b0, b1) return mfma_fn(mfma_res_ty, [av, bv, acc_in, 0, 0, 0]) - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) b1v = _i64_to_v4f16(b1) acc_mid = mfma_fn(mfma_res_ty, [a0v, b0v, acc_in, 0, 0, 0]) return mfma_fn(mfma_res_ty, [a1v, b1v, acc_mid, 0, 0, 0]) - if is_bf16: + if const_expr(is_bf16): a0v = _i64_to_v4i16(a0) a1v = _i64_to_v4i16(a1) b0v = _i64_to_v4i16(b0) @@ -901,7 +902,7 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): scale_vec = _uw(vector.broadcast(T.f32x4, scale_val)) return arith.ArithValue(_math_fma(scale_vec, _uw(f32_partial_vec), _uw(f32_acc_vec))) - if is_int4_bf16 or is_int4_bf16_groupwise: + if const_expr(is_int4_bf16 or is_int4_bf16_groupwise): # W4A16: deferred dequant — unpack int4->bf16 right before MFMA # to minimize VGPR lifetime of dequantized bf16 values. _pending_gate_up = None @@ -914,24 +915,25 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): for mi in range_constexpr(m_repeat): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val - - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0 = arith.constant(-1, type=T.i64) + a1 = arith.constant(-1, type=T.i64) + if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)): a0, a1 = a0_prefetch else: a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) for ni in range_constexpr(num_acc_n): acc_idx = mi * num_acc_n + ni - if is_int4_bf16_groupwise: + if const_expr(is_int4_bf16_groupwise): packed_g, sc_g = b_gate_raw[ni] packed_u, sc_u = b_up_raw[ni] - if _scale_is_bf16: + if const_expr(_scale_is_bf16): sc_g = extract_bf16_scale(arith, sc_g, ku) sc_u = extract_bf16_scale(arith, sc_u, ku) else: packed_g, sc_g = b_gate_raw[ni], None packed_u, sc_u = b_up_raw[ni], None - if is_int4_bf16_groupwise and use_gfx950_cvt: + if const_expr(is_int4_bf16_groupwise and use_gfx950_cvt): # Defer group scale to post-MFMA FMA with pipeline: # Issue current MFMA, then apply FMA for previous iteration's result. bg0, bg1 = unpack_b_w4a16(packed_g, arith, vector, scale_val=None, use_gfx950_cvt=True, defer_scale16=True) @@ -939,7 +941,7 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): bu0, bu1 = unpack_b_w4a16(packed_u, arith, vector, scale_val=None, use_gfx950_cvt=True, defer_scale16=True) tmp_u = mfma_k64(zero_f32_acc, a0, a1, bu0, bu1) # Apply FMA for previous pending result (MFMA already completed). - if _pending_gate_up is not None: + if const_expr(_pending_gate_up is not None): p_idx, p_g, p_u, p_sc_g, p_sc_u = _pending_gate_up gate_list[p_idx] = _acc_scaled_f32(gate_list[p_idx], p_g, p_sc_g) up_list[p_idx] = _acc_scaled_f32(up_list[p_idx], p_u, p_sc_u) @@ -950,7 +952,7 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): bu0, bu1 = unpack_b_w4a16(packed_u, arith, vector, scale_val=sc_u, use_gfx950_cvt=use_gfx950_cvt, defer_scale16=use_gfx950_cvt) up_list[acc_idx] = mfma_k64(up_list[acc_idx], a0, a1, bu0, bu1) # Drain last pending FMA. - if _pending_gate_up is not None: + if const_expr(_pending_gate_up is not None): p_idx, p_g, p_u, p_sc_g, p_sc_u = _pending_gate_up gate_list[p_idx] = _acc_scaled_f32(gate_list[p_idx], p_g, p_sc_g) up_list[p_idx] = _acc_scaled_f32(up_list[p_idx], p_u, p_sc_u) @@ -965,7 +967,7 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)): a0, a1 = a0_prefetch else: a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) @@ -1014,7 +1016,7 @@ def hot_loop_scheduler(): # DS-write hints near the end: match total X LDS-store micro-ops per thread. dswr_tail = num_x_loads - if dswr_tail > sche_iters: + if const_expr(dswr_tail > sche_iters): dswr_tail = sche_iters dswr_start = sche_iters - dswr_tail for sche_i in range_constexpr(sche_iters): @@ -1022,7 +1024,7 @@ def hot_loop_scheduler(): rocdl.sched_mfma(mfma_group) rocdl.sched_dsrd(1) rocdl.sched_mfma(mfma_group) - if sche_i >= dswr_start - 1: + if const_expr(sche_i >= dswr_start - 1): rocdl.sched_dswr(1) rocdl.sched_barrier(0) @@ -1031,7 +1033,7 @@ def hot_loop_scheduler(): x_regs0 = load_x_tile(k0) b_gate_cur = load_b_tile(k0, n_blk_gate, n_intra_gate) b_up_cur = load_b_tile(k0, n_blk_up, n_intra_up) - store_x_tile_to_lds(x_regs0, lds_base_cur) + store_x_tile_to_lds(x_regs0, lds_base_cur, x_load_bytes) gpu.barrier() # Loop-carried ping/pong state. @@ -1074,11 +1076,11 @@ def _flatten_b_tile(b_tile): """Flatten B tile to a 1-D list for scf.for loop-carried state.""" flat = [] for ku_entry in b_tile: - if is_int4_bf16_groupwise: + if const_expr(is_int4_bf16_groupwise): # [(packed, scale), ...] → [packed_0..N, scale_0..N] flat.extend(t[0] for t in ku_entry) flat.extend(t[1] for t in ku_entry) - elif int4_bf16_single_field: + elif const_expr(int4_bf16_single_field): # [raw_i64, ...] → [raw_0..N] flat.extend(ku_entry) else: @@ -1091,13 +1093,13 @@ def _unflatten_b_tile(vals): """Reconstruct B tile from flattened scf.for loop-carried state.""" b_tile, idx = [], 0 for _ in range_constexpr(k_unroll): - if is_int4_bf16_groupwise: + if const_expr(is_int4_bf16_groupwise): packed = list(vals[idx:idx + num_acc_n]) idx += num_acc_n scales = list(vals[idx:idx + num_acc_n]) idx += num_acc_n b_tile.append([(packed[ni], scales[ni]) for ni in range_constexpr(num_acc_n)]) - elif int4_bf16_single_field: + elif const_expr(int4_bf16_single_field): b_tile.append(list(vals[idx:idx + num_acc_n])) idx += num_acc_n else: @@ -1162,7 +1164,7 @@ def _unflatten_b_tile(vals): # After scf.for: extract final state from yielded results. SmemPtr._view_cache = None - if pair_iters > 0: + if const_expr(pair_iters > 0): acc_gate = list(loop_results[:_n_acc]) acc_up = list(loop_results[_n_acc:_p_bg]) b_gate_cur = _unflatten_b_tile(list(loop_results[_p_bg:_p_bu])) @@ -1182,7 +1184,7 @@ def _unflatten_b_tile(vals): a0_prefetch=a0_prefetch_pong, ) a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) + store_x_tile_to_lds(x_regs_ping, lds_base_ping, x_load_bytes) hot_loop_scheduler() gpu.barrier() @@ -1207,15 +1209,15 @@ def _unflatten_b_tile(vals): topk_i32_v = topk_i32 inter_i32_v = fx.Int32(inter_dim) mask24_i32 = fx.Int32(0xFFFFFF) + sw_gate_vals = [] + sw_up_vals = [] - if use_groupwise_scale: + if const_expr(use_groupwise_scale): sw_gate_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n sw_up_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n - elif epilogue_pf is not None: + elif const_expr(epilogue_pf is not None): sw_gate_vals, sw_up_vals = epilogue_pf else: - sw_gate_vals = [] - sw_up_vals = [] for ni in range_constexpr(num_acc_n): col_g = col_g_list[ni] row_gate_idx = expert_off + col_g @@ -1233,7 +1235,7 @@ def _unflatten_b_tile(vals): # When defer_scale16 was used, the x16 correction for v_cvt_off_f32_i4 # was omitted from the hot loop. Fold it into the epilogue scale. - if use_gfx950_cvt: + if const_expr(use_gfx950_cvt): _c16 = fx.Float32(16.0) sw_gate_vals = [v * _c16 for v in sw_gate_vals] sw_up_vals = [v * _c16 for v in sw_up_vals] @@ -1250,8 +1252,8 @@ def _unflatten_b_tile(vals): use_cshuffle_epilog_flag = _use_cshuffle_epilog # ─── Split-K epilogue: two-pass gate/up with f32 atomic fadd ─── - if _is_splitk: - if lds_out is None: + if const_expr(_is_splitk): + if const_expr(lds_out is None): raise RuntimeError("Split-K epilogue requires lds_out (CShuffle)") out_base_idx = buffer_ops.extract_base_index(arg_out) @@ -1283,7 +1285,7 @@ def write_row_to_lds_splitk( fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=T.i32) t2 = fused2 & mask24_i32 t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v) - if x_is_token_slot: + if const_expr(x_is_token_slot): s2 = fused2 >> 24 ts2 = s2 * tokens_i32_v + t2 sx = ( @@ -1311,7 +1313,7 @@ def write_row_to_lds_splitk( v = vector.extract( _acc[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8: + if const_expr(is_int8): v = arith.sitofp(T.f32, v) v = v * sx * _sw[ni] lds_idx = row_base_lds + col_local @@ -1412,8 +1414,8 @@ def store_pair_splitk(*, row_local, row, row_ctx, col_pair0, col_g0, frag): ) return - if use_cshuffle_epilog_flag: - if lds_out is None: + if const_expr(use_cshuffle_epilog_flag): + if const_expr(lds_out is None): raise RuntimeError("CShuffle epilogue enabled but lds_out is not allocated/aliased.") def write_row_to_lds( @@ -1434,7 +1436,7 @@ def write_row_to_lds( # aiter moe_sorting uses sentinel token_id == tokens for padding. # Do NOT rely on buffer OOB semantics for scale loads; explicitly mask. t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v) - if x_is_token_slot: + if const_expr(x_is_token_slot): # slot-major: slot*tokens + token ts2 = s2 * tokens_i32_v + t2 sx = ( @@ -1458,7 +1460,8 @@ def write_row_to_lds( ) # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: + tw = fx.Float32(1.0) + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) for ni in range_constexpr(num_acc_n): @@ -1474,14 +1477,14 @@ def write_row_to_lds( acc_up[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8: + if const_expr(is_int8): vg = arith.sitofp(T.f32, vg) vu = arith.sitofp(T.f32, vu) vg = vg * sx * sw_gate vu = vu * sx * sw_up y = silu(vg) * vu - if doweight_stage1: + if const_expr(doweight_stage1): y = y * tw y16 = arith.trunc_f(T.f16, y) @@ -1545,7 +1548,8 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v) # Do NOT rely on buffer OOB semantics for scale loads; explicitly mask. - if x_is_token_slot: + sx0 = fx.Float32(1.0) + if const_expr(x_is_token_slot): # slot-major: slot*tokens + token ts2 = s2 * tokens_i32_v + t2 sx0 = ( @@ -1574,7 +1578,8 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): idx0 = (t2 * topk_i32_v + s2) * inter_i32_local # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: + tw = fx.Float32(1.0) + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load(sorted_w_rsrc, row, vec_width=1, dtype=T.f32) _if_valid = scf.IfOp(t_valid) @@ -1592,14 +1597,14 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): acc_up[acc_idx], static_position=[ii], dynamic_position=[] ) - if is_int8: + if const_expr(is_int8): vg = arith.sitofp(T.f32, vg) vu = arith.sitofp(T.f32, vu) vg = vg * sx * sw_gate vu = vu * sx * sw_up y = silu(vg) * vu - if doweight_stage1: + if const_expr(doweight_stage1): y = y * tw y = arith.trunc_f(out_mlir(), y) idx_out0 = idx0 + col_i32 @@ -1991,7 +1996,7 @@ def moe_gemm2( # OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens. out_elem_bytes = 4 if out_is_f32 else 2 out_nbytes_idx = tokens_in * n_in * fx.Index(out_elem_bytes) - if not bool(accumulate): + if const_expr(not bool(accumulate)): out_nbytes_idx = ( tokens_in * fx.Index(topk) @@ -2002,18 +2007,16 @@ def moe_gemm2( arg_out, max_size=False, num_records_bytes=out_nbytes_idx ) # scale_x: fp16/bf16 path ignores (implicit scale=1.0); int4_bf16 also uses 1.0. - if is_f16_or_bf16: - sx_rsrc = None - else: + sx_rsrc = -1 + if const_expr(not is_f16_or_bf16): # scale_x (A2 scale): [tokens*topk] f32 -> bytes = tokens*topk*4 sx_nbytes_idx = (tokens_in * c_topk) * fx.Index(4) sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_idx ) # scale_w: fp16/bf16 (non-int4) path ignores; int4_bf16 needs dequant scale. - if not needs_scale_w: - sw_rsrc = None - else: + sw_rsrc = -1 + if const_expr(needs_scale_w): # scale_w: [experts*model_dim] f32 (static shape in practice) sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) @@ -2060,18 +2063,19 @@ def _moe_gemm2_then_body(): # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16/bf16 we require 16B. - if is_f16_or_bf16: - if bytes_per_thread_x % 16 != 0: + x_load_bytes = 16 + if const_expr(is_f16_or_bf16): + if const_expr(bytes_per_thread_x % 16 != 0): raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" ) x_load_bytes = 16 else: - if bytes_per_thread_x % 16 == 0: + if const_expr(bytes_per_thread_x % 16 == 0): x_load_bytes = 16 - elif bytes_per_thread_x % 8 == 0: + elif const_expr(bytes_per_thread_x % 8 == 0): x_load_bytes = 8 - elif bytes_per_thread_x % 4 == 0: + elif const_expr(bytes_per_thread_x % 4 == 0): x_load_bytes = 4 else: raise ValueError( @@ -2105,8 +2109,8 @@ def x_tile_chunk_coord_i32(i: int): vec4_x = T.vec(4, x_elem) - def load_x(idx_i32): - if x_load_bytes == 16: + def load_x(idx_i32, x_load_bytes_v): + if const_expr(x_load_bytes_v == 16): idx_elem = idx_i32 if elem_bytes == 1 else (idx_i32 * fx.Index(2)) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -2117,7 +2121,7 @@ def load_x(idx_i32): vec_elems=vec16_elems, elem_bytes=elem_bytes, ) - if x_load_bytes == 8: + if const_expr(x_load_bytes_v == 8): return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=2, dtype=T.i32) return buffer_ops.buffer_load(x_rsrc, idx_i32, vec_width=1, dtype=T.i32) @@ -2146,15 +2150,15 @@ def load_x(idx_i32): # Base row offset in dword units: row_ts_idx * (k_in/4) x_row_base_div4.append(row_ts_idx * c_k_div4) - def load_x_tile(base_k): + def load_x_tile(base_k, x_load_bytes_v): base_k_div4 = (base_k * arith.index(int(elem_bytes))) // fx.Index(4) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - if x_load_bytes == 16: + x_vec = load_x(idx_i32, x_load_bytes_v) + if const_expr(x_load_bytes_v == 16): parts.append(vector.bitcast(T.i32x4, x_vec)) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes_v == 8): parts.append(vector.bitcast(T.vec(2, T.i32), x_vec)) else: parts.append(vector.bitcast(T.vec(1, T.i32), x_vec)) @@ -2224,7 +2228,7 @@ def load_b_pack(base_k, ki_step, ni): elem_type=w_elem, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes, - unpack_int4=is_int4, + unpack_int4=(is_int4 or is_int4_bf16), ) def load_b_tile(base_k): @@ -2235,7 +2239,7 @@ def load_b_tile(base_k): For groupwise variants, each entry also includes per-group scales: (packs0[ni], packs1[ni], scales0[ni], scales1[ni]) """ - if is_int4_bf16_groupwise: + if const_expr(is_int4_bf16_groupwise): # W4A16 groupwise: load raw packed32 + scale; defer dequant to compute_tile. raw_data = [] for ku in range_constexpr(k_unroll): @@ -2258,7 +2262,7 @@ def load_b_tile(base_k): raw_ku.append((packed32, scale_val)) raw_data.append(raw_ku) return raw_data - elif is_int4_bf16: + elif const_expr(is_int4_bf16): # W4A16 per-row: load raw packed32; defer dequant to compute_tile. raw_data = [] for ku in range_constexpr(k_unroll): @@ -2292,11 +2296,11 @@ def load_b_tile(base_k): return b_tile # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes_v == 16): lds_store_16b_xor16( arith, vector, @@ -2311,7 +2315,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes_v == 8): lds_store_8b_xor16( arith, vector, @@ -2356,10 +2360,12 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) return a0, a1 + epilogue_pf = [] + def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False, a0_prefetch=None): acc_list = list(acc_in) mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 - if _use_mfma_k32: + if const_expr(_use_mfma_k32): mfma_fn = rocdl.mfma_f32_16x16x32_f16 if is_f16 else rocdl.mfma_f32_16x16x32_bf16 else: mfma_fn = ( @@ -2373,7 +2379,7 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False ) epilogue_pf = None - if prefetch_epilogue and not use_groupwise_scale: + if const_expr(prefetch_epilogue and not use_groupwise_scale): expert_off_pf = expert_off_idx sw_pf = [] for ni in range_constexpr(num_acc_n): @@ -2385,9 +2391,8 @@ def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False else buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=T.f32) ) # Also prefetch per-row routed/topk weights (sorted_weights) when enabled. - tw_pf = None - if doweight_stage2: - tw_pf = [] + tw_pf = [] + if const_expr(doweight_stage2): lane_div_16_mul4_pf = lane_div_16 * fx.Index(4) ii_idx_list_pf = [fx.Index(ii) for ii in range(4)] for mi in range_constexpr(m_repeat): @@ -2420,23 +2425,23 @@ def _i64x2_to_v8bf16(lo, hi): return vector.bitcast(T.bf16x8, v2) def mfma_k64(acc0, a0, a1, b0, b1): - if _use_mfma_k32: + if const_expr(_use_mfma_k32): # gfx950: single 16x16x32 MFMA consuming all 128 bits (K=32 f16/bf16) - if is_f16: + if const_expr(is_f16): av = _i64x2_to_v8f16(a0, a1) bv = _i64x2_to_v8f16(b0, b1) else: av = _i64x2_to_v8bf16(a0, a1) bv = _i64x2_to_v8bf16(b0, b1) return mfma_fn(mfma_res_ty, [av, bv, acc0, 0, 0, 0]) - if is_f16: + if const_expr(is_f16): a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) b0v = _i64_to_v4f16(b0) b1v = _i64_to_v4f16(b1) acc1 = mfma_fn(mfma_res_ty, [a0v, b0v, acc0, 0, 0, 0]) return mfma_fn(mfma_res_ty, [a1v, b1v, acc1, 0, 0, 0]) - if is_bf16: + if const_expr(is_bf16): a0v = _i64_to_v4i16(a0) a1v = _i64_to_v4i16(a1) b0v = _i64_to_v4i16(b0) @@ -2453,7 +2458,7 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): scale_vec = _uw(vector.broadcast(T.f32x4, scale_val)) return arith.ArithValue(_math_fma(scale_vec, _uw(f32_partial_vec), _uw(f32_acc_vec))) - if is_int4_bf16 or is_int4_bf16_groupwise: + if const_expr(is_int4_bf16 or is_int4_bf16_groupwise): # W4A16: deferred dequant -- unpack int4->bf16 right before MFMA # to minimize VGPR lifetime of dequantized bf16 values. _pending_acc = None @@ -2466,23 +2471,23 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)): a0, a1 = a0_prefetch else: a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) for ni in range_constexpr(num_acc_n): acc_idx = mi * num_acc_n + ni - if is_int4_bf16_groupwise: + if const_expr(is_int4_bf16_groupwise): packed, sc = b_raw[ni] - if _scale_is_bf16: + if const_expr(_scale_is_bf16): sc = extract_bf16_scale(arith, sc, ku) else: packed, sc = b_raw[ni], None - if is_int4_bf16_groupwise and use_gfx950_cvt: + if const_expr(is_int4_bf16_groupwise and use_gfx950_cvt): b0, b1 = unpack_b_w4a16(packed, arith, vector, scale_val=None, use_gfx950_cvt=True, defer_scale16=True) tmp = mfma_k64(zero_f32_acc, a0, a1, b0, b1) - if _pending_acc is not None: + if const_expr(_pending_acc is not None): p_idx, p_tmp, p_sc = _pending_acc acc_list[p_idx] = _acc_scaled_f32(acc_list[p_idx], p_tmp, p_sc) _pending_acc = (acc_idx, tmp, sc) @@ -2490,7 +2495,7 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): b0, b1 = unpack_b_w4a16(packed, arith, vector, scale_val=sc, use_gfx950_cvt=use_gfx950_cvt, defer_scale16=use_gfx950_cvt) acc_list[acc_idx] = mfma_k64(acc_list[acc_idx], a0, a1, b0, b1) # Drain last pending FMA. - if _pending_acc is not None: + if const_expr(_pending_acc is not None): p_idx, p_tmp, p_sc = _pending_acc acc_list[p_idx] = _acc_scaled_f32(acc_list[p_idx], p_tmp, p_sc) else: @@ -2503,7 +2508,7 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): mi_val = arith.index(mi * 16) curr_row_a_lds = row_a_lds + mi_val - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)): a0, a1 = a0_prefetch else: a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) @@ -2575,25 +2580,25 @@ def hot_loop_scheduler(): rocdl.sched_dsrd(2) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) - if num_acc_n < 4: + if const_expr(num_acc_n < 4): rocdl.sched_dsrd(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_dsrd(1) rocdl.sched_mfma(1) - if tile_m == 16: + if const_expr(tile_m == 16): rocdl.sched_vmem(1) rocdl.sched_mfma(1) # DS-write hints near the end: match total A LDS-store micro-ops per thread. dswr_tail = num_x_loads - if dswr_tail > sche_iters: + if const_expr(dswr_tail > sche_iters): dswr_tail = sche_iters dswr_start = sche_iters - dswr_tail @@ -2602,15 +2607,15 @@ def hot_loop_scheduler(): rocdl.sched_mfma(mfma_group) rocdl.sched_dsrd(1) rocdl.sched_mfma(mfma_group) - if sche_i >= dswr_start - 1: + if const_expr(sche_i >= dswr_start - 1): rocdl.sched_dswr(1) rocdl.sched_barrier(0) # Prologue. k0 = fx.Index(0) - x_regs0 = load_x_tile(k0) + x_regs0 = load_x_tile(k0, x_load_bytes) b_cur = load_b_tile(k0) - store_x_tile_to_lds(x_regs0, lds_base_cur) + store_x_tile_to_lds(x_regs0, lds_base_cur, x_load_bytes) gpu.barrier() acc = [acc_init] * (num_acc_n * m_repeat) @@ -2630,7 +2635,7 @@ def hot_loop_scheduler(): odd_k_tiles = (num_k_tiles_py % 2) == 1 tail_tiles = 1 if odd_k_tiles else 2 k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) - if k_main2_py < 0: + if const_expr(k_main2_py < 0): k_main2_py = 0 c2_tile_k = arith.index(tile_k * 2) @@ -2650,10 +2655,10 @@ def _flatten_b_tile(b_tile): """Flatten B tile to a 1-D list for scf.for loop-carried state.""" flat = [] for ku_entry in b_tile: - if is_int4_bf16_groupwise: + if const_expr(is_int4_bf16_groupwise): flat.extend(t[0] for t in ku_entry) flat.extend(t[1] for t in ku_entry) - elif int4_bf16_single_field: + elif const_expr(int4_bf16_single_field): flat.extend(ku_entry) else: flat.extend(ku_entry[0]) @@ -2664,13 +2669,13 @@ def _unflatten_b_tile(vals): """Reconstruct B tile from flattened scf.for loop-carried state.""" b_tile, idx = [], 0 for _ in range_constexpr(k_unroll): - if is_int4_bf16_groupwise: + if const_expr(is_int4_bf16_groupwise): packed = list(vals[idx:idx + num_acc_n]) idx += num_acc_n scales = list(vals[idx:idx + num_acc_n]) idx += num_acc_n b_tile.append([(packed[ni], scales[ni]) for ni in range_constexpr(num_acc_n)]) - elif int4_bf16_single_field: + elif const_expr(int4_bf16_single_field): b_tile.append(list(vals[idx:idx + num_acc_n])) idx += num_acc_n else: @@ -2715,12 +2720,12 @@ def _unflatten_b_tile(vals): loop_results = yield list(_ac) + _flatten_b_tile(_bn) + list(_a0n) SmemPtr._view_cache = None - if pair_iters > 0: + if const_expr(pair_iters > 0): acc = list(loop_results[:_n_acc]) b_cur = _unflatten_b_tile(list(loop_results[_p_b:_p_a0])) a0_prefetch_pong = (loop_results[_p_a0], loop_results[_p_a0 + 1]) - if odd_k_tiles: + if const_expr(odd_k_tiles): acc, epilogue_pf = compute_tile( acc, b_cur, @@ -2730,7 +2735,7 @@ def _unflatten_b_tile(vals): ) else: k_tail1 = k_in - tile_k - x_regs_ping = load_x_tile(k_tail1) + x_regs_ping = load_x_tile(k_tail1, x_load_bytes) b_ping = load_b_tile(k_tail1) acc, _ = compute_tile(acc, b_cur, lds_base_pong, a0_prefetch=a0_prefetch_pong) @@ -2767,17 +2772,17 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): sw_pf = None tw_pf = None - if epilogue_pf is not None: + if const_expr(epilogue_pf is not None): sw_pf, tw_pf = epilogue_pf # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). - if use_groupwise_scale: + sw_vals = [] + if const_expr(use_groupwise_scale): # Groupwise: weight scale already applied per-group in K-loop. sw_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n - elif sw_pf is not None: + elif const_expr(sw_pf is not None): sw_vals = sw_pf else: - sw_vals = [] for ni in range_constexpr(num_acc_n): col_g = col_g_list[ni] row_w_idx = expert_off + col_g @@ -2789,11 +2794,11 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): # When defer_scale16 was used, the x16 correction for v_cvt_off_f32_i4 # was omitted from the hot loop. Fold it into the epilogue scale. - if use_gfx950_cvt: + if const_expr(use_gfx950_cvt): _c16 = fx.Float32(16.0) sw_vals = [v * _c16 for v in sw_vals] - if out_is_f32: + if const_expr(out_is_f32): # origin/dev_a16w4: f32 output uses scalar f32 atomics and skips CShuffle/LDS. c4_i32 = fx.Int32(4) @@ -2829,9 +2834,10 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): ) ) - if doweight_stage2: + tw = fx.Float32(1.0) + if const_expr(doweight_stage2): tw_idx = (mi * 4) + ii - if tw_pf is not None: + if const_expr(tw_pf is not None): tw = ts_ok.select(tw_pf[tw_idx], fx.Float32(0.0)) else: tw = arith.select( @@ -2847,10 +2853,10 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): sw = sw_vals[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if is_int8: + if const_expr(is_int8): v = arith.sitofp(T.f32, v) v = v * sx * sw - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw col_i32 = arith.index_cast(T.i32, col_g) idx_elem = idx0 + col_i32 @@ -2866,15 +2872,15 @@ def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): body_row=_stage2_row_atomic, ) else: - if lds_out is None: + if const_expr(lds_out is None): raise RuntimeError( "FLYDSL_MOE_STAGE2_CSHUFFLE=1 but lds_out is not allocated/aliased." ) # For bf16 global atomics (gfx942 only), precompute the output base address. # gfx950+ has buffer_atomic_pk_add_bf16, so bf16 uses buffer atomics there. - out_base_idx = None - if _needs_global_atomic_bf16: + out_base_idx = fx.Index(0) + if const_expr(_needs_global_atomic_bf16): out_base_idx = buffer_ops.extract_base_index(arg_out) def write_row_to_lds( @@ -2908,9 +2914,10 @@ def write_row_to_lds( ) ) - if doweight_stage2: + tw = fx.Float32(1.0) + if const_expr(doweight_stage2): tw_idx = (mi * 4) + ii - if tw_pf is not None: + if const_expr(tw_pf is not None): tw = tw_pf[tw_idx] else: tw = buffer_ops.buffer_load( @@ -2922,10 +2929,10 @@ def write_row_to_lds( sw = sw_vals[ni] acc_idx = mi * num_acc_n + ni v = vector.extract(acc[acc_idx], static_position=[ii], dynamic_position=[]) - if is_int8: + if const_expr(is_int8): v = arith.sitofp(T.f32, v) v = v * sx * sw - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw v_out = arith.trunc_f(out_elem(), v) @@ -2953,15 +2960,15 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): t = fused & mask24_i32 s = fused >> 24 idx0 = t * model_i32 - if not bool(accumulate): + if const_expr(not bool(accumulate)): ts = t * topk_i32_v + s idx0 = ts * model_i32 col_i32 = arith.index_cast(T.i32, col_g0) idx_elem = idx0 + col_i32 idx_elem_even = idx_elem & mask_even_i32 - if _needs_global_atomic_bf16: + if const_expr(_needs_global_atomic_bf16): # gfx942: no buffer_atomic_pk_add_bf16, use global atomicrmw fadd - if bool(accumulate): + if const_expr(bool(accumulate)): byte_off = idx_elem_even * c2_i32 byte_off_idx = arith.index_cast(T.index, byte_off) ptr_addr_idx = out_base_idx + byte_off_idx @@ -2981,7 +2988,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): else: # f16, or bf16 on gfx950+ (has buffer_atomic_pk_add_bf16) byte_off = idx_elem_even * c2_i32 - if bool(accumulate): + if const_expr(bool(accumulate)): atomic_add_f16x2(frag, byte_off) else: buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) @@ -3187,12 +3194,13 @@ def moe_reduction_kernel( x_div = fx.logical_divide(x_tiled[None, tile_i32], fx.make_layout(VEC_WIDTH, 1)) x_thread = x_div[None, tid_i32] - if use_mask: + mv_ok = True + if const_expr(use_mask): m_idx_i32 = fx.Int32(token_idx * c_topk + fx.Index(k)) mv = buffer_ops.buffer_load(mask_rsrc, m_idx_i32, vec_width=1, dtype=i8_type()) mv_ok = mv != fx.Int8(0) - if n_sub > 1: + if const_expr(n_sub > 1): x_inner = fx.logical_divide(x_thread, fx.make_layout(copy_vec_width, 1)) for si in range_constexpr(n_sub): src = x_inner[None, fx.Int32(si)] if n_sub > 1 else x_thread @@ -3200,18 +3208,15 @@ def moe_reduction_kernel( fx.copy_atom_call(copy_atom, src, r) vec_e = fx.memref_load_vec(r) - if use_mask: + if const_expr(use_mask): zero_e = vector.broadcast(vec_type_e, arith.constant(0.0, type=elem_type())) vec_e = mv_ok.select(vec_e, zero_e) - if elem_bits < 32: - vec_c = vec_e.extf(vec_type_c) - else: - vec_c = vec_e + vec_c = vec_e.extf(vec_type_c) if elem_bits < 32 else vec_e acc_vecs[si] = acc_vecs[si] + vec_c # ── Store results ── - if n_sub > 1: + if const_expr(n_sub > 1): y_row = Y_buf[tok_i32, None] y_tiled = fx.logical_divide(y_row, fx.make_layout(tile_cols, 1)) y_div = fx.logical_divide(y_tiled[None, tile_i32], fx.make_layout(VEC_WIDTH, 1)) @@ -3219,10 +3224,12 @@ def moe_reduction_kernel( for si in range_constexpr(n_sub): out_vec = acc_vecs[si] - if elem_bits < 32: + if const_expr(elem_bits < 32): out_vec = out_vec.truncf(vec_type_e) - if n_sub > 1: + # Placeholder init to avoid unbound name before branch assignment. + dst = fx.make_layout(1, 1) + if const_expr(n_sub > 1): dst = y_inner[None, fx.Int32(si)] else: y_row = Y_buf[tok_i32, None] @@ -3246,7 +3253,8 @@ def moe_reduction_kernel( for k in range_constexpr(topk): k_idx = fx.Index(k) x_idx_i32 = fx.Int32((token_base + k_idx) * c_model_dim + col) - if use_mask: + v = arith.constant(0.0, type=elem_type()) + if const_expr(use_mask): m_idx_i32 = fx.Int32(token_base + k_idx) mv = buffer_ops.buffer_load( mask_rsrc, m_idx_i32, vec_width=1, dtype=i8_type() @@ -3259,12 +3267,12 @@ def moe_reduction_kernel( v = buffer_ops.buffer_load( x_rsrc, x_idx_i32, vec_width=1, dtype=elem_type() ) - if dtype_str in ("f16", "bf16"): + if const_expr(dtype_str in ("f16", "bf16")): v = v.extf(compute_type()) a = a + v out = a - if dtype_str in ("f16", "bf16"): + if const_expr(dtype_str in ("f16", "bf16")): out = out.truncf(elem_type()) y_idx_i32 = fx.Int32(token_idx * c_model_dim + col) buffer_ops.buffer_store(out, y_rsrc, y_idx_i32) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 1e6d38ed..31e814cb 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -258,7 +258,10 @@ def _elem_type(): return T.f16 if is_bf16: return T.bf16 - if is_fp4: + def load_fp4_scale_chunk(_base_k): + raise RuntimeError("load_fp4_scale_chunk called when is_fp4=False") + + if const_expr(is_fp4): return T.i8 return T.i8 if is_int8 else T.f8 @@ -288,6 +291,9 @@ def _out_elem(): lds_tile_bytes = int(tile_m) * int(lds_stride_bytes) // a_elem_vec_pack lds_out_bytes = 2 * int(tile_m) * int(tile_n) if use_cshuffle_epilog else 0 + lds_pong_offset = 0 + lds_ping_offset = 0 + lds_alloc_offset = 0 if int(lds_stage) == 2: assert lds_out_bytes % 2 == 0, "lds_out_bytes should be multiple of 2" buffer_size_bytes = max(lds_tile_bytes, lds_out_bytes // lds_stage) @@ -327,6 +333,7 @@ def kernel_gemm( ) # ---- Layouts ---- + _k_div4_factor = (K * elem_bytes) // 4 // a_elem_vec_pack kpack_bytes = 8 if is_int4 else 16 @@ -359,31 +366,39 @@ def kernel_gemm( base_ptr_pong = allocator_pong.get_base() base_ptr_ping = allocator_ping.get_base() - if lds_stage == 2: - lds_a_pong = SmemPtr( + lds_a_pong_ptr = SmemPtr(base_ptr_pong, lds_alloc_offset, _elem_type(), shape=(1,)) + lds_a_ping_ptr = lds_a_pong_ptr + lds_out_ptr = SmemPtr(base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(1,)) + + if const_expr(lds_stage == 2): + lds_a_pong_ptr = SmemPtr( base_ptr_pong, lds_pong_offset, _elem_type(), shape=(tile_m * tile_k,) - ).get() - lds_a_ping = SmemPtr( + ) + lds_a_ping_ptr = SmemPtr( base_ptr_ping, lds_ping_offset, _elem_type(), shape=(tile_m * tile_k,) - ).get() + ) if use_cshuffle_epilog: - lds_out = SmemPtr( + lds_out_ptr = SmemPtr( base_ptr_pong, lds_pong_offset, _out_elem(), shape=(tile_m * tile_n,) - ).get() + ) else: - lds_out = None + lds_out_ptr = SmemPtr(base_ptr_pong, lds_pong_offset, _out_elem(), shape=(1,)) else: - lds_a_ptr = SmemPtr( + lds_a_pong_ptr = SmemPtr( base_ptr_pong, lds_alloc_offset, _elem_type(), shape=(lds_total_elems,) ) - lds_a_pong = lds_a_ptr.get() - lds_a_ping = lds_a_pong - lds_out = ( - SmemPtr(base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(tile_m * tile_n,)).get() - if use_cshuffle_epilog - else None - ) + lds_a_ping_ptr = lds_a_pong_ptr + if use_cshuffle_epilog: + lds_out_ptr = SmemPtr( + base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(tile_m * tile_n,) + ) + else: + lds_out_ptr = SmemPtr(base_ptr_pong, lds_alloc_offset, _out_elem(), shape=(1,)) + + lds_a_pong = lds_a_pong_ptr.get() + lds_a_ping = lds_a_ping_ptr.get() + lds_out = lds_out_ptr.get() # ---- Buffer resources (runtime byte sizes for OOB protection) ---- _a_nrec = arith.index_cast(T.i64, c_m * (K * elem_bytes // a_elem_vec_pack)) @@ -476,11 +491,11 @@ def _extract_b_packs(b16): b_i64x2 = vector.bitcast(T.i64x2, b16) b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) - if not is_f16_or_bf16: + if const_expr(not is_f16_or_bf16): return b0_i64, b1_i64 b0_v1 = vector.from_elements(T.vec(1, T.i64), [b0_i64]) b1_v1 = vector.from_elements(T.vec(1, T.i64), [b1_i64]) - if is_f16: + if const_expr(is_f16): return vector.bitcast(T.f16x4, b0_v1), vector.bitcast(T.f16x4, b1_v1) return vector.bitcast(T.i16x4, b0_v1), vector.bitcast(T.i16x4, b1_v1) @@ -494,7 +509,7 @@ def _load_b_single(k_dword_offset, ni): return _extract_b_packs(b16) def load_b_packs_k64(base_k, ku: int, ni: int): - if is_int4: + if const_expr(is_int4): ki0 = (ku * 2) + 0 ki1 = (ku * 2) + 1 return load_b_pack(base_k, ki0, ni), load_b_pack(base_k, ki1, ni) @@ -511,7 +526,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): return _extract_b_packs(b16) def load_b_tile(base_k): - if not is_int4 and not is_f16_or_bf16: + if const_expr((not is_int4) and (not is_f16_or_bf16)): base_k_bytes = base_k * elem_bytes k0_base = base_k_bytes // c64_b k_dwords = [] @@ -558,12 +573,12 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): a0_i64 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) a1_i64 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) - if not is_f16_or_bf16: + if const_expr(not is_f16_or_bf16): return a0_i64, a1_i64 a0_v1 = vector.from_elements(T.vec(1, T.i64), [a0_i64]) a1_v1 = vector.from_elements(T.vec(1, T.i64), [a1_i64]) - if is_f16: + if const_expr(is_f16): return vector.bitcast(T.f16x4, a0_v1), vector.bitcast(T.f16x4, a1_v1) return vector.bitcast(T.i16x4, a0_v1), vector.bitcast(T.i16x4, a1_v1) @@ -623,41 +638,54 @@ def a_tile_chunk_coord_i32_async(i: int): chunk_i32=a_async_load_dword, ) - def dma_a_tile_to_lds(base_k_div4, lds_buffer): + def dma_a_tile_to_lds( + base_k_div4, + lds_buffer, + *, + wave_id_v, + wave_size_v, + dma_bytes_v, + num_a_async_loads_v, + a_tile_chunk_coord_i32_async_fn, + c4_v, + k_blocks16_v, + bx_m_v, + k_bytes_factor_v, + total_threads_v, + a_rsrc_v, + ): from flydsl._mlir.dialects import memref as memref_dialect - dma_bytes = a_async_load_bytes wave_offset = rocdl.readfirstlane( T.i64, arith.index_cast( - T.i64, wave_id * arith.constant(wave_size * dma_bytes, index=True) + T.i64, wave_id_v * arith.constant(wave_size_v * dma_bytes_v, index=True) ), ) - - for i in range_constexpr(num_a_async_loads): - row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32_async(i) - col_a_local_sw = swizzle_xor16(row_a_local, col_a_local_i32 * c4, k_blocks16) - row_a_global = bx_m + row_a_local - global_byte_idx = row_a_global * k_bytes_factor + (base_k_div4 * c4 + col_a_local_sw) + lds_base = memref_dialect.extract_aligned_pointer_as_index(lds_buffer) + lds_ptr_base = buffer_ops.create_llvm_ptr(arith.index_cast(T.i64, lds_base), address_space=3) + lds_ptr = buffer_ops.get_element_ptr(lds_ptr_base, wave_offset) + + for i in range_constexpr(num_a_async_loads_v): + row_a_local, col_a_local_i32 = a_tile_chunk_coord_i32_async_fn(i) + col_a_local_sw = swizzle_xor16(row_a_local, col_a_local_i32 * c4_v, k_blocks16_v) + row_a_global = bx_m_v + row_a_local + global_byte_idx = row_a_global * k_bytes_factor_v + (base_k_div4 * c4_v + col_a_local_sw) global_offset = arith.index_cast(T.i32, global_byte_idx) - if i == 0: - lds_base = memref_dialect.extract_aligned_pointer_as_index(lds_buffer) - lds_ptr_base = buffer_ops.create_llvm_ptr(arith.index_cast(T.i64, lds_base), address_space=3) - lds_ptr = buffer_ops.get_element_ptr(lds_ptr_base, wave_offset) - else: + if const_expr(i > 0): lds_ptr = buffer_ops.get_element_ptr( lds_ptr, - static_byte_offset=total_threads * dma_bytes, + static_byte_offset=total_threads_v * dma_bytes_v, ) - size_i32 = arith.constant(dma_bytes, type=T.i32) + size_i32 = arith.constant(dma_bytes_v, type=T.i32) soffset = arith.constant(0, type=T.i32) offset_imm = arith.constant(0, type=T.i32) aux = arith.constant(1, type=T.i32) rocdl.raw_ptr_buffer_load_lds( - a_rsrc, + a_rsrc_v, lds_ptr, size_i32, global_offset, @@ -666,9 +694,23 @@ def dma_a_tile_to_lds(base_k_div4, lds_buffer): aux, ) - def prefetch_a_to_lds(base_k, lds_buffer): - base_k_div4 = base_k // 4 // a_elem_vec_pack - dma_a_tile_to_lds(base_k_div4, lds_buffer) + def prefetch_a_to_lds(base_k, lds_buffer, *, a_elem_vec_pack_v, dma_a_tile_to_lds_fn): + base_k_div4 = base_k // 4 // a_elem_vec_pack_v + dma_a_tile_to_lds_fn( + base_k_div4, + lds_buffer, + wave_id_v=wave_id, + wave_size_v=wave_size, + dma_bytes_v=a_async_load_bytes, + num_a_async_loads_v=num_a_async_loads, + a_tile_chunk_coord_i32_async_fn=a_tile_chunk_coord_i32_async, + c4_v=c4, + k_blocks16_v=k_blocks16, + bx_m_v=bx_m, + k_bytes_factor_v=k_bytes_factor, + total_threads_v=total_threads, + a_rsrc_v=a_rsrc, + ) def prefetch_a_tile(base_k): base_k_bytes = base_k * elem_bytes // a_elem_vec_pack @@ -686,7 +728,11 @@ def prefetch_ab_tile(base_k): # ── FP4 scale pre-fetch (outside compute_tile for latency hiding) ── _fp4_tilek128 = False - if is_fp4: + + def load_fp4_scale_chunk(_base_k): + raise RuntimeError("load_fp4_scale_chunk called when is_fp4=False") + + if const_expr(is_fp4): _fp4_pack_M_outer = 2 _fp4_pack_N_outer = 2 _fp4_pack_K_outer = 2 @@ -749,7 +795,7 @@ def load_fp4_scale_chunk(base_k): # ── Compute tile (MFMA) ─────────────────────────────────────────── def compute_tile(accs_in, b_tile_in, lds_buffer, *, is_last_tile=False, a0_prefetch=None, fp4_scales=None, fp4_scale_half=0): scales_pf = {} - if is_last_tile and (not is_f16_or_bf16): + if const_expr(is_last_tile and (not is_f16_or_bf16)): s_b_vals = [] for ni in range_constexpr(num_acc_n): col_g = by_n + n_tile_base + (ni * 16) + lane_mod_16 @@ -773,7 +819,7 @@ def compute_tile(accs_in, b_tile_in, lds_buffer, *, is_last_tile=False, a0_prefe str(gpu_arch).startswith("gfx95") and (not is_int8) and (not is_int4) and (not is_f16_or_bf16) ) - if use_mfma_scale_128: + if const_expr(use_mfma_scale_128): if (int(tile_k) % 128) != 0: raise ValueError( f"tile_k must be divisible by 128 for mfma_scale_x128, got tile_k={tile_k}" @@ -796,7 +842,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(T.vec(4, T.i64), [x0, x1, x2, x3]) return vector.bitcast(T.vec(8, T.i32), v4) - if is_fp4: + if const_expr(is_fp4): _fp4_a_sc, _fp4_b_sc = fp4_scales if fp4_scales else ([], []) ku128_iters = 1 if _fp4_tilek128 else _k_unroll_packed ikxdl_iters = 1 if _fp4_tilek128 else _fp4_pack_K @@ -817,6 +863,8 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for imxdl in range_constexpr(_fp4_pack_M): mi_idx = mi_p * _fp4_pack_M + imxdl curr_row_a_lds = row_a_lds + (mi_idx * 16) + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if (a0_prefetch is not None) and (k_idx == 0) and (mi_idx == 0): a0, a1 = a0_prefetch else: @@ -848,6 +896,8 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) if (a0_prefetch is not None) and (ku0 == 0) and (mi == 0): a0, a1 = a0_prefetch else: @@ -869,11 +919,16 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): return current_accs_list, scales_pf mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 - if is_int8: + + def _mfma_fn_placeholder(_res_ty, _ops): + raise RuntimeError("mfma_fn is not selected for current dtype path") + + mfma_fn = _mfma_fn_placeholder + if const_expr(is_int8): mfma_fn = mfma_i32_k32 - elif is_f16: + elif const_expr(is_f16): mfma_fn = rocdl.mfma_f32_16x16x16f16 - elif is_bf16: + elif const_expr(is_bf16): mfma_fn = rocdl.mfma_f32_16x16x16bf16_1k else: mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8 @@ -891,7 +946,9 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): col_base = col_offset_base_bytes + ki64 for mi in range_constexpr(m_repeat): curr_row_a_lds = row_a_lds + (mi * 16) - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0 = arith.constant(0, type=T.i64) + a1 = arith.constant(0, type=T.i64) + if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)): a0, a1 = a0_prefetch else: a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer) @@ -907,10 +964,9 @@ def mfma_k64_bytes(acc_in, a0, a1, b0, b1): vec1_out = T.vec(1, _out_elem()) def store_output(final_accs, scales): - if is_f16_or_bf16 or is_fp4: - s_b_vals = None - s_a_vecs = None - else: + s_b_vals = [] + s_a_vecs = [] + if const_expr(not (is_f16_or_bf16 or is_fp4)): s_b_vals = scales["s_b_vals"] s_a_vecs = scales["s_a_vecs"] @@ -921,7 +977,8 @@ def store_output(final_accs, scales): def write_row_to_lds(*, mi, ii, row_in_tile, row, row_base_lds, col_base_local, num_acc_n, lds_out): - if _needs_per_token_scale: + s_a = arith.constant(1.0, type=T.f32) + if const_expr(_needs_per_token_scale): s_a_vec4 = s_a_vecs[mi] s_a = vector.extract(s_a_vec4, static_position=[ii], dynamic_position=[]) for ni in range_constexpr(num_acc_n): @@ -929,11 +986,11 @@ def write_row_to_lds(*, mi, ii, row_in_tile, row, row_base_lds, acc_idx = mi * num_acc_n + ni acc = final_accs[acc_idx] val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - if is_int8: + if const_expr(is_int8): val = arith.sitofp(T.f32, val) - if is_f16_or_bf16 or is_fp4: + if const_expr(is_f16_or_bf16 or is_fp4): val_s = val - elif _needs_per_token_scale: + elif const_expr(_needs_per_token_scale): val_s = (val * s_a) * s_b_vals[ni] else: val_s = val @@ -968,7 +1025,8 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): return def body_row(*, mi, ii, row_in_tile, row): - if _needs_per_token_scale: + s_a = arith.constant(1.0, type=T.f32) + if const_expr(_needs_per_token_scale): s_a_vec4 = s_a_vecs[mi] s_a = vector.extract(s_a_vec4, static_position=[ii], dynamic_position=[]) col_base = by_n + n_tile_base + lane_mod_16 @@ -977,11 +1035,11 @@ def body_row(*, mi, ii, row_in_tile, row): acc_idx = mi * num_acc_n + ni acc = final_accs[acc_idx] val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - if is_int8: + if const_expr(is_int8): val = arith.sitofp(T.f32, val) - if is_f16_or_bf16 or is_fp4: + if const_expr(is_f16_or_bf16 or is_fp4): val_s = val - elif _needs_per_token_scale: + elif const_expr(_needs_per_token_scale): val_s = (val * s_a) * s_b_vals[ni] else: val_s = val @@ -1013,7 +1071,7 @@ def _build_scheduler(numer: int, denom: int): prev = cur return out - if _is_gfx942: + if const_expr(_is_gfx942): mfma_group = num_acc_n mfma_total = (k_unroll * 2) * m_repeat * mfma_group mfma_per_iter = 2 * mfma_group @@ -1063,11 +1121,11 @@ def _build_scheduler(numer: int, denom: int): num_a_scale_loads = num_fp4_scale_k_groups * (m_repeat // 2) num_b_scale_loads = num_fp4_scale_k_groups * (num_acc_n // 2) num_gmem_loads += num_a_scale_loads + num_b_scale_loads - # print("mfma_total, dswr_tail, dstr_advance", mfma_total, dswr_tail, dstr_advance) dsrd_preload_eff = min(int(dsrd_preload), num_ds_load) dvmem_preload_eff = min(int(dvmem_preload), num_gmem_loads) vmem_remaining = num_gmem_loads - dvmem_preload_eff dsrd_remaining = num_ds_load - dsrd_preload_eff + vmem_schedule = [] if vmem_remaining > 0 and vmem_remaining < mfma_total: vmem_schedule = (_build_scheduler(vmem_remaining, vmem_remaining) + [0] * (mfma_total - vmem_remaining)) @@ -1137,38 +1195,85 @@ def _unflatten_b_tile(flat): n_accs = num_acc_n * m_repeat n_btile = k_unroll * 2 * num_acc_n n_a0pf = 2 + n_fp4_asc = 0 + n_fp4_bsc = 0 if is_fp4: n_fp4_asc = _k_unroll_packed_outer * _m_repeat_packed_outer n_fp4_bsc = _k_unroll_packed_outer * _num_acc_n_packed_outer - def _pack_state(accs_l, bt_flat, a0pf, fp4_scales=None): + def _pack_state(accs_l, bt_flat, a0pf, fp4_scales=None, *, is_fp4_v): state = list(accs_l) + list(bt_flat) + [a0pf[0], a0pf[1]] - if is_fp4: + if is_fp4_v: a_scales, b_scales = fp4_scales state.extend(a_scales) state.extend(b_scales) return state - def _unpack_state(vals): - accs_l = list(vals[:n_accs]) - bt_flat = list(vals[n_accs:n_accs + n_btile]) - a0pf = (vals[n_accs + n_btile], vals[n_accs + n_btile + 1]) - if not is_fp4: + def _unpack_state(vals, *, n_accs_v, n_btile_v, n_a0pf_v, is_fp4_v, n_fp4_asc_v, n_fp4_bsc_v): + accs_l = list(vals[:n_accs_v]) + bt_flat = list(vals[n_accs_v:n_accs_v + n_btile_v]) + a0pf = (vals[n_accs_v + n_btile_v], vals[n_accs_v + n_btile_v + 1]) + if not is_fp4_v: return accs_l, bt_flat, a0pf, None - sc_base = n_accs + n_btile + n_a0pf - a_scales = list(vals[sc_base:sc_base + n_fp4_asc]) - b_scales = list(vals[sc_base + n_fp4_asc:sc_base + n_fp4_asc + n_fp4_bsc]) + sc_base = n_accs_v + n_btile_v + n_a0pf_v + a_scales = list(vals[sc_base:sc_base + n_fp4_asc_v]) + b_scales = list(vals[sc_base + n_fp4_asc_v:sc_base + n_fp4_asc_v + n_fp4_bsc_v]) return accs_l, bt_flat, a0pf, (a_scales, b_scales) - def _build_pingpong_body(k_iv, inner_state): - accs_in, bt_flat_in, a0pf_in, fp4_scales_pong_in = _unpack_state(inner_state) + def _build_pingpong_body( + k_iv, + inner_state, + *, + _unpack_state, + _unflatten_b_tile, + _fp4_tilek128, + tile_k, + use_async_copy, + prefetch_a_to_lds, + a_elem_vec_pack, + dma_a_tile_to_lds, + prefetch_a_tile, + prefetch_b_tile, + compute_tile, + lds_a_pong, + lds_a_ping, + store_a_tile_to_lds, + hot_loop_scheduler, + num_b_loads, + gpu, + prefetch_a0_pack, + load_fp4_scale_chunk, + is_fp4, + rocdl, + _pack_state, + _flatten_b_tile, + lds_load_packs_k64, + row_a_lds, + col_offset_base_bytes, + n_accs, + n_btile, + n_a0pf, + n_fp4_asc, + n_fp4_bsc, + ): + accs_in, bt_flat_in, a0pf_in, fp4_scales_pong_in = _unpack_state( + inner_state, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_in = _unflatten_b_tile(bt_flat_in) - if _fp4_tilek128: + if const_expr(_fp4_tilek128): next_k1 = k_iv + tile_k - if use_async_copy: - prefetch_a_to_lds(next_k1, lds_a_ping) + if const_expr(use_async_copy): + prefetch_a_to_lds( + next_k1, lds_a_ping, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_tile_ping = prefetch_a_tile(next_k1) b_tile_ping = prefetch_b_tile(next_k1) @@ -1176,18 +1281,25 @@ def _build_pingpong_body(k_iv, inner_state): accs_in, b_tile_pong_in, lds_a_pong, a0_prefetch=a0pf_in, fp4_scales=fp4_scales_pong_in, fp4_scale_half=0, ) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_tile_ping, lds_a_ping) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) + a0_prefetch_ping = prefetch_a0_pack( + lds_a_ping, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) next_k2 = k_iv + (tile_k * 2) _sc_ping = load_fp4_scale_chunk(next_k2) if is_fp4 else None rocdl.sched_barrier(0) - if use_async_copy: - prefetch_a_to_lds(next_k2, lds_a_pong) + if const_expr(use_async_copy): + prefetch_a_to_lds( + next_k2, lds_a_pong, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_tile_pong = prefetch_a_tile(next_k2) b_tile_pong_new = prefetch_b_tile(next_k2) @@ -1195,76 +1307,161 @@ def _build_pingpong_body(k_iv, inner_state): accs_in, b_tile_ping, lds_a_ping, a0_prefetch=a0_prefetch_ping, fp4_scales=fp4_scales_pong_in, fp4_scale_half=1, ) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_tile_pong, lds_a_pong) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_pong_new = prefetch_a0_pack(lds_a_pong) + a0_prefetch_pong_new = prefetch_a0_pack( + lds_a_pong, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) - return _pack_state(accs_in, _flatten_b_tile(b_tile_pong_new), - a0_prefetch_pong_new, _sc_ping) + return _pack_state( + accs_in, + _flatten_b_tile(b_tile_pong_new), + a0_prefetch_pong_new, + _sc_ping, + is_fp4_v=is_fp4, + ) next_k1 = k_iv + tile_k - if use_async_copy: - prefetch_a_to_lds(next_k1, lds_a_ping) + if const_expr(use_async_copy): + prefetch_a_to_lds( + next_k1, lds_a_ping, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_tile = prefetch_a_tile(next_k1) _sc_ping = load_fp4_scale_chunk(k_iv + fx.Index(tile_k)) if is_fp4 else None b_tile_ping = prefetch_b_tile(next_k1) accs_in, _ = compute_tile(accs_in, b_tile_pong_in, lds_a_pong, a0_prefetch=a0pf_in, fp4_scales=fp4_scales_pong_in) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_tile, lds_a_ping) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) + a0_prefetch_ping = prefetch_a0_pack( + lds_a_ping, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) next_k2 = k_iv + (tile_k * 2) - if use_async_copy: - prefetch_a_to_lds(next_k2, lds_a_pong) + if const_expr(use_async_copy): + prefetch_a_to_lds( + next_k2, lds_a_pong, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_tile = prefetch_a_tile(next_k2) _sc_pong = load_fp4_scale_chunk(k_iv + (tile_k * 2)) if is_fp4 else None b_tile_pong_new = prefetch_b_tile(next_k2) accs_in, _ = compute_tile(accs_in, b_tile_ping, lds_a_ping, a0_prefetch=a0_prefetch_ping, fp4_scales=_sc_ping) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_tile, lds_a_pong) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_pong_new = prefetch_a0_pack(lds_a_pong) + a0_prefetch_pong_new = prefetch_a0_pack( + lds_a_pong, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) - return _pack_state(accs_in, _flatten_b_tile(b_tile_pong_new), - a0_prefetch_pong_new, _sc_pong) + return _pack_state( + accs_in, + _flatten_b_tile(b_tile_pong_new), + a0_prefetch_pong_new, + _sc_pong, + is_fp4_v=is_fp4, + ) - if lds_stage == 2: - def prefetch_a0_pack(lds_buffer): - return lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_buffer) + if const_expr(lds_stage == 2): + def prefetch_a0_pack(lds_buffer, *, lds_load_packs_k64_fn, row_a_lds_v, col_offset_base_bytes_v): + return lds_load_packs_k64_fn(row_a_lds_v, col_offset_base_bytes_v, lds_buffer) k0 = fx.Index(0) b_tile0 = prefetch_b_tile(k0) - if use_async_copy: - prefetch_a_to_lds(k0, lds_a_pong) + if const_expr(use_async_copy): + prefetch_a_to_lds( + k0, lds_a_pong, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: store_a_tile_to_lds(prefetch_a_tile(k0), lds_a_pong) gpu.barrier() accs = [acc_init] * n_accs - a0_prefetch_pong = prefetch_a0_pack(lds_a_pong) + a0_prefetch_pong = prefetch_a0_pack( + lds_a_pong, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) fp4_scales0 = load_fp4_scale_chunk(fx.Index(0)) if is_fp4 else None + final_accs = 1 + scales = 1 num_tiles = K // tile_k - if _fp4_tilek128: - if (num_tiles % 2) == 1: + if const_expr(_fp4_tilek128): + if const_expr((num_tiles % 2) == 1): c_k_main = K - tile_k - init_state = _pack_state(accs, _flatten_b_tile(b_tile0), - a0_prefetch_pong, fp4_scales0) + init_state = _pack_state( + accs, + _flatten_b_tile(b_tile0), + a0_prefetch_pong, + fp4_scales0, + is_fp4_v=is_fp4, + ) results = init_state for iv, inner in range(0, c_k_main, tile_k * 2, init=init_state): - results = yield _build_pingpong_body(iv, inner) - accs, bt_flat, a0pf, fp4_scales_final = _unpack_state(results) + results = yield _build_pingpong_body( + iv, + inner, + _unpack_state=_unpack_state, + _unflatten_b_tile=_unflatten_b_tile, + _fp4_tilek128=_fp4_tilek128, + tile_k=tile_k, + use_async_copy=use_async_copy, + prefetch_a_to_lds=prefetch_a_to_lds, + a_elem_vec_pack=a_elem_vec_pack, + dma_a_tile_to_lds=dma_a_tile_to_lds, + prefetch_a_tile=prefetch_a_tile, + prefetch_b_tile=prefetch_b_tile, + compute_tile=compute_tile, + lds_a_pong=lds_a_pong, + lds_a_ping=lds_a_ping, + store_a_tile_to_lds=store_a_tile_to_lds, + hot_loop_scheduler=hot_loop_scheduler, + num_b_loads=num_b_loads, + gpu=gpu, + prefetch_a0_pack=prefetch_a0_pack, + load_fp4_scale_chunk=load_fp4_scale_chunk, + is_fp4=is_fp4, + rocdl=rocdl, + _pack_state=_pack_state, + _flatten_b_tile=_flatten_b_tile, + lds_load_packs_k64=lds_load_packs_k64, + row_a_lds=row_a_lds, + col_offset_base_bytes=col_offset_base_bytes, + n_accs=n_accs, + n_btile=n_btile, + n_a0pf=n_a0pf, + n_fp4_asc=n_fp4_asc, + n_fp4_bsc=n_fp4_bsc, + ) + accs, bt_flat, a0pf, fp4_scales_final = _unpack_state( + results, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_final = _unflatten_b_tile(bt_flat) final_accs, scales = compute_tile( accs, b_tile_pong_final, lds_a_pong, @@ -1273,42 +1470,143 @@ def prefetch_a0_pack(lds_buffer): ) else: c_k_stop = K - (tile_k * 3) - init_state = _pack_state(accs, _flatten_b_tile(b_tile0), - a0_prefetch_pong, fp4_scales0) + init_state = _pack_state( + accs, + _flatten_b_tile(b_tile0), + a0_prefetch_pong, + fp4_scales0, + is_fp4_v=is_fp4, + ) results = init_state for iv, inner in range(0, c_k_stop, tile_k * 2, init=init_state): - results = yield _build_pingpong_body(iv, inner) - accs, bt_flat, a0pf, fp4_scales_ep = _unpack_state(results) + results = yield _build_pingpong_body( + iv, + inner, + _unpack_state=_unpack_state, + _unflatten_b_tile=_unflatten_b_tile, + _fp4_tilek128=_fp4_tilek128, + tile_k=tile_k, + use_async_copy=use_async_copy, + prefetch_a_to_lds=prefetch_a_to_lds, + a_elem_vec_pack=a_elem_vec_pack, + dma_a_tile_to_lds=dma_a_tile_to_lds, + prefetch_a_tile=prefetch_a_tile, + prefetch_b_tile=prefetch_b_tile, + compute_tile=compute_tile, + lds_a_pong=lds_a_pong, + lds_a_ping=lds_a_ping, + store_a_tile_to_lds=store_a_tile_to_lds, + hot_loop_scheduler=hot_loop_scheduler, + num_b_loads=num_b_loads, + gpu=gpu, + prefetch_a0_pack=prefetch_a0_pack, + load_fp4_scale_chunk=load_fp4_scale_chunk, + is_fp4=is_fp4, + rocdl=rocdl, + _pack_state=_pack_state, + _flatten_b_tile=_flatten_b_tile, + lds_load_packs_k64=lds_load_packs_k64, + row_a_lds=row_a_lds, + col_offset_base_bytes=col_offset_base_bytes, + n_accs=n_accs, + n_btile=n_btile, + n_a0pf=n_a0pf, + n_fp4_asc=n_fp4_asc, + n_fp4_bsc=n_fp4_bsc, + ) + accs, bt_flat, a0pf, fp4_scales_ep = _unpack_state( + results, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_ep = _unflatten_b_tile(bt_flat) last_k = arith.index(K - tile_k) b_tile_ping = prefetch_b_tile(last_k) - if use_async_copy: - prefetch_a_to_lds(last_k, lds_a_ping) + if const_expr(use_async_copy): + prefetch_a_to_lds( + last_k, lds_a_ping, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_regs_ping = prefetch_a_tile(last_k) accs, _ = compute_tile( accs, b_tile_pong_ep, lds_a_pong, a0_prefetch=a0pf, fp4_scales=fp4_scales_ep, fp4_scale_half=0, ) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_regs_ping, lds_a_ping) rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) + a0_prefetch_ping = prefetch_a0_pack( + lds_a_ping, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) final_accs, scales = compute_tile( accs, b_tile_ping, lds_a_ping, is_last_tile=not is_fp4, a0_prefetch=a0_prefetch_ping, fp4_scales=fp4_scales_ep, fp4_scale_half=1, ) - elif (num_tiles % 2) == 1: + elif const_expr((num_tiles % 2) == 1): c_k_main = K - tile_k - init_state = _pack_state(accs, _flatten_b_tile(b_tile0), - a0_prefetch_pong, fp4_scales0) + init_state = _pack_state( + accs, + _flatten_b_tile(b_tile0), + a0_prefetch_pong, + fp4_scales0, + is_fp4_v=is_fp4, + ) results = init_state for iv, inner in range(0, c_k_main, tile_k * 2, init=init_state): - results = yield _build_pingpong_body(iv, inner) - accs, bt_flat, a0pf, fp4_scales_final = _unpack_state(results) + results = yield _build_pingpong_body( + iv, + inner, + _unpack_state=_unpack_state, + _unflatten_b_tile=_unflatten_b_tile, + _fp4_tilek128=_fp4_tilek128, + tile_k=tile_k, + use_async_copy=use_async_copy, + prefetch_a_to_lds=prefetch_a_to_lds, + a_elem_vec_pack=a_elem_vec_pack, + dma_a_tile_to_lds=dma_a_tile_to_lds, + prefetch_a_tile=prefetch_a_tile, + prefetch_b_tile=prefetch_b_tile, + compute_tile=compute_tile, + lds_a_pong=lds_a_pong, + lds_a_ping=lds_a_ping, + store_a_tile_to_lds=store_a_tile_to_lds, + hot_loop_scheduler=hot_loop_scheduler, + num_b_loads=num_b_loads, + gpu=gpu, + prefetch_a0_pack=prefetch_a0_pack, + load_fp4_scale_chunk=load_fp4_scale_chunk, + is_fp4=is_fp4, + rocdl=rocdl, + _pack_state=_pack_state, + _flatten_b_tile=_flatten_b_tile, + lds_load_packs_k64=lds_load_packs_k64, + row_a_lds=row_a_lds, + col_offset_base_bytes=col_offset_base_bytes, + n_accs=n_accs, + n_btile=n_btile, + n_a0pf=n_a0pf, + n_fp4_asc=n_fp4_asc, + n_fp4_bsc=n_fp4_bsc, + ) + accs, bt_flat, a0pf, fp4_scales_final = _unpack_state( + results, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_final = _unflatten_b_tile(bt_flat) final_accs, scales = compute_tile( accs, b_tile_pong_final, lds_a_pong, @@ -1316,29 +1614,83 @@ def prefetch_a0_pack(lds_buffer): ) else: c_k_stop = K - (tile_k * 3) - init_state = _pack_state(accs, _flatten_b_tile(b_tile0), - a0_prefetch_pong, fp4_scales0) + init_state = _pack_state( + accs, + _flatten_b_tile(b_tile0), + a0_prefetch_pong, + fp4_scales0, + is_fp4_v=is_fp4, + ) results = init_state for iv, inner in range(0, c_k_stop, tile_k * 2, init=init_state): - results = yield _build_pingpong_body(iv, inner) - accs, bt_flat, a0pf, fp4_scales_ep = _unpack_state(results) + results = yield _build_pingpong_body( + iv, + inner, + _unpack_state=_unpack_state, + _unflatten_b_tile=_unflatten_b_tile, + _fp4_tilek128=_fp4_tilek128, + tile_k=tile_k, + use_async_copy=use_async_copy, + prefetch_a_to_lds=prefetch_a_to_lds, + a_elem_vec_pack=a_elem_vec_pack, + dma_a_tile_to_lds=dma_a_tile_to_lds, + prefetch_a_tile=prefetch_a_tile, + prefetch_b_tile=prefetch_b_tile, + compute_tile=compute_tile, + lds_a_pong=lds_a_pong, + lds_a_ping=lds_a_ping, + store_a_tile_to_lds=store_a_tile_to_lds, + hot_loop_scheduler=hot_loop_scheduler, + num_b_loads=num_b_loads, + gpu=gpu, + prefetch_a0_pack=prefetch_a0_pack, + load_fp4_scale_chunk=load_fp4_scale_chunk, + is_fp4=is_fp4, + rocdl=rocdl, + _pack_state=_pack_state, + _flatten_b_tile=_flatten_b_tile, + lds_load_packs_k64=lds_load_packs_k64, + row_a_lds=row_a_lds, + col_offset_base_bytes=col_offset_base_bytes, + n_accs=n_accs, + n_btile=n_btile, + n_a0pf=n_a0pf, + n_fp4_asc=n_fp4_asc, + n_fp4_bsc=n_fp4_bsc, + ) + accs, bt_flat, a0pf, fp4_scales_ep = _unpack_state( + results, + n_accs_v=n_accs, + n_btile_v=n_btile, + n_a0pf_v=n_a0pf, + is_fp4_v=is_fp4, + n_fp4_asc_v=n_fp4_asc, + n_fp4_bsc_v=n_fp4_bsc, + ) b_tile_pong_ep = _unflatten_b_tile(bt_flat) last_k = arith.index(K - tile_k) b_tile_ping = prefetch_b_tile(last_k) - if use_async_copy: - prefetch_a_to_lds(last_k, lds_a_ping) + if const_expr(use_async_copy): + prefetch_a_to_lds( + last_k, lds_a_ping, a_elem_vec_pack_v=a_elem_vec_pack, dma_a_tile_to_lds_fn=dma_a_tile_to_lds + ) else: a_regs_ping = prefetch_a_tile(last_k) _sc_last = load_fp4_scale_chunk(last_k) if is_fp4 else None accs, _ = compute_tile(accs, b_tile_pong_ep, lds_a_pong, a0_prefetch=a0pf, fp4_scales=fp4_scales_ep) - if not use_async_copy: + if const_expr(not use_async_copy): store_a_tile_to_lds(a_regs_ping, lds_a_ping) hot_loop_scheduler() rocdl.s_waitcnt(num_b_loads) gpu.barrier() - a0_prefetch_ping = prefetch_a0_pack(lds_a_ping) + a0_prefetch_ping = prefetch_a0_pack( + lds_a_ping, + lds_load_packs_k64_fn=lds_load_packs_k64, + row_a_lds_v=row_a_lds, + col_offset_base_bytes_v=col_offset_base_bytes, + ) final_accs, scales = compute_tile( accs, b_tile_ping, lds_a_ping, is_last_tile=not is_fp4, a0_prefetch=a0_prefetch_ping, fp4_scales=_sc_last, @@ -1453,4 +1805,4 @@ def compile_preshuffle_gemm_w4( return inner -__all__ = ["compile_preshuffle_gemm_a8", "compile_preshuffle_gemm_w4"] +__all__ = ["compile_preshuffle_gemm_a8", "compile_preshuffle_gemm_w4"] \ No newline at end of file diff --git a/kernels/rdna_fp8_preshuffle_gemm.py b/kernels/rdna_fp8_preshuffle_gemm.py index 09fc56bd..d29ff9cf 100644 --- a/kernels/rdna_fp8_preshuffle_gemm.py +++ b/kernels/rdna_fp8_preshuffle_gemm.py @@ -347,7 +347,7 @@ def _unflatten_b(flat): init_state = _flatten_tile(a_cur) + list(accs) + _flatten_tile(b_cur) # Main K-loop: SCF outer with constexpr inner unroll - if full_outer_iters > 0: + if const_expr(full_outer_iters > 0): for iv, state in range(0, full_outer_iters * k_unroll, k_unroll, init=init_state): s_a = _unflatten_a(list(state[:n_a])) s_accs = list(state[n_a : n_a + n_acc]) @@ -369,7 +369,7 @@ def _unflatten_b(flat): b_cur = _unflatten_b(list(results[n_a + n_acc :])) # Handle remainder tiles - if remainder > 0: + if const_expr(remainder > 0): for j in range_constexpr(remainder): next_kt = fx.Index(full_outer_iters * k_unroll + j + 1) a_next = _load_a_tile(next_kt) diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index 09f0c88e..09bd36f2 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -101,16 +101,16 @@ def block_reduce_add2(val0, val1): if lane == fx.Int32(0): wave_idx = ArithValue(wave).index_cast(T.index) - s_red.store(w0, [wave_idx]) - s_red2.store(w1, [wave_idx]) + SmemPtr.store(s_red, w0, [wave_idx]) + SmemPtr.store(s_red2, w1, [wave_idx]) gpu.barrier() if wave == fx.Int32(0): in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, fx.Int32(0)) lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) - v0 = s_red.load([lane_safe_idx]) - v1 = s_red2.load([lane_safe_idx]) + v0 = SmemPtr.load(s_red, [lane_safe_idx]) + v1 = SmemPtr.load(s_red2, [lane_safe_idx]) z = fx.Float32(0.0) ww0 = in_range.select(v0, z) ww1 = in_range.select(v1, z) @@ -119,12 +119,12 @@ def block_reduce_add2(val0, val1): if lane == fx.Int32(0): c0_idx = fx.Index(0) - s_red.store(ww0, [c0_idx]) - s_red2.store(ww1, [c0_idx]) + SmemPtr.store(s_red, ww0, [c0_idx]) + SmemPtr.store(s_red2, ww1, [c0_idx]) gpu.barrier() c0_idx = fx.Index(0) - return s_red.load([c0_idx]), s_red2.load([c0_idx]) + return SmemPtr.load(s_red, [c0_idx]), SmemPtr.load(s_red2, [c0_idx]) # ================================================================== # Fast path: N is a multiple of tile_cols @@ -191,6 +191,7 @@ def _store_vec(val, div_tensor, idx): y = (x * rrms) * g + out_e = y.to(elem_dtype) if dtype_str == "bf16": if USE_HW_CVT_PK_BF16_F32: out_e = y.to(elem_dtype) diff --git a/kernels/softmax_kernel.py b/kernels/softmax_kernel.py index 8f5d4d32..893daa49 100644 --- a/kernels/softmax_kernel.py +++ b/kernels/softmax_kernel.py @@ -88,7 +88,7 @@ def wave_reduce(x, mode): w = w.addf(peer, fastmath=fm_fast) return w - def block_reduce(val, mode): + def block_reduce(val, mode, s_red_buffer): if RED_SLOTS == 1: return wave_reduce(val, mode) @@ -100,25 +100,25 @@ def block_reduce(val, mode): if lane == fx.Int32(0): wave_idx = ArithValue(wave).index_cast(T.index) - s_red.store(w, [wave_idx]) + SmemPtr.store(s_red_buffer, w, [wave_idx]) gpu.barrier() if wave == fx.Int32(0): in_range = lane < RED_SLOTS lane_safe = in_range.select(lane, fx.Int32(0)) lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) - v = s_red.load([lane_safe_idx]) + v = SmemPtr.load(s_red_buffer, [lane_safe_idx]) z = neutral ww = in_range.select(v, z) ww = wave_reduce(ww, mode) if lane == fx.Int32(0): c0_idx = fx.Index(0) - s_red.store(ww, [c0_idx]) + SmemPtr.store(s_red_buffer, ww, [c0_idx]) gpu.barrier() c0_idx = fx.Index(0) - return s_red.load([c0_idx]) + return SmemPtr.load(s_red_buffer, [c0_idx]) # ================================================================== # Fast path: N is a multiple of tile_cols @@ -167,7 +167,7 @@ def _store_vec(val, div_tensor, idx): red_max = x.reduce(ReductionOp.MAX) thread_max = thread_max.maximumf(red_max) - global_max = block_reduce(thread_max, "max") + global_max = block_reduce(thread_max, "max", s_red) # 2. Exp + local sum thread_sum = c_zero_f @@ -180,7 +180,7 @@ def _store_vec(val, div_tensor, idx): red_sum = exp_val.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sum = thread_sum + red_sum - global_sum = block_reduce(thread_sum, "sum") + global_sum = block_reduce(thread_sum, "sum", s_red) # 3. Normalize + store c_one = arith.constant(1.0, type=compute_type) @@ -243,7 +243,7 @@ def _store_scalar(divided, index, val): row_buffer.append((safe_val, is_valid)) thread_max = thread_max.maximumf(safe_val) - global_max = block_reduce(thread_max, "max") + global_max = block_reduce(thread_max, "max", s_red) # 2. Exp + sum thread_sum = c_zero_f @@ -256,7 +256,7 @@ def _store_scalar(divided, index, val): thread_sum = thread_sum + safe_exp new_buffer.append((exp_val, is_valid)) - global_sum = block_reduce(thread_sum, "sum") + global_sum = block_reduce(thread_sum, "sum", s_red) c_one = arith.constant(1.0, type=compute_type) inv_sum = c_one / ArithValue(global_sum) @@ -268,6 +268,7 @@ def _store_scalar(divided, index, val): buf_idx += 1 if arith.cmpi(arith.CmpIPredicate.ult, idx, Int32(N)): norm_val = ArithValue(exp_val) * inv_sum + out_e = norm_val if dtype_str == "f32": out_e = norm_val else: diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index 5bf62901..e43c302b 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -2,6 +2,7 @@ # Copyright (c) 2025 FlyDSL Project Contributors import ast +import contextlib import difflib import inspect import types @@ -11,6 +12,7 @@ from .._mlir import ir from .._mlir.dialects import arith, scf from ..expr import const_expr +from ..expr.numeric import _unwrap_value, _wrap_like from ..utils import env, log @@ -71,7 +73,7 @@ def transform(cls, f): orig_code = ast.unparse(module) if env.debug.ast_diff else None func_node = module.body[0] rewriter = transformer_ctor(context=context, first_lineno=f.__code__.co_firstlineno - 1) - func_node = rewriter.generic_visit(func_node) + func_node = rewriter.visit(func_node) if env.debug.ast_diff: new_code = ast.unparse(func_node) diff = list( @@ -125,18 +127,136 @@ def transform(cls, f): _ASTREWRITE_MARKER = "_flydsl_ast_rewriter_generated_" +class SymbolScopeTracker: + def __init__(self): + self.scopes = [] + self.callables = [] + + def record_symbol(self, name: str): + if not self.scopes: + return + if name == "_": + return + self.scopes[-1].add(name) + + def record_callable(self, name: str): + if not self.callables: + return + self.callables[-1].add(name) + + def snapshot_symbol_scopes(self): + return self.scopes.copy() + + def snapshot_callable_scopes(self): + return self.callables.copy() + + @contextlib.contextmanager + def function_scope(self): + self.scopes.append(set()) + self.callables.append(set()) + try: + yield + finally: + self.scopes.pop() + self.callables.pop() + + @contextlib.contextmanager + def control_flow_scope(self): + self.scopes.append(set()) + try: + yield + finally: + self.scopes.pop() + + class Transformer(ast.NodeTransformer): def __init__(self, context, first_lineno): super().__init__() self.context = context self.first_lineno = first_lineno + self.symbol_scopes = SymbolScopeTracker() + + def _record_target_symbols(self, target): + if isinstance(target, ast.Name): + self.symbol_scopes.record_symbol(target.id) + elif isinstance(target, (ast.Tuple, ast.List)): + for t in target.elts: + self._record_target_symbols(t) + elif isinstance(target, ast.Starred): + self._record_target_symbols(target.value) + + def _visit_stmt_block(self, stmts): + new_stmts = [] + for stmt in stmts: + transformed = self.visit(stmt) + if isinstance(transformed, list): + new_stmts.extend(transformed) + else: + new_stmts.append(transformed) + return new_stmts def visit_FunctionDef(self, node: ast.FunctionDef): if getattr(node, _ASTREWRITE_MARKER, False): return node - node = self.generic_visit(node) + + with self.symbol_scopes.function_scope(): + for arg in node.args.posonlyargs: + self.symbol_scopes.record_symbol(arg.arg) + for arg in node.args.args: + self.symbol_scopes.record_symbol(arg.arg) + for arg in node.args.kwonlyargs: + self.symbol_scopes.record_symbol(arg.arg) + node = self.generic_visit(node) + return node + def visit_Assign(self, node: ast.Assign): + for target in node.targets: + self._record_target_symbols(target) + return self.generic_visit(node) + + def visit_AugAssign(self, node: ast.AugAssign): + self._record_target_symbols(node.target) + return self.generic_visit(node) + + def visit_AnnAssign(self, node: ast.AnnAssign): + self._record_target_symbols(node.target) + return self.generic_visit(node) + + def visit_For(self, node: ast.For): + self._record_target_symbols(node.target) + node.iter = self.visit(node.iter) + with self.symbol_scopes.control_flow_scope(): + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) + return node + + def visit_If(self, node: ast.If): + node.test = self.visit(node.test) + with self.symbol_scopes.control_flow_scope(): + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) + return node + + def visit_While(self, node: ast.While): + node.test = self.visit(node.test) + with self.symbol_scopes.control_flow_scope(): + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) + return node + + def visit_With(self, node: ast.With): + for item in node.items: + if item.optional_vars is not None: + self._record_target_symbols(item.optional_vars) + return self.generic_visit(node) + @ASTRewriter.register class RewriteBoolOps(Transformer): @@ -232,110 +352,547 @@ def _is_dynamic(cond): @staticmethod def _to_i1(cond): - if hasattr(cond, "ir_value"): - return cond.ir_value() - return cond + return _unwrap_value(cond) + + @staticmethod + def _normalize_named_values(names, values, names_label="names", values_label="values"): + names = tuple(names or ()) + values = tuple(values or ()) + if len(names) != len(values): + raise ValueError( + f"{names_label} and {values_label} must have the same length, " + f"got {len(names)} and {len(values)}" + ) + return names, values @staticmethod - def scf_if_dispatch(cond, then_fn, else_fn=None): + def _normalize_branch_result(branch_result, state_names, state_map, branch_label): + if not state_names: + return [] + + if isinstance(branch_result, dict): + result_map = dict(branch_result) + elif branch_result is None: + result_map = {} + elif len(state_names) == 1 and not isinstance(branch_result, (list, tuple)): + result_map = {state_names[0]: branch_result} + elif isinstance(branch_result, (list, tuple)) and len(branch_result) == len(state_names): + result_map = dict(zip(state_names, branch_result)) + else: + raise TypeError( + f"{branch_label} must return dict/tuple/list for stateful dispatch; got {type(branch_result).__name__}" + ) + + values = [] + for name in state_names: + if name in result_map: + values.append(result_map[name]) + elif name in state_map: + values.append(state_map[name]) + else: + raise NameError( + f"variable '{name}' is not available before if/else and is not assigned in {branch_label}" + ) + return values + + @staticmethod + def _unwrap_mlir_values(values, state_names, branch_label): + raw_values = [] + for name, value in zip(state_names, values): + raw = _unwrap_value(value) + if not isinstance(raw, ir.Value): + raise TypeError( + f"if/else variable '{name}' in {branch_label} is {type(raw).__name__}, " + "not an MLIR Value. Only MLIR Values can be yielded from dynamic if/else branches." + ) + raw_values.append(raw) + return raw_values + + @staticmethod + def _pack_dispatch_results(results, state_values): + if not results: + return None + wrapped = [_wrap_like(v, exemplar) for v, exemplar in zip(results, state_values)] + if len(wrapped) == 1: + return wrapped[0] + return tuple(wrapped) + + @staticmethod + def _collect_result_dict(result_names, local_vars): + return {name: local_vars[name] for name in result_names} + + @staticmethod + def _pack_named_values(names, values): + if not names: + return None + if len(names) == 1: + return values[0] + return tuple(values) + + @staticmethod + def _merge_partial_results(base_names, base_values, part_names, part_values): + merged = {name: value for name, value in zip(base_names, base_values)} + merged.update({name: value for name, value in zip(part_names, part_values)}) + return [merged[name] for name in base_names] + + @staticmethod + def _call_branch(fn, result_names, state_values): + sig = inspect.signature(fn) + params = list(sig.parameters.values()) + pos_params = [ + p + for p in params + if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + has_varargs = any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in params) + if has_varargs or len(pos_params) >= len(state_values) + 1: + return fn(result_names, *state_values) + return fn(*state_values) + + @staticmethod + def scf_if_dispatch( + cond, + then_fn, + else_fn=None, + *, + result_names=(), + result_values=(), + state_names=(), + state_values=(), + auto_else=False, + ): + # Backward compatibility: old call-sites pass state_* only. + if not result_names and state_names: + result_names = state_names + if not result_values and state_values: + result_values = state_values + result_names, result_values = ReplaceIfWithDispatch._normalize_named_values( + result_names, result_values, "result_names", "result_values" + ) + # Only variables with an incoming value can be scf.if results/yields. + effective_result_pairs = [ + (name, value) + for name, value in zip(result_names, result_values) + if _unwrap_value(value) is not None + ] + effective_result_names = tuple(name for name, _ in effective_result_pairs) + effective_result_values = tuple(value for _, value in effective_result_pairs) + effective_result_map = {name: value for name, value in effective_result_pairs} + if not ReplaceIfWithDispatch._is_dynamic(cond): - # compile-time evaluation - if cond: - then_fn() - elif else_fn is not None: - else_fn() - return + taken = then_fn if cond else else_fn + if taken is None: + return ReplaceIfWithDispatch._pack_named_values(result_names, result_values) + result = ReplaceIfWithDispatch._call_branch(taken, effective_result_names, result_values) + if not effective_result_names: + return ReplaceIfWithDispatch._pack_named_values(result_names, result_values) + partial_values = ReplaceIfWithDispatch._normalize_branch_result( + result, effective_result_names, effective_result_map, "selected branch" + ) + merged_values = ReplaceIfWithDispatch._merge_partial_results( + result_names, result_values, effective_result_names, partial_values + ) + return ReplaceIfWithDispatch._pack_named_values(result_names, merged_values) - has_else = else_fn is not None - loc = ir.Location.unknown() - if_op = scf.IfOp(ReplaceIfWithDispatch._to_i1(cond), [], has_else=has_else, loc=loc) - with ir.InsertionPoint(if_op.regions[0].blocks[0]): - then_fn() - scf.YieldOp([]) - if has_else: - if len(if_op.regions[1].blocks) == 0: - if_op.regions[1].blocks.append(*[]) - with ir.InsertionPoint(if_op.regions[1].blocks[0]): - else_fn() + cond_i1 = ReplaceIfWithDispatch._to_i1(cond) + if not isinstance(cond_i1, ir.Value): + raise TypeError(f"dynamic if condition must lower to ir.Value, got {type(cond_i1).__name__}") + + if not effective_result_names: + has_else = else_fn is not None + if_op = scf.IfOp(cond_i1, [], has_else=has_else, loc=ir.Location.unknown()) + with ir.InsertionPoint(if_op.regions[0].blocks[0]): + ReplaceIfWithDispatch._call_branch(then_fn, effective_result_names, result_values) scf.YieldOp([]) + if has_else: + if len(if_op.regions[1].blocks) == 0: + if_op.regions[1].blocks.append(*[]) + with ir.InsertionPoint(if_op.regions[1].blocks[0]): + ReplaceIfWithDispatch._call_branch(else_fn, effective_result_names, result_values) + scf.YieldOp([]) + return ReplaceIfWithDispatch._pack_named_values(result_names, result_values) + + if else_fn is None: + else_fn = lambda *args: {} + + state_raw = [] + for name, value in zip(effective_result_names, effective_result_values): + raw = _unwrap_value(value) + if not isinstance(raw, ir.Value): + raise TypeError( + f"state variable '{name}' is {type(raw).__name__}, not an MLIR Value; " + "stateful dynamic if requires MLIR-backed values." + ) + state_raw.append(raw) + + result_types = [v.type for v in state_raw] + if_op = scf.IfOp(cond_i1, result_types, has_else=True, loc=ir.Location.unknown()) + + with ir.InsertionPoint(if_op.regions[0].blocks[0]): + then_result = ReplaceIfWithDispatch._call_branch(then_fn, effective_result_names, result_values) + then_values = ReplaceIfWithDispatch._normalize_branch_result( + then_result, effective_result_names, effective_result_map, "then-branch" + ) + then_raw = ReplaceIfWithDispatch._unwrap_mlir_values(then_values, effective_result_names, "then-branch") + for name, expect_ty, got in zip(effective_result_names, result_types, then_raw): + if got.type != expect_ty: + raise TypeError( + f"if/else variable '{name}' type mismatch in then-branch: " + f"expected {expect_ty}, got {got.type}" + ) + scf.YieldOp(then_raw) + + if len(if_op.regions[1].blocks) == 0: + if_op.regions[1].blocks.append(*[]) + with ir.InsertionPoint(if_op.regions[1].blocks[0]): + else_result = ReplaceIfWithDispatch._call_branch(else_fn, effective_result_names, result_values) + else_values = ReplaceIfWithDispatch._normalize_branch_result( + else_result, effective_result_names, effective_result_map, "else-branch" + ) + else_raw = ReplaceIfWithDispatch._unwrap_mlir_values(else_values, effective_result_names, "else-branch") + for name, expect_ty, got in zip(effective_result_names, result_types, else_raw): + if got.type != expect_ty: + raise TypeError( + f"if/else variable '{name}' type mismatch in else-branch: " + f"expected {expect_ty}, got {got.type}" + ) + scf.YieldOp(else_raw) + + partial_wrapped = ReplaceIfWithDispatch._pack_dispatch_results( + list(if_op.results), effective_result_values + ) + if len(effective_result_names) == 1: + partial_values = [partial_wrapped] + else: + partial_values = list(partial_wrapped) + merged_values = ReplaceIfWithDispatch._merge_partial_results( + result_names, result_values, effective_result_names, partial_values + ) + return ReplaceIfWithDispatch._pack_named_values(result_names, merged_values) @classmethod def rewrite_globals(cls): return { "const_expr": const_expr, "scf_if_dispatch": cls.scf_if_dispatch, + "scf_if_collect_results": cls._collect_result_dict, } - _REWRITE_HELPER_NAMES = {"dsl_not_", "dsl_and_", "dsl_or_", - "scf_if_dispatch", "const_expr", "type", - "bool", "isinstance", "hasattr"} + _REWRITE_HELPER_NAMES = { + "const_expr", + "type", + "bool", + "isinstance", + "hasattr", + } @staticmethod def _could_be_dynamic(test_node): """Check if an if-condition AST could produce an MLIR Value at runtime. - Calls to RewriteBoolOps helpers (dsl_not_, dsl_and_, dsl_or_) and - Python builtins are NOT considered dynamic — they just wrap Python-level - boolean logic. Only calls to user/MLIR functions can produce Values. + Layer-by-layer recursive check: + 1) classify current node if possible, + 2) otherwise recurse into direct children, + 3) unresolved nodes default to static (no forced rewrite). """ - for child in ast.walk(test_node): - if isinstance(child, ast.Call): - func = child.func - if isinstance(func, ast.Name) and func.id in ReplaceIfWithDispatch._REWRITE_HELPER_NAMES: - continue + def _is_literal_expr(node): + if isinstance(node, ast.Constant): return True - return False + if isinstance(node, (ast.Tuple, ast.List, ast.Set)): + return all(_is_literal_expr(e) for e in node.elts) + if isinstance(node, ast.Dict): + return all( + (k is None or _is_literal_expr(k)) and _is_literal_expr(v) + for k, v in zip(node.keys, node.values) + ) + return False + + def _try_static_value(node): + if not _is_literal_expr(node): + return False, None + if isinstance(node, ast.Constant): + return True, node.value + try: + return True, ast.literal_eval(node) + except Exception as e: + log().error( + "[FlyDSL ast_rewriter] literal_eval failed: " + f"node={ast.dump(node, include_attributes=False)}, err={e!r}" + ) + return False, None + + def _eval_static_compare_pair(lhs, op, rhs): + op_text_map = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", + ast.Is: "is", + ast.IsNot: "is not", + ast.In: "in", + ast.NotIn: "not in", + } + try: + op_text = op_text_map.get(type(op)) + if op_text is None: + return None + return eval( + f"lhs_val {op_text} rhs_val", + {"__builtins__": {}}, + {"lhs_val": lhs, "rhs_val": rhs}, + ) + except Exception as e: + log().error( + "[FlyDSL ast_rewriter] static compare eval failed: " + f"op={type(op).__name__}, lhs={lhs!r}, rhs={rhs!r}, err={e!r}" + ) + return None + + def _visit(node): + if _is_literal_expr(node): + return False + if isinstance(node, ast.Compare): + compare_parts = [node.left, *node.comparators] + for i, op in enumerate(node.ops): + lhs_node = compare_parts[i] + rhs_node = compare_parts[i + 1] + lhs_ok, lhs_val = _try_static_value(lhs_node) + rhs_ok, rhs_val = _try_static_value(rhs_node) + if lhs_ok and rhs_ok: + pair_result = _eval_static_compare_pair(lhs_val, op, rhs_val) + if pair_result is False: + return False + return any(_visit(part) for part in compare_parts) + if isinstance(node, ast.Call): + func = node.func + if not (isinstance(func, ast.Name) and func.id in ReplaceIfWithDispatch._REWRITE_HELPER_NAMES): + return True + if isinstance(node, ast.Name): + return True + + for child in ast.iter_child_nodes(node): + if _visit(child): + return True + + # If this expression cannot be proven dynamic from itself or children, + # keep it static to avoid over-rewriting unrelated Python control flow. + return False + + return _visit(test_node) + + @staticmethod + def _collect_assigned_vars(node: ast.If, active_symbols): + write_args = [] + invoked_args = [] + + def add_unique(items, name): + if isinstance(name, str) and name not in items: + items.append(name) + + def in_active_symbols(name): + return any(name in symbol_scope for symbol_scope in active_symbols) + + class RegionAnalyzer(ast.NodeVisitor): + force_store = False + + @staticmethod + def _get_call_base(func_node): + if isinstance(func_node, ast.Attribute): + if isinstance(func_node.value, ast.Attribute): + return RegionAnalyzer._get_call_base(func_node.value) + if isinstance(func_node.value, ast.Name): + return func_node.value.id + return None + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Store) or self.force_store: + add_unique(write_args, node.id) + + def visit_Subscript(self, node): + if isinstance(node.ctx, ast.Store): + self.force_store = True + self.visit(node.value) + self.force_store = False + self.visit(node.slice) + else: + self.generic_visit(node) + + def visit_Assign(self, node): + self.force_store = True + for target in node.targets: + self.visit(target) + self.force_store = False + self.visit(node.value) + + def visit_AugAssign(self, node): + self.force_store = True + self.visit(node.target) + self.force_store = False + self.visit(node.value) + + def visit_Call(self, node): + base_name = RegionAnalyzer._get_call_base(node.func) + if base_name is not None and base_name != "self": + add_unique(invoked_args, base_name) + + self.generic_visit(node) + + analyzer = RegionAnalyzer() + analyzer.visit(ast.Module(body=node.body, type_ignores=[])) + if node.orelse: + analyzer.visit(ast.Module(body=node.orelse, type_ignores=[])) + + invoked_args = [name for name in invoked_args if name not in write_args] + write_args = [name for name in write_args if in_active_symbols(name)] + invoked_args = [name for name in invoked_args if in_active_symbols(name)] + return write_args + invoked_args + + @staticmethod + def _state_value_expr(name): + return ast.Call( + func=ast.Attribute( + value=ast.Call(func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]), + attr="get", + ctx=ast.Load(), + ), + args=[ast.Constant(value=name), ast.Constant(value=None)], + keywords=[], + ) def visit_If(self, node: ast.If) -> List[ast.AST]: + active_symbols_before_if = self.symbol_scopes.snapshot_symbol_scopes() if _is_constexpr(node.test): node.test = _unwrap_constexpr(node.test) - node = self.generic_visit(node) + node = super().visit_If(node) return node if not self._could_be_dynamic(node.test): - node = self.generic_visit(node) + node = super().visit_If(node) return node - node = self.generic_visit(node) - uid = ReplaceIfWithDispatch._counter - ReplaceIfWithDispatch._counter += 1 - - then_name = f"__then_{uid}" - then_func = ast.FunctionDef( - name=then_name, - args=ast.arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), - body=node.body, - decorator_list=[], - type_params=[], - ) - setattr(then_func, _ASTREWRITE_MARKER, True) - then_func = ast.copy_location(then_func, node) - then_func = ast.fix_missing_locations(then_func) - - dispatch_args = [node.test, ast.Name(then_name, ctx=ast.Load())] - result = [then_func] + with self.symbol_scopes.control_flow_scope(): + node.test = self.visit(node.test) + with self.symbol_scopes.control_flow_scope(): + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) + uid = ReplaceIfWithDispatch._counter + ReplaceIfWithDispatch._counter += 1 + + then_name = f"__then_{uid}" + result_names = self._collect_assigned_vars(node, active_symbols_before_if) + + fn_args = [ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in result_names] + + def _state_return_node(): + return ast.Return( + value=ast.Call( + func=ast.Name(id="scf_if_collect_results", ctx=ast.Load()), + args=[ + ast.Name(id="__ret_names", ctx=ast.Load()), + ast.Call(func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[]), + ], + keywords=[], + ) + ) - if node.orelse: - else_name = f"__else_{uid}" - else_func = ast.FunctionDef( - name=else_name, - args=ast.arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), - body=node.orelse, + then_func = ast.FunctionDef( + name=then_name, + args=ast.arguments(posonlyargs=[], args=fn_args, kwonlyargs=[], kw_defaults=[], defaults=[]), + body=list(node.body) + ([_state_return_node()] if result_names else []), decorator_list=[], type_params=[], ) - setattr(else_func, _ASTREWRITE_MARKER, True) - else_func = ast.copy_location(else_func, node) - else_func = ast.fix_missing_locations(else_func) - dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) - result.append(else_func) - - dispatch_call = ast.Expr( - value=ast.Call(func=ast.Name("scf_if_dispatch", ctx=ast.Load()), args=dispatch_args, keywords=[]) - ) - dispatch_call = ast.copy_location(dispatch_call, node) - dispatch_call = ast.fix_missing_locations(dispatch_call) - result.append(dispatch_call) + setattr(then_func, _ASTREWRITE_MARKER, True) + then_func = ast.copy_location(then_func, node) + then_func = ast.fix_missing_locations(then_func) + + dispatch_args = [node.test, ast.Name(then_name, ctx=ast.Load())] + dispatch_keywords = [] + if result_names: + dispatch_keywords.extend( + [ + ast.keyword( + arg="result_names", + value=ast.Tuple(elts=[ast.Constant(value=v) for v in result_names], ctx=ast.Load()), + ), + ast.keyword( + arg="result_values", + value=ast.Tuple( + elts=[self._state_value_expr(v) for v in result_names], + ctx=ast.Load(), + ), + ), + ] + ) + result = [then_func] + + else_name = None + synthesized_else = False + if node.orelse: + else_name = f"__else_{uid}" + else_func = ast.FunctionDef( + name=else_name, + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in result_names], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=list(node.orelse) + ([_state_return_node()] if result_names else []), + decorator_list=[], + type_params=[], + ) + setattr(else_func, _ASTREWRITE_MARKER, True) + else_func = ast.copy_location(else_func, node) + else_func = ast.fix_missing_locations(else_func) + dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) + result.append(else_func) + elif result_names: + else_name = f"__else_{uid}" + synthesized_else = True + else_func = ast.FunctionDef( + name=else_name, + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg="__ret_names", annotation=None)] + [ast.arg(arg=v, annotation=None) for v in result_names], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[_state_return_node()], + decorator_list=[], + type_params=[], + ) + setattr(else_func, _ASTREWRITE_MARKER, True) + else_func = ast.copy_location(else_func, node) + else_func = ast.fix_missing_locations(else_func) + dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) + result.append(else_func) + + if synthesized_else: + dispatch_keywords.append(ast.keyword(arg="auto_else", value=ast.Constant(value=True))) + + dispatch_value = ast.Call( + func=ast.Name("scf_if_dispatch", ctx=ast.Load()), + args=dispatch_args, + keywords=dispatch_keywords, + ) + if result_names and else_name is not None: + if len(result_names) == 1: + target = ast.Name(id=result_names[0], ctx=ast.Store()) + else: + target = ast.Tuple(elts=[ast.Name(id=v, ctx=ast.Store()) for v in result_names], ctx=ast.Store()) + dispatch_stmt = ast.Assign(targets=[target], value=dispatch_value) + else: + dispatch_stmt = ast.Expr(value=dispatch_value) + dispatch_stmt = ast.copy_location(dispatch_stmt, node) + dispatch_stmt = ast.fix_missing_locations(dispatch_stmt) + result.append(dispatch_stmt) - return result + return result @ASTRewriter.register @@ -413,9 +970,16 @@ def visit_For(self, node: ast.For) -> ast.For: node.iter.func = ast.Name(id="scf_range", ctx=ast.Load()) line = ast.dump(node.iter) if "for_" in line or "scf.for_" in line or "scf_range" in line: - node = self.generic_visit(node) + node.iter = self.visit(node.iter) + with self.symbol_scopes.control_flow_scope(): + if isinstance(node.target, ast.Name): + self.symbol_scopes.record_symbol(node.target.id) + node.body = self._visit_stmt_block(node.body) + if node.orelse: + with self.symbol_scopes.control_flow_scope(): + node.orelse = self._visit_stmt_block(node.orelse) new_yield = ast.Expr(ast.Yield(value=None)) - if not self._is_yield(node.body[-1]): + if node.body and not self._is_yield(node.body[-1]): last_statement = node.body[-1] assert last_statement.end_lineno is not None, ( f"last_statement {ast.unparse(last_statement)} must have end_lineno" @@ -509,37 +1073,38 @@ def rewrite_globals(cls): def visit_While(self, node: ast.While) -> List[ast.AST]: if _is_constexpr(node.test): node.test = _unwrap_constexpr(node.test) - node = self.generic_visit(node) + node = super().visit_While(node) return node - node = self.generic_visit(node) - if isinstance(node.test, ast.NamedExpr): - test = node.test.value - else: - test = node.test - w = ast.Call(func=ast.Name("scf_while_gen", ctx=ast.Load()), args=[test], keywords=[]) - w = ast.copy_location(w, node) - assign = ast.Assign( - targets=[ast.Name(f"w_{node.lineno}", ctx=ast.Store())], - value=w, - ) - assign = ast.fix_missing_locations(ast.copy_location(assign, node)) - - next_ = ast.Call( - func=ast.Name("next", ctx=ast.Load()), - args=[ - ast.Name(f"w_{node.lineno}", ctx=ast.Load()), - ast.Constant(False, kind="bool"), - ], - keywords=[], - ) - next_ = ast.fix_missing_locations(ast.copy_location(next_, node)) - if isinstance(node.test, ast.NamedExpr): - node.test.value = next_ - else: - new_test = ast.NamedExpr(target=ast.Name(f"__init__{node.lineno}", ctx=ast.Store()), value=next_) - new_test = ast.copy_location(new_test, node) - node.test = new_test + with self.symbol_scopes.control_flow_scope(): + node = super().visit_While(node) + if isinstance(node.test, ast.NamedExpr): + test = node.test.value + else: + test = node.test + w = ast.Call(func=ast.Name("scf_while_gen", ctx=ast.Load()), args=[test], keywords=[]) + w = ast.copy_location(w, node) + assign = ast.Assign( + targets=[ast.Name(f"w_{node.lineno}", ctx=ast.Store())], + value=w, + ) + assign = ast.fix_missing_locations(ast.copy_location(assign, node)) - node = ast.fix_missing_locations(node) - assign = ast.fix_missing_locations(assign) - return [assign, node] + next_ = ast.Call( + func=ast.Name("next", ctx=ast.Load()), + args=[ + ast.Name(f"w_{node.lineno}", ctx=ast.Load()), + ast.Constant(False, kind="bool"), + ], + keywords=[], + ) + next_ = ast.fix_missing_locations(ast.copy_location(next_, node)) + if isinstance(node.test, ast.NamedExpr): + node.test.value = next_ + else: + new_test = ast.NamedExpr(target=ast.Name(f"__init__{node.lineno}", ctx=ast.Store()), value=next_) + new_test = ast.copy_location(new_test, node) + node.test = new_test + + node = ast.fix_missing_locations(node) + assign = ast.fix_missing_locations(assign) + return [assign, node] diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index 315357be..e797dc57 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -215,6 +215,48 @@ def _extract_arith(val, signed): return v.with_signedness(signed) if isinstance(v, ArithValue) else v +def _unwrap_value(value): + """Convert FlyDSL wrappers to raw MLIR values when possible.""" + if isinstance(value, ir.Value): + return value + if isinstance(value, (bool, int, float)): + try: + return as_numeric(value).ir_value() + except Exception: + log().error(f"failed to construct {as_numeric(value)} from {value}") + return value + if hasattr(value, "__fly_values__"): + values = value.__fly_values__() + if len(values) == 1: + return values[0] + if hasattr(value, "ir_value"): + return value.ir_value() + return value + + +def _wrap_like(value, exemplar=None): + """Wrap an MLIR value back to a FlyDSL wrapper when possible.""" + if not isinstance(value, ir.Value): + return value + + if exemplar is not None: + if isinstance(exemplar, Numeric): + return type(exemplar)(value) + ctor = getattr(type(exemplar), "__fly_construct__", None) + if ctor is not None: + try: + return ctor([value]) + except Exception: + log().error(f"failed to construct {type(exemplar)} from {value}") + return value + + try: + return Numeric.from_ir_type(value.type)(value) + except Exception: + log().error(f"failed to construct {Numeric.from_ir_type(value.type)} from {value}") + return value + + def _make_binop(op, promote=True, widen_bool=False, swap=False): """Create a binary-operator closure for Numeric subclasses.""" def _apply(lhs, rhs, *, loc=None, ip=None): diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 4451bbd2..793274fd 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -769,16 +769,24 @@ def mma_make_fragment(operand_id, tiled_mma, input, loc=None, ip=None): @traced_op -def copy(copy_atom, src, dst, *, pred=None, loc=None, ip=None): - return fly.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip) +def copy(copy_atom, src, dst, *, pred=None, loc=None, ip=None, **kwargs): + return fly.copy(copy_atom.set_value(kwargs), src, dst, pred=pred, loc=loc, ip=ip) @traced_op -def gemm(mma_atom, d, a, b, c, *, traversal_order=None, traversal_layout=None, loc=None, ip=None): +def gemm(mma_atom, d, a, b, c, *, traversal_order=None, traversal_layout=None, loc=None, ip=None, **kwargs): if traversal_order is not None and traversal_layout is not None: raise ValueError("Only one of 'traversal_order' or 'traversal_layout' can be specified, not both") return fly.gemm( - mma_atom, d, a, b, c, traversal_order=traversal_order, traversal_layout=traversal_layout, loc=loc, ip=ip + mma_atom.set_value(kwargs), + d, + a, + b, + c, + traversal_order=traversal_order, + traversal_layout=traversal_layout, + loc=loc, + ip=ip, ) diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index bd7fa71c..e0524317 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -4,7 +4,7 @@ import ctypes import enum from inspect import isclass -from typing import Generic, Type, TypeVar +from typing import Generic, Type, TypeVar, overload from flydsl.runtime.device import get_rocm_arch @@ -652,8 +652,18 @@ def layout_dst_tv(self): def layout_ref_tv(self): return static(self.type.tv_layout_ref) + @overload + def set_value(self, field: str, value, loc=None, ip=None): ... + @overload + def set_value(self, field: dict, loc=None, ip=None): ... + @traced_op - def set_value(self, field, value, loc=None, ip=None): + def set_value(self, field, value=None, loc=None, ip=None): + if isinstance(field, dict): + result = self + for k, v in field.items(): + result = atom_set_value(result, k, v, loc=loc, ip=ip) + return result return atom_set_value(self, field, value, loc=loc, ip=ip) @@ -683,8 +693,18 @@ def layout_B_tv(self): def layout_C_tv(self): return static(self.type.tv_layout_c) + @overload + def set_value(self, field: str, value, loc=None, ip=None): ... + @overload + def set_value(self, field: dict, loc=None, ip=None): ... + @traced_op - def set_value(self, field, value, loc=None, ip=None): + def set_value(self, field, value=None, loc=None, ip=None): + if isinstance(field, dict): + result = self + for k, v in field.items(): + result = atom_set_value(result, k, v, loc=loc, ip=ip) + return result return atom_set_value(self, field, value, loc=loc, ip=ip) diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index cdaf4bb1..897ff34d 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -42,15 +42,16 @@ if [ "${RUN_TESTS_FULL:-0}" != "1" ]; then fi # --------------------------------------------------------------------------- -# 1. All pytest-based tests (kernels + unit + examples) +# 1. All pytest-based tests (kernels + unit + system + examples) # --------------------------------------------------------------------------- echo "========================================================================" -echo "Pytest: kernels + unit + examples" +echo "Pytest: kernels + unit + system + examples" echo "========================================================================" python3 -m pytest \ tests/kernels/ \ tests/unit/ \ + tests/system/ \ tests/python/examples/ \ "${pytest_args[@]}" diff --git a/tests/kernels/test_quant.py b/tests/kernels/test_quant.py index e0c842c5..a7215690 100644 --- a/tests/kernels/test_quant.py +++ b/tests/kernels/test_quant.py @@ -128,24 +128,24 @@ def block_reduce_max(val): if arith.cmpi(arith.CmpIPredicate.eq, lane, Int32(0)): wave_idx = arith.index_cast(T.index, wave) - s_red.store(w, [wave_idx]) + SmemPtr.store(s_red, w, [wave_idx]) gpu.barrier() if arith.cmpi(arith.CmpIPredicate.eq, wave, Int32(0)): in_range = lane < RED_SLOTS lane_safe = arith.select(in_range, lane, Int32(0)) lane_safe_idx = arith.index_cast(T.index, lane_safe) - v = s_red.load([lane_safe_idx]) + v = SmemPtr.load(s_red, [lane_safe_idx]) ww = arith.select(in_range, v, c_zero_f) ww = wave_reduce_max(ww) if arith.cmpi(arith.CmpIPredicate.eq, lane, Int32(0)): c0_idx = arith.constant(0, index=True) - s_red.store(ww, [c0_idx]) + SmemPtr.store(s_red, ww, [c0_idx]) gpu.barrier() c0_idx = arith.constant(0, index=True) - return s_red.load([c0_idx]) + return SmemPtr.load(s_red, [c0_idx]) # ── Layout API: buffer-backed tensors ──────────────────────────── Input_buf = fx.rocdl.make_buffer_tensor(Input) diff --git a/tests/system/test_control_flow_compile.py b/tests/system/test_control_flow_compile.py new file mode 100644 index 00000000..5d7a6759 --- /dev/null +++ b/tests/system/test_control_flow_compile.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +import flydsl.compiler as flyc +import flydsl.expr as fx +import pytest +import torch + + +def test_control_flow_kernel_snippet_compiles_without_error(monkeypatch): + if not torch.cuda.is_available(): + pytest.skip("CUDA device is required for control-flow compile coverage test") + + @flyc.kernel + def vecAbsKernel( + A: fx.Tensor, + C: fx.Tensor, + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], + print_debug: fx.Constexpr[bool] = True, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + if print_debug and bid == 0 and tid <= 2: + fx.printf("[kernel] bid={}, tid={}", bid, tid) + + @flyc.jit + def vecAbs( + A: fx.Tensor, + C, + n: fx.Int32, + const_n: fx.Constexpr[int], + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], + stream: fx.Stream = fx.Stream(None), + ): + tile_elems = block_dim * vec_width + grid_x = (n + tile_elems - 1) // tile_elems + vecAbsKernel(A, C, block_dim, vec_width).launch( + grid=(grid_x, 1, 1), block=(block_dim, 1, 1), stream=stream + ) + + monkeypatch.setenv("FLYDSL_COMPILE_ONLY", "1") + threads = 64 + vec = 4 + size = threads * vec + a = torch.randn(size, device="cuda", dtype=torch.float32) + c = torch.empty_like(a) + t_a = flyc.from_dlpack(a).mark_layout_dynamic(leading_dim=0, divisibility=vec) + vecAbs(t_a, c, size, size, threads, vec) + + +def test_control_flow_dynamic_if_end_to_end_numeric(monkeypatch): + if not torch.cuda.is_available(): + pytest.skip("CUDA device is required for dynamic if end-to-end test") + # Avoid compile-cache hits so dynamic dispatch is exercised in this test process. + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + + @flyc.kernel + def dynamicIfKernel( + A: fx.Tensor, + B: fx.Tensor, + C: fx.Tensor, + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + tile_elems = block_dim * vec_width + + A = fx.rocdl.make_buffer_tensor(A) + B = fx.rocdl.make_buffer_tensor(B) + C = fx.rocdl.make_buffer_tensor(C) + + tA = fx.logical_divide(A, fx.make_layout(tile_elems, 1)) + tB = fx.logical_divide(B, fx.make_layout(tile_elems, 1)) + tC = fx.logical_divide(C, fx.make_layout(tile_elems, 1)) + + tA = fx.slice(tA, (None, bid)) + tB = fx.slice(tB, (None, bid)) + tC = fx.slice(tC, (None, bid)) + + tA = fx.logical_divide(tA, fx.make_layout(vec_width, 1)) + tB = fx.logical_divide(tB, fx.make_layout(vec_width, 1)) + tC = fx.logical_divide(tC, fx.make_layout(vec_width, 1)) + + reg_ty = fx.MemRefType.get( + fx.T.f32(), fx.LayoutType.get(vec_width, 1), fx.AddressSpace.Register + ) + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32) + + rA = fx.memref_alloca(reg_ty, fx.make_layout(vec_width, 1)) + rB = fx.memref_alloca(reg_ty, fx.make_layout(vec_width, 1)) + rC = fx.memref_alloca(reg_ty, fx.make_layout(vec_width, 1)) + + fx.copy_atom_call(copy_atom, fx.slice(tA, (None, tid)), rA) + fx.copy_atom_call(copy_atom, fx.slice(tB, (None, tid)), rB) + + vA = fx.memref_load_vec(rA) + vB = fx.memref_load_vec(rB) + vOut = fx.arith.addf(vA, vB) + + # Runtime branch (tid/bid come from GPU execution), so this should lower to dynamic scf.if. + if (tid % 2) == 0: + vOut = fx.arith.addf(vOut, vA) + else: + vOut = fx.arith.subf(vOut, vB) + + if (bid % 2) == 0: + vOut = fx.arith.addf(vOut, vB) + else: + vOut = fx.arith.subf(vOut, vA) + + fx.memref_store_vec(vOut, rC) + fx.copy_atom_call(copy_atom, rC, fx.slice(tC, (None, tid))) + + @flyc.jit + def dynamicIfVec( + A: fx.Tensor, + B: fx.Tensor, + C, + n: fx.Int32, + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], + stream: fx.Stream = fx.Stream(None), + ): + tile_elems = block_dim * vec_width + grid_x = (n + tile_elems - 1) // tile_elems + dynamicIfKernel(A, B, C, block_dim, vec_width).launch( + grid=(grid_x, 1, 1), block=(block_dim, 1, 1), stream=stream + ) + + block_dim = 64 + vec_width = 4 + num_blocks = 5 + size = block_dim * vec_width * num_blocks + + a = torch.randn(size, device="cuda", dtype=torch.float32) + b = torch.randn(size, device="cuda", dtype=torch.float32) + c = torch.empty_like(a) + + t_a = flyc.from_dlpack(a).mark_layout_dynamic(leading_dim=0, divisibility=vec_width) + dynamicIfVec(t_a, b, c, size, block_dim, vec_width) + torch.cuda.synchronize() + + a3 = a.view(num_blocks, block_dim, vec_width) + b3 = b.view(num_blocks, block_dim, vec_width) + tid = torch.arange(block_dim, device="cuda").view(1, block_dim, 1) + bid = torch.arange(num_blocks, device="cuda").view(num_blocks, 1, 1) + + ref = a3 + b3 + ref = torch.where((tid % 2) == 0, ref + a3, ref - b3) + ref = torch.where((bid % 2) == 0, ref + b3, ref - a3) + ref = ref.reshape(-1) + + torch.testing.assert_close(c, ref, rtol=1e-5, atol=1e-5) diff --git a/tests/unit/test_if_dispatch_paths.py b/tests/unit/test_if_dispatch_paths.py new file mode 100644 index 00000000..fbc1aceb --- /dev/null +++ b/tests/unit/test_if_dispatch_paths.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +import ast + +import pytest + +from flydsl._mlir.ir import Context, FunctionType, InsertionPoint, IntegerType, Location, Module +from flydsl._mlir.dialects import arith, func +from flydsl.compiler.ast_rewriter import ASTRewriter, ReplaceIfWithDispatch +from flydsl.expr.numeric import Int32 + + +def test_collect_assigned_vars_supports_tuple_and_augassign(): + code = """ +a, (b, c) = foo() +d += 1 +""" + stmts = ast.parse(code).body + node = ast.If(test=ast.Constant(value=True), body=stmts, orelse=[]) + active_symbols = [{"a", "b", "c", "d"}] + assigned = ReplaceIfWithDispatch._collect_assigned_vars(node, active_symbols) + assert assigned == ["a", "b", "c", "d"] + + +def test_collect_assigned_vars_supports_annassign_walrus_with_except_for(): + code = """ +x: int = 1 +for i in range(4): + y = i +with ctx() as w: + z = w +try: + pass +except Exception as e: + err = e +if (n := foo()): + out = n +""" + stmts = ast.parse(code).body + node = ast.If(test=ast.Constant(value=True), body=stmts, orelse=[]) + active_symbols = [{"x", "i", "y", "w", "z", "e", "err", "n", "out"}] + assigned = ReplaceIfWithDispatch._collect_assigned_vars(node, active_symbols) + assert assigned == ["x", "i", "y", "w", "z", "err", "n", "out"] + + +def test_scf_if_dispatch_static_with_states_no_ifop(): + with Context(), Location.unknown(): + module = Module.create() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + f = func.FuncOp("test_static_states", FunctionType.get([], [i32])) + entry = f.add_entry_block() + with InsertionPoint(entry): + x = Int32(1) + + def then_fn(x): + return {"x": Int32(42)} + + def else_fn(x): + return {"x": Int32(99)} + + out = ReplaceIfWithDispatch.scf_if_dispatch( + True, + then_fn, + else_fn, + state_names=("x",), + state_values=(x,), + ) + func.ReturnOp([out.ir_value()]) + + assert module.operation.verify() + assert "scf.if" not in str(module) + + +def test_scf_if_dispatch_dynamic_with_states_build_ifop(): + with Context(), Location.unknown(): + module = Module.create() + i1 = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + f = func.FuncOp("test_dynamic_states", FunctionType.get([i1], [i32])) + entry = f.add_entry_block() + with InsertionPoint(entry): + cond = entry.arguments[0] + x = Int32(arith.ConstantOp(i32, 1).result) + + def then_fn(x): + return {"x": Int32(arith.ConstantOp(i32, 42).result)} + + def else_fn(x): + return {"x": Int32(arith.ConstantOp(i32, 99).result)} + + out = ReplaceIfWithDispatch.scf_if_dispatch( + cond, + then_fn, + else_fn, + state_names=("x",), + state_values=(x,), + ) + assert isinstance(out, Int32) + func.ReturnOp([out.ir_value()]) + + assert module.operation.verify() + ir_text = str(module) + assert "scf.if" in ir_text + assert "-> (i32)" in ir_text + + +def test_scf_if_dispatch_dynamic_type_mismatch_has_clear_error(): + with Context(), Location.unknown(): + module = Module.create() + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + i1 = IntegerType.get_signless(1) + with InsertionPoint(module.body): + f = func.FuncOp("test_dynamic_type_mismatch", FunctionType.get([i1], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + cond = entry.arguments[0] + x = Int32(arith.ConstantOp(i32, 1).result) + + def then_fn(x): + return {"x": arith.ConstantOp(i32, 2).result} + + def else_fn(x): + return {"x": arith.ConstantOp(i64, 3).result} + + with pytest.raises(TypeError, match="type mismatch|mismatched types"): + ReplaceIfWithDispatch.scf_if_dispatch( + cond, + then_fn, + else_fn, + state_names=("x",), + state_values=(x,), + ) + + +def test_scf_if_dispatch_dynamic_non_mlir_value_is_promoted(): + with Context(), Location.unknown(): + module = Module.create() + i32 = IntegerType.get_signless(32) + i1 = IntegerType.get_signless(1) + with InsertionPoint(module.body): + f = func.FuncOp("test_dynamic_non_mlir", FunctionType.get([i1], [])) + entry = f.add_entry_block() + with InsertionPoint(entry): + cond = entry.arguments[0] + x = Int32(arith.ConstantOp(i32, 1).result) + + def then_fn(x): + return {"x": 7} + + def else_fn(x): + return {"x": arith.ConstantOp(i32, 3).result} + + out = ReplaceIfWithDispatch.scf_if_dispatch( + cond, + then_fn, + else_fn, + state_names=("x",), + state_values=(x,), + ) + assert isinstance(out, Int32) + + +def test_ast_rewrite_keeps_semantics_for_static_bool(): + called = {"n": 0} + + def sample(flag): + x = 1 + if flag: + x = 2 + else: + x = 3 + return x + + ASTRewriter.transform(sample) + original_dispatch = sample.__globals__["scf_if_dispatch"] + + def traced_dispatch(*args, **kwargs): + called["n"] += 1 + return original_dispatch(*args, **kwargs) + + sample.__globals__["scf_if_dispatch"] = traced_dispatch + assert sample(True) == 2 + assert sample(False) == 3 + assert called["n"] in (0, 2) + + +def test_ast_rewrite_does_not_rewrite_static_string_compare(): + called = {"n": 0} + + def sample(dtype_str): + out = 0 + if dtype_str == "f32": + out = 1 + else: + out = 2 + return out + + ASTRewriter.transform(sample) + original_dispatch = sample.__globals__["scf_if_dispatch"] + + def traced_dispatch(*args, **kwargs): + called["n"] += 1 + return original_dispatch(*args, **kwargs) + + sample.__globals__["scf_if_dispatch"] = traced_dispatch + assert sample("f32") == 1 + assert sample("bf16") == 2 + assert called["n"] == 2