From c0992b56aa8e7792b712fd0c87e6d8d26119c4cc Mon Sep 17 00:00:00 2001 From: "yashao@amd.com" Date: Thu, 9 Apr 2026 17:36:02 +0000 Subject: [PATCH 1/5] rm mem_ops --- kernels/custom_all_reduce_kernel.py | 169 ++++++++++++++++++----- python/flydsl/expr/__init__.py | 2 +- python/flydsl/expr/buffer_ops.py | 29 ++++ python/flydsl/expr/mem_ops.py | 199 ---------------------------- 4 files changed, 164 insertions(+), 235 deletions(-) delete mode 100644 python/flydsl/expr/mem_ops.py diff --git a/kernels/custom_all_reduce_kernel.py b/kernels/custom_all_reduce_kernel.py index 06bbaa47..f0581163 100644 --- a/kernels/custom_all_reduce_kernel.py +++ b/kernels/custom_all_reduce_kernel.py @@ -11,14 +11,113 @@ from __future__ import annotations import flydsl.compiler as flyc -from flydsl.expr import arith as ea, gpu, range_constexpr, mem_ops, vector as ev +from flydsl.expr import arith as ea, gpu, range_constexpr, vector as ev, buffer_ops from flydsl.expr.typing import T, Int32, Int64, Stream from flydsl._mlir import ir -from flydsl._mlir.dialects import scf +from flydsl._mlir.dialects import scf, llvm, rocdl from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from kernels.custom_all_reduce import _KMAXBLOCKS as _MAX_BLOCKS + +# --------------------------------------------------------------------------- +# Low-level memory helpers — all operate on raw i64 device addresses. +# +# Cache modifier bits for buffer_load / buffer_store (AMD GFX942 aux field): +# bit 0 = SC0 — bypass L1/TCP cache +# bit 1 = SC1 — bypass L2/TCC cache +# bit 2 = NT — nontemporal (bypass hardware prefetcher) +# --------------------------------------------------------------------------- +_CM_CACHED = 0 # normal cached access +_CM_SC1 = 2 # bypass L2 only (reads from signal bufs across GPUs) +_CM_SC0_SC1 = 3 # bypass L1+L2 (writes to signal bufs: fully uncached) +_CM_NT = 4 # nontemporal (bulk data writes, bypasses L2 prefetch) + + +# ---- bulk data: 16-byte (128-bit) load / store ---------------------------- + +def _load_v4i32(addr_i64): + """Load vector<4xi32> (16 bytes) from a raw i64 device address.""" + rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) + return buffer_ops.buffer_load(rsrc, ea.constant(0, type=T.i32), + vec_width=4, dtype=T.i32) + + +def _store_v4i32(addr_i64, data): + """Store vector<4xi32> (16 bytes) to a raw i64 device address (cached).""" + rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) + buffer_ops.buffer_store(data, rsrc, ea.constant(0, type=T.i32), + cache_modifier=_CM_CACHED) + + +def _store_v4i32_nt(addr_i64, v4i32_val): + """Store vector<4xi32> nontemporal (nt) — bypasses L2 prefetcher. + + Use for large output writes after all-reduce so dirty lines do not + pollute L2 and the end-sync signal remains the cache-line of interest. + """ + rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) + buffer_ops.buffer_store(v4i32_val, rsrc, ea.constant(0, type=T.i32), + cache_modifier=_CM_NT) + rocdl.s_waitcnt(0) + + +# ---- signal buffer: uncached i32 load / store ---------------------------- + +def _load_i32_uncached(addr_i64): + """Load i32 bypassing L2 (sc1) — for polling cross-GPU signal buffers.""" + rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) + val = buffer_ops.buffer_load(rsrc, ea.constant(0, type=T.i32), + vec_width=1, dtype=T.i32, + cache_modifier=_CM_SC1) + rocdl.s_waitcnt(0) + return val + + +def _store_i32_uncached(addr_i64, val_i32): + """Store i32 bypassing L1+L2 (sc0+sc1) — for signal buffer writes.""" + rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) + buffer_ops.buffer_store(val_i32, rsrc, ea.constant(0, type=T.i32), + cache_modifier=_CM_SC0_SC1) + rocdl.s_waitcnt(0) + + +def _invalidate_l1(): + """Invalidate L1 scalar cache (buffer_inv sc1). + + Call inside a polling loop after an uncached load to discard stale L1 + lines so the next iteration sees fresh data from L2/HBM. + """ + llvm.InlineAsmOp(None, [], "buffer_inv sc1", "", has_side_effects=True) + + +def _store_i32_uncached_flush(addr_i64, val_i32): + """Store i32 with L2 writeback then sc0+sc1 store. + + Use after cached data stores (st_global / buffer_store cached) so that + dirty L2 lines reach HBM before the signal becomes visible to peer GPUs. + buffer_wbl2 cannot be expressed as a buffer_store flag, so it stays as + inline asm. + """ + llvm.InlineAsmOp(None, [], "buffer_wbl2 sc0 sc1", "", has_side_effects=True) + rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) + buffer_ops.buffer_store(val_i32, rsrc, ea.constant(0, type=T.i32), + cache_modifier=_CM_SC0_SC1) + rocdl.s_waitcnt(0) + + +# ---- pointer array helpers ----------------------------------------------- + +def _load_device_ptr(array_base_i64, index): + """Load i64 pointer from a device-side pointer array at *index*. + + Uses buffer_load(dtype=i64): offset is in elements so buffer_load + automatically scales by 8 bytes internally. + """ + rsrc = buffer_ops.create_buffer_resource_from_addr(array_base_i64) + return buffer_ops.buffer_load(rsrc, index, vec_width=1, dtype=T.i64) + + # Signal buffer layout offsets (bytes), derived from _MAX_BLOCKS. # start[_MAX_BLOCKS][8] of uint32 | end[_MAX_BLOCKS][8] of uint32 | flag[_MAX_BLOCKS] of uint32 _SG_START_OFF_B = 0 @@ -67,7 +166,7 @@ def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngp flag_addr = (self_sg_i64 + ea.constant(_SG_FLAG_OFF_B, type=i64) + bid_i32.extui(i64) * ea.constant(4, type=i64)) - flag = mem_ops.load_i32_uncached(flag_addr) + ea.constant(1, type=i32) + flag = _load_i32_uncached(flag_addr) + ea.constant(1, type=i32) bid8 = bid_i32 * ea.constant(8, type=i32) lin_lane = bid8 + lane_i32 @@ -81,8 +180,8 @@ def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngp if_op = scf.IfOp(is_lane, results_=[], has_else=False) with ir.InsertionPoint(if_op.then_block): peer_sg = ea.select_by_index(lane_i32, sgs_i64) - mem_ops.store_i32_uncached_flush(peer_sg + start_rank_off, flag) - init_cur = mem_ops.load_i32_uncached(start_wait_addr) + _store_i32_uncached_flush(peer_sg + start_rank_off, flag) + init_cur = _load_i32_uncached(start_wait_addr) w = scf.WhileOp([i32], [init_cur]) wb = ir.Block.create_at_start(w.before, [i32]) wa = ir.Block.create_at_start(w.after, [i32]) @@ -91,14 +190,14 @@ def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngp need_wait = _u(cur) < flag scf.ConditionOp(need_wait, [cur]) with ir.InsertionPoint(wa): - scf.YieldOp([mem_ops.load_i32_uncached(start_wait_addr)]) + scf.YieldOp([_load_i32_uncached(start_wait_addr)]) scf.YieldOp([]) gpu.barrier() is_t0 = lane_i32 == ea.constant(0, type=i32) if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) with ir.InsertionPoint(if_t0.then_block): - mem_ops.store_i32(flag_addr, flag) + _store_i32_uncached(flag_addr, flag) scf.YieldOp([]) return flag_addr @@ -122,7 +221,7 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, gpu.barrier() flag_addr = (self_sg_i64 + ea.constant(_SG_FLAG_OFF_B, type=i64) + bid_i32.extui(i64) * ea.constant(4, type=i64)) - flag = mem_ops.load_i32_uncached(flag_addr) + ea.constant(1, type=i32) + flag = _load_i32_uncached(flag_addr) + ea.constant(1, type=i32) bid8 = bid_i32 * ea.constant(8, type=i32) lin_lane = bid8 + lane_i32 @@ -137,10 +236,10 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, with ir.InsertionPoint(if_op.then_block): peer_sg = ea.select_by_index(lane_i32, sgs_i64) if need_wbl2: - mem_ops.store_i32_uncached_flush(peer_sg + end_rank_off, flag) + _store_i32_uncached_flush(peer_sg + end_rank_off, flag) else: - mem_ops.store_i32_uncached(peer_sg + end_rank_off, flag) - init_cur = mem_ops.load_i32_uncached(end_wait_addr) + _store_i32_uncached(peer_sg + end_rank_off, flag) + init_cur = _load_i32_uncached(end_wait_addr) w = scf.WhileOp([i32], [init_cur]) wb = ir.Block.create_at_start(w.before, [i32]) wa = ir.Block.create_at_start(w.after, [i32]) @@ -149,8 +248,8 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, need_wait = _u(cur) < flag scf.ConditionOp(need_wait, [cur]) with ir.InsertionPoint(wa): - nxt = mem_ops.load_i32_uncached(end_wait_addr) - mem_ops.invalidate_l1() + nxt = _load_i32_uncached(end_wait_addr) + _invalidate_l1() scf.YieldOp([nxt]) scf.YieldOp([]) @@ -158,7 +257,7 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, is_t0 = lane_i32 == ea.constant(0, type=i32) if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) with ir.InsertionPoint(if_t0.then_block): - mem_ops.store_i32(flag_addr, flag) + _store_i32_uncached(flag_addr, flag) scf.YieldOp([]) @@ -272,8 +371,8 @@ def allreduce_1stage_arr( in_ptrs_i64 = in_ptrs.ir_value() out_ptr_i64 = out_ptr.ir_value() - sgs = [mem_ops.load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - in_ptrs_arr = [mem_ops.load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] smem_sym = f"allreduce_1s_smem_ws{world_size}_t{threads}" n_smem = 2 * threads @@ -311,7 +410,7 @@ def allreduce_1stage_arr( # Each warp loads data from its rank into shared memory in_base = ea.select_by_index(warp_id, in_ptrs_arr) off16 = p.extui(i64) * ea.constant(16, type=i64) - raw = mem_ops.load_v4i32(in_base + off16) + raw = _load_v4i32(in_base + off16) sm_base = parity * ea.constant(threads, type=i32) sm_idx = ea.index_cast(idx, sm_base + lane_i32) smem_ptr.store(raw, [sm_idx]) @@ -338,7 +437,7 @@ def allreduce_1stage_arr( else: out_bits = ev.bitcast(v4i32, acc.truncf(v8half)) dst_off = p.extui(i64) * ea.constant(16, type=i64) - mem_ops.store_v4i32(out_ptr_i64 + dst_off, out_bits) + _store_v4i32(out_ptr_i64 + dst_off, out_bits) scf.YieldOp([]) # No barrier 2 needed: parity double-buffer ensures next iteration @@ -384,9 +483,9 @@ def allreduce_2stage_arr( tmp_ptrs_i64 = tmp_ptrs.ir_value() out_ptr_i64 = out_ptr.ir_value() - sgs = [mem_ops.load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - in_ptrs_arr = [mem_ops.load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - tmp_ptrs_arr = [mem_ops.load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + tmp_ptrs_arr = [_load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] # Compute pack range for this rank's reduce-scatter partition start_p = rank_i32 * ea.constant(part_p, type=i32) @@ -423,7 +522,7 @@ def allreduce_2stage_arr( def _build_reduce_body(cur, smem_base_expr=None): """Emit reduce body: load → smem → barrier1 → warp0 reduce → [barrier2].""" in_base = ea.select_by_index(warp_id, in_ptrs_arr) - raw = mem_ops.load_v4i32(in_base + cur.extui(i64) * ea.constant(16, type=i64)) + raw = _load_v4i32(in_base + cur.extui(i64) * ea.constant(16, type=i64)) if smem_base_expr is None: sm_idx = ea.index_cast(idx, lane_i32) else: @@ -453,7 +552,7 @@ def _build_reduce_body(cur, smem_base_expr=None): else: out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) rel_p = cur - start_p - mem_ops.store_v4i32(tmp_out_i64 + rel_p.extui(i64) * ea.constant(16, type=i64), + _store_v4i32(tmp_out_i64 + rel_p.extui(i64) * ea.constant(16, type=i64), out_raw) scf.YieldOp([]) @@ -515,9 +614,9 @@ def _build_reduce_body(cur, smem_base_expr=None): else: dst_rank = _u(sum_rw) % ea.constant(world_size, type=i32) tmp_base = ea.select_by_index(warp_id, tmp_ptrs_arr) - raw = mem_ops.load_v4i32(tmp_base + cur.extui(i64) * ea.constant(16, type=i64)) + raw = _load_v4i32(tmp_base + cur.extui(i64) * ea.constant(16, type=i64)) dst_pack = dst_rank * ea.constant(part_p, type=i32) + cur - mem_ops.store_v4i32(out_ptr_i64 + dst_pack.extui(i64) * ea.constant(16, type=i64), + _store_v4i32(out_ptr_i64 + dst_pack.extui(i64) * ea.constant(16, type=i64), raw) scf.YieldOp([cur + stride_pack2]) else: @@ -542,10 +641,10 @@ def _build_reduce_body(cur, smem_base_expr=None): ifp = scf.IfOp(ok, results_=[], has_else=False) with ir.InsertionPoint(ifp.then_block): src_off = cur.extui(i64) * ea.constant(16, type=i64) - raw = mem_ops.load_v4i32(tmp_ptrs_arr[p] + src_off) + raw = _load_v4i32(tmp_ptrs_arr[p] + src_off) dst_pack_idx = ea.constant(p * part_p, type=i32) + cur dst_off = dst_pack_idx.extui(i64) * ea.constant(16, type=i64) - mem_ops.store_v4i32(out_ptr_i64 + dst_off, raw) + _store_v4i32(out_ptr_i64 + dst_off, raw) scf.YieldOp([]) scf.YieldOp([cur + stride_i32]) @@ -585,8 +684,8 @@ def allreduce_2stage_write_mode( out_ptrs_i64 = out_ptrs.ir_value() tmp_ptrs_i64 = tmp_ptrs.ir_value() - sgs = [mem_ops.load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - out_ptrs_arr = [mem_ops.load_device_ptr(out_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + out_ptrs_arr = [_load_device_ptr(out_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] tnum_gpu_i32 = ea.constant(tnum_gpu, type=i32) log2_tnum = int(math.log2(tnum_gpu)) @@ -605,7 +704,7 @@ def allreduce_2stage_write_mode( allocator_wm.finalize() smem_ptr = SmemPtr(allocator_wm.get_base(), smem_wm_off, v4i32, shape=(n_smem_wm,)) smem_ptr.get() - tmp_out_i64 = mem_ops.load_device_ptr(tmp_ptrs_i64, rank_i32) + tmp_out_i64 = _load_device_ptr(tmp_ptrs_i64, rank_i32) # ---- Stage 1: scatter local input to REMOTE tmp buffers ---- start_w = warp_id * ea.constant(part_p, type=i32) @@ -628,10 +727,10 @@ def allreduce_2stage_write_mode( with ir.InsertionPoint(as1): cur = as1.arguments[0] stride_s1 = as1.arguments[1] - raw = mem_ops.load_v4i32(inp_ptr_i64 + cur.extui(i64) * ea.constant(16, type=i64)) + raw = _load_v4i32(inp_ptr_i64 + cur.extui(i64) * ea.constant(16, type=i64)) rel_idx = cur - start_w dst_off = rank_i32 * ea.constant(part_p, type=i32) + rel_idx - dst_tmp = mem_ops.load_device_ptr(tmp_ptrs_i64, warp_id) + dst_tmp = _load_device_ptr(tmp_ptrs_i64, warp_id) tmp_addr = dst_tmp + dst_off.extui(i64) * ea.constant(16, type=i64) is_tmp_null = dst_tmp == ea.constant(0, type=i64) tmp_low4 = tmp_addr & ea.constant(0xF, type=i64) @@ -641,7 +740,7 @@ def allreduce_2stage_write_mode( with ir.InsertionPoint(if_tmp_ok.then_block): scf.YieldOp([]) with ir.InsertionPoint(if_tmp_ok.else_block): - mem_ops.store_v4i32(tmp_addr, raw) + _store_v4i32(tmp_addr, raw) scf.YieldOp([]) scf.YieldOp([cur + stride_s1, stride_s1]) @@ -682,7 +781,7 @@ def allreduce_2stage_write_mode( with ir.InsertionPoint(raw_if.then_block): scf.YieldOp([ea.constant_vector(0, v4i32)]) with ir.InsertionPoint(raw_if.else_block): - scf.YieldOp([mem_ops.load_v4i32(load_addr)]) + scf.YieldOp([_load_v4i32(load_addr)]) raw = raw_if.results[0] sm_idx = ea.index_cast(idx, lane_i32) @@ -731,7 +830,7 @@ def allreduce_2stage_write_mode( with ir.InsertionPoint(if_out_ok.then_block): scf.YieldOp([]) with ir.InsertionPoint(if_out_ok.else_block): - mem_ops.store_v4i32_nt(out_addr, out_raw) + _store_v4i32_nt(out_addr, out_raw) scf.YieldOp([]) scf.YieldOp([cur + stride_s2, stride_s2]) diff --git a/python/flydsl/expr/__init__.py b/python/flydsl/expr/__init__.py index 3fef52c7..7c2bd331 100644 --- a/python/flydsl/expr/__init__.py +++ b/python/flydsl/expr/__init__.py @@ -9,5 +9,5 @@ from . import utils -from . import arith, vector, gpu, buffer_ops, rocdl, math, mem_ops +from . import arith, vector, gpu, buffer_ops, rocdl, math from .rocdl import tdm_ops diff --git a/python/flydsl/expr/buffer_ops.py b/python/flydsl/expr/buffer_ops.py index 6c41c311..4355c884 100644 --- a/python/flydsl/expr/buffer_ops.py +++ b/python/flydsl/expr/buffer_ops.py @@ -72,6 +72,7 @@ def _get_buffer_flags(arch=None): 'create_llvm_ptr', 'get_element_ptr', 'create_buffer_resource', + 'create_buffer_resource_from_addr', 'buffer_load', 'buffer_store', 'BufferResourceDescriptor', @@ -323,6 +324,34 @@ def _num_records_from_memref_type() -> Optional[int]: return BufferResourceDescriptor(rsrc) +def create_buffer_resource_from_addr(addr_i64: ir.Value) -> ir.Value: + """Create AMD buffer resource descriptor from a raw i64 device address. + + Useful when working with runtime pointer arrays (e.g. IPC-mapped addresses + or device-side pointer tables) where no fly.memref is available. + The full address is encoded as the buffer base; callers should pass + byte offset 0 to buffer_load / buffer_store. + + Args: + addr_i64: Raw 64-bit device address (i64 MLIR value). + + Returns: + ROCDL buffer resource descriptor (!llvm.ptr<8>). + + Example: + >>> rsrc = create_buffer_resource_from_addr(raw_addr_i64) + >>> data = buffer_load(rsrc, i32_zero, vec_width=4, dtype=T.i32) + """ + addr_i64 = _unwrap_value(addr_i64) + ptr_type = ir.Type.parse('!llvm.ptr') + base_ptr = llvm.IntToPtrOp(ptr_type, addr_i64).result + flags = _create_i32_constant(_get_buffer_flags()) + stride = _create_i16_constant(0) + num_records = _create_i64_constant(0xFFFFFFFF) + rsrc_type = ir.Type.parse('!llvm.ptr<8>') + return rocdl.MakeBufferRsrcOp(rsrc_type, base_ptr, stride, num_records, flags).result + + @traced_op def create_buffer_resource(memref_val: ir.Value, stride: int = 0, diff --git a/python/flydsl/expr/mem_ops.py b/python/flydsl/expr/mem_ops.py deleted file mode 100644 index 4e74ea9e..00000000 --- a/python/flydsl/expr/mem_ops.py +++ /dev/null @@ -1,199 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 FlyDSL Project Contributors - -"""Low-level memory operations via inline assembly for multi-GPU kernels. - -Provides wrappers around GFX942 inline assembly instructions for: - -- **Uncached** loads/stores (``sc0 sc1`` — system-scope coherent, - for cross-GPU signal buffers allocated with ``hipDeviceMallocUncached``) -- **Nontemporal** stores (``nt`` — bypasses L1/L2 cache, works on any - memory type including regular ``hipMalloc`` / IPC-mapped addresses) -- **Cached** vector loads/stores (16-byte / ``v4i32``) -- Device-side pointer access - -All functions operate on raw ``ir.Value`` (i32/i64/vector<4xi32>). - -Example:: - - from flydsl.expr import mem_ops - - val = mem_ops.load_i32_uncached(addr) - mem_ops.store_i32_uncached_flush(peer_addr, flag) - data = mem_ops.load_v4i32(data_addr) -""" - -from __future__ import annotations - -from .._mlir import ir -from .._mlir.dialects import arith as _arith, llvm, rocdl -from .meta import traced_op -from .typing import T - - -# --------------------------------------------------------------------------- -# Uncached i32 operations (system-scope coherent, for signal buffers) -# --------------------------------------------------------------------------- - -@traced_op -def load_i32_uncached(addr_i64): - """Load i32 from global address, bypassing L1 cache (system-scope). - - Emits ``global_load_dword ... sc1`` on GFX942. - Typically used to poll cross-GPU signal buffers. - """ - v = llvm.InlineAsmOp( - T.i32, [addr_i64], - "global_load_dword $0, $1, off sc1", "=v,v", - has_side_effects=True, - ).result - rocdl.s_waitcnt(0) - return v - - -@traced_op -def store_i32_uncached_flush(addr_i64, val_i32): - """Store i32 with L2 flush + system-scope coherence for XGMI visibility. - - Emits ``buffer_wbl2 sc0 sc1`` followed by ``global_store_dword ... sc0 sc1``. - Use after cached data stores (``store_v4i32``) to ensure L2 dirty lines - reach HBM before the signal becomes visible to peer GPUs. - """ - llvm.InlineAsmOp(None, [], "buffer_wbl2 sc0 sc1", "", has_side_effects=True) - llvm.InlineAsmOp( - None, [addr_i64, val_i32], - "global_store_dword $0, $1, off sc0 sc1", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -@traced_op -def store_i32_uncached(addr_i64, val_i32): - """Store i32 with system-scope coherence (no L2 flush). - - Emits ``global_store_dword ... sc0 sc1``. - Use after nontemporal data stores (``store_v4i32_nt``) which already - bypass L2 — no ``buffer_wbl2`` is needed. - """ - llvm.InlineAsmOp( - None, [addr_i64, val_i32], - "global_store_dword $0, $1, off sc0 sc1", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -@traced_op -def store_i32(addr_i64, val_i32): - """Store i32 to global address (normal cached store). - - Emits ``global_store_dword ... off`` with no cache coherence flags. - Use for writes visible only to the local GPU (e.g. updating own signal). - """ - llvm.InlineAsmOp( - None, [addr_i64, val_i32], - "global_store_dword $0, $1, off", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -# --------------------------------------------------------------------------- -# v4i32 (16-byte) vector operations -# --------------------------------------------------------------------------- - -@traced_op -def load_v4i32(addr_i64): - """Load 16 bytes (``vector<4xi32>``) from global address. - - Emits ``flat_load_dwordx4``. - """ - v = llvm.InlineAsmOp( - T.i32x4, [addr_i64], - "flat_load_dwordx4 $0, $1", "=v,v", - has_side_effects=True, - ).result - rocdl.s_waitcnt(0) - return v - - -@traced_op -def store_v4i32(addr_i64, v4i32_val): - """Store 16 bytes (``vector<4xi32>``) to global address. - - Emits ``global_store_dwordx4 ... off``. - """ - llvm.InlineAsmOp( - None, [addr_i64, v4i32_val], - "global_store_dwordx4 $0, $1, off", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -@traced_op -def store_v4i32_nt(addr_i64, v4i32_val): - """Store 16 bytes with nontemporal hint, bypassing L1/L2 cache. - - Emits ``flat_store_dwordx4 ... nt``. - Suitable for large data writes across XGMI — works on any memory type - (regular ``hipMalloc``, IPC-mapped coarse-grained memory). - """ - llvm.InlineAsmOp( - None, [addr_i64, v4i32_val], - "flat_store_dwordx4 $0, $1 nt", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -# --------------------------------------------------------------------------- -# Pointer helpers -# --------------------------------------------------------------------------- - -@traced_op -def load_device_ptr(array_base_i64, index): - """Load an i64 pointer from a device-side pointer array. - - Computes ``base + index * 8``, casts to ``!llvm.ptr``, and loads i64. - - Args: - array_base_i64: Base address of the pointer array (i64). - index: Array index (i32 or i64). - """ - from . import arith as ea - - i64 = T.i64 - if hasattr(index, 'type') and isinstance(index.type, ir.IntegerType) and index.type.width == 32: - index = _arith.ExtUIOp(i64, index).result - elem_addr = array_base_i64 + index * ea.constant(8, type=i64) - ptr = llvm.IntToPtrOp(ir.Type.parse("!llvm.ptr"), elem_addr).result - return llvm.LoadOp(i64, ptr).result - - -@traced_op -def invalidate_l1(): - """Invalidate L1 scalar cache (``buffer_inv sc1``). - - Use inside a polling loop after a remote-visible load to discard stale - L1 cache lines so the next iteration sees fresh data from L2/HBM. - """ - llvm.InlineAsmOp(None, [], "buffer_inv sc1", "", has_side_effects=True) - - -__all__ = [ - # Uncached i32 (system-scope coherent) - "load_i32_uncached", - "store_i32_uncached_flush", - "store_i32_uncached", - "store_i32", - # v4i32 (16-byte vector) - "load_v4i32", - "store_v4i32", - "store_v4i32_nt", - # Cache control - "invalidate_l1", - # Pointer helpers - "load_device_ptr", -] From e5fd5d4134b88a61c347c1a8140bf34be569e7c3 Mon Sep 17 00:00:00 2001 From: "yashao@amd.com" Date: Fri, 10 Apr 2026 02:59:11 +0000 Subject: [PATCH 2/5] add ci testcases of allreduce --- .github/workflows/flydsl.yaml | 125 ++++++++++++++++++++++++- scripts/run_tests.sh | 49 ++++++++++ tests/arch_compat.py | 1 + tests/kernels/test_flydsl_allreduce.py | 99 ++++++++++++++++++++ tests/pytest.ini | 1 + 5 files changed, 274 insertions(+), 1 deletion(-) diff --git a/.github/workflows/flydsl.yaml b/.github/workflows/flydsl.yaml index 6b54a304..d53b9ae7 100644 --- a/.github/workflows/flydsl.yaml +++ b/.github/workflows/flydsl.yaml @@ -19,10 +19,18 @@ env: GITHUB_COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.event.head_commit.id }} jobs: + # --------------------------------------------------------------------------- + # Single-GPU tests: kernels, unit, examples, MLIR FileCheck, benchmarks. + # Runs on 1-GPU and Navi runners only. + # --------------------------------------------------------------------------- test: strategy: matrix: - runners: [ 'linux-flydsl-mi325-1', 'linux-flydsl-mi355-1', 'linux-flydsl-navi-2' ] + runners: [ + 'linux-flydsl-mi325-1', + 'linux-flydsl-mi355-1', + 'linux-flydsl-navi-2', + ] fail-fast: false runs-on: ${{ matrix.runners }} steps: @@ -169,3 +177,118 @@ jobs: run: | docker stop flydsl_test docker rm flydsl_test + + # --------------------------------------------------------------------------- + # Multi-GPU allreduce tests: ONLY for 8-GPU runners. + # Runs on BOTH linux-flydsl-mi325-8 AND linux-flydsl-mi355-8 independently. + # fail-fast: false ensures both runners always complete even if one fails. + # --------------------------------------------------------------------------- + multi-gpu: + needs: test + name: Multi-GPU AllReduce Tests (${{ matrix.runners }}) + strategy: + matrix: + runners: [ + 'linux-flydsl-mi325-8', + 'linux-flydsl-mi355-8', + ] + fail-fast: false + runs-on: ${{ matrix.runners }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + repository: ${{ env.GITHUB_REPO_NAME }} + ref: ${{ env.GITHUB_COMMIT_SHA }} + path: flydsl-test + + - name: Start CI container + run: | + echo "Clean up containers..." + docker ps -aq -f name=flydsl_test | xargs -r docker stop | xargs -r docker rm || true + + echo "Start CI container..." + if [ -f "/etc/podinfo/gha-render-devices" ]; then + DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices) + else + DEVICE_FLAG="--device /dev/dri" + fi + + docker run -dt --network=host --user root --device=/dev/kfd $DEVICE_FLAG \ + -v "${GITHUB_WORKSPACE:-$PWD}/flydsl-test:/flydsl-test" \ + --ipc=host --group-add video \ + --shm-size 16g \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + -w /flydsl-test \ + --name flydsl_test \ + ${{ env.DOCKER_IMAGE }} + env: + GITHUB_WORKSPACE: ${{ github.workspace }} + + - name: Install dependencies + run: | + docker exec flydsl_test bash -c "apt-get update && apt-get install -y cmake build-essential patchelf" + docker exec flydsl_test bash -c "python3 -m pip install -U pip setuptools wheel" + docker exec flydsl_test bash -c "python3 -m pip install ninja>=1.11.1" + docker exec flydsl_test bash -c "python3 -m pip install -U 'hypothesis>=6.82.0'" + docker exec flydsl_test bash -c "git config --global --add safe.directory /flydsl-test && cd /flydsl-test && git log" + + - name: Restore cached MLIR install tarball (if available) + id: mlir-cache + uses: actions/cache@v4 + with: + path: mlir_install.tgz + key: mlir-install-${{ matrix.runners }}-${{ hashFiles('flydsl-test/thirdparty/llvm-hash.txt', 'flydsl-test/scripts/build_llvm.sh', 'flydsl-test/CMakeLists.txt', 'flydsl-test/.github/workflows/flydsl.yaml') }} + + - name: Use cached MLIR install tarball (skip LLVM build) + if: steps.mlir-cache.outputs.cache-hit == 'true' + run: | + ls -lh mlir_install.tgz + docker cp mlir_install.tgz flydsl_test:/tmp/mlir_install.tgz + docker exec flydsl_test bash -c "rm -rf /llvm-project/mlir_install && mkdir -p /llvm-project && tar -xzf /tmp/mlir_install.tgz -C /llvm-project" + docker exec flydsl_test bash -c "ls -la /llvm-project/mlir_install/lib/cmake/mlir" + + - name: Build LLVM + if: steps.mlir-cache.outputs.cache-hit != 'true' + run: | + set -ex + docker exec flydsl_test bash -c "cd /flydsl-test && bash scripts/build_llvm.sh" + docker exec flydsl_test bash -c "ls -la /llvm-project/mlir_install/lib/cmake/mlir" + docker cp flydsl_test:/llvm-project/mlir_install.tgz ./mlir_install.tgz || true + + - name: Build FlyDSL (uses MLIR install prefix) + run: | + docker exec flydsl_test bash -c "export MLIR_PATH=/llvm-project/mlir_install && cd /flydsl-test && python3 -m pip install -e . --use-pep517" + + - name: Prepare aiter + run: | + docker exec flydsl_test bash -c "rm -rf /tmp/aiter && git clone --depth 1 --recursive --shallow-submodules https://github.com/ROCm/aiter.git /tmp/aiter" + docker exec flydsl_test bash -c "python3 -c \"from pathlib import Path; src = Path('/tmp/aiter/requirements.txt'); dst = Path('/tmp/aiter/requirements-flydsl-ci.txt'); lines = [line for line in src.read_text().splitlines() if line.strip() and not line.strip().startswith('flydsl==')]; dst.write_text('\\n'.join(lines) + '\\n')\" && python3 -m pip install -r /tmp/aiter/requirements-flydsl-ci.txt" + + - name: Run multi-GPU allreduce tests + run: | + docker exec flydsl_test bash -c " + export PYTHONPATH=/tmp/aiter:\${PYTHONPATH:-} + cd /flydsl-test + python3 -m pytest tests/kernels/test_flydsl_allreduce.py \ + -m multi_gpu -v --no-header --tb=short + " + + - name: Show test logs + if: failure() + run: | + docker exec flydsl_test bash -c 'cd /tmp && tar czf /tmp/logs.tgz *.log 2>/dev/null || echo "no logs"' + docker cp flydsl_test:/tmp/logs.tgz . || true + if [ -f logs.tgz ]; then + tar -xzf logs.tgz || true + cat *.log || true + else + echo "logs.tgz not found; skipping log extraction" + fi + + - name: Clean up + if: always() + run: | + docker stop flydsl_test + docker rm flydsl_test diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index cdaf4bb1..ad3ee6a7 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -116,6 +116,55 @@ done fi +# --------------------------------------------------------------------------- +# 4. Multi-GPU AllReduce tests (requires >= 8 GPUs; skipped otherwise) +# --------------------------------------------------------------------------- +echo "" +echo "========================================================================" +echo "Multi-GPU AllReduce Tests (world_size=8)" +echo "========================================================================" + +# Detect physical GPU count in a subprocess without HIP_VISIBLE_DEVICES +# so that the auto-selected single-GPU index set above does not hide GPUs. +_phys_gpu_count=$( + env -u HIP_VISIBLE_DEVICES python3 -c \ + "import torch; print(torch.cuda.device_count())" 2>/dev/null || echo "0" +) + +if [[ "${_phys_gpu_count}" -ge 8 ]]; then + echo "[run_tests] Detected ${_phys_gpu_count} GPUs — running 8-GPU allreduce tests" + + _run_allreduce() { + local dtype_str="$1" shape="$2" mode="${3:-eager}" iters="${4:-10}" warmup="${5:-2}" + echo " RUN allreduce shape=${shape} dtype=${dtype_str} mode=${mode}" + if env -u HIP_VISIBLE_DEVICES python3 \ + "${REPO_ROOT}/tests/kernels/test_flydsl_allreduce.py" \ + --world_size 8 --iters "${iters}" --warmup "${warmup}" \ + --shapes "${shape},${dtype_str}" --mode "${mode}" \ + --allreduce_impl flydsl; then + echo " PASS allreduce shape=${shape} dtype=${dtype_str} mode=${mode}" + else + echo " FAIL allreduce shape=${shape} dtype=${dtype_str} mode=${mode}" + exit 1 + fi + } + + # Basic accuracy tests (always run on 8-GPU machines) + _run_allreduce fp16 128,8192 eager 10 2 + _run_allreduce bf16 256,8192 eager 10 2 + _run_allreduce fp16 512,4096 eager 10 2 + _run_allreduce fp32 64,4096 eager 10 2 + + # Extended tests (only in full CI) + if [ "${RUN_TESTS_FULL:-0}" = "1" ]; then + _run_allreduce fp16 1024,8192 eager 20 3 + _run_allreduce fp16 128,8192 cudagraph 20 3 + _run_allreduce bf16 256,8192 cudagraph 20 3 + fi +else + echo " SKIP 8-GPU allreduce tests (need >= 8 GPUs, found ${_phys_gpu_count})" +fi + echo "" echo "========================================================================" echo "All tests passed." diff --git a/tests/arch_compat.py b/tests/arch_compat.py index c79b7be6..a436896e 100644 --- a/tests/arch_compat.py +++ b/tests/arch_compat.py @@ -16,6 +16,7 @@ "test_moe_reduce.py", "test_pa.py", "test_quant.py", + "test_flydsl_allreduce.py", # custom_all_reduce requires CDNA (gfx9xx) }) # Example scripts verified to work on RDNA (non-CDNA) GPUs. diff --git a/tests/kernels/test_flydsl_allreduce.py b/tests/kernels/test_flydsl_allreduce.py index 45d2d3fd..c8b12cba 100644 --- a/tests/kernels/test_flydsl_allreduce.py +++ b/tests/kernels/test_flydsl_allreduce.py @@ -684,6 +684,105 @@ def run_all_tests( return pd.DataFrame() +# ============================================================================ +# Pytest test functions for 8-GPU allreduce CI testing +# ============================================================================ + +def _count_physical_gpus() -> int: + """Return number of physically available GPUs via a fresh subprocess. + + Using a subprocess bypasses both HIP_VISIBLE_DEVICES restrictions and + PyTorch's internal device-count cache in the parent pytest process. + """ + import subprocess as _sp + env = {k: v for k, v in os.environ.items() if k != "HIP_VISIBLE_DEVICES"} + try: + r = _sp.run( + [sys.executable, "-c", "import torch; print(torch.cuda.device_count())"], + capture_output=True, text=True, timeout=30, env=env, + ) + return int(r.stdout.strip()) if r.returncode == 0 else 0 + except Exception: + return 0 + + +# All 8-GPU test configurations (always run, no large_shape distinction). +_8GPU_PARAMS = [ + # (shape, dtype_str, mode) + # --- small shapes (edge-case coverage, aligned with aiter) --- + ((2, 7168), "bf16", "cudagraph"), # 14 K elements · BF16 · cudagraph (aiter shape) + ((16, 4096), "fp16", "eager"), # 64 K elements · FP16 · eager + # --- medium shapes --- + ((128, 8192), "bf16", "cudagraph"), # 1 M elements · BF16 · cudagraph + ((96, 4096), "fp16", "eager"), # 384 K elements · FP16 · eager + # --- eager + cudagraph cross-dtype --- + ((512, 8192), "bf16", "eager"), # 4 M elements · BF16 · eager + ((1024, 8192), "fp16", "cudagraph"), # 8 M elements · FP16 · cudagraph + # --- fp32 coverage --- + ((64, 4096), "fp32", "eager"), # 256 K elements · FP32 · eager +] + +# 4-GPU test configurations (fp32 + smaller world_size coverage). +_4GPU_PARAMS = [ + # (shape, dtype_str, mode) + ((64, 4096), "fp32", "eager"), # 256 K elements · FP32 · eager + ((128, 8192), "fp16", "eager"), # 1 M elements · FP16 · eager + ((64, 8192), "bf16", "cudagraph"), # 512 K elements · BF16 · cudagraph +] + + +def _run_subprocess_test(*, world_size, shape, dtype_str, mode): + """Launch the allreduce harness in a subprocess and assert success.""" + import subprocess as _sp + + env = {k: v for k, v in os.environ.items() if k != "HIP_VISIBLE_DEVICES"} + shape_str = ",".join(str(d) for d in shape) + f",{dtype_str}" + + cmd = [ + sys.executable, __file__, + "--world_size", str(world_size), + "--iters", "10", + "--warmup", "2", + "--shapes", shape_str, + "--mode", mode, + "--allreduce_impl", "flydsl", + ] + result = _sp.run(cmd, env=env, timeout=600, capture_output=True, text=True) + assert result.returncode == 0, ( + f"{world_size}-GPU allreduce FAILED: shape={shape}, dtype={dtype_str}, " + f"mode={mode} (exit code {result.returncode})\n" + f"stdout (last 2000 chars):\n{result.stdout[-2000:]}\n" + f"stderr (last 2000 chars):\n{result.stderr[-2000:]}" + ) + + +@pytest.mark.multi_gpu +@pytest.mark.parametrize("shape,dtype_str,mode", _8GPU_PARAMS) +def test_allreduce_8gpu_accuracy(shape, dtype_str, mode): + """8-GPU allreduce accuracy test. + + Runs the allreduce harness in a child subprocess so that + HIP_VISIBLE_DEVICES (auto-set by run_tests.sh to one GPU index) + does not limit device visibility inside the distributed workers. + + Skipped automatically on machines with fewer than 8 physical GPUs. + """ + phys_ng = _count_physical_gpus() + if phys_ng < 8: + pytest.skip(f"Requires >= 8 physical GPUs, found {phys_ng}.") + _run_subprocess_test(world_size=8, shape=shape, dtype_str=dtype_str, mode=mode) + + +@pytest.mark.multi_gpu +@pytest.mark.parametrize("shape,dtype_str,mode", _4GPU_PARAMS) +def test_allreduce_4gpu_accuracy(shape, dtype_str, mode): + """4-GPU allreduce accuracy test (covers fp32 and world_size=4).""" + phys_ng = _count_physical_gpus() + if phys_ng < 4: + pytest.skip(f"Requires >= 4 physical GPUs, found {phys_ng}.") + _run_subprocess_test(world_size=4, shape=shape, dtype_str=dtype_str, mode=mode) + + if __name__ == "__main__": freeze_support() # Align with AIter harness: use spawn to avoid fork+CUDA issues. diff --git a/tests/pytest.ini b/tests/pytest.ini index 765a874a..11b92e48 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -11,3 +11,4 @@ markers = l1b_target_dialect: requires a target lowering stack; pair with rocm_lower (or future backend markers) l2_device: requires GPU and full runtime stack for execution/correctness rocm_lower: L1b/L2 tests that assume the ROCDL lowering path + multi_gpu: marks tests that require >= 8 physical GPUs (skipped automatically on single-GPU machines) From 11bc8aaea12ffcf269edf3330d3d5cda6d9c930d Mon Sep 17 00:00:00 2001 From: "yashao@amd.com" Date: Fri, 10 Apr 2026 06:43:02 +0000 Subject: [PATCH 3/5] fix wm bug --- .github/workflows/flydsl.yaml | 6 - kernels/custom_all_reduce_kernel.py | 223 ++++++++++++++++------------ python/flydsl/expr/buffer_ops.py | 19 ++- scripts/run_tests.sh | 49 ------ tests/pytest.ini | 2 +- 5 files changed, 143 insertions(+), 156 deletions(-) diff --git a/.github/workflows/flydsl.yaml b/.github/workflows/flydsl.yaml index d53b9ae7..4ea8c02d 100644 --- a/.github/workflows/flydsl.yaml +++ b/.github/workflows/flydsl.yaml @@ -261,15 +261,9 @@ jobs: run: | docker exec flydsl_test bash -c "export MLIR_PATH=/llvm-project/mlir_install && cd /flydsl-test && python3 -m pip install -e . --use-pep517" - - name: Prepare aiter - run: | - docker exec flydsl_test bash -c "rm -rf /tmp/aiter && git clone --depth 1 --recursive --shallow-submodules https://github.com/ROCm/aiter.git /tmp/aiter" - docker exec flydsl_test bash -c "python3 -c \"from pathlib import Path; src = Path('/tmp/aiter/requirements.txt'); dst = Path('/tmp/aiter/requirements-flydsl-ci.txt'); lines = [line for line in src.read_text().splitlines() if line.strip() and not line.strip().startswith('flydsl==')]; dst.write_text('\\n'.join(lines) + '\\n')\" && python3 -m pip install -r /tmp/aiter/requirements-flydsl-ci.txt" - - name: Run multi-GPU allreduce tests run: | docker exec flydsl_test bash -c " - export PYTHONPATH=/tmp/aiter:\${PYTHONPATH:-} cd /flydsl-test python3 -m pytest tests/kernels/test_flydsl_allreduce.py \ -m multi_gpu -v --no-header --tb=short diff --git a/kernels/custom_all_reduce_kernel.py b/kernels/custom_all_reduce_kernel.py index f0581163..1213a1db 100644 --- a/kernels/custom_all_reduce_kernel.py +++ b/kernels/custom_all_reduce_kernel.py @@ -34,35 +34,52 @@ _CM_NT = 4 # nontemporal (bulk data writes, bypasses L2 prefetch) +# ---- buffer resource descriptor helper ------------------------------------ + +def _make_rsrc(addr_i64): + """Create buffer resource descriptor from a wave-uniform i64 base address.""" + return buffer_ops.create_buffer_resource_from_addr(addr_i64) + + # ---- bulk data: 16-byte (128-bit) load / store ---------------------------- +# These accept a pre-built rsrc descriptor and a per-lane byte offset (i32). -def _load_v4i32(addr_i64): - """Load vector<4xi32> (16 bytes) from a raw i64 device address.""" - rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) - return buffer_ops.buffer_load(rsrc, ea.constant(0, type=T.i32), - vec_width=4, dtype=T.i32) +def _load_v4i32(rsrc, byte_off_i32): + """Buffer-load vector<4xi32> (16 bytes) with pre-built descriptor.""" + return buffer_ops.buffer_load(rsrc, byte_off_i32, + vec_width=4, dtype=T.i32, + offset_is_bytes=True) -def _store_v4i32(addr_i64, data): - """Store vector<4xi32> (16 bytes) to a raw i64 device address (cached).""" - rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) - buffer_ops.buffer_store(data, rsrc, ea.constant(0, type=T.i32), - cache_modifier=_CM_CACHED) +def _store_v4i32(rsrc, byte_off_i32, data): + """Buffer-store vector<4xi32> (16 bytes), cached.""" + buffer_ops.buffer_store(data, rsrc, byte_off_i32, + cache_modifier=_CM_CACHED, + offset_is_bytes=True) -def _store_v4i32_nt(addr_i64, v4i32_val): - """Store vector<4xi32> nontemporal (nt) — bypasses L2 prefetcher. +def _store_v4i32_nt(rsrc, byte_off_i32, v4i32_val): + """Buffer-store vector<4xi32> nontemporal — bypasses L2 prefetcher.""" + buffer_ops.buffer_store(v4i32_val, rsrc, byte_off_i32, + cache_modifier=_CM_NT, + offset_is_bytes=True) + rocdl.s_waitcnt(0) + + +# ---- signal buffer: i32 load / store -------------------------------------- - Use for large output writes after all-reduce so dirty lines do not - pollute L2 and the end-sync signal remains the cache-line of interest. +def _store_i32(addr_i64, val_i32): + """Store i32 with default caching — for local flag counter updates. + + Signal buffers live in hipDeviceMallocUncached memory, so cached stores + are safe (hardware bypasses caches for UC memory automatically). + Avoids the unnecessary sc0+sc1 bypass and s_waitcnt overhead of the + uncached variant. """ rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) - buffer_ops.buffer_store(v4i32_val, rsrc, ea.constant(0, type=T.i32), - cache_modifier=_CM_NT) - rocdl.s_waitcnt(0) - + buffer_ops.buffer_store(val_i32, rsrc, ea.constant(0, type=T.i32), + cache_modifier=_CM_CACHED) -# ---- signal buffer: uncached i32 load / store ---------------------------- def _load_i32_uncached(addr_i64): """Load i32 bypassing L2 (sc1) — for polling cross-GPU signal buffers.""" @@ -197,7 +214,7 @@ def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngp is_t0 = lane_i32 == ea.constant(0, type=i32) if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) with ir.InsertionPoint(if_t0.then_block): - _store_i32_uncached(flag_addr, flag) + _store_i32(flag_addr, flag) scf.YieldOp([]) return flag_addr @@ -257,7 +274,7 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, is_t0 = lane_i32 == ea.constant(0, type=i32) if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) with ir.InsertionPoint(if_t0.then_block): - _store_i32_uncached(flag_addr, flag) + _store_i32(flag_addr, flag) scf.YieldOp([]) @@ -371,8 +388,9 @@ def allreduce_1stage_arr( in_ptrs_i64 = in_ptrs.ir_value() out_ptr_i64 = out_ptr.ir_value() - sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + in_rsrcs = [_make_rsrc(in_ptrs_arr[i]) for i in range(world_size)] smem_sym = f"allreduce_1s_smem_ws{world_size}_t{threads}" n_smem = 2 * threads @@ -396,6 +414,8 @@ def allreduce_1stage_arr( tid_pack = bid_i32 * tnum_gpu_i32 + lane_id stride_pack = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 + out_rsrc = _make_rsrc(out_ptr_i64) + loop = scf.WhileOp([i32, i32], [tid_pack, ea.constant(0, type=i32)]) bfor = ir.Block.create_at_start(loop.before, [i32, i32]) afor = ir.Block.create_at_start(loop.after, [i32, i32]) @@ -408,9 +428,9 @@ def allreduce_1stage_arr( parity = afor.arguments[1] # Each warp loads data from its rank into shared memory - in_base = ea.select_by_index(warp_id, in_ptrs_arr) - off16 = p.extui(i64) * ea.constant(16, type=i64) - raw = _load_v4i32(in_base + off16) + in_rsrc = ea.select_by_index(warp_id, in_rsrcs) + off_i32 = p * ea.constant(16, type=i32) + raw = _load_v4i32(in_rsrc, off_i32) sm_base = parity * ea.constant(threads, type=i32) sm_idx = ea.index_cast(idx, sm_base + lane_i32) smem_ptr.store(raw, [sm_idx]) @@ -436,8 +456,8 @@ def allreduce_1stage_arr( out_bits = acc.bitcast(v4i32) else: out_bits = ev.bitcast(v4i32, acc.truncf(v8half)) - dst_off = p.extui(i64) * ea.constant(16, type=i64) - _store_v4i32(out_ptr_i64 + dst_off, out_bits) + dst_off_i32 = p * ea.constant(16, type=i32) + _store_v4i32(out_rsrc, dst_off_i32, out_bits) scf.YieldOp([]) # No barrier 2 needed: parity double-buffer ensures next iteration @@ -483,9 +503,11 @@ def allreduce_2stage_arr( tmp_ptrs_i64 = tmp_ptrs.ir_value() out_ptr_i64 = out_ptr.ir_value() - sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - tmp_ptrs_arr = [_load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + tmp_ptrs_arr = [_load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + in_rsrcs = [_make_rsrc(in_ptrs_arr[i]) for i in range(world_size)] + tmp_rsrcs = [_make_rsrc(tmp_ptrs_arr[i]) for i in range(world_size)] # Compute pack range for this rank's reduce-scatter partition start_p = rank_i32 * ea.constant(part_p, type=i32) @@ -513,6 +535,7 @@ def allreduce_2stage_arr( smem_ptr = SmemPtr(allocator.get_base(), smem_off, v4i32, shape=(smem_slots,)) smem_ptr.get() tmp_out_i64 = tmp_ptrs_arr[0] + tmp_out_rsrc = tmp_rsrcs[0] # ---- Stage 1: reduce-scatter ---- # Two implementations selected at compile time via _use_single_buf_2stage: @@ -521,8 +544,9 @@ def allreduce_2stage_arr( def _build_reduce_body(cur, smem_base_expr=None): """Emit reduce body: load → smem → barrier1 → warp0 reduce → [barrier2].""" - in_base = ea.select_by_index(warp_id, in_ptrs_arr) - raw = _load_v4i32(in_base + cur.extui(i64) * ea.constant(16, type=i64)) + in_rsrc = ea.select_by_index(warp_id, in_rsrcs) + off_i32 = cur * ea.constant(16, type=i32) + raw = _load_v4i32(in_rsrc, off_i32) if smem_base_expr is None: sm_idx = ea.index_cast(idx, lane_i32) else: @@ -552,8 +576,8 @@ def _build_reduce_body(cur, smem_base_expr=None): else: out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) rel_p = cur - start_p - _store_v4i32(tmp_out_i64 + rel_p.extui(i64) * ea.constant(16, type=i64), - out_raw) + rel_off_i32 = rel_p * ea.constant(16, type=i32) + _store_v4i32(tmp_out_rsrc, rel_off_i32, out_raw) scf.YieldOp([]) idx_p = start_p + tid_pack @@ -595,6 +619,8 @@ def _build_reduce_body(cur, smem_base_expr=None): self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) # ---- Stage 2: all-gather ---- + out_rsrc = _make_rsrc(out_ptr_i64) + if vec_ok: tid_pack2 = bid_i32 * tnum_gpu_i32 + lane_id stride_pack2 = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 @@ -613,11 +639,12 @@ def _build_reduce_body(cur, smem_base_expr=None): dst_rank = sum_rw & ea.constant(world_size - 1, type=i32) else: dst_rank = _u(sum_rw) % ea.constant(world_size, type=i32) - tmp_base = ea.select_by_index(warp_id, tmp_ptrs_arr) - raw = _load_v4i32(tmp_base + cur.extui(i64) * ea.constant(16, type=i64)) + tmp_rsrc = ea.select_by_index(warp_id, tmp_rsrcs) + src_off_i32 = cur * ea.constant(16, type=i32) + raw = _load_v4i32(tmp_rsrc, src_off_i32) dst_pack = dst_rank * ea.constant(part_p, type=i32) + cur - _store_v4i32(out_ptr_i64 + dst_pack.extui(i64) * ea.constant(16, type=i64), - raw) + dst_off_i32 = dst_pack * ea.constant(16, type=i32) + _store_v4i32(out_rsrc, dst_off_i32, raw) scf.YieldOp([cur + stride_pack2]) else: # Non-vectorized fallback (world_size=6 or num_packs % world_size != 0) @@ -640,11 +667,11 @@ def _build_reduce_body(cur, smem_base_expr=None): ok = _u(cur) < ea.constant(part_p, type=i32) ifp = scf.IfOp(ok, results_=[], has_else=False) with ir.InsertionPoint(ifp.then_block): - src_off = cur.extui(i64) * ea.constant(16, type=i64) - raw = _load_v4i32(tmp_ptrs_arr[p] + src_off) + src_off_i32 = cur * ea.constant(16, type=i32) + raw = _load_v4i32(tmp_rsrcs[p], src_off_i32) dst_pack_idx = ea.constant(p * part_p, type=i32) + cur - dst_off = dst_pack_idx.extui(i64) * ea.constant(16, type=i64) - _store_v4i32(out_ptr_i64 + dst_off, raw) + dst_off_i32 = dst_pack_idx * ea.constant(16, type=i32) + _store_v4i32(out_rsrc, dst_off_i32, raw) scf.YieldOp([]) scf.YieldOp([cur + stride_i32]) @@ -684,8 +711,11 @@ def allreduce_2stage_write_mode( out_ptrs_i64 = out_ptrs.ir_value() tmp_ptrs_i64 = tmp_ptrs.ir_value() - sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - out_ptrs_arr = [_load_device_ptr(out_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + out_ptrs_arr = [_load_device_ptr(out_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + tmp_ptrs_arr = [_load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + tmp_rsrcs = [_make_rsrc(tmp_ptrs_arr[i]) for i in range(world_size)] + out_rsrcs = [_make_rsrc(out_ptrs_arr[i]) for i in range(world_size)] tnum_gpu_i32 = ea.constant(tnum_gpu, type=i32) log2_tnum = int(math.log2(tnum_gpu)) @@ -704,9 +734,11 @@ def allreduce_2stage_write_mode( allocator_wm.finalize() smem_ptr = SmemPtr(allocator_wm.get_base(), smem_wm_off, v4i32, shape=(n_smem_wm,)) smem_ptr.get() - tmp_out_i64 = _load_device_ptr(tmp_ptrs_i64, rank_i32) + tmp_out_i64 = ea.select_by_index(rank_i32, tmp_ptrs_arr) # ---- Stage 1: scatter local input to REMOTE tmp buffers ---- + inp_rsrc = _make_rsrc(inp_ptr_i64) + start_w = warp_id * ea.constant(part_p, type=i32) is_last_w = warp_id == ea.constant(world_size - 1, type=i32) end_w_if = scf.IfOp(is_last_w, results_=[i32], has_else=True) @@ -727,20 +759,22 @@ def allreduce_2stage_write_mode( with ir.InsertionPoint(as1): cur = as1.arguments[0] stride_s1 = as1.arguments[1] - raw = _load_v4i32(inp_ptr_i64 + cur.extui(i64) * ea.constant(16, type=i64)) + cur_off_i32 = cur * ea.constant(16, type=i32) + raw = _load_v4i32(inp_rsrc, cur_off_i32) rel_idx = cur - start_w dst_off = rank_i32 * ea.constant(part_p, type=i32) + rel_idx - dst_tmp = _load_device_ptr(tmp_ptrs_i64, warp_id) - tmp_addr = dst_tmp + dst_off.extui(i64) * ea.constant(16, type=i64) + dst_tmp = ea.select_by_index(warp_id, tmp_ptrs_arr) is_tmp_null = dst_tmp == ea.constant(0, type=i64) - tmp_low4 = tmp_addr & ea.constant(0xF, type=i64) - is_tmp_misaligned = tmp_low4 != ea.constant(0, type=i64) + dst_tmp_low4 = dst_tmp & ea.constant(0xF, type=i64) + is_tmp_misaligned = dst_tmp_low4 != ea.constant(0, type=i64) bad_tmp_addr = is_tmp_null | is_tmp_misaligned if_tmp_ok = scf.IfOp(bad_tmp_addr, results_=[], has_else=True) with ir.InsertionPoint(if_tmp_ok.then_block): scf.YieldOp([]) with ir.InsertionPoint(if_tmp_ok.else_block): - _store_v4i32(tmp_addr, raw) + dst_tmp_rsrc = ea.select_by_index(warp_id, tmp_rsrcs) + dst_off_i32 = dst_off * ea.constant(16, type=i32) + _store_v4i32(dst_tmp_rsrc, dst_off_i32, raw) scf.YieldOp([]) scf.YieldOp([cur + stride_s1, stride_s1]) @@ -749,10 +783,8 @@ def allreduce_2stage_write_mode( self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) # ---- Stage 2: reduce local tmp and write to REMOTE outputs ---- + tmp_out_rsrc = ea.select_by_index(rank_i32, tmp_rsrcs) part_p_i32 = ea.constant(part_p, type=i32) - # The last rank's output partition has largest_part_p elements - # (= part_p + num_packs % world_size). Use a runtime branch so that - # when num_packs is evenly divisible the overhead is minimal (same value). is_last_rank_s2 = rank_i32 == ea.constant(world_size - 1, type=i32) end_s2_if = scf.IfOp(is_last_rank_s2, results_=[i32], has_else=True) with ir.InsertionPoint(end_s2_if.then_block): @@ -771,66 +803,73 @@ def allreduce_2stage_write_mode( cur = as2.arguments[0] stride_s2 = as2.arguments[1] + # All warps load their chunk from tmp into smem src_off = warp_id * ea.constant(part_p, type=i32) + cur - load_addr = tmp_out_i64 + src_off.extui(i64) * ea.constant(16, type=i64) + src_off_i32 = src_off * ea.constant(16, type=i32) is_tmpout_null = tmp_out_i64 == ea.constant(0, type=i64) - load_low4 = load_addr & ea.constant(0xF, type=i64) - is_load_misaligned = load_low4 != ea.constant(0, type=i64) + tmpout_low4 = tmp_out_i64 & ea.constant(0xF, type=i64) + is_load_misaligned = tmpout_low4 != ea.constant(0, type=i64) bad_load_addr = is_tmpout_null | is_load_misaligned raw_if = scf.IfOp(bad_load_addr, results_=[v4i32], has_else=True) with ir.InsertionPoint(raw_if.then_block): scf.YieldOp([ea.constant_vector(0, v4i32)]) with ir.InsertionPoint(raw_if.else_block): - scf.YieldOp([_load_v4i32(load_addr)]) + scf.YieldOp([_load_v4i32(tmp_out_rsrc, src_off_i32)]) raw = raw_if.results[0] sm_idx = ea.index_cast(idx, lane_i32) smem_ptr.store(raw, [sm_idx]) gpu.barrier() - warp_id_local = _u(lane_i32) >> ea.constant(log2_tnum, type=i32) - lane_id_local = lane_i32 - warp_id_local * ea.constant(tnum_gpu, type=i32) - - raw_vals = [] - for wi in range_constexpr(world_size): - sm_i_idx = ea.index_cast(idx, ea.constant(wi * tnum_gpu, type=i32) + lane_id_local) - raw_vals.append(smem_ptr.load([sm_i_idx])) - - acc = None - for wi in range_constexpr(world_size): - raw_i = raw_vals[wi] + # Warp 0 reduces across all warps, writes result to res area + # (smem[threads .. threads+tnum_gpu-1]). Two-barrier pattern + # matching aiter: barrier1 guards tmp_smem, barrier2 guards + # res_smem; between iterations tmp and res are disjoint so no + # WAR hazard exists. + is_w0 = warp_id == ea.constant(0, type=i32) + ifw0 = scf.IfOp(is_w0, results_=[], has_else=False) + with ir.InsertionPoint(ifw0.then_block): + acc = None + for wi in range_constexpr(world_size): + sm_i_idx = ea.index_cast( + idx, ea.constant(wi * tnum_gpu, type=i32) + lane_id) + raw_i = smem_ptr.load([sm_i_idx]) + if is_f32: + vf = raw_i.bitcast(v4f32) + acc = vf if acc is None else acc + vf + else: + v16 = ev.bitcast(v8half, raw_i) + v32 = v16.extf(v8f32) + acc = v32 if acc is None else acc + v32 if is_f32: - vf = raw_i.bitcast(v4f32) - acc = vf if acc is None else acc + vf + out_raw = acc.bitcast(v4i32) else: - v16 = ev.bitcast(v8half, raw_i) - v32 = v16.extf(v8f32) - acc = v32 if acc is None else acc + v32 - if is_f32: - out_raw = acc.bitcast(v4i32) - else: - out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) + out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) + res_idx = ea.index_cast(idx, ea.constant(threads, type=i32) + lane_id) + smem_ptr.store(out_raw, [res_idx]) + scf.YieldOp([]) + + gpu.barrier() + + # All warps read the same reduced result from res area and + # nontemporal-write to their respective remote output buffers. + res_read_idx = ea.index_cast(idx, ea.constant(threads, type=i32) + lane_id) + reduced_val = smem_ptr.load([res_read_idx]) dst_out_off = rank_i32 * ea.constant(part_p, type=i32) + cur - dst_byte_off = dst_out_off.extui(i64) * ea.constant(16, type=i64) - - # Each warp writes its reduced partition directly to the target - # output via flat_store_dwordx4 nt. The nt hint bypasses L1/L2 - # and works for all memory types (including IPC-mapped addresses). - dst_ptr = out_ptrs_arr[0] - for w in range_constexpr(1, world_size): - is_warp_w = warp_id_local == ea.constant(w, type=i32) - dst_ptr = ea.select(is_warp_w, out_ptrs_arr[w], dst_ptr) - out_addr = dst_ptr + dst_byte_off + dst_off_i32 = dst_out_off * ea.constant(16, type=i32) + + dst_ptr = ea.select_by_index(warp_id, out_ptrs_arr) + dst_out_rsrc = ea.select_by_index(warp_id, out_rsrcs) is_out_null = dst_ptr == ea.constant(0, type=i64) - out_low4 = out_addr & ea.constant(0xF, type=i64) - is_out_misaligned = out_low4 != ea.constant(0, type=i64) + dst_ptr_low4 = dst_ptr & ea.constant(0xF, type=i64) + is_out_misaligned = dst_ptr_low4 != ea.constant(0, type=i64) bad_out_addr = is_out_null | is_out_misaligned if_out_ok = scf.IfOp(bad_out_addr, results_=[], has_else=True) with ir.InsertionPoint(if_out_ok.then_block): scf.YieldOp([]) with ir.InsertionPoint(if_out_ok.else_block): - _store_v4i32_nt(out_addr, out_raw) + _store_v4i32_nt(dst_out_rsrc, dst_off_i32, reduced_val) scf.YieldOp([]) scf.YieldOp([cur + stride_s2, stride_s2]) diff --git a/python/flydsl/expr/buffer_ops.py b/python/flydsl/expr/buffer_ops.py index 4355c884..0f76d02f 100644 --- a/python/flydsl/expr/buffer_ops.py +++ b/python/flydsl/expr/buffer_ops.py @@ -388,7 +388,8 @@ def buffer_load(rsrc: ir.Value, dtype = None, mask: Optional[ir.Value] = None, cache_modifier: int = 0, - soffset_bytes: Optional[Union[int, ir.Value]] = None) -> ir.Value: + soffset_bytes: Optional[Union[int, ir.Value]] = None, + offset_is_bytes: bool = False) -> ir.Value: """AMD buffer load operation. Load data from global memory using buffer descriptor and offset. @@ -396,7 +397,7 @@ def buffer_load(rsrc: ir.Value, Args: rsrc: Buffer resource descriptor (!llvm.ptr<8>) - offset: Offset in elements (i32 type) + offset: Offset in elements (i32 type), or in bytes if offset_is_bytes=True vec_width: Vector width (1, 2, or 4) dtype: Element data type (None for f32, or ir.F32Type, etc.) mask: Optional mask for predicated load (i1 type) @@ -404,6 +405,7 @@ def buffer_load(rsrc: ir.Value, soffset_bytes: Optional scalar offset (in BYTES) added by the buffer instruction (soffset). Use this to fold small constant deltas into the instruction instead of emitting extra VGPR address arithmetic. + offset_is_bytes: If True, treat offset as already in bytes (skip element-to-byte scaling). Returns: Loaded data (scalar or vector depending on vec_width) @@ -432,12 +434,13 @@ def buffer_load(rsrc: ir.Value, op = std_arith.IndexCastOp(T.i32(), offset) offset = _unwrap_value(op.result) - # IMPORTANT: Buffer load offset is in BYTES, not elements! - # For vec4xf32, each element is 4 bytes, so multiply offset by 4 - element_bytes = dtype.width // 8 - bytes_const = _create_i32_constant(element_bytes) - op = std_arith.MulIOp(offset, bytes_const) - offset = _unwrap_value(op.result) + # RawPtrBufferLoadOp offset is in BYTES. By default we accept element + # offsets and scale to bytes; set offset_is_bytes=True to skip scaling. + if not offset_is_bytes: + element_bytes = dtype.width // 8 + bytes_const = _create_i32_constant(element_bytes) + op = std_arith.MulIOp(offset, bytes_const) + offset = _unwrap_value(op.result) # Apply mask by setting invalid offsets to max if mask is not None: diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index ad3ee6a7..cdaf4bb1 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -116,55 +116,6 @@ done fi -# --------------------------------------------------------------------------- -# 4. Multi-GPU AllReduce tests (requires >= 8 GPUs; skipped otherwise) -# --------------------------------------------------------------------------- -echo "" -echo "========================================================================" -echo "Multi-GPU AllReduce Tests (world_size=8)" -echo "========================================================================" - -# Detect physical GPU count in a subprocess without HIP_VISIBLE_DEVICES -# so that the auto-selected single-GPU index set above does not hide GPUs. -_phys_gpu_count=$( - env -u HIP_VISIBLE_DEVICES python3 -c \ - "import torch; print(torch.cuda.device_count())" 2>/dev/null || echo "0" -) - -if [[ "${_phys_gpu_count}" -ge 8 ]]; then - echo "[run_tests] Detected ${_phys_gpu_count} GPUs — running 8-GPU allreduce tests" - - _run_allreduce() { - local dtype_str="$1" shape="$2" mode="${3:-eager}" iters="${4:-10}" warmup="${5:-2}" - echo " RUN allreduce shape=${shape} dtype=${dtype_str} mode=${mode}" - if env -u HIP_VISIBLE_DEVICES python3 \ - "${REPO_ROOT}/tests/kernels/test_flydsl_allreduce.py" \ - --world_size 8 --iters "${iters}" --warmup "${warmup}" \ - --shapes "${shape},${dtype_str}" --mode "${mode}" \ - --allreduce_impl flydsl; then - echo " PASS allreduce shape=${shape} dtype=${dtype_str} mode=${mode}" - else - echo " FAIL allreduce shape=${shape} dtype=${dtype_str} mode=${mode}" - exit 1 - fi - } - - # Basic accuracy tests (always run on 8-GPU machines) - _run_allreduce fp16 128,8192 eager 10 2 - _run_allreduce bf16 256,8192 eager 10 2 - _run_allreduce fp16 512,4096 eager 10 2 - _run_allreduce fp32 64,4096 eager 10 2 - - # Extended tests (only in full CI) - if [ "${RUN_TESTS_FULL:-0}" = "1" ]; then - _run_allreduce fp16 1024,8192 eager 20 3 - _run_allreduce fp16 128,8192 cudagraph 20 3 - _run_allreduce bf16 256,8192 cudagraph 20 3 - fi -else - echo " SKIP 8-GPU allreduce tests (need >= 8 GPUs, found ${_phys_gpu_count})" -fi - echo "" echo "========================================================================" echo "All tests passed." diff --git a/tests/pytest.ini b/tests/pytest.ini index 11b92e48..0d00c23f 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -11,4 +11,4 @@ markers = l1b_target_dialect: requires a target lowering stack; pair with rocm_lower (or future backend markers) l2_device: requires GPU and full runtime stack for execution/correctness rocm_lower: L1b/L2 tests that assume the ROCDL lowering path - multi_gpu: marks tests that require >= 8 physical GPUs (skipped automatically on single-GPU machines) + multi_gpu: marks tests that require multi-GPU (>=4 physical GPUs; skipped automatically when unavailable) From 12614798e71bada184e7d0e22e8d059f52a22ffc Mon Sep 17 00:00:00 2001 From: "yashao@amd.com" Date: Mon, 13 Apr 2026 06:06:18 +0000 Subject: [PATCH 4/5] optimization allreduce perf --- kernels/custom_all_reduce.py | 12 +- kernels/custom_all_reduce_kernel.py | 249 ++++++++++++++-------------- python/flydsl/expr/arith.py | 20 --- python/flydsl/expr/vector.py | 23 ++- 4 files changed, 157 insertions(+), 147 deletions(-) diff --git a/kernels/custom_all_reduce.py b/kernels/custom_all_reduce.py index 5f6b7631..a00c0a1b 100644 --- a/kernels/custom_all_reduce.py +++ b/kernels/custom_all_reduce.py @@ -268,6 +268,14 @@ def __init__(self, *, group, device, max_size: int, world_size: int, rank: int, if self.world_size not in {2, 4, 8}: raise ValueError(f"world_size must be one of {{2, 4, 8}}, got {self.world_size}") + # Pre-initialize resource attributes so close() is safe on partial init failure. + self._meta_ptr = None + self._meta_bases = [None] * self.world_size + self._input_buffer_bases = [None] * self.world_size + self._output_buffer_bases = [None] * self.world_size + self._graph_ipc_reg_list = [] + self._out_ptrs_cache = None + alloc_size = self._SIGNAL_SIZE + int(self.max_size) self._meta_ptr = self._alloc_uncached(alloc_size) @@ -373,7 +381,9 @@ def __init__(self, *, group, device, max_size: int, world_size: int, rank: int, def close(self): """Release IPC memory handles for peer GPU buffers.""" - for bases in [self._meta_bases, self._input_buffer_bases, self._output_buffer_bases]: + for bases in [getattr(self, '_meta_bases', []), + getattr(self, '_input_buffer_bases', []), + getattr(self, '_output_buffer_bases', [])]: for b in bases: if b is not None: self._close_mem_handle(int(b)) diff --git a/kernels/custom_all_reduce_kernel.py b/kernels/custom_all_reduce_kernel.py index 1213a1db..569ac1bb 100644 --- a/kernels/custom_all_reduce_kernel.py +++ b/kernels/custom_all_reduce_kernel.py @@ -10,6 +10,8 @@ from __future__ import annotations +import math + import flydsl.compiler as flyc from flydsl.expr import arith as ea, gpu, range_constexpr, vector as ev, buffer_ops from flydsl.expr.typing import T, Int32, Int64, Stream @@ -68,22 +70,14 @@ def _store_v4i32_nt(rsrc, byte_off_i32, v4i32_val): # ---- signal buffer: i32 load / store -------------------------------------- -def _store_i32(addr_i64, val_i32): - """Store i32 with default caching — for local flag counter updates. - - Signal buffers live in hipDeviceMallocUncached memory, so cached stores - are safe (hardware bypasses caches for UC memory automatically). - Avoids the unnecessary sc0+sc1 bypass and s_waitcnt overhead of the - uncached variant. - """ - rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) +def _store_i32(rsrc, val_i32): + """Store i32 with default caching via pre-built rsrc descriptor.""" buffer_ops.buffer_store(val_i32, rsrc, ea.constant(0, type=T.i32), cache_modifier=_CM_CACHED) -def _load_i32_uncached(addr_i64): - """Load i32 bypassing L2 (sc1) — for polling cross-GPU signal buffers.""" - rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) +def _load_i32_uncached(rsrc): + """Load i32 bypassing L2 (sc1) via pre-built rsrc descriptor.""" val = buffer_ops.buffer_load(rsrc, ea.constant(0, type=T.i32), vec_width=1, dtype=T.i32, cache_modifier=_CM_SC1) @@ -91,9 +85,8 @@ def _load_i32_uncached(addr_i64): return val -def _store_i32_uncached(addr_i64, val_i32): - """Store i32 bypassing L1+L2 (sc0+sc1) — for signal buffer writes.""" - rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) +def _store_i32_uncached(rsrc, val_i32): + """Store i32 bypassing L1+L2 (sc0+sc1) via pre-built rsrc descriptor.""" buffer_ops.buffer_store(val_i32, rsrc, ea.constant(0, type=T.i32), cache_modifier=_CM_SC0_SC1) rocdl.s_waitcnt(0) @@ -108,16 +101,12 @@ def _invalidate_l1(): llvm.InlineAsmOp(None, [], "buffer_inv sc1", "", has_side_effects=True) -def _store_i32_uncached_flush(addr_i64, val_i32): - """Store i32 with L2 writeback then sc0+sc1 store. +def _store_i32_uncached_flush(rsrc, val_i32): + """Store i32 with L2 writeback then sc0+sc1 store via pre-built rsrc. - Use after cached data stores (st_global / buffer_store cached) so that - dirty L2 lines reach HBM before the signal becomes visible to peer GPUs. - buffer_wbl2 cannot be expressed as a buffer_store flag, so it stays as - inline asm. + buffer_wbl2 flushes dirty L2 lines to HBM before the signal store. """ llvm.InlineAsmOp(None, [], "buffer_wbl2 sc0 sc1", "", has_side_effects=True) - rsrc = buffer_ops.create_buffer_resource_from_addr(addr_i64) buffer_ops.buffer_store(val_i32, rsrc, ea.constant(0, type=T.i32), cache_modifier=_CM_SC0_SC1) rocdl.s_waitcnt(0) @@ -125,6 +114,24 @@ def _store_i32_uncached_flush(addr_i64, val_i32): # ---- pointer array helpers ----------------------------------------------- +def _pack_i64_vec(values): + """Pack preloaded i64 values into vector for contiguous VGPR storage. + + On AMDGPU the subsequent ``ev.extract`` with a dynamic index lowers + through ``ConvertVectorToLLVM`` to ``llvm.extractelement`` which the + backend emits as ``v_movrels_b32`` (VGPR-relative addressing, ~3 insns) + instead of a chained ``arith.select`` costing 2*(N-1) insns. + """ + vec_type = T.vec(len(values), T.i64) + return ev.from_elements(vec_type, values) + + +def _extract_i64(vec, index): + """Extract i64 from a packed vector by dynamic index (VGPR-relative).""" + idx = ea.index_cast(T.index, index) + return ev.extract(vec, dynamic_position=[idx]) + + def _load_device_ptr(array_base_i64, index): """Load i64 pointer from a device-side pointer array at *index*. @@ -146,6 +153,19 @@ def _load_device_ptr(array_base_i64, index): # Element type helpers # --------------------------------------------------------------------------- +_BYTES_PER_PACK = 16 # sizeof(vector<4xi32>), the atomic load/store unit + + +def _elem_bytes(dtype_str: str) -> int: + """Return byte width of one scalar element for the given dtype.""" + d = (dtype_str or "").strip().lower() + if d in {"f32", "fp32"}: + return 4 + if d in {"f16", "fp16", "bf16"}: + return 2 + raise ValueError(f"unsupported dtype_str: {dtype_str!r}") + + def _elem_type(dtype_str: str) -> ir.Type: d = (dtype_str or "").strip().lower() if d in {"f16", "fp16"}: @@ -158,12 +178,8 @@ def _elem_type(dtype_str: str) -> ir.Type: def _pack_elems(dtype_str: str) -> int: - d = (dtype_str or "").strip().lower() - if d in {"f32", "fp32"}: - return 4 - if d in {"f16", "fp16", "bf16"}: - return 8 - raise ValueError(f"unsupported dtype_str: {dtype_str!r}") + """Number of elements per pack, derived from _BYTES_PER_PACK.""" + return _BYTES_PER_PACK // _elem_bytes(dtype_str) def _u(v): @@ -177,18 +193,18 @@ def _u(v): def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngpus: int): """Start-sync: write start flag to all peers, wait for all to arrive.""" - - i32, i64 = T.i32, T.i64 flag_addr = (self_sg_i64 + ea.constant(_SG_FLAG_OFF_B, type=i64) + bid_i32.extui(i64) * ea.constant(4, type=i64)) - flag = _load_i32_uncached(flag_addr) + ea.constant(1, type=i32) + flag_rsrc = _make_rsrc(flag_addr) + flag = _load_i32_uncached(flag_rsrc) + ea.constant(1, type=i32) bid8 = bid_i32 * ea.constant(8, type=i32) lin_lane = bid8 + lane_i32 start_wait_addr = (self_sg_i64 + ea.constant(_SG_START_OFF_B, type=i64) + lin_lane.extui(i64) * ea.constant(4, type=i64)) + wait_rsrc = _make_rsrc(start_wait_addr) lin_rank = bid8 + rank_i32 start_rank_off = (ea.constant(_SG_START_OFF_B, type=i64) + lin_rank.extui(i64) * ea.constant(4, type=i64)) @@ -196,9 +212,10 @@ def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngp is_lane = _u(lane_i32) < ea.constant(ngpus, type=i32) if_op = scf.IfOp(is_lane, results_=[], has_else=False) with ir.InsertionPoint(if_op.then_block): - peer_sg = ea.select_by_index(lane_i32, sgs_i64) - _store_i32_uncached_flush(peer_sg + start_rank_off, flag) - init_cur = _load_i32_uncached(start_wait_addr) + peer_sg = _extract_i64(_pack_i64_vec(sgs_i64), lane_i32) + peer_rsrc = _make_rsrc(peer_sg + start_rank_off) + _store_i32_uncached(peer_rsrc, flag) + init_cur = _load_i32_uncached(wait_rsrc) w = scf.WhileOp([i32], [init_cur]) wb = ir.Block.create_at_start(w.before, [i32]) wa = ir.Block.create_at_start(w.after, [i32]) @@ -207,43 +224,34 @@ def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngp need_wait = _u(cur) < flag scf.ConditionOp(need_wait, [cur]) with ir.InsertionPoint(wa): - scf.YieldOp([_load_i32_uncached(start_wait_addr)]) + scf.YieldOp([_load_i32_uncached(wait_rsrc)]) scf.YieldOp([]) gpu.barrier() is_t0 = lane_i32 == ea.constant(0, type=i32) if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) with ir.InsertionPoint(if_t0.then_block): - _store_i32(flag_addr, flag) + _store_i32(flag_rsrc, flag) scf.YieldOp([]) return flag_addr def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, - ngpus: int, need_wbl2: bool = False): - """End-sync: write end flag to all peers, wait for all to finish. - - Args: - need_wbl2: True → use st_xgmi_u32 (buffer_wbl2 + signal store). - Required after cached stores (st_global_16b) so - that L2 dirty lines reach HBM before the signal. - False → use st_signal_u32 (signal store only, no wbl2). - For nt data stores (st_nt_16b) which already bypass - L2; uses ATOMIC_RELAXED + MEMORY_SCOPE_SYSTEM. - """ - + ngpus: int): + """End-sync: write end flag to all peers, wait for all to finish.""" i32, i64 = T.i32, T.i64 - gpu.barrier() flag_addr = (self_sg_i64 + ea.constant(_SG_FLAG_OFF_B, type=i64) + bid_i32.extui(i64) * ea.constant(4, type=i64)) - flag = _load_i32_uncached(flag_addr) + ea.constant(1, type=i32) + flag_rsrc = _make_rsrc(flag_addr) + flag = _load_i32_uncached(flag_rsrc) + ea.constant(1, type=i32) bid8 = bid_i32 * ea.constant(8, type=i32) lin_lane = bid8 + lane_i32 end_wait_addr = (self_sg_i64 + ea.constant(_SG_END_OFF_B, type=i64) + lin_lane.extui(i64) * ea.constant(4, type=i64)) + wait_rsrc = _make_rsrc(end_wait_addr) lin_rank = bid8 + rank_i32 end_rank_off = (ea.constant(_SG_END_OFF_B, type=i64) + lin_rank.extui(i64) * ea.constant(4, type=i64)) @@ -251,12 +259,10 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, is_lane = _u(lane_i32) < ea.constant(ngpus, type=i32) if_op = scf.IfOp(is_lane, results_=[], has_else=False) with ir.InsertionPoint(if_op.then_block): - peer_sg = ea.select_by_index(lane_i32, sgs_i64) - if need_wbl2: - _store_i32_uncached_flush(peer_sg + end_rank_off, flag) - else: - _store_i32_uncached(peer_sg + end_rank_off, flag) - init_cur = _load_i32_uncached(end_wait_addr) + peer_sg = _extract_i64(_pack_i64_vec(sgs_i64), lane_i32) + peer_rsrc = _make_rsrc(peer_sg + end_rank_off) + _store_i32_uncached(peer_rsrc, flag) + init_cur = _load_i32_uncached(wait_rsrc) w = scf.WhileOp([i32], [init_cur]) wb = ir.Block.create_at_start(w.before, [i32]) wa = ir.Block.create_at_start(w.after, [i32]) @@ -265,7 +271,7 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, need_wait = _u(cur) < flag scf.ConditionOp(need_wait, [cur]) with ir.InsertionPoint(wa): - nxt = _load_i32_uncached(end_wait_addr) + nxt = _load_i32_uncached(wait_rsrc) _invalidate_l1() scf.YieldOp([nxt]) scf.YieldOp([]) @@ -274,7 +280,7 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, is_t0 = lane_i32 == ea.constant(0, type=i32) if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) with ir.InsertionPoint(if_t0.then_block): - _store_i32(flag_addr, flag) + _store_i32(flag_rsrc, flag) scf.YieldOp([]) @@ -367,10 +373,7 @@ def allreduce_1stage_arr( Each warp loads data from one rank into shared memory, then warp 0 reduces across all warps and writes the result to global memory. """ - - i32, i64 = T.i32, T.i64 - idx = ir.IndexType.get() v4i32 = T.i32x4 if is_f32: v4f32 = T.f32x4 @@ -390,13 +393,13 @@ def allreduce_1stage_arr( sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] - in_rsrcs = [_make_rsrc(in_ptrs_arr[i]) for i in range(world_size)] + in_ptrs_vec = _pack_i64_vec(in_ptrs_arr) smem_sym = f"allreduce_1s_smem_ws{world_size}_t{threads}" n_smem = 2 * threads allocator = SmemAllocator(None, global_sym_name=smem_sym) smem_off = allocator._align(allocator.ptr, 16) - allocator.ptr = smem_off + n_smem * 16 + allocator.ptr = smem_off + n_smem * _BYTES_PER_PACK with ir.InsertionPoint.at_block_begin(gpu_func_op.operation.block): allocator.finalize() smem_ptr = SmemPtr(allocator.get_base(), smem_off, v4i32, shape=(n_smem,)) @@ -415,6 +418,7 @@ def allreduce_1stage_arr( stride_pack = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 out_rsrc = _make_rsrc(out_ptr_i64) + in_rsrc = _make_rsrc(_extract_i64(in_ptrs_vec, warp_id)) loop = scf.WhileOp([i32, i32], [tid_pack, ea.constant(0, type=i32)]) bfor = ir.Block.create_at_start(loop.before, [i32, i32]) @@ -427,12 +431,10 @@ def allreduce_1stage_arr( p = afor.arguments[0] parity = afor.arguments[1] - # Each warp loads data from its rank into shared memory - in_rsrc = ea.select_by_index(warp_id, in_rsrcs) - off_i32 = p * ea.constant(16, type=i32) + off_i32 = p * ea.constant(_BYTES_PER_PACK, type=i32) raw = _load_v4i32(in_rsrc, off_i32) sm_base = parity * ea.constant(threads, type=i32) - sm_idx = ea.index_cast(idx, sm_base + lane_i32) + sm_idx = ea.index_cast(T.index, sm_base + lane_i32) smem_ptr.store(raw, [sm_idx]) gpu.barrier() @@ -443,7 +445,7 @@ def allreduce_1stage_arr( acc = None for wi in range_constexpr(world_size): sm_i_idx = ea.index_cast( - idx, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + sm_base) + T.index, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + sm_base) raw_i = smem_ptr.load([sm_i_idx]) if is_f32: vf = raw_i.bitcast(v4f32) @@ -456,7 +458,7 @@ def allreduce_1stage_arr( out_bits = acc.bitcast(v4i32) else: out_bits = ev.bitcast(v4i32, acc.truncf(v8half)) - dst_off_i32 = p * ea.constant(16, type=i32) + dst_off_i32 = p * ea.constant(_BYTES_PER_PACK, type=i32) _store_v4i32(out_rsrc, dst_off_i32, out_bits) scf.YieldOp([]) @@ -481,10 +483,7 @@ def allreduce_2stage_arr( tmp_ptrs: Int64, out_ptr: Int64, ): - - i32, i64 = T.i32, T.i64 - idx = ir.IndexType.get() v4i32 = T.i32x4 if is_f32: v4f32 = T.f32x4 @@ -506,8 +505,7 @@ def allreduce_2stage_arr( sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] tmp_ptrs_arr = [_load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] - in_rsrcs = [_make_rsrc(in_ptrs_arr[i]) for i in range(world_size)] - tmp_rsrcs = [_make_rsrc(tmp_ptrs_arr[i]) for i in range(world_size)] + in_ptrs_vec = _pack_i64_vec(in_ptrs_arr) # Compute pack range for this rank's reduce-scatter partition start_p = rank_i32 * ea.constant(part_p, type=i32) @@ -529,28 +527,27 @@ def allreduce_2stage_arr( smem_slots = threads if _use_single_buf_2stage else 2 * threads allocator = SmemAllocator(None, global_sym_name=smem_sym) smem_off = allocator._align(allocator.ptr, 16) - allocator.ptr = smem_off + smem_slots * 16 + allocator.ptr = smem_off + smem_slots * _BYTES_PER_PACK with ir.InsertionPoint.at_block_begin(gpu_func_op.operation.block): allocator.finalize() smem_ptr = SmemPtr(allocator.get_base(), smem_off, v4i32, shape=(smem_slots,)) smem_ptr.get() - tmp_out_i64 = tmp_ptrs_arr[0] - tmp_out_rsrc = tmp_rsrcs[0] + tmp_out_rsrc = _make_rsrc(tmp_ptrs_arr[0]) # ---- Stage 1: reduce-scatter ---- # Two implementations selected at compile time via _use_single_buf_2stage: # Single-buffer (large tensor): 8KB LDS, 2 barriers/iter, higher occupancy. # Double-buffer (small tensor): 16KB LDS, 1 barrier/iter (parity trick). + in_rsrc = _make_rsrc(_extract_i64(in_ptrs_vec, warp_id)) def _build_reduce_body(cur, smem_base_expr=None): """Emit reduce body: load → smem → barrier1 → warp0 reduce → [barrier2].""" - in_rsrc = ea.select_by_index(warp_id, in_rsrcs) - off_i32 = cur * ea.constant(16, type=i32) + off_i32 = cur * ea.constant(_BYTES_PER_PACK, type=i32) raw = _load_v4i32(in_rsrc, off_i32) if smem_base_expr is None: - sm_idx = ea.index_cast(idx, lane_i32) + sm_idx = ea.index_cast(T.index, lane_i32) else: - sm_idx = ea.index_cast(idx, smem_base_expr + lane_i32) + sm_idx = ea.index_cast(T.index, smem_base_expr + lane_i32) smem_ptr.store(raw, [sm_idx]) gpu.barrier() # barrier 1: all warps have written smem @@ -560,9 +557,9 @@ def _build_reduce_body(cur, smem_base_expr=None): acc = None for wi in range_constexpr(world_size): if smem_base_expr is None: - sm_r_idx = ea.index_cast(idx, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id) + sm_r_idx = ea.index_cast(T.index, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id) else: - sm_r_idx = ea.index_cast(idx, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + smem_base_expr) + sm_r_idx = ea.index_cast(T.index, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + smem_base_expr) raw_i = smem_ptr.load([sm_r_idx]) if is_f32: vf = raw_i.bitcast(v4f32) @@ -576,7 +573,7 @@ def _build_reduce_body(cur, smem_base_expr=None): else: out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) rel_p = cur - start_p - rel_off_i32 = rel_p * ea.constant(16, type=i32) + rel_off_i32 = rel_p * ea.constant(_BYTES_PER_PACK, type=i32) _store_v4i32(tmp_out_rsrc, rel_off_i32, out_raw) scf.YieldOp([]) @@ -615,6 +612,7 @@ def _build_reduce_body(cur, smem_base_expr=None): # smem half, so warp-0 reads and all-warp writes are disjoint. scf.YieldOp([cur + stride_pack, ea.constant(1, type=i32) - parity]) + gpu.barrier() _signal_end_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) @@ -622,9 +620,10 @@ def _build_reduce_body(cur, smem_base_expr=None): out_rsrc = _make_rsrc(out_ptr_i64) if vec_ok: + tmp_ptrs_vec = _pack_i64_vec(tmp_ptrs_arr) tid_pack2 = bid_i32 * tnum_gpu_i32 + lane_id stride_pack2 = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 - + tmp_rsrc = _make_rsrc(_extract_i64(tmp_ptrs_vec, warp_id)) loop2 = scf.WhileOp([i32], [tid_pack2]) b2 = ir.Block.create_at_start(loop2.before, [i32]) a2 = ir.Block.create_at_start(loop2.after, [i32]) @@ -639,15 +638,14 @@ def _build_reduce_body(cur, smem_base_expr=None): dst_rank = sum_rw & ea.constant(world_size - 1, type=i32) else: dst_rank = _u(sum_rw) % ea.constant(world_size, type=i32) - tmp_rsrc = ea.select_by_index(warp_id, tmp_rsrcs) - src_off_i32 = cur * ea.constant(16, type=i32) + src_off_i32 = cur * ea.constant(_BYTES_PER_PACK, type=i32) raw = _load_v4i32(tmp_rsrc, src_off_i32) dst_pack = dst_rank * ea.constant(part_p, type=i32) + cur - dst_off_i32 = dst_pack * ea.constant(16, type=i32) + dst_off_i32 = dst_pack * ea.constant(_BYTES_PER_PACK, type=i32) _store_v4i32(out_rsrc, dst_off_i32, raw) scf.YieldOp([cur + stride_pack2]) else: - # Non-vectorized fallback (world_size=6 or num_packs % world_size != 0) + tmp_rsrcs = [_make_rsrc(tmp_ptrs_arr[i]) for i in range(world_size)] tid_i32 = bid_i32 * ea.constant(threads, type=i32) + lane_i32 stride_i32 = gpu.grid_dim.x.ir_value() * ea.constant(threads, type=i32) @@ -667,10 +665,10 @@ def _build_reduce_body(cur, smem_base_expr=None): ok = _u(cur) < ea.constant(part_p, type=i32) ifp = scf.IfOp(ok, results_=[], has_else=False) with ir.InsertionPoint(ifp.then_block): - src_off_i32 = cur * ea.constant(16, type=i32) + src_off_i32 = cur * ea.constant(_BYTES_PER_PACK, type=i32) raw = _load_v4i32(tmp_rsrcs[p], src_off_i32) dst_pack_idx = ea.constant(p * part_p, type=i32) + cur - dst_off_i32 = dst_pack_idx * ea.constant(16, type=i32) + dst_off_i32 = dst_pack_idx * ea.constant(_BYTES_PER_PACK, type=i32) _store_v4i32(out_rsrc, dst_off_i32, raw) scf.YieldOp([]) scf.YieldOp([cur + stride_i32]) @@ -688,11 +686,7 @@ def allreduce_2stage_write_mode( out_ptrs: Int64, tmp_ptrs: Int64, ): - import math - - i32, i64 = T.i32, T.i64 - idx = ir.IndexType.get() v4i32 = T.i32x4 if is_f32: v4f32 = T.f32x4 @@ -714,8 +708,8 @@ def allreduce_2stage_write_mode( sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] out_ptrs_arr = [_load_device_ptr(out_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] tmp_ptrs_arr = [_load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] - tmp_rsrcs = [_make_rsrc(tmp_ptrs_arr[i]) for i in range(world_size)] - out_rsrcs = [_make_rsrc(out_ptrs_arr[i]) for i in range(world_size)] + tmp_ptrs_vec = _pack_i64_vec(tmp_ptrs_arr) + out_ptrs_vec = _pack_i64_vec(out_ptrs_arr) tnum_gpu_i32 = ea.constant(tnum_gpu, type=i32) log2_tnum = int(math.log2(tnum_gpu)) @@ -729,12 +723,12 @@ def allreduce_2stage_write_mode( n_smem_wm = 2 * threads allocator_wm = SmemAllocator(None, global_sym_name=smem_sym_wm) smem_wm_off = allocator_wm._align(allocator_wm.ptr, 16) - allocator_wm.ptr = smem_wm_off + n_smem_wm * 16 + allocator_wm.ptr = smem_wm_off + n_smem_wm * _BYTES_PER_PACK with ir.InsertionPoint.at_block_begin(gpu_func_op.operation.block): allocator_wm.finalize() smem_ptr = SmemPtr(allocator_wm.get_base(), smem_wm_off, v4i32, shape=(n_smem_wm,)) smem_ptr.get() - tmp_out_i64 = ea.select_by_index(rank_i32, tmp_ptrs_arr) + tmp_out_i64 = _extract_i64(tmp_ptrs_vec, rank_i32) # ---- Stage 1: scatter local input to REMOTE tmp buffers ---- inp_rsrc = _make_rsrc(inp_ptr_i64) @@ -748,6 +742,13 @@ def allreduce_2stage_write_mode( scf.YieldOp([start_w + ea.constant(part_p, type=i32)]) end_w = end_w_if.results[0] + dst_tmp = _extract_i64(tmp_ptrs_vec, warp_id) + is_tmp_null = dst_tmp == ea.constant(0, type=i64) + dst_tmp_low4 = dst_tmp & ea.constant(0xF, type=i64) + is_tmp_misaligned = dst_tmp_low4 != ea.constant(0, type=i64) + bad_tmp_addr = is_tmp_null | is_tmp_misaligned + dst_tmp_rsrc = _make_rsrc(dst_tmp) + idx_s1 = start_w + tid_pack loop_s1 = scf.WhileOp([i32, i32], [idx_s1, stride_pack]) bs1 = ir.Block.create_at_start(loop_s1.before, [i32, i32]) @@ -759,21 +760,15 @@ def allreduce_2stage_write_mode( with ir.InsertionPoint(as1): cur = as1.arguments[0] stride_s1 = as1.arguments[1] - cur_off_i32 = cur * ea.constant(16, type=i32) + cur_off_i32 = cur * ea.constant(_BYTES_PER_PACK, type=i32) raw = _load_v4i32(inp_rsrc, cur_off_i32) rel_idx = cur - start_w dst_off = rank_i32 * ea.constant(part_p, type=i32) + rel_idx - dst_tmp = ea.select_by_index(warp_id, tmp_ptrs_arr) - is_tmp_null = dst_tmp == ea.constant(0, type=i64) - dst_tmp_low4 = dst_tmp & ea.constant(0xF, type=i64) - is_tmp_misaligned = dst_tmp_low4 != ea.constant(0, type=i64) - bad_tmp_addr = is_tmp_null | is_tmp_misaligned if_tmp_ok = scf.IfOp(bad_tmp_addr, results_=[], has_else=True) with ir.InsertionPoint(if_tmp_ok.then_block): scf.YieldOp([]) with ir.InsertionPoint(if_tmp_ok.else_block): - dst_tmp_rsrc = ea.select_by_index(warp_id, tmp_rsrcs) - dst_off_i32 = dst_off * ea.constant(16, type=i32) + dst_off_i32 = dst_off * ea.constant(_BYTES_PER_PACK, type=i32) _store_v4i32(dst_tmp_rsrc, dst_off_i32, raw) scf.YieldOp([]) scf.YieldOp([cur + stride_s1, stride_s1]) @@ -783,7 +778,7 @@ def allreduce_2stage_write_mode( self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) # ---- Stage 2: reduce local tmp and write to REMOTE outputs ---- - tmp_out_rsrc = ea.select_by_index(rank_i32, tmp_rsrcs) + tmp_out_rsrc = _make_rsrc(tmp_out_i64) part_p_i32 = ea.constant(part_p, type=i32) is_last_rank_s2 = rank_i32 == ea.constant(world_size - 1, type=i32) end_s2_if = scf.IfOp(is_last_rank_s2, results_=[i32], has_else=True) @@ -792,6 +787,19 @@ def allreduce_2stage_write_mode( with ir.InsertionPoint(end_s2_if.else_block): scf.YieldOp([part_p_i32]) end_s2 = end_s2_if.results[0] + + is_tmpout_null = tmp_out_i64 == ea.constant(0, type=i64) + tmpout_low4 = tmp_out_i64 & ea.constant(0xF, type=i64) + is_load_misaligned = tmpout_low4 != ea.constant(0, type=i64) + bad_load_addr = is_tmpout_null | is_load_misaligned + + dst_ptr = _extract_i64(out_ptrs_vec, warp_id) + dst_out_rsrc = _make_rsrc(dst_ptr) + is_out_null = dst_ptr == ea.constant(0, type=i64) + dst_ptr_low4 = dst_ptr & ea.constant(0xF, type=i64) + is_out_misaligned = dst_ptr_low4 != ea.constant(0, type=i64) + bad_out_addr = is_out_null | is_out_misaligned + loop_s2 = scf.WhileOp([i32, i32], [tid_pack, stride_pack]) bs2 = ir.Block.create_at_start(loop_s2.before, [i32, i32]) as2 = ir.Block.create_at_start(loop_s2.after, [i32, i32]) @@ -805,11 +813,7 @@ def allreduce_2stage_write_mode( # All warps load their chunk from tmp into smem src_off = warp_id * ea.constant(part_p, type=i32) + cur - src_off_i32 = src_off * ea.constant(16, type=i32) - is_tmpout_null = tmp_out_i64 == ea.constant(0, type=i64) - tmpout_low4 = tmp_out_i64 & ea.constant(0xF, type=i64) - is_load_misaligned = tmpout_low4 != ea.constant(0, type=i64) - bad_load_addr = is_tmpout_null | is_load_misaligned + src_off_i32 = src_off * ea.constant(_BYTES_PER_PACK, type=i32) raw_if = scf.IfOp(bad_load_addr, results_=[v4i32], has_else=True) with ir.InsertionPoint(raw_if.then_block): scf.YieldOp([ea.constant_vector(0, v4i32)]) @@ -817,7 +821,7 @@ def allreduce_2stage_write_mode( scf.YieldOp([_load_v4i32(tmp_out_rsrc, src_off_i32)]) raw = raw_if.results[0] - sm_idx = ea.index_cast(idx, lane_i32) + sm_idx = ea.index_cast(T.index, lane_i32) smem_ptr.store(raw, [sm_idx]) gpu.barrier() @@ -832,7 +836,7 @@ def allreduce_2stage_write_mode( acc = None for wi in range_constexpr(world_size): sm_i_idx = ea.index_cast( - idx, ea.constant(wi * tnum_gpu, type=i32) + lane_id) + T.index, ea.constant(wi * tnum_gpu, type=i32) + lane_id) raw_i = smem_ptr.load([sm_i_idx]) if is_f32: vf = raw_i.bitcast(v4f32) @@ -845,7 +849,7 @@ def allreduce_2stage_write_mode( out_raw = acc.bitcast(v4i32) else: out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) - res_idx = ea.index_cast(idx, ea.constant(threads, type=i32) + lane_id) + res_idx = ea.index_cast(T.index, ea.constant(threads, type=i32) + lane_id) smem_ptr.store(out_raw, [res_idx]) scf.YieldOp([]) @@ -853,18 +857,12 @@ def allreduce_2stage_write_mode( # All warps read the same reduced result from res area and # nontemporal-write to their respective remote output buffers. - res_read_idx = ea.index_cast(idx, ea.constant(threads, type=i32) + lane_id) + res_read_idx = ea.index_cast(T.index, ea.constant(threads, type=i32) + lane_id) reduced_val = smem_ptr.load([res_read_idx]) dst_out_off = rank_i32 * ea.constant(part_p, type=i32) + cur - dst_off_i32 = dst_out_off * ea.constant(16, type=i32) - - dst_ptr = ea.select_by_index(warp_id, out_ptrs_arr) - dst_out_rsrc = ea.select_by_index(warp_id, out_rsrcs) - is_out_null = dst_ptr == ea.constant(0, type=i64) - dst_ptr_low4 = dst_ptr & ea.constant(0xF, type=i64) - is_out_misaligned = dst_ptr_low4 != ea.constant(0, type=i64) - bad_out_addr = is_out_null | is_out_misaligned + dst_off_i32 = dst_out_off * ea.constant(_BYTES_PER_PACK, type=i32) + if_out_ok = scf.IfOp(bad_out_addr, results_=[], has_else=True) with ir.InsertionPoint(if_out_ok.then_block): scf.YieldOp([]) @@ -874,6 +872,7 @@ def allreduce_2stage_write_mode( scf.YieldOp([cur + stride_s2, stride_s2]) + gpu.barrier() _signal_end_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) diff --git a/python/flydsl/expr/arith.py b/python/flydsl/expr/arith.py index 459ccbdd..38403f6c 100644 --- a/python/flydsl/expr/arith.py +++ b/python/flydsl/expr/arith.py @@ -64,24 +64,4 @@ def cmpf(predicate, lhs, rhs, **kwargs): return _mlir_arith.cmpf(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs) -@traced_op -def select_by_index(index_val, values): - """Select one of *values* by integer *index_val* via chained ``arith.select``. - - Equivalent to a compile-time switch: returns ``values[index_val]``. - - Args: - index_val: Integer index (i32 ``ir.Value``). - values: List of ``ir.Value`` to select from. - - Returns: - The selected ``ir.Value``. - """ - out = values[0] - for i in range(1, len(values)): - pred = _mlir_arith.CmpIOp( - _mlir_arith.CmpIPredicate.eq, index_val, constant(i, type=index_val.type) - ).result - out = _mlir_arith.SelectOp(pred, values[i], out).result - return out diff --git a/python/flydsl/expr/vector.py b/python/flydsl/expr/vector.py index 002adac9..f59eaa26 100644 --- a/python/flydsl/expr/vector.py +++ b/python/flydsl/expr/vector.py @@ -55,9 +55,24 @@ def store(value, memref, indices, *, loc=None, ip=None, **kwargs): ) +# ----------------------------------------------------------------------------- +# Thin wrappers for common op classes that otherwise require `.result` access. +# ----------------------------------------------------------------------------- + + +_KDYNAMIC = -0x8000_0000_0000_0000 # ShapedType::kDynamic sentinel + + @traced_op def extract(vector, static_position=None, dynamic_position=None, *, loc=None, ip=None): - """Wrapper around `vector.ExtractOp(...).result`.""" + """Wrapper around `vector.ExtractOp(...).result`. + + When only ``dynamic_position`` is supplied (without explicit + ``static_position``), each dynamic index needs a corresponding + ``kDynamic`` sentinel in the static attribute so the ODS builder + pairs them correctly. This wrapper fills in the sentinels + automatically. + """ from . import arith as _arith_ext if static_position is None: @@ -65,6 +80,12 @@ def extract(vector, static_position=None, dynamic_position=None, *, loc=None, ip if dynamic_position is None: dynamic_position = [] dynamic_position = [_arith_ext.unwrap(i, index=True, loc=loc) for i in dynamic_position] + + n_static = len(static_position) + n_dynamic = len(dynamic_position) + if n_dynamic > 0 and n_static < n_dynamic: + static_position = list(static_position) + [_KDYNAMIC] * (n_dynamic - n_static) + return _vector.ExtractOp( _arith_ext.unwrap(vector, loc=loc), static_position=static_position, From 8285777c2b59d48b303af5391e818d3d1549860b Mon Sep 17 00:00:00 2001 From: "yashao@amd.com" Date: Mon, 13 Apr 2026 08:30:20 +0000 Subject: [PATCH 5/5] add CI benchmark of allreduce --- .github/workflows/flydsl.yaml | 54 ++++++++- README.md | 2 +- docs/prebuilt_kernels_guide.md | 2 +- kernels/custom_all_reduce_kernel.py | 12 +- python/flydsl/_version.py | 2 +- python/flydsl/expr/buffer_ops.py | 3 +- tests/arch_compat.py | 2 +- tests/kernels/compare_allreduce_benchmark.py | 90 ++++++++++++++ ..._flydsl_allreduce.py => test_allreduce.py} | 112 +++++++++++++++--- tests/pytest.ini | 1 + 10 files changed, 251 insertions(+), 29 deletions(-) create mode 100644 tests/kernels/compare_allreduce_benchmark.py rename tests/kernels/{test_flydsl_allreduce.py => test_allreduce.py} (89%) diff --git a/.github/workflows/flydsl.yaml b/.github/workflows/flydsl.yaml index 4ea8c02d..56cb06a2 100644 --- a/.github/workflows/flydsl.yaml +++ b/.github/workflows/flydsl.yaml @@ -186,6 +186,7 @@ jobs: multi-gpu: needs: test name: Multi-GPU AllReduce Tests (${{ matrix.runners }}) + timeout-minutes: 120 strategy: matrix: runners: [ @@ -262,13 +263,64 @@ jobs: docker exec flydsl_test bash -c "export MLIR_PATH=/llvm-project/mlir_install && cd /flydsl-test && python3 -m pip install -e . --use-pep517" - name: Run multi-GPU allreduce tests + timeout-minutes: 30 run: | docker exec flydsl_test bash -c " cd /flydsl-test - python3 -m pytest tests/kernels/test_flydsl_allreduce.py \ + python3 -m pytest tests/kernels/test_allreduce.py \ -m multi_gpu -v --no-header --tb=short " + - name: Run allreduce benchmark (PR) + timeout-minutes: 30 + run: | + docker exec flydsl_test bash -c " + cd /flydsl-test + python3 tests/kernels/test_allreduce.py \ + --world_size 8 --iters 51 --warmup 5 \ + --allreduce_impl flydsl --mode cudagraph \ + --shapes '2,7168,fp16;32,8192,fp32;128,8192,fp16;1024,7168,bf16;4096,8192,bf16' \ + --output_csv /tmp/bench_pr.csv + " + + - name: Build main branch baseline + timeout-minutes: 20 + run: | + docker exec flydsl_test bash -c " + cd /flydsl-test + git fetch origin main --depth=1 + git worktree add /tmp/flydsl-main origin/main + cd /tmp/flydsl-main + export MLIR_PATH=/llvm-project/mlir_install + python3 -m pip install -e . --use-pep517 2>&1 | tail -5 + " + + - name: Run allreduce benchmark (main) + id: bench-main + timeout-minutes: 30 + continue-on-error: true + run: | + docker exec flydsl_test bash -c " + cp /flydsl-test/tests/kernels/test_allreduce.py \ + /tmp/flydsl-main/tests/kernels/test_allreduce.py + cd /tmp/flydsl-main + python3 tests/kernels/test_allreduce.py \ + --world_size 8 --iters 51 --warmup 5 \ + --allreduce_impl flydsl --mode cudagraph \ + --shapes '2,7168,fp16;32,8192,fp32;128,8192,fp16;1024,7168,bf16;4096,8192,bf16' \ + --output_csv /tmp/bench_main.csv + " + + - name: Check performance regression (PR vs main) + if: steps.bench-main.outcome == 'success' + timeout-minutes: 5 + run: | + docker exec flydsl_test bash -c " + cd /flydsl-test + python3 tests/kernels/compare_allreduce_benchmark.py \ + /tmp/bench_main.csv /tmp/bench_pr.csv + " + - name: Show test logs if: failure() run: | diff --git a/README.md b/README.md index d040c81e..b6ddda36 100644 --- a/README.md +++ b/README.md @@ -363,7 +363,7 @@ See `examples/` for more examples including tiled copy (`02-tiledCopy.py`), tile | **RMSNorm** | `test_rmsnorm.py` | RMSNorm (layout API) | | **Softmax** | `test_softmax.py` | Softmax (layout API) | | **Fused RoPE** | `test_fused_rope_cache.py` | Fused RoPE + KV cache | -| **AllReduce** | `test_flydsl_allreduce.py` | Multi-GPU all-reduce | +| **AllReduce** | `test_allreduce.py` | Multi-GPU all-reduce | | **RDNA GEMM** | `test_rdna_gemm.py` | RDNA FP16/FP8 GEMM | | **GFX1250 GEMM** | `test_gemm_fp8fp4_gfx1250.py` | GFX1250 FP8/FP4 GEMM | | **WMMA GEMM** | `test_wmma_gemm_gfx1250.py` | GFX1250 WMMA GEMM | diff --git a/docs/prebuilt_kernels_guide.md b/docs/prebuilt_kernels_guide.md index 4d3745b5..018b122f 100644 --- a/docs/prebuilt_kernels_guide.md +++ b/docs/prebuilt_kernels_guide.md @@ -338,7 +338,7 @@ What operation do you need? | `tests/kernels/test_rmsnorm.py` | RMSNorm | | `tests/kernels/test_softmax.py` | Softmax | | `tests/kernels/test_fused_rope_cache.py` | Fused RoPE + KV cache | -| `tests/kernels/test_flydsl_allreduce.py` | Multi-GPU all-reduce | +| `tests/kernels/test_allreduce.py` | Multi-GPU all-reduce | | `tests/kernels/test_rdna_gemm.py` | RDNA GEMM | | `tests/kernels/test_gemm_fp8fp4_gfx1250.py` | GFX1250 FP8/FP4 GEMM | | `tests/kernels/test_wmma_gemm_gfx1250.py` | GFX1250 WMMA GEMM | diff --git a/kernels/custom_all_reduce_kernel.py b/kernels/custom_all_reduce_kernel.py index 569ac1bb..ec03e1a9 100644 --- a/kernels/custom_all_reduce_kernel.py +++ b/kernels/custom_all_reduce_kernel.py @@ -448,14 +448,14 @@ def allreduce_1stage_arr( T.index, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + sm_base) raw_i = smem_ptr.load([sm_i_idx]) if is_f32: - vf = raw_i.bitcast(v4f32) + vf = ev.bitcast(v4f32, raw_i) acc = vf if acc is None else acc + vf else: v16 = ev.bitcast(v8half, raw_i) v32 = v16.extf(v8f32) acc = v32 if acc is None else acc + v32 if is_f32: - out_bits = acc.bitcast(v4i32) + out_bits = ev.bitcast(v4i32, acc) else: out_bits = ev.bitcast(v4i32, acc.truncf(v8half)) dst_off_i32 = p * ea.constant(_BYTES_PER_PACK, type=i32) @@ -562,14 +562,14 @@ def _build_reduce_body(cur, smem_base_expr=None): sm_r_idx = ea.index_cast(T.index, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + smem_base_expr) raw_i = smem_ptr.load([sm_r_idx]) if is_f32: - vf = raw_i.bitcast(v4f32) + vf = ev.bitcast(v4f32, raw_i) acc = vf if acc is None else acc + vf else: v16 = ev.bitcast(v8half, raw_i) v32 = v16.extf(v8f32) acc = v32 if acc is None else acc + v32 if is_f32: - out_raw = acc.bitcast(v4i32) + out_raw = ev.bitcast(v4i32, acc) else: out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) rel_p = cur - start_p @@ -839,14 +839,14 @@ def allreduce_2stage_write_mode( T.index, ea.constant(wi * tnum_gpu, type=i32) + lane_id) raw_i = smem_ptr.load([sm_i_idx]) if is_f32: - vf = raw_i.bitcast(v4f32) + vf = ev.bitcast(v4f32, raw_i) acc = vf if acc is None else acc + vf else: v16 = ev.bitcast(v8half, raw_i) v32 = v16.extf(v8f32) acc = v32 if acc is None else acc + v32 if is_f32: - out_raw = acc.bitcast(v4i32) + out_raw = ev.bitcast(v4i32, acc) else: out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) res_idx = ea.index_cast(T.index, ea.constant(threads, type=i32) + lane_id) diff --git a/python/flydsl/_version.py b/python/flydsl/_version.py index b1aa5330..f097a31c 100644 --- a/python/flydsl/_version.py +++ b/python/flydsl/_version.py @@ -1 +1 @@ -__version__ = "0.1.3.1" +__version__ = "0.1.3.1.dev479" diff --git a/python/flydsl/expr/buffer_ops.py b/python/flydsl/expr/buffer_ops.py index 0f76d02f..5c68870d 100644 --- a/python/flydsl/expr/buffer_ops.py +++ b/python/flydsl/expr/buffer_ops.py @@ -437,7 +437,8 @@ def buffer_load(rsrc: ir.Value, # RawPtrBufferLoadOp offset is in BYTES. By default we accept element # offsets and scale to bytes; set offset_is_bytes=True to skip scaling. if not offset_is_bytes: - element_bytes = dtype.width // 8 + elem_ty = dtype.element_type if hasattr(dtype, 'element_type') else dtype + element_bytes = elem_ty.width // 8 bytes_const = _create_i32_constant(element_bytes) op = std_arith.MulIOp(offset, bytes_const) offset = _unwrap_value(op.result) diff --git a/tests/arch_compat.py b/tests/arch_compat.py index a436896e..fb6bab52 100644 --- a/tests/arch_compat.py +++ b/tests/arch_compat.py @@ -16,7 +16,7 @@ "test_moe_reduce.py", "test_pa.py", "test_quant.py", - "test_flydsl_allreduce.py", # custom_all_reduce requires CDNA (gfx9xx) + "test_allreduce.py", # custom_all_reduce requires CDNA (gfx9xx) }) # Example scripts verified to work on RDNA (non-CDNA) GPUs. diff --git a/tests/kernels/compare_allreduce_benchmark.py b/tests/kernels/compare_allreduce_benchmark.py new file mode 100644 index 00000000..2b0942cc --- /dev/null +++ b/tests/kernels/compare_allreduce_benchmark.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +"""Compare two allreduce benchmark CSVs (main vs PR) and flag regressions. + +Usage: + python3 compare_benchmark.py + +Exit code 1 if any case regresses more than BOTH thresholds: + - relative increase > MAX_REGRESSION_PCT (default 10%) + - absolute increase > MIN_ABS_REGRESSION_US (default 5 us) +""" +import sys +import pandas as pd + +MAX_REGRESSION_PCT = 10.0 +MIN_ABS_REGRESSION_US = 5.0 + + +def main(): + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ") + sys.exit(2) + + main_csv, pr_csv = sys.argv[1], sys.argv[2] + + main_df = pd.read_csv(main_csv) + pr_df = pd.read_csv(pr_csv) + + main_agg = main_df[main_df["rank"] == "aggregate"].copy() + pr_agg = pr_df[pr_df["rank"] == "aggregate"].copy() + + # Detect cases that failed/skipped in PR but succeeded on main + pr_agg_indexed = pr_agg.set_index(["shape", "dtype"]) + main_agg_indexed = main_agg.set_index(["shape", "dtype"]) + + pr_broken = pr_agg_indexed[ + (pr_agg_indexed["avg_time_us"] <= 0) | pr_agg_indexed["avg_time_us"].isna() + ] + main_ok = main_agg_indexed[ + (main_agg_indexed["avg_time_us"] > 0) & main_agg_indexed["avg_time_us"].notna() + ] + newly_broken = pr_broken.index.intersection(main_ok.index) + + # Performance comparison for cases that both sides ran successfully + pr_valid = pr_agg_indexed[["avg_time_us"]].loc[ + (pr_agg_indexed["avg_time_us"] > 0) & pr_agg_indexed["avg_time_us"].notna() + ] + main_valid = main_agg_indexed[["avg_time_us"]].loc[ + (main_agg_indexed["avg_time_us"] > 0) & main_agg_indexed["avg_time_us"].notna() + ] + merged = pr_valid.join(main_valid, lsuffix="_pr", rsuffix="_main").dropna() + + fail_count = 0 + + if not merged.empty: + merged["delta_us"] = merged["avg_time_us_pr"] - merged["avg_time_us_main"] + merged["delta_pct"] = (merged["delta_us"] / merged["avg_time_us_main"]) * 100.0 + + print("=== Allreduce Benchmark: PR vs main ===") + for (shape, dtype), row in merged.iterrows(): + regressed = ( + row["delta_pct"] > MAX_REGRESSION_PCT + and row["delta_us"] > MIN_ABS_REGRESSION_US + ) + tag = "REGRESSION" if regressed else "OK" + if regressed: + fail_count += 1 + print( + f" {shape:>20s} {dtype:>4s} " + f"main={row['avg_time_us_main']:8.2f} us " + f"PR={row['avg_time_us_pr']:8.2f} us " + f"delta={row['delta_us']:+8.2f} us ({row['delta_pct']:+5.1f}%) " + f"[{tag}]" + ) + + if len(newly_broken) > 0: + print("\n=== Cases BROKEN in PR (work on main but fail on PR) ===") + for shape, dtype in newly_broken: + fail_count += 1 + err = pr_agg_indexed.loc[(shape, dtype)].get("error", "unknown") + print(f" {shape:>20s} {dtype:>4s} [BROKEN] error: {err}") + + if fail_count > 0: + print(f"\nFAILED: {fail_count} issue(s) detected.") + sys.exit(1) + else: + print("\nPASSED: No regression or breakage detected.") + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/test_flydsl_allreduce.py b/tests/kernels/test_allreduce.py similarity index 89% rename from tests/kernels/test_flydsl_allreduce.py rename to tests/kernels/test_allreduce.py index c8b12cba..eb3d9269 100644 --- a/tests/kernels/test_flydsl_allreduce.py +++ b/tests/kernels/test_allreduce.py @@ -257,6 +257,13 @@ def _dist_worker( mode: "eager" or "cudagraph" - which path to run (separate flows) result_dict: Shared dictionary to collect results from all ranks """ + import warnings + warnings.filterwarnings("ignore") + _devnull_fd = os.open(os.devnull, os.O_WRONLY) + os.dup2(_devnull_fd, 1) + os.dup2(_devnull_fd, 2) + os.close(_devnull_fd) + torch.cuda.set_device(rank) device = torch.device(f"cuda:{rank}") @@ -305,15 +312,6 @@ def _dist_worker( out = torch.empty_like(x_flat) fa = init_custom_ar(meta, rank_data, handles, offsets, rank=rank, full_nvlink=True, out=out) - if rank == 0: - fa_mod = getattr(getattr(fa, "__class__", None), "__module__", None) - fa_name = getattr(getattr(fa, "__class__", None), "__name__", None) - print( - f"[custom_all_reduce] backend=aiter " - f"allreduce_impl={allreduce_impl!r} fa={fa_mod}.{fa_name}", - flush=True, - ) - # Warmup: align all ranks dist.all_reduce(torch.zeros(1, device=device), group=group) torch.cuda.synchronize() @@ -405,7 +403,7 @@ def _run_eager(): elif mode == "cudagraph": if not hasattr(fa, "capture"): if rank == 0: - print("[test_flydsl_allreduce] WARN: fa has no capture(); skipping cudagraph.", flush=True) + print("[test_allreduce] WARN: fa has no capture(); skipping cudagraph.", flush=True) result_dict[rank] = { "rank": rank, "shape": shape, "dtype": dtype_str, "world_size": world_size, "mode": "cudagraph", "max_error": float("nan"), "avg_time_us": 0.0, @@ -638,6 +636,7 @@ def run_all_tests( print(f" Avg time: mean={mean_avg_time:.3f} us/iter, min={min_avg_time:.3f}, max={max_avg_time:.3f}") # Add aggregate row + rank0 = rank_results[0] if rank_results else {} aggregate_result = { "rank": "aggregate", "shape": str(shape), @@ -649,9 +648,10 @@ def run_all_tests( "min_avg_time_us": min_avg_time, "max_avg_time_us": max_avg_time, "device_time_sum_us": sum(r["device_time_sum_us"] for r in rank_results), - "kernel_name": rank_results[0]["kernel_name"] if rank_results else "unknown", + "kernel_name": rank0.get("kernel_name", "unknown"), "num_iters": num_iters, "num_warmup": num_warmup, + "error": rank0.get("error"), } all_results.append(aggregate_result) @@ -678,6 +678,19 @@ def run_all_tests( if not aggregate_df.empty: print(aggregate_df.to_string(index=False)) print("=" * 80) + + failed = [ + r for r in all_results + if r.get("rank") == "aggregate" + and (r.get("kernel_name") in ("skip", "error") or r.get("error")) + ] + if failed: + print("\n✗ FAILED cases:") + for r in failed: + reason = r.get("error") or r.get("kernel_name", "unknown") + print(f" {r['shape']} {r['dtype']} {r['mode']} → {reason}") + sys.exit(1) + return df else: print("\nNo results to save.") @@ -731,7 +744,22 @@ def _count_physical_gpus() -> int: ] -def _run_subprocess_test(*, world_size, shape, dtype_str, mode): +# 8-GPU benchmark configurations: cover all 3 kernel paths × 3 dtypes, cudagraph mode. +# small (2×7168) → 1-stage kernel +# medium (128×8192) → 2-stage kernel +# large (1024×8192) → write-mode kernel +_BENCHMARK_PARAMS = [ + # (shape, dtype_str, mode) + ((2, 7168), "fp16", "cudagraph"), + ((32, 8192), "fp32", "cudagraph"), + ((128, 8192), "fp16", "cudagraph"), + ((1024, 7168), "bf16", "cudagraph"), + ((4096, 8192), "bf16", "cudagraph") +] + + +def _run_subprocess(*, world_size, shape, dtype_str, mode, iters=10, warmup=2, + output_csv=None, timeout=600): """Launch the allreduce harness in a subprocess and assert success.""" import subprocess as _sp @@ -741,23 +769,55 @@ def _run_subprocess_test(*, world_size, shape, dtype_str, mode): cmd = [ sys.executable, __file__, "--world_size", str(world_size), - "--iters", "10", - "--warmup", "2", + "--iters", str(iters), + "--warmup", str(warmup), "--shapes", shape_str, "--mode", mode, "--allreduce_impl", "flydsl", ] - result = _sp.run(cmd, env=env, timeout=600, capture_output=True, text=True) + if output_csv: + cmd += ["--output_csv", output_csv] + result = _sp.run(cmd, env=env, timeout=timeout, capture_output=True, text=True) assert result.returncode == 0, ( f"{world_size}-GPU allreduce FAILED: shape={shape}, dtype={dtype_str}, " f"mode={mode} (exit code {result.returncode})\n" f"stdout (last 2000 chars):\n{result.stdout[-2000:]}\n" f"stderr (last 2000 chars):\n{result.stderr[-2000:]}" ) + return result + + +def _run_subprocess_test(*, world_size, shape, dtype_str, mode): + """Launch the allreduce accuracy test in a subprocess.""" + _run_subprocess(world_size=world_size, shape=shape, dtype_str=dtype_str, mode=mode) + + +def _run_subprocess_benchmark(*, world_size, shape, dtype_str, mode): + """Launch the allreduce benchmark in a subprocess with more iterations. + + Returns the CSV output path for downstream baseline comparison. + """ + shape_tag = "x".join(str(d) for d in shape) + csv_path = f"/tmp/allreduce_bench_{shape_tag}_{dtype_str}_{mode}.csv" + result = _run_subprocess( + world_size=world_size, shape=shape, dtype_str=dtype_str, mode=mode, + iters=51, warmup=5, output_csv=csv_path, timeout=900, + ) + if result.stdout: + for line in result.stdout.splitlines(): + if "avg_time" in line.lower() or "max_error" in line.lower() or "aggregate" in line.lower(): + print(line) + return csv_path + + +def _param_id(shape, dtype_str, mode): + s = "x".join(str(d) for d in shape) + return f"{s}-{dtype_str}-{mode}" @pytest.mark.multi_gpu -@pytest.mark.parametrize("shape,dtype_str,mode", _8GPU_PARAMS) +@pytest.mark.parametrize("shape,dtype_str,mode", _8GPU_PARAMS, + ids=[_param_id(*p) for p in _8GPU_PARAMS]) def test_allreduce_8gpu_accuracy(shape, dtype_str, mode): """8-GPU allreduce accuracy test. @@ -774,7 +834,25 @@ def test_allreduce_8gpu_accuracy(shape, dtype_str, mode): @pytest.mark.multi_gpu -@pytest.mark.parametrize("shape,dtype_str,mode", _4GPU_PARAMS) +@pytest.mark.benchmark +@pytest.mark.parametrize("shape,dtype_str,mode", _BENCHMARK_PARAMS, + ids=[_param_id(*p) for p in _BENCHMARK_PARAMS]) +def test_allreduce_8gpu_benchmark(shape, dtype_str, mode): + """8-GPU allreduce benchmark test. + + Uses 51 iters / 5 warmup to get stable timing data. + Performance regression is checked at the CI workflow level by comparing + this PR's results against the main branch (run separately). + """ + phys_ng = _count_physical_gpus() + if phys_ng < 8: + pytest.skip(f"Requires >= 8 physical GPUs, found {phys_ng}.") + _run_subprocess_benchmark(world_size=8, shape=shape, dtype_str=dtype_str, mode=mode) + + +@pytest.mark.multi_gpu +@pytest.mark.parametrize("shape,dtype_str,mode", _4GPU_PARAMS, + ids=[_param_id(*p) for p in _4GPU_PARAMS]) def test_allreduce_4gpu_accuracy(shape, dtype_str, mode): """4-GPU allreduce accuracy test (covers fp32 and world_size=4).""" phys_ng = _count_physical_gpus() diff --git a/tests/pytest.ini b/tests/pytest.ini index 0d00c23f..af4f0e41 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -12,3 +12,4 @@ markers = l2_device: requires GPU and full runtime stack for execution/correctness rocm_lower: L1b/L2 tests that assume the ROCDL lowering path multi_gpu: marks tests that require multi-GPU (>=4 physical GPUs; skipped automatically when unavailable) + benchmark: marks performance benchmark tests (longer runtime, more iterations)