From 7ef3cb9cce61ff1331836f159ec47e3f4981bd40 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Tue, 18 Nov 2025 19:19:52 +0800 Subject: [PATCH 01/17] test mla and nsa performance on modal --- docker/flash_mla/__init__.py | 14 + docker/flash_mla/flash_mla_interface.py | 333 ++++++++++++ docker/flash_mla/txl_mla_interface.py | 645 ++++++++++++++++++++++++ docker/flash_mla/txl_nsa_interface.py | 561 +++++++++++++++++++++ docker/lib.py | 73 +++ docker/test_flash_mla_decoding.py | 387 ++++++++++++++ docker/test_flash_mla_prefill.py | 293 +++++++++++ 7 files changed, 2306 insertions(+) create mode 100644 docker/flash_mla/__init__.py create mode 100644 docker/flash_mla/flash_mla_interface.py create mode 100644 docker/flash_mla/txl_mla_interface.py create mode 100644 docker/flash_mla/txl_nsa_interface.py create mode 100644 docker/lib.py create mode 100644 docker/test_flash_mla_decoding.py create mode 100644 docker/test_flash_mla_prefill.py diff --git a/docker/flash_mla/__init__.py b/docker/flash_mla/__init__.py new file mode 100644 index 0000000..84dc241 --- /dev/null +++ b/docker/flash_mla/__init__.py @@ -0,0 +1,14 @@ +__version__ = "1.0.0" + +from flash_mla.flash_mla_interface import ( + get_mla_metadata, + flash_mla_with_kvcache, + flash_attn_varlen_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_mla_sparse_fwd +) + +from flash_mla.txl_nsa_interface import txl_mla + +from flash_mla.txl_mla_interface import mla_test diff --git a/docker/flash_mla/flash_mla_interface.py b/docker/flash_mla/flash_mla_interface.py new file mode 100644 index 0000000..4d27621 --- /dev/null +++ b/docker/flash_mla/flash_mla_interface.py @@ -0,0 +1,333 @@ +from typing import Optional, Tuple + +import torch + +import flash_mla.cuda as flash_mla_cuda + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_q_tokens_per_head_k: int, + num_heads_k: int, + num_heads_q: Optional[int] = None, + is_fp8_kvcache: bool = False, + topk: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. + num_heads_k: The number of k heads. + num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled + is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. + topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. + + Returns: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return flash_mla_cuda.get_mla_decoding_metadata(cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head dimension of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. + softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md + indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. + + Returns: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if indices is not None: + assert causal == False, "causal must be `false` if sparse attention is enabled." + out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( + q, + k_cache, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + is_fp8_kvcache, + indices + ) + return out, softmax_lse + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + + Returns: + (output, max_logits, lse) + About the definition of output, max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/docker/flash_mla/txl_mla_interface.py b/docker/flash_mla/txl_mla_interface.py new file mode 100644 index 0000000..1f6e6d2 --- /dev/null +++ b/docker/flash_mla/txl_mla_interface.py @@ -0,0 +1,645 @@ +import triton +import triton.language as tl +import txl +import torch +import os +import sys +import math +from triton.tools.tensor_descriptor import TensorDescriptor +import triton.profiler as proton +from contextlib import contextmanager + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def _maybe_desc(x, shape, strides, block_shape): + if isinstance(x, tl.tensor_descriptor): + return x + else: + return tl.make_tensor_descriptor(x, shape, strides, block_shape) + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + +def supports_host_descriptor(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + +tma_cfg = dict(BLOCK_M=64, BLOCK_N=64, NUM_STAGES=1) + +def _pre_hook(nargs): + BM = nargs["BLOCK_M"] + BN = nargs["BLOCK_N"] + D = nargs["D"] + PE = nargs["PE_DIM"] + if not isinstance(nargs["desc_qhat"], TensorDescriptor): + return + nargs["desc_qhat"].block_shape = [BM, D] + nargs["desc_qpe"].block_shape = [BM, PE] + nargs["desc_zkv"].block_shape = [BN, D//2] + nargs["desc_kpe"].block_shape = [BN, PE] + nargs["desc_o_lat"].block_shape = [BM, D//2] +ws_cuta_cfg = dict(BLOCK_M=64, BLOCK_N=64, NUM_STAGES=1) # 可按需调优 +@txl.autotune( + configs=[txl.Config(ws_cuta_cfg, num_stages=1, num_warps=4, num_warpgroups=3, pre_hook=_pre_hook)], + key=["KV_SEQ_LEN","D","PE_DIM"] +) +@txl.jit +#@txl.jit(src_file='dump/1107mla/ENHKSWNKWUIZUX7DHO4HA2ZRZBQBQLWHT3AD4YLIWFQ5KCCEZSPA/mla_decode_latent_sharedZ_ws_dim2_2K_txl_change.ptx') +def mla_txl( # cutedsl 这里要求KV_SEQ_LEN是BLOCK_N*2的整数倍,因为我这边还没有写奇数倍数的逻辑 + sm_scale, M, + B, H, H_KV, + desc_qhat, # [B,H,N_Q,D] + desc_zkv, # [B,H_KV,KV_SEQ_LEN,D] + desc_o_lat, # [B,H,N_Q,D] + N_Q, KV_SEQ_LEN, + desc_qpe, desc_kpe, # Q_pe: [B,H,N_Q,PE], K_pe: [B,H_KV,KV_SEQ_LEN,PE] + PE_DIM: tl.constexpr, + D: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_STAGES: tl.constexpr, + NUM_SMS: tl.constexpr, +): + tl.static_assert(BLOCK_N <= D) + tl.static_assert(D % 2 == 0) + tl.static_assert(PE_DIM == BLOCK_N) + tl.static_assert(NUM_STAGES == 1) + dtype = tl.float16 + + pid_m = tl.program_id(0) + off_kvh_SMs = tl.program_id(1) # [KV_HEADS, NUM_SMS] + batch_per_cta = B // NUM_SMS + off_kvh = off_kvh_SMs // NUM_SMS + off_SMs = off_kvh_SMs % NUM_SMS + begin_batch = off_SMs * batch_per_cta + end_batch = begin_batch + batch_per_cta + + rows_q = B * H * N_Q + rows_kv = B * H_KV * KV_SEQ_LEN + + desc_qhat = _maybe_desc(desc_qhat, [rows_q, D], [D, 1], [BLOCK_M, D]) + desc_o_lat = _maybe_desc(desc_o_lat, [rows_q, D], [D, 1], [BLOCK_M, D]) + desc_zkv = _maybe_desc(desc_zkv, [rows_kv, D], [D, 1], [BLOCK_N, D]) + desc_qpe = _maybe_desc(desc_qpe, [rows_q, PE_DIM], [PE_DIM, 1], [BLOCK_M, PE_DIM]) + desc_kpe = _maybe_desc(desc_kpe, [rows_kv, PE_DIM], [PE_DIM, 1], [BLOCK_N, PE_DIM]) + + bQ = txl.smem_alloc([BLOCK_M, D], dtype=dtype) + bQpe = txl.smem_alloc([BLOCK_M, PE_DIM], dtype=dtype) + mQ = txl.mbar_alloc(1) + mQpe = txl.mbar_alloc(1) + + bZL0 = txl.smem_alloc([BLOCK_N, D//2], dtype=dtype, num_stages=1) + bZR0 = txl.smem_alloc([BLOCK_N, D//2], dtype=dtype, num_stages=1) + bZL1 = txl.smem_alloc([BLOCK_N, D//2], dtype=dtype, num_stages=1) + bZR1 = txl.smem_alloc([BLOCK_N, D//2], dtype=dtype, num_stages=1) + bKpe0 = txl.smem_alloc([BLOCK_N, PE_DIM], dtype=dtype, num_stages=1) + bKpe1 = txl.smem_alloc([BLOCK_N, PE_DIM], dtype=dtype, num_stages=1) + mZL0 = txl.mbar_alloc(1, num_stages=1) + mZR0 = txl.mbar_alloc(1, num_stages=1) + mZL1 = txl.mbar_alloc(1, num_stages=1) + mZR1 = txl.mbar_alloc(1, num_stages=1) + mKpe0 = txl.mbar_alloc(1, num_stages=1) + mKpe1 = txl.mbar_alloc(1, num_stages=1) + + mQK0 = txl.mbar_alloc(128, num_stages=1) + mQK1 = txl.mbar_alloc(128, num_stages=1) + mPV0_L = txl.mbar_alloc(128, num_stages=1) + mPV0_R = txl.mbar_alloc(128, num_stages=1) + mPV1_L = txl.mbar_alloc(128, num_stages=1) + mPV1_R = txl.mbar_alloc(128, num_stages=1) + mQ0 = txl.mbar_alloc(128, num_stages=1) + mQ1 = txl.mbar_alloc(128, num_stages=1) + + bP0 = txl.smem_alloc([BLOCK_M, BLOCK_N], dtype=dtype, num_stages=1) + bP1 = bQpe # assert BLOCK_N == PE_DIM + + bMax = txl.smem_alloc([BLOCK_M], dtype=tl.float32, num_stages=1) + bL0 = txl.smem_alloc([BLOCK_M], dtype=tl.float32, num_stages=1) + bL1 = txl.smem_alloc([BLOCK_M], dtype=tl.float32, num_stages=1) + + mma_layout: tl.constexpr = txl.NVMMADistributedLayout( + version=[3, 0], + warps_per_cta=[4, 1], + instr_shape=[16, 64, 16], + ) + + max_reg_layout: tl.constexpr = txl.SliceLayout( + dim=1, + parent=mma_layout + ) + + A_op_layout: tl.constexpr = txl.DotOperandLayout(operand_index=0, parent=mma_layout, k_width=2) + + cur_bQ = txl.get_buffer(bQ, 0) + cur_mQ = txl.get_buffer(mQ, 0) + cur_bQpe = txl.get_buffer(bQpe, 0) + cur_mQpe = txl.get_buffer(mQpe, 0) + + cur_mZL0 = txl.get_buffer(mZL0, 0) + cur_mZR0 = txl.get_buffer(mZR0, 0) + cur_bZL0 = txl.get_buffer(bZL0, 0) + cur_bZR0 = txl.get_buffer(bZR0, 0) + cur_bKpe0 = txl.get_buffer(bKpe0, 0) + cur_mKpe0 = txl.get_buffer(mKpe0, 0) + + cur_mZL1 = txl.get_buffer(mZL1, 0) + cur_mZR1 = txl.get_buffer(mZR1, 0) + cur_bZL1 = txl.get_buffer(bZL1, 0) + cur_bZR1 = txl.get_buffer(bZR1, 0) + cur_bKpe1 = txl.get_buffer(bKpe1, 0) + cur_mKpe1 = txl.get_buffer(mKpe1, 0) + + cur_mQK0 = txl.get_buffer(mQK0, 0) + cur_mQK1 = txl.get_buffer(mQK1, 0) + cur_mPV0_L = txl.get_buffer(mPV0_L, 0) + cur_mPV0_R = txl.get_buffer(mPV0_R, 0) + cur_mPV1_L = txl.get_buffer(mPV1_L, 0) + cur_mPV1_R = txl.get_buffer(mPV1_R, 0) + cur_mQ0 = txl.get_buffer(mQ0, 0) + cur_mQ1 = txl.get_buffer(mQ1, 0) + + # TODO: 初始化cur_Max,每次循环置零 + cur_Max = txl.get_buffer(bMax, 0) + cur_P0 = txl.get_buffer(bP0, 0) + + cur_P1 = txl.get_buffer(bP1, 0) + + cur_L0 = txl.get_buffer(bL0, 0) + cur_L1 = txl.get_buffer(bL1, 0) + + if txl.is_warpgroup([0]): + + phase = 1 + q_phase = 0 + + for off_z in range(begin_batch, end_batch): + + heads_per_kv = H // H_KV + q_base = off_z * (H * N_Q) + off_kvh * (heads_per_kv * N_Q) # [Z,H,N_Q,R0] =>[Z,KV_HEADS,q_heads_per_kv_heads*N_Q,R0] + qo_off = q_base + pid_m * BLOCK_M + kv_head_idx = off_kvh + kv_base = off_z * (H_KV * KV_SEQ_LEN) + kv_head_idx * KV_SEQ_LEN # [Z,KV_HEADS,KV_SEQ_LEN,R0] + kv_off = kv_base + + txl.mbar_expect(cur_mQ, BLOCK_M * D * 2) + txl.mbar_expect(cur_mQpe, BLOCK_M * PE_DIM * 2) + + txl.mbar_wait(cur_mQ0, q_phase^1) + txl.mbar_wait(cur_mQ1, q_phase^1) + + txl.tma_load(cur_bQ, desc_qhat, [qo_off, 0], cur_mQ) + txl.mbar_wait(cur_mQ, q_phase) + txl.tma_load(cur_bQpe, desc_qpe, [qo_off, 0], cur_mQpe) + txl.mbar_wait(cur_mQpe, q_phase) + + q_phase ^= 1 + + for _n in range(0, KV_SEQ_LEN, BLOCK_N*2): + + # pe only participate in QK, not PV + txl.mbar_wait(cur_mQK0, phase) + txl.mbar_expect(cur_mKpe0, BLOCK_N*PE_DIM*2) + txl.tma_load(cur_bKpe0, desc_kpe, [kv_off, 0], cur_mKpe0) + + txl.mbar_wait(cur_mQK1, phase) + txl.mbar_expect(cur_mKpe1, BLOCK_N*PE_DIM*2) + txl.tma_load(cur_bKpe1, desc_kpe, [kv_off+BLOCK_N, 0], cur_mKpe1) + + txl.mbar_wait(cur_mPV0_L, phase) + txl.mbar_expect(cur_mZL0, BLOCK_N*D//2*2) + txl.tma_load(cur_bZL0, desc_zkv, [kv_off, 0], cur_mZL0) + + txl.mbar_wait(cur_mPV1_R, phase) + txl.mbar_expect(cur_mZR1, BLOCK_N*D//2*2) + txl.tma_load(cur_bZR1, desc_zkv, [kv_off+BLOCK_N, D//2], cur_mZR1) + + txl.mbar_wait(cur_mPV0_R, phase) + txl.mbar_expect(cur_mZR0, BLOCK_N*D//2*2) + txl.tma_load(cur_bZR0, desc_zkv, [kv_off, D//2], cur_mZR0) + + txl.mbar_wait(cur_mPV1_L, phase) + txl.mbar_expect(cur_mZL1, BLOCK_N*D//2*2) + txl.tma_load(cur_bZL1, desc_zkv, [kv_off+BLOCK_N, 0], cur_mZL1) + + kv_off += 2*BLOCK_N + phase ^= 1 + + # wg1: QK0, P0 and PVl, pass P0 to wg1 + if txl.is_warpgroup([1]): # left consumer + + phase = 0 + q_phase = 0 + + for off_z in range(begin_batch, end_batch): + + heads_per_kv = H // H_KV + q_base = off_z * (H * N_Q) + off_kvh * (heads_per_kv * N_Q) # [Z,H,N_Q,R0] =>[Z,KV_HEADS,q_heads_per_kv_heads*N_Q,R0] + qo_off = q_base + pid_m * BLOCK_M + kv_head_idx = off_kvh + kv_base = off_z * (H_KV * KV_SEQ_LEN) + kv_head_idx * KV_SEQ_LEN # [Z,KV_HEADS,KV_SEQ_LEN,R0] + kv_off = kv_base + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + qk_scale = sm_scale * 1.44269504 + + m_i = tl.full([BLOCK_M], -float("inf"), tl.float32) + l_i = tl.full([BLOCK_M], 1.0, tl.float32) + accL= tl.zeros([BLOCK_M, D//2], dtype=tl.float32) + + txl.mbar_wait(cur_mQ, q_phase) + txl.mbar_wait(cur_mQpe, q_phase) + rQpe = txl.smem_load(cur_bQpe, A_op_layout) # bQpe is reused for bP1 + #txl.print("rQpe:", rQpe) + + cur_bQl = txl.smem_slice(cur_bQ, 0, D//2, 1) # NOTE: currently should only slice on 1 is "slice" + cur_bQr = txl.smem_slice(cur_bQ, D//2, D//2, 1) + + q_phase ^= 1 + + for _n in range(0, KV_SEQ_LEN, BLOCK_N*2): + txl.mbar_wait(cur_mKpe0, phase) + acc_s = tl.dot(rQpe, cur_bKpe0.T) # QKpe0 + txl.dot_wait(0) + + txl.mbar_arrive(cur_mQK0) # only the pe part of QK + + # TODO: slice pipeline + txl.mbar_wait(cur_mZL0, phase) # load of Kl0 + acc_s += tl.dot(cur_bQl, cur_bZL0.T) # QKl0 + txl.dot_wait(0) + + txl.mbar_wait(cur_mZR0, phase) # load of Kr0 + acc_s += tl.dot(cur_bQr, cur_bZR0.T) # QKr0 -> QK0 + txl.dot_wait(0) + + m_ij0 = tl.maximum(m_i, tl.max(acc_s, 1) * qk_scale) + alpha0 = tl.math.exp2(m_i - m_ij0) + + # frag_smem_store + #if idx_in_warpgroup % 4 == 0: # first col of mma layout + txl.frag_smem_store(cur_Max, m_ij0, layout=max_reg_layout) + #txl.smem_store(cur_Max, m_ij0) + + txl.bar_arrive(12, 256) # BAR 12 Max0 ready + + acc_s = acc_s * qk_scale - m_ij0[:, None] + p0 = tl.math.exp2(acc_s) + l_ij0 = tl.sum(p0, 1) + l_i = l_i * alpha0 + l_ij0 + accL = accL * alpha0[:, None] + m_i = m_ij0 + accL = tl.dot(p0.to(dtype), cur_bZL0, accL) # PVl0 + txl.dot_wait(0) + txl.mbar_arrive(cur_mPV0_L) + + txl.bar_wait(14, 256) # BAR 14 Max1 ready + + m_ij1 = txl.frag_smem_load(cur_Max, [BLOCK_M], max_reg_layout) # load Max1 + alpha1 = tl.math.exp2(m_i - m_ij1) + m_i = m_ij1 + + # rescale P0 + p0 = p0 * alpha1[:, None] + + txl.smem_store(cur_P0, p0.to(dtype)) # store P0 -> wg2 + + txl.bar_arrive(11, 256) # BAR 11 P0 ready + + + txl.mbar_wait(cur_mZL1, phase) # wait Vl1 + + # rescale accL + accL = accL * alpha1[:, None] + l_i = l_i * alpha1 # TODO + + txl.bar_wait(10, 256) # BAR 10 P1 ready + + txl.fence_proxy_async() + accL = tl.dot(cur_P1.to(dtype), cur_bZL1, accL) # PVl1 -> PVl + txl.dot_wait(0) + txl.mbar_arrive(cur_mPV1_L) + + phase ^= 1 + + #if idx_in_warpgroup % 4 == 0: # 0123 has same value + txl.frag_smem_store(cur_L0, l_i, layout=max_reg_layout) + #txl.smem_store(cur_L0, l_i) + + txl.bar_arrive(9, 256) # BAR 9 L0 ready + + txl.bar_wait(8, 256) # BAR 8 L1 ready + + L1_reg = txl.frag_smem_load(cur_L1, (64,), max_reg_layout) + l_i = l_i + L1_reg # l0+l1 + m_i += tl.math.log2(l_i) + accL = accL / l_i[:, None] + m_ptrs = M + off_z * (H * N_Q) + off_kvh * (heads_per_kv * N_Q) + offs_m + tl.store(m_ptrs, m_i) + + # reg -> smem -> gmem + txl.smem_store(cur_bZL0, accL.to(dtype)) # store to Vl0, which reused as PVl + txl.tma_store(cur_bZL0, desc_o_lat, [qo_off, 0]) + txl.mbar_arrive(cur_mQ0) + + # wg2: QK1, P1 and PVr + if txl.is_warpgroup([2]): # right consumer + + phase = 0 + q_phase = 0 + + for off_z in range(begin_batch, end_batch): + + heads_per_kv = H // H_KV + q_base = off_z * (H * N_Q) + off_kvh * (heads_per_kv * N_Q) # [Z,H,N_Q,R0] =>[Z,KV_HEADS,q_heads_per_kv_heads*N_Q,R0] + qo_off = q_base + pid_m * BLOCK_M + kv_head_idx = off_kvh + kv_base = off_z * (H_KV * KV_SEQ_LEN) + kv_head_idx * KV_SEQ_LEN # [Z,KV_HEADS,KV_SEQ_LEN,R0] + kv_off = kv_base + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + qk_scale = sm_scale * 1.44269504 + + m_i = tl.full([BLOCK_M], -float("inf"), tl.float32) + l_i = tl.full([BLOCK_M], 1.0, tl.float32) + accR= tl.zeros([BLOCK_M, D//2], dtype=tl.float32) + + txl.mbar_wait(cur_mQ, q_phase) + txl.mbar_wait(cur_mQpe, q_phase) + rQpe = txl.smem_load(cur_bQpe, A_op_layout) + #txl.print("rQpe:", rQpe) + + cur_bQl = txl.smem_slice(cur_bQ, 0, D//2, 1) + cur_bQr = txl.smem_slice(cur_bQ, D//2, D//2, 1) + + q_phase ^= 1 + + for _n in range(0, KV_SEQ_LEN, BLOCK_N*2): + + txl.mbar_wait(cur_mKpe1, phase) + acc_s = tl.dot(rQpe, cur_bKpe1.T) # QKpe1 + txl.dot_wait(0) + txl.mbar_arrive(cur_mQK1) + + # TODO: slice pipeline + txl.mbar_wait(cur_mZR1, phase) + acc_s += tl.dot(cur_bQr, cur_bZR1.T) # QKr1 + txl.dot_wait(0) + + txl.mbar_wait(cur_mZL1, phase) + acc_s += tl.dot(cur_bQl, cur_bZL1.T) # QKl1 -> QK1 + txl.dot_wait(0) + + txl.bar_wait(12, 256) # BAR 12 Max0 ready + + m_ij0 = txl.frag_smem_load(cur_Max, [64], max_reg_layout) # max0 -> wg2 + + m_ij1 = tl.maximum(m_ij0, tl.max(acc_s, 1) * qk_scale) + alpha1 = tl.math.exp2(m_i - m_ij1) + acc_s = acc_s * qk_scale - m_ij1[:, None] + p1 = tl.math.exp2(acc_s) + l_ij1 = tl.sum(p1, 1) + l_i = l_i * alpha1 + l_ij1 + accR = accR * alpha1[:, None] + m_i = m_ij1 + + # frag_smem_store + #if idx_in_warpgroup % 4 == 0: + txl.frag_smem_store(cur_Max, m_ij1, layout=max_reg_layout) # max_all -> wg1 + #txl.smem_store(cur_Max, m_ij1) # max_all -> wg1 + + txl.bar_arrive(14, 256) # BAR 14 Max1 ready + + accR = tl.dot(p1.to(dtype), cur_bZR1, accR) # PVr1 + txl.dot_wait(0) + txl.mbar_arrive(cur_mPV1_R) + + # 1111 + txl.bar_wait(11, 256) # BAR 11 P0 ready + + txl.mbar_wait(cur_mZR0, phase) # Vr0 + txl.fence_proxy_async() + accR = tl.dot(cur_P0.to(dtype), cur_bZR0, accR) # PVr0 -> PVr + txl.dot_wait(0) + txl.mbar_arrive(cur_mPV0_R) + + txl.smem_store(cur_P1, p1.to(dtype)) # save P1 + + #txl.bar_wait(15, 128) + txl.bar_arrive(10, 256) # BAR 10 P1 ready + + phase ^= 1 + + # frag_smem_store + #if idx_in_warpgroup % 4 == 0: + txl.frag_smem_store(cur_L1, l_i, layout=max_reg_layout) # l1 -> wg1, TODO: sync + #txl.smem_store(cur_L1, l_i) # l1 -> wg1, TODO: sync + + #txl.bar_wait(9, 128) + txl.bar_arrive(8, 256) # BAR 8 L1 ready + + txl.bar_wait(9, 256) # BAR 9 L0 ready + + L0_reg = txl.frag_smem_load(cur_L0, (64,), max_reg_layout) + l_i = l_i + L0_reg # l0+l1 + accR = accR / l_i[:, None] + + txl.smem_store(cur_bZL1, accR.to(dtype)) # reuse Vl1 for PVr + txl.tma_store(cur_bZL1, desc_o_lat, [qo_off, D//2]) + #txl.smem_store(cur_AccR, accR.to(dtype)) + #txl.tma_store(cur_AccR, desc_o_lat, [qo_off, D//2]) + txl.mbar_arrive(cur_mQ1) + +def mla_test(q, kv, qpe, kpe, sm_scale, algo=0): + HEAD_DIM_Q = q.shape[-1] + HEAD_DIM_Z = kv.shape[-1] + HEAD_DIM_PE = qpe.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_Z + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + o = torch.empty_like(q) + extra_kern_args = {} + + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + if supports_host_descriptor(): + # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor + y_dim = q.shape[0] * q.shape[1] * q.shape[2] + kv_dim = kv.shape[0] * kv.shape[1] * kv.shape[2] + + dummy_block = [1, 1] + desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_Z], strides=[HEAD_DIM_Z, 1], block_shape=dummy_block) + desc_kv = TensorDescriptor(kv, shape=[kv_dim, HEAD_DIM_Z], strides=[HEAD_DIM_Z, 1], block_shape=dummy_block) + desc_qpe = TensorDescriptor(qpe, shape=[y_dim, HEAD_DIM_PE], strides=[HEAD_DIM_PE, 1], block_shape=dummy_block) + desc_kpe = TensorDescriptor(kpe, shape=[kv_dim, HEAD_DIM_PE], strides=[HEAD_DIM_PE, 1], block_shape=dummy_block) + desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_Z], strides=[HEAD_DIM_Z, 1], block_shape=dummy_block) + else: + desc_q = q + desc_kv = kv + desc_kpe = kpe + desc_qpe = qpe + desc_o = o + + def alloc_fn(size: int, align: int, _): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + # q_heads_per_kv_heads*N_Q // BLOCK_M + # [KV_HEADS, NUM_SMS] + q_heads_per_kv_heads = q.shape[1] // kv.shape[1] + total_q_seqlen = q_heads_per_kv_heads * q.shape[2] + + def grid(META): + return (triton.cdiv(total_q_seqlen, META["BLOCK_M"]), kv.shape[1]*NUM_SMS, 1) + + algo_map = { + 0: mla_txl, + } + + algo_map[algo][grid]( + sm_scale, M, + q.shape[0], q.shape[1], kv.shape[1], + desc_q, + desc_kv, + desc_o, + q.shape[2], kv.shape[2], + desc_qpe, desc_kpe, + PE_DIM=HEAD_DIM_PE, + D=HEAD_DIM_Z, + NUM_SMS=NUM_SMS, + **extra_kern_args + ) + + return o + +def ref_mla(q, kv, qpe, kpe, sm_scale): + Z, H, N_Q, R0 = q.shape + _, KV_HEADS, KV_SEQ_LEN, _ = kv.shape + PE_DIM = qpe.shape[-1] + + heads_per_kv = H // KV_HEADS + o = torch.empty_like(q) + m_all = torch.empty((Z, H, N_Q), device=q.device, dtype=torch.float32) + + for z in range(Z): + for h in range(H): + kv_head_idx = h // heads_per_kv + q_ = q[z, h] # [N_Q, R0] + qpe_ = qpe[z, h] # [N_Q, PE_DIM] + kv_ = kv[z, kv_head_idx] # [KV_SEQ_LEN, R0] + kpe_ = kpe[z, kv_head_idx] # [KV_SEQ_LEN, PE_DIM] + + qk = (q_ @ kv_.T) + (qpe_ @ kpe_.T) # [N_Q, KV_SEQ_LEN] + qk = qk * sm_scale * 1.44269504 # log2(e)=1.44269504 + m_i = torch.max(qk, dim=1).values + p = torch.pow(2, qk - m_i[:, None]) + l_i = torch.sum(p, dim=1) + m_all[z, h] = m_i + torch.log2(l_i) + o[z, h] = (p @ kv_) / l_i[:, None] + return o + +@contextmanager +def proton_context(): + proton.activate(0) + try: + yield + finally: + proton.deactivate(0) + +def bench_fn(label, reps, warmup_reps, fn, *args): + print(f"Benchmarking {label}: ...", end="") + for _ in range(warmup_reps): + fn(*args) + with proton_context(): + for _ in range(reps): + fn(*args) + print(f"\rBenchmarking {label}: done") + +def bench_op(Z, H, N_Q, KV_HEADS, KV_SEQ_LEN, R0, PE_DIM, dtype=torch.float16, algo=0, reps=1000, warmup_reps=1000): + q = (torch.randn((Z, H, N_Q, R0), dtype=dtype, device=DEVICE)) + kv = (torch.randn((Z, KV_HEADS, KV_SEQ_LEN, R0), dtype=dtype, device=DEVICE)) + qpe = (torch.randn((Z, H, N_Q, PE_DIM), dtype=dtype, device=DEVICE)) + kpe = (torch.randn((Z, KV_HEADS, KV_SEQ_LEN, PE_DIM), dtype=dtype, device=DEVICE)) + sm_scale = 1 / math.sqrt(R0) + + bench_fn( + f"mla Z{Z} H{H} NQ{N_Q} KH{KV_HEADS} KS{KV_SEQ_LEN} R0{R0} PE{PE_DIM} algo{algo}", + reps, + warmup_reps, + lambda q, kv, qpe, kpe, sm_scale, algo: mla_test(q, kv, qpe, kpe, sm_scale, algo), + q, kv, qpe, kpe, sm_scale, algo + ) + +def test_op(Z, H, N_Q, KV_HEADS, KV_SEQ_LEN, R0, PE_DIM, dtype=torch.float16, algo=0, no_tune=False): + q = (torch.randn((Z, H, N_Q, R0), dtype=dtype, device=DEVICE)) + kv = (torch.randn((Z, KV_HEADS, KV_SEQ_LEN, R0), dtype=dtype, device=DEVICE)) + qpe = (torch.randn((Z, H, N_Q, PE_DIM), dtype=dtype, device=DEVICE)) + kpe = (torch.randn((Z, KV_HEADS, KV_SEQ_LEN, PE_DIM), dtype=dtype, device=DEVICE)) + # q = (torch.ones((Z, H, N_Q, R0), dtype=dtype, device=DEVICE)) + # kv = (torch.ones((Z, KV_HEADS, KV_SEQ_LEN, R0), dtype=dtype, device=DEVICE)) + # kvu = (torch.ones((Z, KV_HEADS, KV_SEQ_LEN//2, R0), dtype=dtype, device=DEVICE)*0.5) + # kvd = (torch.ones((Z, KV_HEADS, KV_SEQ_LEN//2, R0), dtype=dtype, device=DEVICE)*2.0) + # kv = torch.cat([kvu, kvd], dim=2) + # qpe = (torch.ones((Z, H, N_Q, PE_DIM), dtype=dtype, device=DEVICE)) + # qpe1 = (torch.randn((Z, H, N_Q-20, PE_DIM), dtype=dtype, device=DEVICE)) + # qpe2 = (torch.ones((Z, H, 20, PE_DIM), dtype=dtype, device=DEVICE)) + # qpe = torch.cat([qpe2, qpe1], dim=2) + # kpe = (torch.ones((Z, KV_HEADS, KV_SEQ_LEN, PE_DIM), dtype=dtype, device=DEVICE)) + sm_scale = 1 / math.sqrt(R0) + + tri_out = mla_test(q, kv, qpe, kpe, sm_scale, algo=algo) + print("finish") + print(tri_out.shape) + print(f"triton out: {tri_out[:1, :1, :5, :5]}") + # debug_out = mla_test(q, kv, qpe, kpe, sm_scale, algo=5) + ref_out = ref_mla(q, kv, qpe, kpe, sm_scale) + + max_err = (tri_out - ref_out).abs().max().item() + # location = (tri_out - ref_out).abs().argmax().item() + # print(f"debug out: {debug_out}") + # print(f"ref out: {ref_out}") + # tri_out_around_location = tri_out.view(-1)[max(0, location-3):location+4] + # ref_out_around_location = ref_out.view(-1)[max(0, location-3):location+4] + # num_equal_locations = (tri_out - ref_out).abs() >= max_err * 0.05 + # max_err_indices = torch.nonzero(num_equal_locations) + # print(f"max err location indices: {max_err_indices}") + # print(f"tri around max err location: {tri_out_around_location}") + # print(f"ref around max err location: {ref_out_around_location}") + print(f"Z{Z} H{H} NQ{N_Q} KH{KV_HEADS} KS{KV_SEQ_LEN} R0{R0} PE{PE_DIM} | max err: {max_err:.6f}") + +def show_profile(profile_name): + import triton.profiler.viewer as proton_viewer + metric_names = ["time/ms"] + # if precision == 'fp8': + # metric_names = metric_names + ["tflop8/s"] + # elif precision == 'fp16': + # metric_names = metric_names + ["tflop16/s"] + file_name = f"{profile_name}.hatchet" + tree, metrics = proton_viewer.parse(metric_names, file_name) + proton_viewer.print_tree(tree, metrics) + +if __name__ == "__main__": + no_tune=True + + dump_dir="/workspace/dump/" + #dump_dir=None + print("TEST...") + from triton import knobs + + #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "txlgpu-pipeliner" + knobs.runtime.override_arch='sm90' + knobs.autotuning.print=True + knobs.compilation.always_compile=True + + if dump_dir: + knobs.compilation.dump_ir=True + knobs.cache.dump_dir=dump_dir + # knobs.compilation.override=True + # knobs.cache.override_dir=dump_dir + # NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + # print(f"NUM_SMS: {NUM_SMS}") # 132 + test_op(132, 32, 64, 1, 256, 512, 64, algo=0, no_tune=no_tune) + + proton.start("mla", hook="triton") + proton.deactivate() + bench_op(132, 32, 64, 1, 256, 512, 64, algo=0, reps=100, warmup_reps=100) + proton.finalize() + show_profile("mla") diff --git a/docker/flash_mla/txl_nsa_interface.py b/docker/flash_mla/txl_nsa_interface.py new file mode 100644 index 0000000..41927fe --- /dev/null +++ b/docker/flash_mla/txl_nsa_interface.py @@ -0,0 +1,561 @@ +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor +import txl + +from typing import Optional, Tuple + +import torch + +GROUP_SIZE = tl.constexpr(8) +# NamedBarriers +wg0_bunch_0_ready = tl.constexpr(8) +wg1_bunch_0_ready = tl.constexpr(9) +wg0_s0_ready = tl.constexpr(10) +wg1_s1_ready = tl.constexpr(11) +sL_ready = tl.constexpr(12) +warpgroup0_sync = tl.constexpr(13) +warpgroup1_sync = tl.constexpr(14) +@txl.jit +#@txl.jit(src_file='dump/smem/SKCNG6F2XUQBA3XUKJ4ASFQM2E6HEJDVNB3J2MBCN23UCKHYIFJQ/txl_mla0.ptx') +#@txl.jit(diff_mode="llir", log_dir='dump/') +#@txl.jit(diff_mode="ttgir", diff_select=4, log_dir='dump/smem') +def txl_mla0( + q_nope_desc, q_pe_desc, + kv_nope_ptr, kv_pe_ptr, + o_desc, + max_logits_desc, lse_desc, + indices_ptr, + SCALE_LOG2: tl.constexpr, + + S_Q : tl.constexpr, + S_KV : tl.constexpr, + B_H : tl.constexpr, + D_Q : tl.constexpr, + D_V : tl.constexpr, + NUM_HEAD_BLOCKS : tl.constexpr, + TOPK : tl.constexpr, + B_TOPK : tl.constexpr, + STRIDE_KV_NOPE_0: tl.constexpr, + STRIDE_KV_PE_0: tl.constexpr, + ): + + tid = txl.tid(0) + idx_in_warpgroup = tid % 128 + pid = tl.program_id(0) + s_q_idx = pid // NUM_HEAD_BLOCKS + q_h_idx = pid % NUM_HEAD_BLOCKS + NUM_TOPK_BLOCKS: tl.constexpr = TOPK // B_TOPK + + + offs_q = pid * B_H + + q_nope_buf = txl.smem_alloc([B_H, 512], dtype=tl.bfloat16, num_stages=1); sQnope = txl.get_buffer(q_nope_buf, 0) + q_pe_buf = txl.smem_alloc([B_H, 64], dtype=tl.bfloat16, num_stages=1); sQpe = txl.get_buffer(q_pe_buf, 0) + mbar_q_nope_buf = txl.mbar_alloc(1, num_stages=1); mbar_q_nope = txl.get_buffer(mbar_q_nope_buf, 0) + mbar_q_pe_buf = txl.mbar_alloc(1, num_stages=1); mbar_q_pe = txl.get_buffer(mbar_q_pe_buf, 0) + + + #index_layout1d: tl.constexpr = txl.BlockedLayout([8], [32], [4], [0]) + #index_layout2d: tl.constexpr = txl.BlockedLayout([1, 8], [32, 1], [4, 1], [1, 0]) + index_layout2d: tl.constexpr = txl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) + + Knope_buf = txl.smem_alloc([B_TOPK, 512], dtype=tl.bfloat16, num_stages=2); + Kpe_buf = txl.smem_alloc([B_TOPK, 64], dtype=tl.bfloat16, num_stages=2); + sKnope0 = txl.get_buffer(Knope_buf, 0) + sKpe0 = txl.get_buffer(Kpe_buf, 0) + sKnope1 = txl.get_buffer(Knope_buf, 1) + sKpe1 = txl.get_buffer(Kpe_buf, 1) + + sS0 = txl.get_buffer(Kpe_buf, 0) # reuse + sS1_buf = txl.smem_alloc([B_H, B_TOPK], dtype=tl.bfloat16, num_stages=1); + sS1 = txl.get_buffer(sS1_buf, 0) + + #sO_buf = txl.smem_alloc([B_H, 512], dtype=tl.bfloat16, num_stages=1); + sO = txl.get_buffer(q_nope_buf, 0) # reuse Q + + is_kv_valid_layout0: tl.constexpr = txl.BlockedLayout([1, 1], [4, 8], [4, 1], [1, 0]) # 8 threads are for dilation + is_kv_valid_layout: tl.constexpr = txl.SliceLayout(dim=1, parent=is_kv_valid_layout0) # 8 threads are for dilation + is_kv_valid_buf = txl.smem_alloc([B_TOPK], dtype=tl.int8, num_stages=2); + is_kv_valid0 = txl.get_buffer(is_kv_valid_buf, 0) + is_kv_valid1 = txl.get_buffer(is_kv_valid_buf, 1) + + layout_acc: tl.constexpr = txl.NVMMADistributedLayout([3, 0], [4, 1], [16, 64, 16]) + + sM_buf = txl.smem_alloc([B_H], dtype=tl.float32, num_stages=1); sM = txl.get_buffer(sM_buf, 0) + sL_buf = txl.smem_alloc([B_H], dtype=tl.float32, num_stages=2); sL0 = txl.get_buffer(sL_buf, 0); sL1 = txl.get_buffer(sL_buf, 1) + layout_sM: tl.constexpr = txl.SliceLayout(dim=1, parent=layout_acc) + layout_row: tl.constexpr = txl.SliceLayout(dim=0, parent=layout_acc) + + # mma barriers + mbar_k0_free = txl.mbar_alloc(128, num_stages=2); mbar_k0_free0 = txl.get_buffer(mbar_k0_free, 0); mbar_k0_free1 = txl.get_buffer(mbar_k0_free, 1) + mbar_k0_ready = txl.mbar_alloc(128, num_stages=2); mbar_k0_ready0 = txl.get_buffer(mbar_k0_ready, 0); mbar_k0_ready1 = txl.get_buffer(mbar_k0_ready, 1) + mbar_k1_free = txl.mbar_alloc(128, num_stages=2); mbar_k1_free0 = txl.get_buffer(mbar_k1_free, 0); mbar_k1_free1 = txl.get_buffer(mbar_k1_free, 1) + mbar_k1_ready = txl.mbar_alloc(128, num_stages=2); mbar_k1_ready0 = txl.get_buffer(mbar_k1_ready, 0); mbar_k1_ready1 = txl.get_buffer(mbar_k1_ready, 1) + + mbar_is_kv_valid_ready_buf = txl.mbar_alloc(16, num_stages=1); mbar_is_kv_valid_ready = txl.get_buffer(mbar_is_kv_valid_ready_buf, 0) + + + if txl.is_warpgroup([0, 1]): + txl.reg_alloc(216) + if txl.is_warpgroup([0]): + txl.mbar_expect(mbar_q_nope, B_H*512*2) + txl.tma_load(sQnope, q_nope_desc, [offs_q, 0], mbar_q_nope) + txl.mbar_expect(mbar_q_pe, B_H*64*2) + txl.tma_load(sQpe, q_pe_desc, [offs_q, 0], mbar_q_pe) + + txl.mbar_wait(mbar_q_nope, 0) + txl.mbar_wait(mbar_q_pe, 0) + + rP = tl.zeros([B_H, B_TOPK], dtype=tl.float32) + rM = tl.zeros([B_H], dtype=tl.float32) - float("inf") # m_i + rL = tl.zeros([B_H], dtype=tl.float32) # l_i + rO = tl.zeros([B_H, 256], dtype=tl.float32) # D_V//2 + + cur_bar_wait_phase = 0 + + sKnope0l = txl.smem_slice(sKnope0, 0, 256, dim=1) + # B_TOPK, 256 + sV0l = sKnope0l + sKnope1l = txl.smem_slice(sKnope1, 0, 256, dim=1) + # B_TOPK, 256 + sV1l = sKnope1l + + sKnope0r = txl.smem_slice(sKnope0, 256, 256, dim=1) + # B_TOPK, 256 + sV0r = sKnope0r + sKnope1r = txl.smem_slice(sKnope1, 256, 256, dim=1) + # B_TOPK, 256 + sV1r = sKnope1r + + for block_idx in range(0, NUM_TOPK_BLOCKS, 2): + # iter (-1) rP0 = sQ @ sK0 + if txl.is_warpgroup([0]): + if block_idx == 0: + # pipelined_wait_and_qkt_gemm_l + txl.mbar_wait(mbar_k0_ready0, cur_bar_wait_phase) + # slice 0-3 + for tile_index in tl.static_range(0, 256, 64): + cur_sQ = txl.smem_slice(sQnope, tile_index, 64, dim=1) + cur_sK = txl.smem_slice(sKnope0, tile_index, 64, dim=1) + rP = tl.dot(cur_sQ, cur_sK.T, rP) + + # pipelined_wait_and_qkt_gemm_r + txl.mbar_wait(mbar_k0_ready1, cur_bar_wait_phase) + # slice 4-7 + for tile_index in tl.static_range(256, 512, 64): + cur_sQ = txl.smem_slice(sQnope, tile_index, 64, dim=1) + cur_sK = txl.smem_slice(sKnope0, tile_index, 64, dim=1) + rP = tl.dot(cur_sQ, cur_sK.T, rP) + rP = tl.dot(sQpe, sKpe0.T, rP) # [B_H, 64] x [64, B_TOPK] + + txl.dot_wait(0) + + # rP1 = sQ @ sK1 + if txl.is_warpgroup([1]): + # pipelined_wait_and_qkt_gemm_r + txl.mbar_wait(mbar_k1_ready1, cur_bar_wait_phase) + # slice 4-7 + for tile_index in tl.static_range(256, 512, 64): + cur_sQ = txl.smem_slice(sQnope, tile_index, 64, dim=1) + cur_sK = txl.smem_slice(sKnope1, tile_index, 64, dim=1) + rP = tl.dot(cur_sQ, cur_sK.T, rP) + rP = tl.dot(sQpe, sKpe1.T, rP) # [B_H, 64] x [64, B_TOPK] + + # pipelined_wait_and_qkt_gemm_l + txl.mbar_wait(mbar_k1_ready0, cur_bar_wait_phase) + # slice 0-3 + for tile_index in tl.static_range(0, 256, 64): + cur_sQ = txl.smem_slice(sQnope, tile_index, 64, dim=1) + cur_sK = txl.smem_slice(sKnope1, tile_index, 64, dim=1) + rP = tl.dot(cur_sQ, cur_sK.T, rP) + + + txl.dot_wait(0) + + # mask_rP + txl.mbar_wait(mbar_is_kv_valid_ready, cur_bar_wait_phase) + if txl.is_warpgroup([0]): + #reg_is_kv_valid = txl.smem_load(is_kv_valid0, layout_acc) # mind the shape + reg_is_kv_valid = txl.smem_load(is_kv_valid0, layout_row) # mind the shape + else: + #reg_is_kv_valid = txl.smem_load(is_kv_valid1, layout_acc) # mind the shape + reg_is_kv_valid = txl.smem_load(is_kv_valid1, layout_row) # mind the shape + rP = tl.where(reg_is_kv_valid != 0, rP, float('-inf')) + + if txl.is_warpgroup([1]): + txl.bar_wait(wg0_bunch_0_ready, 256) # wait for wg0 to provide half of sM + + # online_softmax_and_rescale_o + #txl.mbar_wait(mbar_is_kv_valid_ready, cur_bar_wait_phase) + cur_max: tl.tensor = tl.max(rP, axis = 1) # reduce on 2 rows, and a warp reduce + cur_max *= SCALE_LOG2 + # wg0: rM load from sM in the last round + # wg1: rM load from sM written by wg0 + if txl.is_warpgroup([0]): + new_maxs = tl.maximum(rM, cur_max) # m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + else: + r_sM = txl.frag_smem_load(sM, [64], layout_sM) + new_maxs = tl.maximum(r_sM, cur_max) # m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + scale_for_o = tl.exp2(rM - new_maxs) # alpha = tl.math.exp2(m_i - m_ij) + rO *= scale_for_o[:, None] # broadcast, then mind the encoding + + rP = tl.exp2(rP * SCALE_LOG2 - new_maxs[:, None]) # qk = qk * qk_scale - m_ij[:, None]; p = tl.math.exp2(qk) + # TODO: warp_reduce + cur_sum = tl.sum(rP, axis=1) # l_ij = tl.sum(p, 1) + rL = rL * scale_for_o + cur_sum # l_i = l_i * alpha + l_ij + rS = rP.to(tl.bfloat16) # p = p.to(dtype) + + # wg0 save half, then wg1 save whole + # TODO: syncwarp? + if idx_in_warpgroup % 4 == 0: # only store for every 4 because they are the same + #txl.frag_smem_store(sM, new_maxs, layout_sM) + txl.smem_store(sM, new_maxs) # TODO: remove the layout of frag_smem_store + + rM = new_maxs # m_i = m_ij + + if txl.is_warpgroup([0]): + # 1. NamedBarriers, half sM for wg1 is ready + txl.bar_arrive(wg0_bunch_0_ready, 256) # inform wg1 that half sM is saved + rO = tl.dot(rS, sV0l, rO) # O0 += S0 @ V0l, [B_H, B_TOPK] x [B_TOPK, 256], use scaled rO and ready rS + txl.dot_wait(0) + txl.mbar_arrive(mbar_k0_free0) + + # 1.1. NamedBarriers, wait full sM from wg1 + txl.bar_wait(wg1_bunch_0_ready, 256) + new_rM = txl.frag_smem_load(sM, [64], layout_sM) + scale_factors = tl.exp2(rM - new_rM) + rM = new_rM + + # 1.2. scale_rS, wg0 need additional scale of new rM, rL and rO later + rS = (rP * scale_factors[:, None]).to(tl.bfloat16) + + # 2. save_rS_to_sS + txl.smem_store(sS0, rS) # check rS nvmma or not + txl.bar_arrive(wg0_s0_ready, 256) + + # 3. wait for sS1 + txl.bar_wait(wg1_s1_ready, 256) # wait for wg1 s1 + + # 3.1 rescale_rO, wg0 need additional scale of new rM, for rL and rO + rO *= scale_factors[:, None] # final scale + rL *= scale_factors # check layout + + # 3.2 wait for sS1 + rO = tl.dot(sS1, sV1l, rO) # O0 += S1 @ V1l, get the whole Vl + + if txl.is_warpgroup([1]): + # 1. NamedBarriers, updated other half of sM for wg0 is ready + txl.bar_arrive(wg1_bunch_0_ready, 256) # inform wg0 of whole sM and let it rescale + rO = tl.dot(rS, sV1r, rO) # O1 += S1 @ V1r, [B_H, B_TOPK] x [B_TOPK, 256], should not release k1 mbar, bcoz no dot wait + + # 2. save_rS_to_sS, and use sS0 + txl.smem_store(sS1, rS) # check rS nvmma or not + txl.bar_wait(wg0_s0_ready, 256) # get s0 from wg0 + rO = tl.dot(sS0, sV0r, rO) # O1 += S0 @ V0r + txl.bar_arrive(wg1_s1_ready, 256) # tell wg0 s1 is ready + + # 3. wait all dots for s vr + # sV1r + txl.dot_wait(1) + txl.mbar_arrive(mbar_k1_free1) + # sV0r + txl.dot_wait(0) + txl.mbar_arrive(mbar_k0_free1) + + cur_bar_wait_phase ^= 1 + + if txl.is_warpgroup([0]): + if block_idx + 2 < NUM_TOPK_BLOCKS: + # pipelined_wait_and_qkt_gemm_l + txl.mbar_wait(mbar_k0_ready0, cur_bar_wait_phase) + # slice 0-3 + for tile_index in tl.static_range(0, 256, 64): + cur_sQ = txl.smem_slice(sQnope, tile_index, 64, dim=1) + cur_sK = txl.smem_slice(sKnope0, tile_index, 64, dim=1) + rP = tl.dot(cur_sQ, cur_sK.T, rP) + + txl.dot_wait(1) + # mark sV1l as free + txl.mbar_arrive(mbar_k1_free0) + + # pipelined_wait_and_qkt_gemm_r + txl.mbar_wait(mbar_k0_ready1, cur_bar_wait_phase) + # slice 4-7 + for tile_index in tl.static_range(256, 512, 64): + cur_sQ = txl.smem_slice(sQnope, tile_index, 64, dim=1) + cur_sK = txl.smem_slice(sKnope0, tile_index, 64, dim=1) + rP = tl.dot(cur_sQ, cur_sK.T, rP) + rP = tl.dot(sQpe, sKpe0.T, rP) # [B_H, 64] x [64, B_TOPK] + + # The whole sQ @ sK0 + txl.dot_wait(0) + else: + txl.dot_wait(0) + # mark sV1l as free + txl.mbar_arrive(mbar_k1_free0) + + # After block_idx + + # reduce_L + # TODO: reduce L on warp_reduce + if txl.is_warpgroup([0]): + if idx_in_warpgroup % 4 == 0: # only store for every 4 because they are the same + #txl.frag_smem_store(sL0, rL, layout_sM) + txl.smem_store(sL0, rL) + if txl.is_warpgroup([1]): + if idx_in_warpgroup % 4 == 0: # only store for every 4 because they are the same + #txl.frag_smem_store(sL1, rL, layout_sM) + txl.smem_store(sL1, rL) + # all finished sL store + txl.bar_wait(sL_ready, 256) + if txl.is_warpgroup([0]): + peer_L = txl.frag_smem_load(sL1, (64,), layout_sM) + else: + peer_L = txl.frag_smem_load(sL0, (64,), layout_sM) + rL += peer_L + + # store_O + scale_factors = tl.where(rL == 0.0, 1.0, 1.0/rL) + cur_rO = (rO * scale_factors[:, None]).to(tl.bfloat16) + #sO_index = 0 + #cur_sO = txl.smem_slice(sO, sO_index, 256, 1) + #for tile_index in range(0, 256, 64): + # cur_tile_sO = txl.smem_slice((cur_sO, 0, 64, 1) + # txl.smem_store(cur_tile_sO, cur_rO) # TODO: split cur_rO + # txl.bar_wait(warpgroup0_sync, 128) + # txl.tma_store(cur_tile_sO, o_desc, [offs_q, tile_index]) + + if txl.is_warpgroup([0]): + cur_sO = txl.smem_slice(sO, 0, 256, 1) + txl.smem_store(cur_sO, cur_rO) + txl.bar_wait(warpgroup0_sync, 128) + txl.tma_store(cur_sO, o_desc, [offs_q, 0]) + else: + cur_sO = txl.smem_slice(sO, 256, 256, 1) + txl.smem_store(cur_sO, cur_rO) + txl.bar_wait(warpgroup1_sync, 128) + txl.tma_store(cur_sO, o_desc, [offs_q, 256]) + + #txl.bar_wait(warpgroup0_sync, 128) + #o_desc.store(cur_rO, [offs_q, 0]) + + if txl.is_warpgroup([1]): + # save lse + final_max_logits = tl.where(rL == 0.0, float('-inf'), rM) + final_lse = tl.where(rL == 0.0, float('-inf'), tl.log2(rL) + rM) + max_logits_desc.store([offs_q], final_max_logits) + lse_desc.store([offs_q], final_lse) + + + if txl.is_warpgroup([2]): + txl.reg_dealloc(72) + + gIndices = indices_ptr + s_q_idx * TOPK + index_offs = tl.arange(0, B_TOPK * GROUP_SIZE) // GROUP_SIZE # 4 strided with stride 128//8, e.g. 0 16 32 48 + + inc_index_offs = tl.reshape(tl.arange(0, B_TOPK * 64) % 64, (64, 64)) + inc_index_offs = txl.relayout(inc_index_offs, (64, 64), index_layout2d) # TODO: col 8x8, natually redundant or merged in reshape? + + cur_bar_wait_phase = 1 + + #for block_idx in range(0, 2, 2): + for block_idx in range(0, NUM_TOPK_BLOCKS, 2): + # buf_idx 0 + cur_index_nope_arr = () + cur_index_pe_arr = () + is_token_valid_arr = () + + for buf_idx in tl.static_range(0, 2): + #for buf_idx in tl.static_range(0, 1): + cur_g_indices = gIndices + (block_idx + buf_idx) * B_TOPK + + cur_index = tl.load(cur_g_indices + index_offs) # 4 strided with stride 128//8 + cur_index_nope0 = cur_index * STRIDE_KV_NOPE_0 # 4 strided with 128 + cur_index_pe0 = cur_index * STRIDE_KV_PE_0 + cur_index_nope0 = tl.broadcast_to(cur_index_nope0[:, None], (512, 8)) # 4x8 strided with 128 rows + cur_index_nope0 = tl.reshape(cur_index_nope0, (64, 64)) # make it col 8x8 + cur_index_nope0 = txl.relayout(cur_index_nope0, (64, 64), index_layout2d) # TODO: redundant? + cur_index_nope0 = cur_index_nope0 + inc_index_offs + cur_index_pe0 = tl.broadcast_to(cur_index_pe0[:, None], (512, 8)) + cur_index_pe0 = tl.reshape(cur_index_pe0, (64, 64)) + cur_index_pe0 = txl.relayout(cur_index_pe0, (64, 64), index_layout2d) + cur_index_pe0 = cur_index_pe0 + inc_index_offs + + + is_token_valid = (cur_index >=0) & (cur_index < S_KV) + is_token_valid0 = tl.broadcast_to(is_token_valid[:, None], (512, 8)) # each thread repeat 8 times + is_token_valid0 = tl.reshape(is_token_valid0, (64, 64)) # reshape auto makes it correct + + is_token_valid0 = txl.relayout(is_token_valid0, (64, 64), index_layout2d) #TODO: not working? + + cur_index_nope_arr = cur_index_nope_arr + (cur_index_nope0,) + cur_index_pe_arr = cur_index_pe_arr + (cur_index_pe0,) + is_token_valid_arr = is_token_valid_arr + (is_token_valid0,) + + + + # V0l + txl.mbar_wait(mbar_k0_free0, cur_bar_wait_phase) + is_token_valid = is_token_valid_arr[0] + #txl.async_load_wait(0) + for tile_index in tl.static_range(0, 256, 64): + #for tile_index in tl.static_range(0, 64, 64): + cur_sKnope = txl.smem_slice(sKnope0, tile_index, 64, dim=1) + offs_kv = cur_index_nope_arr[0] + tile_index + txl.async_load(cur_sKnope, kv_nope_ptr+offs_kv, mask=is_token_valid, contiguity=8) # col 8x8 + #txl.async_load(cur_sKnope, kv_nope_ptr+offs_kv, contiguity=8) + #txl.async_load_wait(0) + txl.mbar_arrive(mbar_k0_ready0, track_async_op=True) + + # V1r + txl.mbar_wait(mbar_k1_free1, cur_bar_wait_phase) + is_token_valid = is_token_valid_arr[1] + for tile_index in tl.static_range(256, 512, 64): + #for tile_index in tl.static_range(0, 64, 64): + cur_sKnope = txl.smem_slice(sKnope1, tile_index, 64, dim=1) + offs_kv = cur_index_nope_arr[1] + tile_index + txl.async_load(cur_sKnope, kv_nope_ptr+offs_kv, mask=is_token_valid, contiguity=8) + #txl.async_load(cur_sKnope, kv_nope_ptr+offs_kv, contiguity=8) + #txl.async_load_wait(0) + offs_kv = cur_index_pe_arr[1] + txl.async_load(sKpe1, kv_pe_ptr+offs_kv, mask=is_token_valid, contiguity=8) + txl.mbar_arrive(mbar_k1_ready1, track_async_op=True) + + # V0r + txl.mbar_wait(mbar_k0_free1, cur_bar_wait_phase) + is_token_valid = is_token_valid_arr[0] + for tile_index in tl.static_range(256, 512, 64): + cur_sKnope = txl.smem_slice(sKnope0, tile_index, 64, dim=1) + offs_kv = cur_index_nope_arr[0] + tile_index + txl.async_load(cur_sKnope, kv_nope_ptr+offs_kv, mask=is_token_valid, contiguity=8) + offs_kv = cur_index_pe_arr[0] + txl.async_load(sKpe0, kv_pe_ptr+offs_kv, mask=is_token_valid, contiguity=8) + txl.mbar_arrive(mbar_k0_ready1, track_async_op=True) + + # V1l + txl.mbar_wait(mbar_k1_free0, cur_bar_wait_phase) + is_token_valid = is_token_valid_arr[1] + for tile_index in tl.static_range(0, 256, 64): + cur_sKnope = txl.smem_slice(sKnope1, tile_index, 64, dim=1) + offs_kv = cur_index_nope_arr[1] + tile_index + txl.async_load(cur_sKnope, kv_nope_ptr+offs_kv, mask=is_token_valid, contiguity=8) + txl.mbar_arrive(mbar_k1_ready0, track_async_op=True) + + if tid % 8 == 0: + is_token_valid0 = is_token_valid_arr[0] + is_token_valid1 = is_token_valid_arr[1] + # smem and layout are same shape + txl.frag_smem_store(is_kv_valid0, is_token_valid0.to(tl.int8), is_kv_valid_layout) # frag: only the first for each thread + txl.frag_smem_store(is_kv_valid1, is_token_valid1.to(tl.int8), is_kv_valid_layout) + txl.mbar_arrive(mbar_is_kv_valid_ready) + + cur_bar_wait_phase ^= 1 + +def txl_mla( + #q: torch.Tensor, + #kv: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_nope: torch.Tensor, + kv_pe: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + ): + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + + Returns: + (output, max_logits, lse) + About the definition of output, max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp + """ + #dump_dir='dump/smem/' + dump_dir = None + + from triton import knobs + import os + #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "tritongpu-remove-layout-conversions" + #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "txlgpu-smem-alloc-legalize" + #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "txlgpu-smem-alloc-layout-conversions" + #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "txlgpu-pipeliner" + knobs.runtime.override_arch='sm90' + #knobs.autotuning.print=True + #knobs.compilation.always_compile=True + + if dump_dir: + knobs.compilation.dump_ir=True + knobs.cache.dump_dir=dump_dir + + + B_H = 64 + B_TOPK = 64 + D_Q = 576 + D_K = D_Q + D_V = 512 + + s_q = q_nope.size(0) + s_kv = kv_nope.size(0) + top_k = indices.size(2) + h_q = q_nope.size(1) + + h_kv = kv_nope.size(1) + assert h_kv == 1 + d_qk = D_Q + d_v = D_V + + qk_scale = sm_scale * 1.44269504 + + #q_nope, q_pe = torch.split(q, [512, 64], dim=-1) + #q_nope = q_nope.contiguous() + #q_pe = q_pe.contiguous() + + #kv_nope, kv_pe = torch.split(kv, [512, 64], dim=-1) + #kv_nope = kv_nope.contiguous() + #kv_pe = kv_pe.contiguous() + + + out = torch.empty((s_q, h_q, d_v), dtype=q_nope.dtype, device=q_nope.device) + max_logits = torch.empty((s_q, h_q), dtype=torch.float32, device=q_nope.device) + lse = torch.empty((s_q, h_q), dtype=torch.float32, device=q_nope.device) + + q_nope_desc = TensorDescriptor(q_nope, (s_q*h_q, 512), (512, 1), [B_H, 512]) + q_pe_desc = TensorDescriptor(q_pe, (s_q*h_q, 64), (64, 1), [B_H, 64]) + o_desc = TensorDescriptor(out, (s_q*h_q, d_v), (d_v, 1), [B_H, D_V]) + max_logits_desc = TensorDescriptor(max_logits, (s_q*h_q, ), (1, ), [B_H]) + lse_desc = TensorDescriptor(lse, (s_q*h_q, ), (1, ), [B_H]) + + # TESTS + #out1 = torch.empty((s_q, h_q, d_v//2), dtype=q.dtype, device=q.device) + #out2 = torch.empty((s_q, h_q, d_v//2), dtype=q.dtype, device=q.device) + #o1_desc = TensorDescriptor(out1, (s_q*h_q, d_v//2), (d_v//2, 1), [B_H, D_V//2]) + #o2_desc = TensorDescriptor(out2, (s_q*h_q, d_v//2), (d_v//2, 1), [B_H, D_V//2]) + + NUM_HEAD_BLOCKS = h_q // B_H + txl_mla0[(NUM_HEAD_BLOCKS * s_q,)]( + q_nope_desc, q_pe_desc, + kv_nope, kv_pe, + o_desc, + max_logits_desc, lse_desc, + indices, + + qk_scale, + s_q, s_kv, + B_H, D_Q, D_V, + NUM_HEAD_BLOCKS, + top_k, + B_TOPK, + kv_nope.stride(0), + kv_pe.stride(0), + num_warps=4, num_warpgroups=3) + return out, max_logits, lse diff --git a/docker/lib.py b/docker/lib.py new file mode 100644 index 0000000..f884721 --- /dev/null +++ b/docker/lib.py @@ -0,0 +1,73 @@ +from typing import List + +import torch + +def cdiv(x: int, y: int): + return (x+y-1) // y + +def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7): + """ + Check if two tensors are close enough + """ + def get_cos_diff(x: torch.Tensor, y: torch.Tensor) -> float: + """ + Calculate the cosine diff between two tensors + """ + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum().item() + if denominator == 0: + return 0 + sim = 2 * (x * y).sum().item() / denominator + return 1 - sim + assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}" + + ans = ans.clone().to(torch.float) + ref = ref.clone().to(torch.float) + + # Deal with anomalies + def deal_with_anomalies(val: float): + ref_mask = (ref == val) if (val == val) else (ref != ref) + ans_mask = (ans == val) if (val == val) else (ans != ans) + ref[ref_mask] = 0.0 + ans[ans_mask] = 0.0 + if not torch.equal(ref_mask, ans_mask): + print(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref") + return False + return True + + anomalies_check_passed = True + anomalies_check_passed &= deal_with_anomalies(float("inf")) + anomalies_check_passed &= deal_with_anomalies(float("-inf")) + anomalies_check_passed &= deal_with_anomalies(float("nan")) + + if not anomalies_check_passed: + return False + + cos_diff = get_cos_diff(ans, ref) + raw_abs_err = torch.abs(ans-ref) + raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6)) + rel_err = raw_rel_err.masked_fill(raw_abs_err List[int]: + result = [] + for size in t.shape[::-1]: + result.append(pos % size) + pos = pos // size + assert pos == 0 + return result[::-1] + print(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}") + print(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}") + print(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)") + print(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})") + return False + else: + if abs(cos_diff) > cos_diff_tol: + print(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})") + return False + return True \ No newline at end of file diff --git a/docker/test_flash_mla_decoding.py b/docker/test_flash_mla_decoding.py new file mode 100644 index 0000000..5487694 --- /dev/null +++ b/docker/test_flash_mla_decoding.py @@ -0,0 +1,387 @@ +import argparse +import math +import random +import dataclasses +from typing import Optional, Tuple + +import torch +import triton + +import flash_mla +import sys as quant +from lib import cdiv, check_is_allclose + +@dataclasses.dataclass +class TestParam: + b: int # Batch size + s_q: int # Number of queries for one request + s_k: int # Seq len, or mean seq len if varlen == True + is_varlen: bool + is_causal: bool + is_fp8: bool + topk: Optional[int] = None + test_performance: bool = True + is_all_indices_invalid: bool = False + have_zero_seqlen_k: bool = False + block_size: int = 64 + h_q: int = 128 # Number of q heads + h_kv: int = 1 # Number of kv heads + d: int = 576 # Q/K head dim (= dv + RoPE dim) + dv: int = 512 # V head dim + seed: int = 0 + + +def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Generate test data from a given configuration + Return: [cache_seqlens, q, block_table, blocked_k] + Pay attention: This function changes the random seed + """ + random.seed(t.seed) + torch.manual_seed(t.seed) + torch.cuda.manual_seed(t.seed) + torch.backends.cudnn.deterministic = True + + assert t.h_q % t.h_kv == 0 + + cache_seqlens_cpu = torch.full((t.b,), t.s_k, dtype=torch.int32, device='cpu') + if t.is_varlen: + for i in range(t.b): + cache_seqlens_cpu[i] = max(random.normalvariate(t.s_k, t.s_k / 2), t.s_q) + + if t.have_zero_seqlen_k: + zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0 + cache_seqlens_cpu[zeros_mask] = 0 + + max_seqlen = cache_seqlens_cpu.max().item() + max_seqlen_pad = cdiv(max_seqlen, 256) * 256 #256 + cache_seqlens = cache_seqlens_cpu.cuda() # [132] each value is 256 + + q_value = torch.randn(t.b, t.h_q, t.s_q, 512).clamp(-1.0, 1.0) + q_pe = torch.randn(t.b, t.h_q, t.s_q, 64).clamp(-1.0, 1.0) + q = torch.cat([q_value, q_pe], dim=-1).permute(0, 2, 1, 3).contiguous() # [132, 64, 128, 576] + # q = torch.randn(t.b, t.s_q, t.h_q, t.d) # [132, 64, 128, 576] + # q.clamp_(min=-1.0, max=1.0) + + block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size) # [132, 4] + # block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(t.b, -1) + k_value = (torch.randn(t.b, t.h_kv, max_seqlen_pad, 512) / 10).clamp(-1.0, 1.0) + k_pe = (torch.randn(t.b, t.h_kv, max_seqlen_pad, 64) / 10).clamp(-1.0, 1.0) + k = torch.cat([k_value, k_pe], dim=-1).permute(0, 2, 1, 3).contiguous() # [132, max_seqlen_pad, 1, 576] + blocked_k = k.view(t.b * (max_seqlen_pad // t.block_size), t.block_size, t.h_kv, t.d) # [?, block_size, h_kv, d] + assert blocked_k.size(0) == block_table.numel() + # blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10 + # blocked_k.clamp_(min=-1.0, max=1.0) + + if t.topk is None: + for i in range(t.b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, t.block_size) + blocked_k[block_table[i][cur_num_blocks:]] = float("nan") + if cur_len % t.block_size != 0: + blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan") + block_table[i][cur_num_blocks:] = 2147480000 + return cache_seqlens, q, block_table, blocked_k, None, None, q_value, k_value, q_pe, k_pe + else: + block_table_cpu = block_table.cpu() + abs_indices = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu") + indices_in_kvcache = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu") + for i in range(t.b): + # Generate indices + for j in range(t.s_q): + cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk] + cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size) + if len(cur_abs_indices) < t.topk: + pad_len = t.topk - len(cur_abs_indices) + cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')]) + cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')]) + + # Mask KV + perm = torch.randperm(t.topk, device='cpu') + cur_abs_indices = cur_abs_indices[perm] + cur_blocked_indices = cur_blocked_indices[perm] + + # Fill it with invalid indices if needed + if t.is_all_indices_invalid: + cur_abs_indices.fill_(-1) + cur_blocked_indices.fill_(-1) + + abs_indices[i, j, :] = cur_abs_indices + indices_in_kvcache[i, j, :] = cur_blocked_indices + + # Mask nonused KV as NaN + all_indices = indices_in_kvcache.flatten().tolist() + all_indices = list(set(all_indices)) + if -1 in all_indices: + all_indices.remove(-1) + all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu') + + blocked_k = blocked_k.view(-1, t.h_kv, t.d) + nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu') + nonused_indices_mask[all_indices] = False + blocked_k[nonused_indices_mask, :, :] = float("nan") + blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d) + + abs_indices = abs_indices.to(q.device) + indices_in_kvcache = indices_in_kvcache.to(q.device) + + return cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache, q_value, k_value, q_pe, k_pe + + +def reference_torch( + cache_seqlens: torch.Tensor, # [batch_size] + block_table: torch.Tensor, # [batch_size, ?] + q: torch.Tensor, # [batch_size, s_q, h_q, d] + blocked_k: torch.Tensor, # [?, block_size, h_kv, d] + dv: int, + is_causal: bool, + indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation in PyTorch + """ + def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): + mask = torch.zeros(s_q, s_k, dtype=torch.bool) + for i in range(s_q): + cur_indices = indices[i] + valid_indices = cur_indices[cur_indices != -1] + mask[i, valid_indices] = True + return mask + + def scaled_dot_product_attention( + batch_idx: int, + query: torch.Tensor, # [h_q, s_q, d] + kv: torch.Tensor, # [h_kv, s_k, d] + dv: int, + is_causal, + indices: Optional[torch.Tensor], # [s_q, topk] + ) -> Tuple[torch.Tensor, torch.Tensor]: + h_q = query.size(0) + h_kv = kv.size(0) + s_q = query.shape[-2] + s_k = kv.shape[-2] + query = query.float() + kv = kv.float() + if h_kv != 1: + kv = kv.repeat_interleave(h_q // h_kv, dim=0) + kv[kv != kv] = 0.0 + attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] + if (is_causal and query.size(1) > 1) or indices is not None: + mask = torch.ones(s_q, s_k, dtype=torch.bool) + if is_causal: + assert indices is None + mask = mask.tril(diagonal=s_k - s_q) + if indices is not None: + mask &= get_topk_attn_mask(s_q, s_k, indices) + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float) + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + attn_weight += attn_bias.to(q.dtype) + attn_weight /= math.sqrt(query.size(-1)) + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] + # Correct for q tokens which has no attendable k + lonely_q_mask = (lse == float("-inf")) + output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output, lse + + b, s_q, h_q, d = q.size() + block_size = blocked_k.size(1) + h_kv = blocked_k.size(2) + cache_seqlens_cpu = cache_seqlens.cpu() + out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, block_size) + cur_block_indices = block_table[i][0: cur_num_blocks] + cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] + cur_out, cur_lse = scaled_dot_product_attention( + i, + q[i].transpose(0, 1), + cur_kv.transpose(0, 1), + dv, + is_causal, + indices[i] if indices is not None else None + ) + out_ref[i] = cur_out.transpose(0, 1) + lse_ref[i] = cur_lse + out_ref = out_ref.to(torch.bfloat16) + return out_ref, lse_ref + + +@torch.inference_mode() +def test_flash_mla(t: TestParam): + print('-------------------------------') + print(f"Running on {t}...") + + # Generating test data + torch.cuda.synchronize() + cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache, q_value, k_value, q_pe, k_pe = generate_test_data(t) + + if t.is_fp8: + # The quantization error may be too large to be distinguished from wrong kernels + # So we quantize and de-quantize kv-cache here to mitigate quantization error + blocked_k_quantized = quant.quantize_k_cache(blocked_k, t.dv, 128) + blocked_k_dequantized = quant.dequantize_k_cache(blocked_k_quantized) + blocked_k = blocked_k_dequantized + + # Get schedule metadata + torch.cuda.synchronize() + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( + cache_seqlens, + t.s_q * t.h_q // t.h_kv, + t.h_kv, + t.h_q, + t.is_fp8, + t.topk + ) + + # q_txl = torch.randn((t.b, t.h_q, t.s_q, 512), dtype=torch.float16, device='cuda').clamp(-1.0, 1.0) + # q_pe_txl = (torch.randn((t.b, t.h_q, t.s_q, 64), dtype=torch.float16, device='cuda') / 10).clamp(-1.0, 1.0) + # kv_txl = torch.randn((t.b, t.h_kv, t.s_k, 512), dtype=torch.float16, device='cuda').clamp(-1.0, 1.0) + # kv_pe_txl = (torch.randn((t.b, t.h_kv, t.s_k, 64), dtype=torch.float16, device='cuda') / 10).clamp(-1.0, 1.0) + + torch.cuda.synchronize() + + def run_flash_mla(): + return flash_mla.flash_mla_with_kvcache( + q, + blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore + block_table, + cache_seqlens, + t.dv, + tile_scheduler_metadata, + num_splits, + causal=t.is_causal, + is_fp8_kvcache=t.is_fp8, + indices=indices_in_kvcache + ) + + def txl_mla(): + return flash_mla.mla_test( + q_value, + k_value, + q_pe, + k_pe, + 1 / math.sqrt(576), + algo = 0, + ) + + out_ans, lse_ans = run_flash_mla() + txl_ans_out = txl_mla().permute(0, 2, 1, 3).contiguous() + out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) + assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) + assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536) + assert check_is_allclose("txl_out", txl_ans_out, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) + print("Correctness check passed!") + print(f"ref_out result sample: {out_ref[0, 0, 0, :8]}") + print(f"flash_mla_out result sample: {out_ans[0, 0, 0, :8]}") + print(f"txl_mla_out result sample: {txl_ans_out[0, 0, 0, :8]}") + print("===============================") + print("Running performance test...") + if t.test_performance: + time_usage_txl: float = triton.testing.do_bench(txl_mla) / 1000 # type: ignore + time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore + mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk + compute_volume_flop = t.b * t.h_q * t.s_q * sum([ + 2 * t.d * mean_attended_seqlens, # Q * K^T + 2 * mean_attended_seqlens * t.dv, # attention * V + ]) + q_elem_size = torch.bfloat16.itemsize + kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize + memory_volume_B = t.b * sum([ + t.s_q * t.h_q * (t.d * q_elem_size), # Q + (t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V + t.s_q * t.h_q * (t.dv * q_elem_size), # Output + ]) + achieved_tflops = compute_volume_flop / time_usage / 1e12 + achieved_gBps = memory_volume_B / time_usage / 1e9 + achieved_tflops_txl = compute_volume_flop / time_usage_txl / 1e12 + achieved_gBps_txl = memory_volume_B / time_usage_txl / 1e9 + + print(f"FLASH MLA: {time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") + print(f"TXL MLA: {time_usage_txl * 1000:.3f} ms, {achieved_tflops_txl:.0f} TFLOPS, {achieved_gBps_txl:.0f} GB/s") + + +def main(torch_dtype): + device = torch.device("cuda:0") + torch.set_default_dtype(torch_dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + + correctness_cases = [ + TestParam(b, s_q, s_k, is_varlen, is_causal, is_fp8, topk, test_performance=False) + for b in [1, 2, 6, 64] + for s_q in [1, 2, 4] + for s_k in [20, 140, 4096] + for is_varlen in [False, True] + for is_causal in [False, True] + for (is_fp8, topk) in [ + (False, None), + (True, 128), + (True, 2048) + ] + if not (is_causal and topk is not None) + ] + + corner_cases = [ + # Cases where all topk indices are invalid + TestParam(128, 2, 4096, is_varlen=True, is_causal=False, is_fp8=True, topk=topk, test_performance=False, is_all_indices_invalid=True) + for topk in [128, 2048, 4096] + ] + [ + # Cases where some kv cache have zero length + TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=False, have_zero_seqlen_k=True) + for (is_causal, is_fp8, topk) in [ + (False, False, None), + (True, False, None), + (False, True, 128), + (False, True, 2048), + ] + ] + + performance_cases = [ + TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=True) + for (is_causal, is_fp8, topk) in [ + (False, False, None), + (True, False, None), + (False, True, 2048), + ] + for s_q in [1, 2] + for s_k in [4096, 8192, 16384, 32768] + ] + + testcases = correctness_cases + corner_cases + performance_cases + testcases = [ + TestParam(132, 64, 256, is_varlen=False, is_causal=False, is_fp8=False, topk=None, test_performance=True), + # TestParam(132, 2, 4096, is_varlen=False, is_causal=False, is_fp8=False, topk=None, test_performance=True) + ] + + # Prune out unsupported cases + cc_major, cc_minor = torch.cuda.get_device_capability() + if cc_major == 10: + testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)] + + for testcase in testcases: + test_flash_mla(testcase) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + choices=["bf16", "fp16"], + default="fp16", + help="Data type to use for testing (bf16 or fp16)", + ) + + args = parser.parse_args() + + torch_dtype = torch.bfloat16 + if args.dtype == "fp16": + torch_dtype = torch.float16 + + main(torch_dtype) diff --git a/docker/test_flash_mla_prefill.py b/docker/test_flash_mla_prefill.py new file mode 100644 index 0000000..5ada4f9 --- /dev/null +++ b/docker/test_flash_mla_prefill.py @@ -0,0 +1,293 @@ +import math +import time +from typing import Tuple +import random +import dataclasses + +import torch +import triton + +from flash_mla import flash_mla_sparse_fwd, txl_mla +from lib import check_is_allclose + +@dataclasses.dataclass +class TestParam: + b: int + s_q: int + s_kv: int + topk: int + h_q: int = 128 + h_kv: int = 1 + d_qk: int = 576 + d_v: int = 512 + seed: int = 0 + check_correctness: bool = True + benchmark: bool = True + +@dataclasses.dataclass +class Testcase: + t: TestParam + q: torch.Tensor + kv: torch.Tensor + indices: torch.Tensor + +def generate_testcase(t: TestParam) -> Testcase: + torch.manual_seed(t.seed) + torch.cuda.manual_seed(t.seed) + random.seed(t.seed) + q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10 + kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10 + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((t.b, t.s_q, t.h_kv, t.topk), t.s_kv, dtype=torch.int32) + for b in range(t.b): + for s in range(t.s_q): + for h in range(t.h_kv): + # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention + near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 + cur_indices = torch.randperm(t.s_kv)[:t.topk] + cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),)) + if len(cur_indices) < t.topk: + cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)]) + cur_indices = cur_indices[torch.randperm(t.topk)] + indices[b, s, h] = cur_indices + indices = indices.to(q.device) + + return Testcase( + t=t, + q=q, + kv=kv, + indices=indices + ) + +def get_flop(p: TestParam) -> float: + flop = 2 * sum([ + p.h_q * p.d_qk * p.topk, + p.h_q * p.d_v * p.topk + ]) * p.b * p.s_q + return flop + +def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: + return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) + + assert p.b == 1 + indices = t.indices[0, :, 0, :] # [s_q, topk] + invalid_indices_mask = (indices < 0) | (indices >= p.s_kv) + qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk] + kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk] + + kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(p.s_q, p.topk, p.d_qk) # [s_q, topk, d_qk] + attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk] + attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf')) + attn_score *= sm_scale * math.log2(math.e) + max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q] + lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] + attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk] + result = attn_score @ kvs[:, :, :p.d_v] + return (max_logits, lse, result) + +@torch.inference_mode() +def run_test(p: TestParam) -> bool: + print("================") + print(f"Running on {p}") + torch.cuda.empty_cache() + assert p.b == 1 + + t = generate_testcase(p) + sm_scale = 1 / math.sqrt(p.d_qk) + torch.cuda.synchronize() + + q = t.q.squeeze(0) + kv = t.kv.squeeze(0) + + q_nope, q_pe = torch.split(q, [512, 64], dim=-1) + q_nope = q_nope.contiguous() + q_pe = q_pe.contiguous() + + kv_nope, kv_pe = torch.split(kv, [512, 64], dim=-1) + kv_nope = kv_nope.contiguous() + kv_pe = kv_pe.contiguous() + + def run_ans_fmla(): + return flash_mla_sparse_fwd( + t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale + ) + + def run_ans_txl(): + return txl_mla( + q_nope, q_pe, kv_nope, kv_pe, t.indices.squeeze(0), sm_scale=sm_scale + ) + + ans_out, ans_max_logits, ans_lse = run_ans_fmla() + ans_out_txl, ans_max_logits_txl, ans_lse_txl = run_ans_txl() + + torch.cuda.synchronize() + + if p.benchmark: + flop = get_flop(p) + prefill_ans_time: float = triton.testing.do_bench(run_ans_fmla, warmup=10, rep=20) / 1000 # type: ignore + prefill_flops = flop / prefill_ans_time / 1e12 + print(f"FlashMLA Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops") + prefill_ans_time_txl: float = triton.testing.do_bench(run_ans_txl, warmup=10, rep=20) / 1000 # type: ignore + prefill_flops_txl = flop / prefill_ans_time_txl / 1e12 + print(f"TXL MLA Prefill: {prefill_ans_time_txl * 1e6:4.0f} us, {prefill_flops_txl:.3f} TFlops") + + if p.check_correctness: + torch.cuda.synchronize() + ref_max_logits, ref_lse, ref_out = reference_torch(p, t, sm_scale) + torch.cuda.synchronize() + res = (ans_out-ref_out) + + is_correct = True + is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6) + is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536) + is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536) + + is_correct &= check_is_allclose("out_txl", ans_out_txl, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6) + is_correct &= check_is_allclose("max_logits_txl", ans_max_logits_txl, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536) + is_correct &= check_is_allclose("lse_txl", ans_lse_txl, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536) + + return is_correct + else: + return True + +def main(): + device = torch.device("cuda:0") + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.set_float32_matmul_precision('high') + + correctness_cases = [ + # Regular shapes + TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) + for s_kv, topk in [ + # Regular shapes + (128, 128), + (256, 256), + (512, 512), + + # Irregular shapes + (592, 128), + (1840, 256), + (1592, 384), + (1521, 512), + + # Irregular shapes with OOB TopK + (95, 128), + (153, 256), + (114, 384), + ] + for s_q in [ + 1, 62 + ] + ] + + corner_cases = [ + # In these cases, some blocks may not have any valid topk indices + TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) + for s_kv, topk in [ + (32, 2048), + (64, 8192) + ] + for s_q in [1, 1024] + ] + + performance_cases = [ + TestParam(1, s_q, s_kv, topk, h_q=128) + for s_q in [4096] + for s_kv in [4096, 8192, 16384, 32768, 49152, 65536, 81920, 98304, 114688, 131072] + for topk in [2048] + ] + + testcases = correctness_cases + corner_cases + performance_cases + testcases = [ + TestParam(1, 64, 128, 128, h_q=128, benchmark=True, check_correctness=True) + ] + + failed_cases = [] + for test in testcases: + if test.benchmark: + time.sleep(0.2) + is_correct = run_test(test) + if not is_correct: + failed_cases.append(test) + + if len(failed_cases) > 0: + print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") + for case in failed_cases: + print(f" {case}") + else: + print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") + +if __name__ == '__main__': + device = torch.device("cuda:0") + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.set_float32_matmul_precision('high') + + correctness_cases = [ + # Regular shapes + TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) + for s_kv, topk in [ + # Regular shapes + (128, 128), + (256, 256), + (512, 512), + + # Irregular shapes + (592, 128), + (1840, 256), + (1592, 384), + (1521, 512), + + # Irregular shapes with OOB TopK + (95, 128), + (153, 256), + (114, 384), + ] + for s_q in [ + 1, 62 + ] + ] + + corner_cases = [ + # In these cases, some blocks may not have any valid topk indices + TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) + for s_kv, topk in [ + (32, 2048), + (64, 8192) + ] + for s_q in [1, 1024] + ] + + performance_cases = [ + TestParam(1, s_q, s_kv, topk, h_q=128) + for s_q in [4096] + for s_kv in [4096, 8192, 16384, 32768, 49152, 65536, 81920, 98304, 114688, 131072] + for topk in [2048] + ] + + testcases = correctness_cases + corner_cases + performance_cases + testcases = [ + TestParam(1, 64, 128, 128, h_q=128, benchmark=True, check_correctness=True) + ] + + failed_cases = [] + for test in testcases: + if test.benchmark: + time.sleep(0.2) + is_correct = run_test(test) + if not is_correct: + failed_cases.append(test) + + if len(failed_cases) > 0: + print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") + for case in failed_cases: + print(f" {case}") + else: + print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") From 5b99796abe87cd3410813d9ee58eb9cca4412875 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Tue, 18 Nov 2025 19:23:55 +0800 Subject: [PATCH 02/17] add modal file --- docker/modal_mla_decoding.py | 116 +++++++++++++++++++++++++++++++++++ docker/modal_mla_prefill.py | 115 ++++++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 docker/modal_mla_decoding.py create mode 100644 docker/modal_mla_prefill.py diff --git a/docker/modal_mla_decoding.py b/docker/modal_mla_decoding.py new file mode 100644 index 0000000..e648be0 --- /dev/null +++ b/docker/modal_mla_decoding.py @@ -0,0 +1,116 @@ +from modal import Image, App, Volume +import pathlib +local_dir = pathlib.Path(__file__).parent +root_dir = local_dir.parent +requirements_file = root_dir / "requirements.txt" +txl_wheel_file = local_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" + +test_file_lib = local_dir / "lib.py" +# test_file = local_dir / "test_flash_mla_prefill.py" +test_file = local_dir / "test_flash_mla_decoding.py" + +flash_mla_dir = local_dir / "flash_mla" + +app = App(name="txl") # Note: this is optional since Modal 0.57 +volume = Volume.from_name("txl-dump", create_if_missing=True) # create a cloud volume to store compiled dump files + +txl_image = ( + Image.debian_slim(python_version="3.12") + #Image.from_dockerfile(path="./Dockerfile") + .workdir("/workspace") + .add_local_file(txl_wheel_file, remote_path="/workspace/", copy=True) # copy the local code to the image + .run_commands( "ls .") + .pip_install_from_requirements(requirements_file) # local file not remote file + .run_commands( + "pip install /workspace/txl-3.4.0-cp312-cp312-linux_x86_64.whl", + ) + .add_local_file(test_file, remote_path="/workspace/test_txl.py", copy=False) # copy after image build, no need rebuild + .add_local_file(test_file_lib, remote_path="/workspace/lib.py", copy=False) + .add_local_dir(flash_mla_dir, remote_path="/workspace/flash_mla", copy=False) +) + +# Example function that uses the image +@app.function(gpu="H100", image=txl_image, timeout=60*60, + volumes={"/workspace/dump": volume}) +def run_demo(): + import subprocess, sys, os, torch, time + def get_gpu_type(): + + try: + # Execute nvidia-smi command to query GPU details + result = subprocess.run(['nvidia-smi', '-q'], capture_output=True, text=True, check=True) + output = result.stdout + + # Look for indicators of SXM or PCIe in the output + for line in output.split("\n"): + if "Product Name" in line: + print(line) + if 'H100' in line and 'HBM3' in line: + return True + except subprocess.CalledProcessError as e: + print(f"Error running nvidia-smi: {e}") + except FileNotFoundError: + print("nvidia-smi not found. Please ensure NVIDIA drivers are installed and in your PATH.") + return False + + def test_demo(): + os.makedirs("/workspace/dump", exist_ok=True) + logs_dir = pathlib.Path("/workspace/dump/logs") + logs_dir.mkdir(parents=True, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + log_path = logs_dir / f"mla-{ts}.log" + + env = os.environ.copy() + # env["TRITON_PRINT_AUTOTUNING"] = "0" + # env["TRITON_KERNEL_DUMP"] = "1" + # env["TRITON_DUMP_DIR"] = "/workspace/dump" + # env["TRITON_ALWAYS_COMPILE"] = "1" + # env["CUDA_LAUNCH_BLOCKING"] = "1" + + cmd = [sys.executable, "-u", "/workspace/test_txl.py"] + + with open(log_path, "w", buffering=1, encoding="utf-8", errors="replace") as f: + proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, env=env, bufsize=1 + ) + assert proc.stdout is not None + for line in proc.stdout: + print(line, end="") + f.write(line) + + rc = proc.wait() + + print(f"\n=== FULL LOG SAVED ===\n{log_path}\n") + if rc != 0: + raise SystemExit(rc) + + def import_cuBLAS_lib(): + import os, ctypes, pathlib + import nvidia.cublas, nvidia.cuda_runtime + + cublas_dir = (pathlib.Path(nvidia.cublas.__file__).parent / "lib").resolve() + cudart_dir = (pathlib.Path(nvidia.cuda_runtime.__file__).parent / "lib").resolve() + + os.environ["LD_LIBRARY_PATH"] = f"{cublas_dir}:{cudart_dir}:" + os.environ.get("LD_LIBRARY_PATH","") + + for name12, name in [("libcublas.so.12","libcublas.so"), + ("libcublasLt.so.12","libcublasLt.so")]: + src = cublas_dir / name12 + dst = cublas_dir / name + if src.exists() and not dst.exists(): + try: dst.symlink_to(name12) + except FileExistsError: pass + except PermissionError: pass + + ctypes.CDLL(str(cudart_dir / "libcudart.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublasLt.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublas.so.12"), mode=ctypes.RTLD_GLOBAL) + + def test_demo2(): + from test_txl import main + import torch + main(torch.float16) + + import_cuBLAS_lib() + test_demo2() \ No newline at end of file diff --git a/docker/modal_mla_prefill.py b/docker/modal_mla_prefill.py new file mode 100644 index 0000000..c467b37 --- /dev/null +++ b/docker/modal_mla_prefill.py @@ -0,0 +1,115 @@ +from modal import Image, App, Volume +import pathlib +local_dir = pathlib.Path(__file__).parent +root_dir = local_dir.parent +requirements_file = root_dir / "requirements.txt" +txl_wheel_file = local_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" + +test_file_lib = local_dir / "lib.py" +test_file = local_dir / "test_flash_mla_prefill.py" +# test_file = local_dir / "test_flash_mla_decoding.py" + +flash_mla_dir = local_dir / "flash_mla" + +app = App(name="txl") # Note: this is optional since Modal 0.57 +volume = Volume.from_name("txl-dump", create_if_missing=True) # create a cloud volume to store compiled dump files + +txl_image = ( + Image.debian_slim(python_version="3.12") + #Image.from_dockerfile(path="./Dockerfile") + .workdir("/workspace") + .add_local_file(txl_wheel_file, remote_path="/workspace/", copy=True) # copy the local code to the image + .run_commands( "ls .") + .pip_install_from_requirements(requirements_file) # local file not remote file + .run_commands( + "pip install /workspace/txl-3.4.0-cp312-cp312-linux_x86_64.whl", + ) + .add_local_file(test_file, remote_path="/workspace/test_txl.py", copy=False) # copy after image build, no need rebuild + .add_local_file(test_file_lib, remote_path="/workspace/lib.py", copy=False) + .add_local_dir(flash_mla_dir, remote_path="/workspace/flash_mla", copy=False) +) + +# Example function that uses the image +@app.function(gpu="H100", image=txl_image, timeout=60*60, + volumes={"/workspace/dump": volume}) +def run_demo(): + import subprocess, sys, os, torch, time + def get_gpu_type(): + + try: + # Execute nvidia-smi command to query GPU details + result = subprocess.run(['nvidia-smi', '-q'], capture_output=True, text=True, check=True) + output = result.stdout + + # Look for indicators of SXM or PCIe in the output + for line in output.split("\n"): + if "Product Name" in line: + print(line) + if 'H100' in line and 'HBM3' in line: + return True + except subprocess.CalledProcessError as e: + print(f"Error running nvidia-smi: {e}") + except FileNotFoundError: + print("nvidia-smi not found. Please ensure NVIDIA drivers are installed and in your PATH.") + return False + + def test_demo(): + os.makedirs("/workspace/dump", exist_ok=True) + logs_dir = pathlib.Path("/workspace/dump/logs") + logs_dir.mkdir(parents=True, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + log_path = logs_dir / f"mla-{ts}.log" + + env = os.environ.copy() + # env["TRITON_PRINT_AUTOTUNING"] = "0" + # env["TRITON_KERNEL_DUMP"] = "1" + # env["TRITON_DUMP_DIR"] = "/workspace/dump" + # env["TRITON_ALWAYS_COMPILE"] = "1" + # env["CUDA_LAUNCH_BLOCKING"] = "1" + + cmd = [sys.executable, "-u", "/workspace/test_txl.py"] + + with open(log_path, "w", buffering=1, encoding="utf-8", errors="replace") as f: + proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, env=env, bufsize=1 + ) + assert proc.stdout is not None + for line in proc.stdout: + print(line, end="") + f.write(line) + + rc = proc.wait() + + print(f"\n=== FULL LOG SAVED ===\n{log_path}\n") + if rc != 0: + raise SystemExit(rc) + + def import_cuBLAS_lib(): + import os, ctypes, pathlib + import nvidia.cublas, nvidia.cuda_runtime + + cublas_dir = (pathlib.Path(nvidia.cublas.__file__).parent / "lib").resolve() + cudart_dir = (pathlib.Path(nvidia.cuda_runtime.__file__).parent / "lib").resolve() + + os.environ["LD_LIBRARY_PATH"] = f"{cublas_dir}:{cudart_dir}:" + os.environ.get("LD_LIBRARY_PATH","") + + for name12, name in [("libcublas.so.12","libcublas.so"), + ("libcublasLt.so.12","libcublasLt.so")]: + src = cublas_dir / name12 + dst = cublas_dir / name + if src.exists() and not dst.exists(): + try: dst.symlink_to(name12) + except FileExistsError: pass + except PermissionError: pass + + ctypes.CDLL(str(cudart_dir / "libcudart.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublasLt.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublas.so.12"), mode=ctypes.RTLD_GLOBAL) + + def test_demo2(): + from test_txl import main + main() + + import_cuBLAS_lib() + test_demo2() \ No newline at end of file From 592971e92f6614fcf68696309c7db719ed7a42e0 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Tue, 18 Nov 2025 21:04:05 +0800 Subject: [PATCH 03/17] tl.store cause bug --- docker/flash_mla/txl_mla_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/flash_mla/txl_mla_interface.py b/docker/flash_mla/txl_mla_interface.py index 1f6e6d2..28abedf 100644 --- a/docker/flash_mla/txl_mla_interface.py +++ b/docker/flash_mla/txl_mla_interface.py @@ -333,8 +333,8 @@ def mla_txl( # cutedsl 这里要求KV_SEQ_LEN是BLOCK_N*2的整数倍,因为 l_i = l_i + L1_reg # l0+l1 m_i += tl.math.log2(l_i) accL = accL / l_i[:, None] - m_ptrs = M + off_z * (H * N_Q) + off_kvh * (heads_per_kv * N_Q) + offs_m - tl.store(m_ptrs, m_i) + # m_ptrs = M + off_z * (H * N_Q) + off_kvh * (heads_per_kv * N_Q) + offs_m + # tl.store(m_ptrs, m_i) # reg -> smem -> gmem txl.smem_store(cur_bZL0, accL.to(dtype)) # store to Vl0, which reused as PVl From 54d750363b9b1d3daf8c8effec0c53579f65cd6b Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Thu, 20 Nov 2025 09:53:00 +0800 Subject: [PATCH 04/17] fix performance test problem --- docker/flash_mla/__init__.py | 4 +- docker/flash_mla/txl_mla_interface.py | 54 ++++++++++++++++ docker/flash_mla/txl_nsa_interface.py | 75 ++++++++++++++++++++++ docker/modal_mla_decoding.py | 4 +- docker/test_flash_mla_decoding.py | 39 ++++++----- docker/test_flash_mla_prefill.py | 17 +++-- python/txl/tutorials/01-matmul.py | 19 +++--- python/txl/tutorials/02-flash-attention.py | 10 +-- 8 files changed, 183 insertions(+), 39 deletions(-) diff --git a/docker/flash_mla/__init__.py b/docker/flash_mla/__init__.py index 84dc241..d099973 100644 --- a/docker/flash_mla/__init__.py +++ b/docker/flash_mla/__init__.py @@ -9,6 +9,6 @@ flash_mla_sparse_fwd ) -from flash_mla.txl_nsa_interface import txl_mla +from flash_mla.txl_nsa_interface import txl_mla, make_txl_mla_runner -from flash_mla.txl_mla_interface import mla_test +from flash_mla.txl_mla_interface import mla_test, make_mla_runner diff --git a/docker/flash_mla/txl_mla_interface.py b/docker/flash_mla/txl_mla_interface.py index 28abedf..a466091 100644 --- a/docker/flash_mla/txl_mla_interface.py +++ b/docker/flash_mla/txl_mla_interface.py @@ -510,6 +510,60 @@ def grid(META): return o +def make_mla_runner(q, kv, qpe, kpe, sm_scale, algo=0): + HEAD_DIM_Q = q.shape[-1] + HEAD_DIM_Z = kv.shape[-1] + HEAD_DIM_PE = qpe.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_Z + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + o = torch.empty_like(q) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), + device=q.device, dtype=torch.float32) + + if supports_host_descriptor(): + y_dim = q.shape[0] * q.shape[1] * q.shape[2] + kv_dim = kv.shape[0] * kv.shape[1] * kv.shape[2] + dummy_block = [1, 1] + desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_Z], strides=[HEAD_DIM_Z, 1], block_shape=dummy_block) + desc_kv = TensorDescriptor(kv, shape=[kv_dim, HEAD_DIM_Z], strides=[HEAD_DIM_Z, 1], block_shape=dummy_block) + desc_qpe = TensorDescriptor(qpe, shape=[y_dim, HEAD_DIM_PE], strides=[HEAD_DIM_PE, 1], block_shape=dummy_block) + desc_kpe = TensorDescriptor(kpe, shape=[kv_dim, HEAD_DIM_PE], strides=[HEAD_DIM_PE, 1], block_shape=dummy_block) + desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_Z], strides=[HEAD_DIM_Z, 1], block_shape=dummy_block) + else: + desc_q, desc_kv, desc_qpe, desc_kpe, desc_o = q, kv, qpe, kpe, o + + def alloc_fn(size: int, align: int, _): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + q_heads_per_kv_heads = q.shape[1] // kv.shape[1] + total_q_seqlen = q_heads_per_kv_heads * q.shape[2] + + def grid(META): + return (triton.cdiv(total_q_seqlen, META["BLOCK_M"]), + kv.shape[1] * NUM_SMS, 1) + + algo_map = {0: mla_txl} + kern = algo_map[algo] + + def run_once(): + kern[grid]( + sm_scale, M, + q.shape[0], q.shape[1], kv.shape[1], + desc_q, desc_kv, desc_o, + q.shape[2], kv.shape[2], + desc_qpe, desc_kpe, + PE_DIM=HEAD_DIM_PE, + D=HEAD_DIM_Z, + NUM_SMS=NUM_SMS, + ) + return o + + return run_once + + def ref_mla(q, kv, qpe, kpe, sm_scale): Z, H, N_Q, R0 = q.shape _, KV_HEADS, KV_SEQ_LEN, _ = kv.shape diff --git a/docker/flash_mla/txl_nsa_interface.py b/docker/flash_mla/txl_nsa_interface.py index 41927fe..58acf27 100644 --- a/docker/flash_mla/txl_nsa_interface.py +++ b/docker/flash_mla/txl_nsa_interface.py @@ -559,3 +559,78 @@ def txl_mla( kv_pe.stride(0), num_warps=4, num_warpgroups=3) return out, max_logits, lse + +def make_txl_mla_runner( + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_nope: torch.Tensor, + kv_pe: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + dump_dir=None, +): + + from triton import knobs + + knobs.runtime.override_arch = "sm90" + # knobs.autotuning.print = True + # knobs.compilation.always_compile = True + + if dump_dir is not None: + knobs.compilation.dump_ir = True + knobs.cache.dump_dir = dump_dir + + B_H = 64 + B_TOPK = 64 + D_Q = 576 + D_K = D_Q + D_V = d_v + + s_q = q_nope.size(0) + s_kv = kv_nope.size(0) + top_k = indices.size(2) + h_q = q_nope.size(1) + h_kv = kv_nope.size(1) + assert h_kv == 1 + + d_qk = D_Q + d_v = D_V + + qk_scale = sm_scale * 1.44269504 + + out = torch.empty((s_q, h_q, d_v), dtype=q_nope.dtype, device=q_nope.device) + max_logits = torch.empty((s_q, h_q), dtype=torch.float32, device=q_nope.device) + lse = torch.empty((s_q, h_q), dtype=torch.float32, device=q_nope.device) + + q_nope_desc = TensorDescriptor(q_nope, (s_q * h_q, 512), (512, 1), [B_H, 512]) + q_pe_desc = TensorDescriptor(q_pe, (s_q * h_q, 64), (64, 1), [B_H, 64]) + o_desc = TensorDescriptor(out, (s_q * h_q, d_v), (d_v, 1), [B_H, D_V]) + max_logits_desc = TensorDescriptor(max_logits, (s_q * h_q,), (1,), [B_H]) + lse_desc = TensorDescriptor(lse, (s_q * h_q,), (1,), [B_H]) + + NUM_HEAD_BLOCKS = h_q // B_H + + grid = (NUM_HEAD_BLOCKS * s_q,) + + def runner(): + txl_mla0[grid]( + q_nope_desc, q_pe_desc, + kv_nope, kv_pe, + o_desc, + max_logits_desc, lse_desc, + indices, + qk_scale, + s_q, s_kv, + B_H, D_Q, D_V, + NUM_HEAD_BLOCKS, + top_k, + B_TOPK, + kv_nope.stride(0), + kv_pe.stride(0), + num_warps=4, + num_warpgroups=3, + ) + return out, max_logits, lse + + return runner diff --git a/docker/modal_mla_decoding.py b/docker/modal_mla_decoding.py index e648be0..1fdba31 100644 --- a/docker/modal_mla_decoding.py +++ b/docker/modal_mla_decoding.py @@ -35,7 +35,8 @@ def run_demo(): import subprocess, sys, os, torch, time def get_gpu_type(): - + sm = torch.cuda.get_device_properties("cuda").multi_processor_count + print(f"SM count: {sm}") try: # Execute nvidia-smi command to query GPU details result = subprocess.run(['nvidia-smi', '-q'], capture_output=True, text=True, check=True) @@ -112,5 +113,6 @@ def test_demo2(): import torch main(torch.float16) + get_gpu_type() import_cuBLAS_lib() test_demo2() \ No newline at end of file diff --git a/docker/test_flash_mla_decoding.py b/docker/test_flash_mla_decoding.py index 5487694..f8f200b 100644 --- a/docker/test_flash_mla_decoding.py +++ b/docker/test_flash_mla_decoding.py @@ -260,30 +260,32 @@ def run_flash_mla(): indices=indices_in_kvcache ) - def txl_mla(): - return flash_mla.mla_test( - q_value, - k_value, - q_pe, - k_pe, - 1 / math.sqrt(576), - algo = 0, - ) + runner = flash_mla.make_mla_runner(q_value, k_value, q_pe, k_pe, 1 / math.sqrt(576), algo = 0) + + # def txl_mla(): + # return flash_mla.mla_test( + # q_value, + # k_value, + # q_pe, + # k_pe, + # 1 / math.sqrt(576), + # algo = 0, + # ) out_ans, lse_ans = run_flash_mla() - txl_ans_out = txl_mla().permute(0, 2, 1, 3).contiguous() + txl_ans_out = runner().permute(0, 2, 1, 3).contiguous() out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536) assert check_is_allclose("txl_out", txl_ans_out, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) print("Correctness check passed!") - print(f"ref_out result sample: {out_ref[0, 0, 0, :8]}") - print(f"flash_mla_out result sample: {out_ans[0, 0, 0, :8]}") - print(f"txl_mla_out result sample: {txl_ans_out[0, 0, 0, :8]}") + # print(f"ref_out result sample: {out_ref[0, 0, 0, :8]}") + # print(f"flash_mla_out result sample: {out_ans[0, 0, 0, :8]}") + # print(f"txl_mla_out result sample: {txl_ans_out[0, 0, 0, :8]}") print("===============================") print("Running performance test...") if t.test_performance: - time_usage_txl: float = triton.testing.do_bench(txl_mla) / 1000 # type: ignore + torch.cuda.synchronize() time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk compute_volume_flop = t.b * t.h_q * t.s_q * sum([ @@ -299,6 +301,8 @@ def txl_mla(): ]) achieved_tflops = compute_volume_flop / time_usage / 1e12 achieved_gBps = memory_volume_B / time_usage / 1e9 + torch.cuda.synchronize() + time_usage_txl: float = triton.testing.do_bench(runner) / 1000 # type: ignore achieved_tflops_txl = compute_volume_flop / time_usage_txl / 1e12 achieved_gBps_txl = memory_volume_B / time_usage_txl / 1e9 @@ -355,8 +359,9 @@ def main(torch_dtype): testcases = correctness_cases + corner_cases + performance_cases testcases = [ - TestParam(132, 64, 256, is_varlen=False, is_causal=False, is_fp8=False, topk=None, test_performance=True), - # TestParam(132, 2, 4096, is_varlen=False, is_causal=False, is_fp8=False, topk=None, test_performance=True) + # TestParam(132, 64, 256, is_varlen=False, is_causal=False, is_fp8=False, topk=None, test_performance=True), + TestParam(132, 2, s_k, is_varlen=False, is_causal=False, is_fp8=False, topk=None, test_performance=True) + for s_k in [1024, 2048, 4096, 8192, 16384] ] # Prune out unsupported cases @@ -364,7 +369,9 @@ def main(torch_dtype): if cc_major == 10: testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)] + import time for testcase in testcases: + time.sleep(0.2) test_flash_mla(testcase) diff --git a/docker/test_flash_mla_prefill.py b/docker/test_flash_mla_prefill.py index 5ada4f9..38699f6 100644 --- a/docker/test_flash_mla_prefill.py +++ b/docker/test_flash_mla_prefill.py @@ -7,7 +7,7 @@ import torch import triton -from flash_mla import flash_mla_sparse_fwd, txl_mla +from flash_mla import flash_mla_sparse_fwd, txl_mla, make_txl_mla_runner from lib import check_is_allclose @dataclasses.dataclass @@ -116,13 +116,15 @@ def run_ans_fmla(): t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale ) - def run_ans_txl(): - return txl_mla( - q_nope, q_pe, kv_nope, kv_pe, t.indices.squeeze(0), sm_scale=sm_scale - ) + runner = make_txl_mla_runner(q_nope, q_pe, kv_nope, kv_pe, t.indices.squeeze(0), sm_scale=sm_scale) + + # def run_ans_txl(): + # return txl_mla( + # q_nope, q_pe, kv_nope, kv_pe, t.indices.squeeze(0), sm_scale=sm_scale + # ) ans_out, ans_max_logits, ans_lse = run_ans_fmla() - ans_out_txl, ans_max_logits_txl, ans_lse_txl = run_ans_txl() + ans_out_txl, ans_max_logits_txl, ans_lse_txl = runner() torch.cuda.synchronize() @@ -131,7 +133,8 @@ def run_ans_txl(): prefill_ans_time: float = triton.testing.do_bench(run_ans_fmla, warmup=10, rep=20) / 1000 # type: ignore prefill_flops = flop / prefill_ans_time / 1e12 print(f"FlashMLA Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops") - prefill_ans_time_txl: float = triton.testing.do_bench(run_ans_txl, warmup=10, rep=20) / 1000 # type: ignore + torch.cuda.synchronize() + prefill_ans_time_txl: float = triton.testing.do_bench(runner, warmup=10, rep=20) / 1000 # type: ignore prefill_flops_txl = flop / prefill_ans_time_txl / 1e12 print(f"TXL MLA Prefill: {prefill_ans_time_txl * 1e6:4.0f} us, {prefill_flops_txl:.3f} TFlops") diff --git a/python/txl/tutorials/01-matmul.py b/python/txl/tutorials/01-matmul.py index ad1637c..3ca189a 100644 --- a/python/txl/tutorials/01-matmul.py +++ b/python/txl/tutorials/01-matmul.py @@ -1975,15 +1975,15 @@ def profile(M, N, K, dtype, log=False): parser.add_argument("-K", type=int, required=False, default=512) parser.add_argument("--K_range", type=int, nargs=2) parser.add_argument("--K_step", type=int, default=512) - parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") + parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp8") args = parser.parse_args() - dump_dir='dump/0930mm_split/' - #dump_dir = None + # dump_dir='dump/0930mm_split/' + dump_dir = None from triton import knobs #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "tritongpu-remove-layout-conversions" - os.environ["TRITON_LLVM_DEBUG_ONLY"] = "txlgpu-pipeliner" + # os.environ["TRITON_LLVM_DEBUG_ONLY"] = "txlgpu-pipeliner" #knobs.runtime.override_arch='sm100' knobs.autotuning.print=True knobs.compilation.always_compile=True @@ -2005,14 +2005,17 @@ def profile(M, N, K, dtype, log=False): #validate(32, 32, 32, dtype) #validate(128, 128, 512, dtype, log=True) - validate(8192, 8192, args.K_range[0], dtype, log=True) + # validate(8192, 8192, args.K_range[0], dtype, log=True) #profile(8192, 8192, args.K_range[0], dtype) - exit() + # exit() + print(dtype) proton.start("matmul", hook="triton") #proton.deactivate() - for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): - bench(K, dtype) + # for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + # bench(K, dtype) + for k_seqlen in [256, 512, 1024, 2048, 4096, 8192, 16384]: + bench(k_seqlen, dtype) proton.finalize() show_profile(args.prec, "matmul") diff --git a/python/txl/tutorials/02-flash-attention.py b/python/txl/tutorials/02-flash-attention.py index fbbc285..dd8e5d5 100644 --- a/python/txl/tutorials/02-flash-attention.py +++ b/python/txl/tutorials/02-flash-attention.py @@ -2845,7 +2845,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16, algo=0, no_tune= TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') -BATCH, N_HEADS, HEAD_DIM = 16, 32, 128 +BATCH, N_HEADS, HEAD_DIM = 4, 32, 128 TORCH_HAS_FP8=False # vary seq length for fixed head and batch=4 @@ -2858,8 +2858,8 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16, algo=0, no_tune= configs.append( triton.testing.Benchmark( x_names=["N_CTX"], - #x_vals=[2**i for i in range(10, 15)], - x_vals=[2**i for i in range(14, 15)], + x_vals=[2**i for i in range(10, 15)], + # x_vals=[2**i for i in range(14, 15)], line_arg="provider", line_vals=(["triton-fp16"] if Has_TXL else []) + (["triton-fp8"] if TORCH_HAS_FP8 else []) + (["flash"] if HAS_FLASH else []), @@ -2948,7 +2948,7 @@ def run_test(algo=0, dump_dir=None): #test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=2, no_tune=no_tune, profiling=PROFILING) #test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=4, no_tune=no_tune, profiling=PROFILING) #test_op(1, 2, 1536, 128, False, dtype=torch.float16, algo=4, no_tune=no_tune, profiling=PROFILING) - test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=algo, no_tune=no_tune, profiling=PROFILING) + # test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=algo, no_tune=no_tune, profiling=PROFILING) print("BENCH...") bench_flash_attention.run(save_path=".", print_data=True, algo=algo, no_tune=no_tune) @@ -2956,4 +2956,4 @@ def run_test(algo=0, dump_dir=None): if __name__ == "__main__": #run_test(6, dump_dir='dump/fa1113') #run_test(5, dump_dir='dump/fa1117') - run_test(5) + run_test(4) From 5842775f338d914628c1ba76390bdc5d57ca6e00 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Thu, 20 Nov 2025 12:32:37 +0800 Subject: [PATCH 05/17] add draw.py; add mla benchmark --- docker/cross_entropy.py | 3 +- docker/draw/draw_attention.py | 168 +++++ docker/draw/draw_gemm.py | 95 +++ docker/draw/draw_mla.py | 102 +++ docker/draw/draw_softmax.py | 94 +++ docker/modal_tilelang.py | 118 ++++ docker/tilelang/benchmark_mla.py | 650 ++++++++++++++++++++ docker/tilelang/example_mla_decode_paged.py | 408 ++++++++++++ docker/tilelang/profiler/__init__.py | 287 +++++++++ python/txl/tutorials/04-softmax.py | 78 +-- 10 files changed, 1963 insertions(+), 40 deletions(-) create mode 100644 docker/draw/draw_attention.py create mode 100644 docker/draw/draw_gemm.py create mode 100644 docker/draw/draw_mla.py create mode 100644 docker/draw/draw_softmax.py create mode 100644 docker/modal_tilelang.py create mode 100644 docker/tilelang/benchmark_mla.py create mode 100644 docker/tilelang/example_mla_decode_paged.py create mode 100644 docker/tilelang/profiler/__init__.py diff --git a/docker/cross_entropy.py b/docker/cross_entropy.py index fc376d6..743969c 100644 --- a/docker/cross_entropy.py +++ b/docker/cross_entropy.py @@ -53,4 +53,5 @@ def get_gpu_type(): from test_txl import test_softmax #test_softmax(size=16*1024) - test_softmax(M=32*1024, N=32*1024) + for i in [1, 4, 8, 16, 32]: + test_softmax(M=32*1024, N=i*1024) diff --git a/docker/draw/draw_attention.py b/docker/draw/draw_attention.py new file mode 100644 index 0000000..013b22e --- /dev/null +++ b/docker/draw/draw_attention.py @@ -0,0 +1,168 @@ +import numpy as np +import matplotlib.pyplot as plt + +# 横轴:上下四张图的 context length +ctx = np.array([1024, 2048, 4096, 8192, 16384]) +x = np.arange(len(ctx)) # 真正用于画图的位置:0,1,2,... + +# ====================== 示例数据,按需替换 ====================== +# FP16, causal = False +fp16_nc_fa3 = np.array([570, 600, 610, 630, 640]) +fp16_nc_txl = np.array([484, 530, 565, 588, 599]) +fp16_nc_triton= np.array([390, 460, 500, 520, 540]) +fp16_nc_tile = np.array([447, 590, 610, 570, 600]) +fp16_nc_tk = np.array([453, 597, 599, 610, 590]) + +# FP16, causal = True Txl 还没有支持 causal attention +fp16_c_fa3 = np.array([420, 520, 620, 680, 650]) +fp16_c_tawa = np.array([380, 480, 580, 630, 620]) +fp16_c_triton = np.array([320, 420, 500, 550, 540]) +fp16_c_tile = np.array([340, 440, 520, 570, 560]) +fp16_c_tk = np.array([330, 430, 510, 560, 550]) + +# FP8, causal = False Txl 还没有支持 FP8 +fp8_nc_fa3 = np.array([550, 750, 850, 900, 880]) +fp8_nc_tawa = np.array([500, 700, 780, 820, 800]) +fp8_nc_triton = np.array([420, 600, 700, 730, 710]) +fp8_nc_tile = np.array([440, 620, 720, 760, 740]) +fp8_nc_tk = np.array([430, 610, 710, 750, 730]) + +# FP8, causal = True Txl 还没有支持 FP8 +fp8_c_fa3 = np.array([600, 800, 900, 950, 930]) +fp8_c_tawa = np.array([540, 720, 800, 840, 820]) +fp8_c_triton = np.array([450, 630, 720, 760, 740]) +fp8_c_tile = np.array([470, 650, 740, 780, 760]) +fp8_c_tk = np.array([460, 640, 730, 770, 750]) +# =============================================================== + +methods_up = ["FA3 (CUTLASS)", "Txl", "Triton", "TileLang", "ThunderKittens"] +methods_down = ["FA3 (CUTLASS)", "Txl", "Triton"] +colors = { + "FA3 (CUTLASS)": "#f1c40f", + "Txl": "#e74c3c", + "Triton": "#1abc9c", + "TileLang": "#ff6fb3", + "ThunderKittens": "#3498db", +} + +data = { + ("FP16, causal=False"): { + "FA3 (CUTLASS)": fp16_nc_fa3, + "Txl": fp16_nc_txl, + "Triton": fp16_nc_triton, + "TileLang": fp16_nc_tile, + "ThunderKittens": fp16_nc_tk, + }, + ("FP16, causal=True"): { + "FA3 (CUTLASS)": fp16_c_fa3, + "Txl": fp16_c_tawa, + "Triton": fp16_c_triton, + "TileLang": fp16_c_tile, + "ThunderKittens": fp16_c_tk, + }, + ("FP8, causal=False"): { + "FA3 (CUTLASS)": fp8_nc_fa3, + "Txl": fp8_nc_tawa, + "Triton": fp8_nc_triton, + }, + ("FP8, causal=True"): { + "FA3 (CUTLASS)": fp8_c_fa3, + "Txl": fp8_c_tawa, + "Triton": fp8_c_triton, + }, +} + +fig, axes = plt.subplots(2, 2, figsize=(12, 4), sharex=True) + +bar_width = 0.16 + +def plot_panel(ax, title, panel_data, ylim, yticks, methods): + count = len(methods) + offsets = (np.arange(count) - (count - 1) / 2) * bar_width + for i, m in enumerate(methods): + vals = panel_data[m] + if m == "FA3 (CUTLASS)": + ax.bar( + x + offsets[i], # 注意这里用 x,而不是 ctx + vals, + bar_width, + label=m, + color=colors[m], + hatch="//", + ) + else: + ax.bar( + x + offsets[i], + vals, + bar_width, + label=m, + color=colors[m], + ) + ax.set_title(title) + ax.set_ylim(*ylim) + ax.set_yticks(yticks) + ax.grid(axis="y", linestyle="--", alpha=0.4) + ax.set_xticks(x) + ax.set_xticklabels(ctx) # 只把刻度标签写成 1024, 2048... + +# 上排:FP16 +plot_panel( + axes[0, 0], + "FP16, causal=false", + data[("FP16, causal=False")], + ylim=(0, 800), + yticks=[0, 200, 400, 600, 800], + methods=methods_up, +) +plot_panel( + axes[0, 1], + "FP16, causal=true", + data[("FP16, causal=True")], + ylim=(0, 800), + yticks=[0, 200, 400, 600, 800], + methods=methods_up, +) + +# 下排:FP8 +plot_panel( + axes[1, 0], + "FP8, causal=false", + data[("FP8, causal=False")], + ylim=(0, 1000), + yticks=[0, 250, 500, 750, 1000], + methods=methods_down, +) +plot_panel( + axes[1, 1], + "FP8, causal=true", + data[("FP8, causal=True")], + ylim=(0, 1000), + yticks=[0, 250, 500, 750, 1000], + methods=methods_down, +) + +# axes[0, 0].set_ylabel("Throughput (TFLOPs/s)") +# axes[1, 0].set_ylabel("Throughput (TFLOPs/s)") +# axes[1, 0].set_xlabel("Context length") +# axes[1, 1].set_xlabel("Context length") + +handles, labels = axes[0, 0].get_legend_handles_labels() +fig.legend( + handles, + labels, + loc="upper center", + ncol=5, + bbox_to_anchor=(0.5, 0.98), +) + +# 全局坐标轴标签,只写一次,自动在整张图居中对齐 +fig.supylabel("Throughput (TFLOPs/s)") # 左侧垂直居中 +fig.supxlabel("Context length") # 底部水平居中 + +plt.subplots_adjust(top=0.82, bottom=0.12, left=0.08, right=0.98, + wspace=0.15, hspace=0.35) + +# plt.tight_layout() +plt.show() + +# batch4-head32-d128 \ No newline at end of file diff --git a/docker/draw/draw_gemm.py b/docker/draw/draw_gemm.py new file mode 100644 index 0000000..84a69c6 --- /dev/null +++ b/docker/draw/draw_gemm.py @@ -0,0 +1,95 @@ +import numpy as np +import matplotlib.pyplot as plt + +gemm_k = np.array([256, 512, 1024, 2048, 4096, 8192, 16384]) + +# ------- 示例数据,自行替换 ------- +cublas_fp16 = np.array([517, 626, 712, 717, 697, 680, 667]) +# tawa_fp16 = np.array([580, 760, 780, 770, 780, 760, 740]) +txl_fp16 = np.array([473, 615, 705, 739, 748, 701, 694]) +triton_fp16 = np.array([470, 603, 680, 678, 670, 640, 630]) +tile_fp16 = np.array([300, 420, 600, 690, 700, 720, 740]) +tk_fp16 = np.array([400, 680, 680, 709, 780, 788, 798]) + +cublas_fp8 = np.array([1300, 1400, 1500, 1550, 1500, 1500, 1400]) +tawa_fp8 = np.array([900, 1470, 1600, 1600, 1550, 1500, 1400]) +# fp8 运行不了 +txl_fp8 = np.array([900, 1470, 1600, 1600, 1550, 1500, 1400]) +triton_fp8 = np.array([600, 1000, 1450, 1500, 1500, 1500, 1400]) +tile_fp8 = np.array([700, 1400, 1500, 1600, 1550, 1500, 1400]) +tk_fp8 = np.array([600, 1200, 1600, 1500, 1500, 1450, 1400]) +# ------------------------------- + +theoretical_fp16 = 1000 +theoretical_fp8 = 2000 + +bar_width = 0.14 +x = np.arange(len(gemm_k)) + +fig, axes = plt.subplots(1, 2, figsize=(12, 3), sharey=False) + +# ------------ 左图 FP16 ------------ +ax = axes[0] +ax.axhline(theoretical_fp16, color='gray', linewidth=3, label='Theoretical Peak') + +ax.bar(x - 2*bar_width, cublas_fp16, bar_width, + label='cuBLAS', color='#f1c40f', hatch='//') +ax.bar(x - 1*bar_width, txl_fp16, bar_width, + label='Txl', color='#e74c3c') +ax.bar(x + 0*bar_width, triton_fp16, bar_width, + label='Triton', color='#1abc9c') +ax.bar(x + 1*bar_width, tile_fp16, bar_width, + label='TileLang', color='#ff6fb3') +ax.bar(x + 2*bar_width, tk_fp16, bar_width, + label='ThunderKittens', color='#3498db') + +ax.set_title('FP16') +ax.set_ylabel('Throughput (TFLOPs/s)') +ax.set_xticks(x) +ax.set_xticklabels(gemm_k) +ax.set_xlabel('GEMM K size') +ax.grid(axis='y', linestyle='--', alpha=0.4) + +left_ylim = 1200 +ax.set_ylim(0, left_ylim) +ax.set_yticks([0, 200, 400, 600, 800, 1000, 1200]) + +# ------------ 右图 FP8 ------------ +ax = axes[1] +ax.axhline(theoretical_fp8, color='gray', linewidth=3) + +ax.bar(x - 2*bar_width, cublas_fp8, bar_width, + color='#f1c40f', hatch='//') +ax.bar(x - 1*bar_width, txl_fp8, bar_width, + color='#e74c3c') +ax.bar(x + 0*bar_width, triton_fp8, bar_width, + color='#1abc9c') +ax.bar(x + 1*bar_width, tile_fp8, bar_width, + color='#ff6fb3') +ax.bar(x + 2*bar_width, tk_fp8, bar_width, + color='#3498db') + +ax.set_title('FP8') +ax.set_xticks(x) +ax.set_xticklabels(gemm_k) +ax.set_xlabel('GEMM K size') +ax.grid(axis='y', linestyle='--', alpha=0.4) + +# 关键:让 2000 所在位置与 1000 对齐 +right_ylim = theoretical_fp8 * left_ylim / theoretical_fp16 # = 2400 +ax.set_ylim(0, right_ylim) +ax.set_yticks(np.arange(0, 2001, 500)) # 0, 500, 1000, 1500, 2000 + +# ------------ 顶部 legend & 布局 ------------ +handles, labels = axes[0].get_legend_handles_labels() +fig.legend(handles, labels, loc='upper center', ncol=6, + bbox_to_anchor=(0.5, 0.98)) + +# plt.subplots_adjust(top=0.78, bottom=0.18, left=0.07, right=0.98, wspace=0.25) +plt.subplots_adjust(top=0.78, bottom=0.18, left=0.08, right=0.98, + wspace=0.15, hspace=0.35) + +# plt.tight_layout() +plt.show() + +# m=8192, n=8192 \ No newline at end of file diff --git a/docker/draw/draw_mla.py b/docker/draw/draw_mla.py new file mode 100644 index 0000000..9707245 --- /dev/null +++ b/docker/draw/draw_mla.py @@ -0,0 +1,102 @@ +import numpy as np +import matplotlib.pyplot as plt + +# 横轴:上下四张图的 context length +ctx = np.array([1024, 2048, 4096, 8192, 16384, 32768]) +x = np.arange(len(ctx)) # 真正用于画图的位置:0,1,2,... + +# ====================== 示例数据,按需替换 ====================== +# FP16, causal = False +fp16_nc_fm = np.array([529, 587, 612, 634, 596, 621]) +fp16_nc_txl = np.array([481, 527, 551, 550, 572, 581]) +fp16_nc_triton= np.array([20, 78, 98, 112, 136, 150]) +fp16_nc_tile = np.array([310, 402, 423, 453, 487, 476]) +fp16_nc_fi = np.array([290, 320, 360, 387, 350, 362]) +# =============================================================== + +methods = ["FlashMLA", "Txl", "Triton", "TileLang", "Flashinfer"] + +colors = { + "FlashMLA": "#f1c40f", + "Txl": "#e74c3c", + "Triton": "#1abc9c", + "TileLang": "#ff6fb3", + "Flashinfer": "#3498db", +} + +data = { + ("FP16, causal=False"): { + "FlashMLA": fp16_nc_fm, + "Txl": fp16_nc_txl, + "Triton": fp16_nc_triton, + "TileLang": fp16_nc_tile, + "Flashinfer": fp16_nc_fi, + }, +} + +# fig, axes = plt.subplots(2, 2, figsize=(12, 4), sharex=True) +fig, ax = plt.subplots(figsize=(6, 4)) + +bar_width = 0.16 + +def plot_panel(ax, title, panel_data, ylim, yticks, methods): + count = len(methods) + offsets = (np.arange(count) - (count - 1) / 2) * bar_width + for i, m in enumerate(methods): + vals = panel_data[m] + if m == "FlashMLA": + ax.bar( + x + offsets[i], # 注意这里用 x,而不是 ctx + vals, + bar_width, + label=m, + color=colors[m], + hatch="//", + ) + else: + ax.bar( + x + offsets[i], + vals, + bar_width, + label=m, + color=colors[m], + ) + ax.set_title(title) + ax.set_ylim(*ylim) + ax.set_yticks(yticks) + ax.grid(axis="y", linestyle="--", alpha=0.4) + ax.set_xticks(x) + ax.set_xticklabels(ctx) # 只把刻度标签写成 1024, 2048... + +# 上排:FP16 +plot_panel( + ax, + "FP16, causal=false", + data[("FP16, causal=False")], + ylim=(0, 800), + yticks=[0, 200, 400, 600, 800], + methods=methods, +) + +# axes[0, 0].set_ylabel("Throughput (TFLOPs/s)") +# axes[1, 0].set_ylabel("Throughput (TFLOPs/s)") +# axes[1, 0].set_xlabel("Context length") +# axes[1, 1].set_xlabel("Context length") + +handles, labels = ax.get_legend_handles_labels() +fig.legend( + handles, + labels, + loc="upper center", + ncol=5, + bbox_to_anchor=(0.5, 0.98), +) + +ax.set_ylabel("Throughput (TFLOPs/s)") +ax.set_xlabel("Context length") + +plt.subplots_adjust(top=0.82) + +plt.show() + +# b=132, s_q=2, h_q=128, h_kv=1, d=576, dv=512, causal=False, dtype=torch.float16 diff --git a/docker/draw/draw_softmax.py b/docker/draw/draw_softmax.py new file mode 100644 index 0000000..e6f67b7 --- /dev/null +++ b/docker/draw/draw_softmax.py @@ -0,0 +1,94 @@ +import numpy as np +import matplotlib.pyplot as plt + +ctx = ["(32K, 1K)", "(32K, 4K)", "(32K, 8K)", "(32K, 32K)", "(32K, 65K)"] +x = np.arange(len(ctx)) # 真正用于画图的位置:0,1,2,... + +# ====================== 示例数据,按需替换 ====================== +bf16_qk = np.array([2662, 2936, 2994, 3019, 3006]) +bf16_txl = np.array([1167, 2748, 3002, 3024, 3039]) +bf16_torch = np.array([2042, 1136, 1307, 1412, 1443]) +# =============================================================== + +methods = ["Quack", "Txl", "Torch"] + +colors = { + "Quack": "#f1c40f", + "Txl": "#e74c3c", + "Torch": "#1abc9c", +} + +data = { + ("Softmax, BF16"): { + "Quack": bf16_qk, + "Txl": bf16_txl, + "Torch": bf16_torch, + }, +} + +# fig, axes = plt.subplots(2, 2, figsize=(12, 4), sharex=True) +fig, ax = plt.subplots(figsize=(6, 4)) + +bar_width = 0.16 + +def plot_panel(ax, title, panel_data, ylim, yticks, methods): + count = len(methods) + offsets = (np.arange(count) - (count - 1) / 2) * bar_width + for i, m in enumerate(methods): + vals = panel_data[m] + if m == "Quack": + ax.bar( + x + offsets[i], # 注意这里用 x,而不是 ctx + vals, + bar_width, + label=m, + color=colors[m], + hatch="//", + ) + else: + ax.bar( + x + offsets[i], + vals, + bar_width, + label=m, + color=colors[m], + ) + ax.set_title(title) + ax.set_ylim(*ylim) + ax.set_yticks(yticks) + ax.grid(axis="y", linestyle="--", alpha=0.4) + ax.set_xticks(x) + ax.set_xticklabels(ctx) # 只把刻度标签写成 1024, 2048... + +# 上排:FP16 +plot_panel( + ax, + "Softmax, BF16", + data[("Softmax, BF16")], + ylim=(0, 3500), + yticks=[0, 500, 1000, 1500, 2000, 2500, 3000, 3500], + methods=methods, +) + +# axes[0, 0].set_ylabel("Throughput (TFLOPs/s)") +# axes[1, 0].set_ylabel("Throughput (TFLOPs/s)") +# axes[1, 0].set_xlabel("Context length") +# axes[1, 1].set_xlabel("Context length") + +handles, labels = ax.get_legend_handles_labels() +fig.legend( + handles, + labels, + loc="upper center", + ncol=5, + bbox_to_anchor=(0.5, 0.98), +) + +ax.set_ylabel("Memory Bandwidth (GB/s)") +ax.set_xlabel("(M, N): (Batch size, Reduction dim)") + +plt.subplots_adjust(top=0.82) + +plt.show() + +# b=132, s_q=2, h_q=128, h_kv=1, d=576, dv=512, causal=False, dtype=torch.float16 diff --git a/docker/modal_tilelang.py b/docker/modal_tilelang.py new file mode 100644 index 0000000..1e017ae --- /dev/null +++ b/docker/modal_tilelang.py @@ -0,0 +1,118 @@ +from modal import Image, App, Volume +import pathlib +local_dir = pathlib.Path(__file__).parent +root_dir = local_dir.parent +requirements_file = root_dir / "requirements.txt" +txl_wheel_file = local_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" + +flash_mla_dir = local_dir / "flash_mla" +test_file = local_dir / "tilelang" / "benchmark_mla.py" +kernel_file = local_dir / "tilelang" / "example_mla_decode_paged.py" + +app = App(name="txl") # Note: this is optional since Modal 0.57 +volume = Volume.from_name("txl-dump", create_if_missing=True) # create a cloud volume to store compiled dump files + +txl_image = ( + Image.from_registry( + "nvidia/cuda:12.4.0-devel-ubuntu22.04", + add_python="3.12", + ) + #Image.from_dockerfile(path="./Dockerfile") + .workdir("/workspace") + .add_local_file(txl_wheel_file, remote_path="/workspace/", copy=True) # copy the local code to the image + .run_commands( "ls .") + .pip_install_from_requirements(requirements_file) # local file not remote file + .pip_install("tilelang") + .run_commands( + "pip install /workspace/txl-3.4.0-cp312-cp312-linux_x86_64.whl", + ) + .add_local_file(test_file, remote_path="/workspace/test_txl.py", copy=False) # copy after image build, no need rebuild + .add_local_file(kernel_file, remote_path="/workspace/example_mla_decode_paged.py", copy=False) + .add_local_dir(flash_mla_dir, remote_path="/workspace/flash_mla", copy=False) +) + +# Example function that uses the image +@app.function(gpu="H100", image=txl_image, timeout=60*60, + volumes={"/workspace/dump": volume}) +def run_demo(): + import subprocess, sys, os, torch, time + def get_gpu_type(): + + try: + # Execute nvidia-smi command to query GPU details + result = subprocess.run(['nvidia-smi', '-q'], capture_output=True, text=True, check=True) + output = result.stdout + + # Look for indicators of SXM or PCIe in the output + for line in output.split("\n"): + if "Product Name" in line: + print(line) + if 'H100' in line and 'HBM3' in line: + return True + except subprocess.CalledProcessError as e: + print(f"Error running nvidia-smi: {e}") + except FileNotFoundError: + print("nvidia-smi not found. Please ensure NVIDIA drivers are installed and in your PATH.") + return False + + def test_demo(): + os.makedirs("/workspace/dump", exist_ok=True) + logs_dir = pathlib.Path("/workspace/dump/logs") + logs_dir.mkdir(parents=True, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + log_path = logs_dir / f"mla-{ts}.log" + + env = os.environ.copy() + # env["TRITON_PRINT_AUTOTUNING"] = "0" + # env["TRITON_KERNEL_DUMP"] = "1" + # env["TRITON_DUMP_DIR"] = "/workspace/dump" + # env["TRITON_ALWAYS_COMPILE"] = "1" + # env["CUDA_LAUNCH_BLOCKING"] = "1" + + cmd = [sys.executable, "-u", "/workspace/test_txl.py"] + + with open(log_path, "w", buffering=1, encoding="utf-8", errors="replace") as f: + proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, env=env, bufsize=1 + ) + assert proc.stdout is not None + for line in proc.stdout: + print(line, end="") + f.write(line) + + rc = proc.wait() + + print(f"\n=== FULL LOG SAVED ===\n{log_path}\n") + if rc != 0: + raise SystemExit(rc) + + def import_cuBLAS_lib(): + import os, ctypes, pathlib + import nvidia.cublas, nvidia.cuda_runtime + + cublas_dir = (pathlib.Path(nvidia.cublas.__file__).parent / "lib").resolve() + cudart_dir = (pathlib.Path(nvidia.cuda_runtime.__file__).parent / "lib").resolve() + + # cuda_home = cudart_dir.parent + # os.environ.setdefault("CUDA_HOME", str(cuda_home)) + + os.environ["LD_LIBRARY_PATH"] = f"{cublas_dir}:{cudart_dir}:" + os.environ.get("LD_LIBRARY_PATH","") + + for name12, name in [("libcublas.so.12","libcublas.so"), + ("libcublasLt.so.12","libcublasLt.so")]: + src = cublas_dir / name12 + dst = cublas_dir / name + if src.exists() and not dst.exists(): + try: dst.symlink_to(name12) + except FileExistsError: pass + except PermissionError: pass + + ctypes.CDLL(str(cudart_dir / "libcudart.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublasLt.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublas.so.12"), mode=ctypes.RTLD_GLOBAL) + + if get_gpu_type(): + print("Running on H100 GPU with HBM3 memory.") + import_cuBLAS_lib() + test_demo() \ No newline at end of file diff --git a/docker/tilelang/benchmark_mla.py b/docker/tilelang/benchmark_mla.py new file mode 100644 index 0000000..d8e40b5 --- /dev/null +++ b/docker/tilelang/benchmark_mla.py @@ -0,0 +1,650 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch +import triton +import triton.language as tl + +import tilelang +from tilelang.profiler import do_bench +from example_mla_decode_paged import mla_decode_tilelang + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, + h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@torch.inference_mode() +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, + h_kv, d, dv, causal, dtype): + from flash_mla import flash_mla_with_kvcache, get_mla_metadata + + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + ) + + out_flash, lse_flash = flash_mla() + t = triton.testing.do_bench(flash_mla) + return out_flash, lse_flash, t + +@torch.inference_mode() +def run_txl_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, + h_q, h_kv, d, dv, causal, dtype): + from flash_mla import make_mla_runner + + q_value = q[..., :dv].permute(0, 2, 1, 3).contiguous() + q_pe = q[..., dv:].permute(0, 2, 1, 3).contiguous() + blocked_k = blocked_k.view(b, max_seqlen_pad, h_kv, d) + k_value = blocked_k[..., :dv].permute(0, 2, 1, 3).contiguous() + k_pe = blocked_k[..., dv:].permute(0, 2, 1, 3).contiguous() + # print(f"{q_value.shape=}, {q_pe.shape=}, {k_value.shape=}, {k_pe.shape=}") + + runner = make_mla_runner(q_value, k_value, q_pe, k_pe, 1 / math.sqrt(576), algo = 0,) + + out_txl = runner() + torch.cuda.synchronize() + t = triton.testing.do_bench(runner) + + return out_txl, None, t + +@torch.inference_mode() +def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, + h_q, h_kv, d, dv, causal, dtype): + # pip install flashinfer-python + import flashinfer + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., + dv:].contiguous() + + kv_indptr = [0] + kv_indices = [] + for i in range(b): + seq_len = cache_seqlens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_table[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + for seq_len in cache_seqlens[1:]: + kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) + + q_indptr = torch.arange(0, b + 1).int() * s_q + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") + mla_wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + cache_seqlens, + h_q, + dv, + d - dv, + block_size, + causal, + 1 / math.sqrt(d), + q.dtype, + blocked_k.dtype, + ) + + def flashinfer(): + output, lse = mla_wrapper.run( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope, + blocked_k_pe, + return_lse=True) + return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) + + out_flash, lse_flash = flashinfer() + t = triton.testing.do_bench(flashinfer) + return out_flash, lse_flash, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ + None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, + None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ + None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + ) + + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + num_warps=4, + num_stages=2, + ) + + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, + cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., + dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton( + q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, + num_kv_splits, 1 / math.sqrt(d), block_size) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +@torch.inference_mode() +def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, + cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., + dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + BLOCK_N = 64 + BLOCK_H = 64 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, + num_kv_splits, block_size) + + def flash_mla_tilelang(): + out = kernel( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, dpe), + blocked_k_nope.view(-1, h_kv, dv), + blocked_k_pe.view(-1, h_kv, dpe), + block_table, + cache_seqlens, + glse, + out_partial, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_tilelang() + t = do_bench(flash_mla_tilelang) + return out_flash, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "tilelang": run_flash_mla_tilelang, + "flash_mla": run_flash_mla, + "flashinfer": run_flashinfer, + "flash_mla_triton": run_flash_mla_triton, + "txl": run_txl_mla, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, + s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, + s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flashinfer", "flash_mla_triton", "tilelang" + ] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: + # flashinfer has a different lse return value + # flash_mla_triton and flash_mla_tilelang doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( + torch.finfo(dtype).bits // 8) + print( + f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" + ) + print( + f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" + ) + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + mean_seqlens = cache_seqlens.float().mean().item() + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, + s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( + torch.finfo(dtype).bits // 8) + print( + f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" + ) + return bytes / 10**6 / perf_b + + +available_targets = [ + "torch", + # "tilelang", + "flash_mla", + # "flashinfer", + # "flash_mla_triton", + "txl", +] + +shape_configs = [{ + "b": + batch, + "s_q": + 2, + "cache_seqlens": + torch.tensor([seqlen for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": + head, + "h_kv": + 1, + "d": + 512 + 64, + "dv": + 512, + "causal": + False, + "dtype": + torch.float16 +} for batch in [132] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="tilelang") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + # args = get_args() + # benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + # with open(f"{benchmark_type}_perf.csv", "w") as fout: + # fout.write("name,batch,seqlen,head,bw\n") + # for shape in shape_configs: + # if args.all: + # for target in available_targets: + # perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], + # shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], + # shape["causal"], shape["dtype"]) + # fout.write( + # f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + # ) + # elif args.compare: + # perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], + # shape["cache_seqlens"], shape["h_q"], shape["h_kv"], + # shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + # fout.write( + # f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + # ) + # fout.write( + # f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + # ) + # elif args.one: + # perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], + # shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], + # shape["causal"], shape["dtype"]) + # fout.write( + # f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + # ) + for shape in shape_configs: + for target in available_targets: + perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], + shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], + shape["causal"], shape["dtype"]) \ No newline at end of file diff --git a/docker/tilelang/example_mla_decode_paged.py b/docker/tilelang/example_mla_decode_paged.py new file mode 100644 index 0000000..d23ff00 --- /dev/null +++ b/docker/tilelang/example_mla_decode_paged.py @@ -0,0 +1,408 @@ +import torch +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import argparse +from tilelang.profiler import do_bench +import math + + +@tilelang.jit( + out_idx=[8], pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }) +def mla_decode_tilelang(batch, + h_q, + h_kv, + max_seqlen_pad, + dv, + dpe, + block_N, + block_H, + num_split, + block_size, + softmax_scale=None): + if softmax_scale is None: + softmax_scale = (dv + dpe)**-0.5 + scale = float(softmax_scale * 1.44269504) # log2(e) + dtype = "float16" + accum_dtype = "float" + kv_group_num = h_q // h_kv + VALID_BLOCK_H = min(block_H, kv_group_num) + assert h_kv == 1, "h_kv must be 1" + assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N" + + @T.macro + def flash_mla_kernel( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + CACHE_SEQLENS: T.Tensor([batch], "int32"), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): + Q_shared = T.alloc_shared([block_H, dv], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) + KV_shared = T.alloc_shared([block_N, dv], dtype) + K_pe_shared = T.alloc_shared([block_N, dpe], dtype) + O_shared = T.alloc_shared([block_H, dv], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_o = T.alloc_fragment([block_H, dv], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout({ + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + }) + + T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) + for kr in T.Pipelined(loop_range, num_stages=2): + k = loop_range - 1 - kr + kv_start = BLOCK_TABLE[bx, (k * block_N) // + block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm( + Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm( + Q_pe_shared, + K_pe_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + if kr == 0: + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], + -T.infinity(accum_dtype), acc_s[i, j]) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] *= scores_scale[i] + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + + @T.macro + def flash_mla_split_kv_kernel( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + CACHE_SEQLENS: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + ): + with T.Kernel( + batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dv], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) + KV_shared = T.alloc_shared([block_N, dv], dtype) + K_pe_shared = T.alloc_shared([block_N, dpe], dtype) + O_shared = T.alloc_shared([block_H, dv], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dv], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout({ + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + }) + + T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N) + blocks_per_split = T.floordiv(total_blocks, num_split) + remaining_blocks = T.floormod(total_blocks, num_split) + loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)) + start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N + + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = BLOCK_TABLE[bx, (start + k * block_N) // + block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm( + Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm( + Q_pe_shared, + K_pe_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], + -T.infinity(accum_dtype), acc_s[i, j]) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) + + @T.macro + def combine( + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + with T.Kernel(h_q, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dv], dtype) + o_accum_local = T.alloc_fragment([dv], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout({ + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + }) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dv): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dv): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dv): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, + Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Tensor([batch, h_q, dv], dtype), + Q_pe: T.Tensor([batch, h_q, dpe], dtype), + KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), + cache_seqlens: T.Tensor([batch], "int32"), + glse: T.Tensor([batch, h_q, num_split], dtype), + Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), + Output: T.Tensor([batch, h_q, dv], dtype), + ): + flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) + temp_mask = torch.ones( + s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, + h_kv, d, dv, causal, dtype): + # q: [b, s_q, h_q, d] + # block_table: [b, max_seqlen_pad // block_size] + # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] + # cache_seqlens: [b] + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out.to(dtype), lse.to(dtype) + + out_torch, _ = ref_mla() + return out_torch + + +def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, + h_q, h_kv, d, dv, causal, dtype): + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., + dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + BLOCK_N = 64 + BLOCK_H = min(64, h_q // h_kv) + softmax_scale = d**-0.5 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, + num_kv_splits, block_size, softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + + def flash_mla_tilelang(): + out = profiler.func( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, dpe), + blocked_k_nope.view(-1, h_kv, dv), + blocked_k_pe.view(-1, h_kv, dpe), + block_table, + cache_seqlens, + glse, + out_partial, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_tilelang() + t = do_bench(flash_mla_tilelang) + out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, + cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) + print("All close") + return out_flash, t + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=128, help='batch size') + parser.add_argument('--h_q', type=int, default=128, help='q heads number') + parser.add_argument('--h_kv', type=int, default=1, help='kv heads number') + parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length') + parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe') + parser.add_argument('--dv', type=int, default=512, help='value head dim') + args = parser.parse_args() + b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv + + device = "cuda" + dtype = torch.float16 + + s_q = 1 # for decode, s_q = 1 + block_size = 64 + cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], + dtype=torch.int32, + device=device) + dpe = d - dv + causal = True + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 + + total_flops = s_q * total_seqlens * h_q * d * 2 + + q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32, + device=device).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) + out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, + s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) diff --git a/docker/tilelang/profiler/__init__.py b/docker/tilelang/profiler/__init__.py new file mode 100644 index 0000000..5af1fc2 --- /dev/null +++ b/docker/tilelang/profiler/__init__.py @@ -0,0 +1,287 @@ +"""The profiler and convert to torch utils""" +from __future__ import annotations +from typing import Callable, Any, Literal +from functools import partial +import torch +from contextlib import suppress +from dataclasses import dataclass +import tvm +from tilelang.utils.tensor import ( + get_tensor_supply, + TensorSupplyType, + torch_assert_close, +) +from tilelang.engine.param import KernelParam +from tilelang.jit.adapter import BaseKernelAdapter +from tilelang.profiler.bench import do_bench + + +@dataclass +class Profiler: + """A profiler class for benchmarking and validating kernel implementations. + + Attributes: + params: List of kernel parameters defining the input/output specifications + result_idx: Indices indicating which parameters are output tensors + supply_type: Type of tensor supply to use (e.g., random, zeros, etc.) + adapter: Optional kernel adapter for interfacing with different backends + """ + + params: list[KernelParam] + result_idx: list[int] + supply_type: TensorSupplyType + adapter: BaseKernelAdapter | None = None + + def __post_init__(self): + """Initialize tensor supply after dataclass initialization""" + self.result_idx = self._legalize_result_idx(self.result_idx) + self.supply = get_tensor_supply(self.supply_type) + + def _legalize_result_idx(self, result_idx: list[int] | None = None) -> list[int]: + params = self.params + # result_idx is a list of indices of the output tensors + if result_idx is None: + result_idx = [] + elif isinstance(result_idx, int): + if result_idx > len(params) or result_idx < -len(params): + raise ValueError( + f"result_idx should be an integer between {-len(params)} and {len(params) - 1}") + if result_idx < 0: + result_idx = len(params) + result_idx + result_idx = [result_idx] + elif not isinstance(result_idx, list): + raise ValueError("result_idx should be a list of integers") + + return result_idx + + def with_default_adapter(self, adapter: BaseKernelAdapter) -> Profiler: + self.adapter = adapter + return self + + def _get_inputs(self, with_output=False): + ins = [] + for i in range(len(self.params)): + if with_output or i not in self.result_idx: + ins.append(self.supply(self.params[i])) + return ins + + def _get_params(self, with_output=False): + params = [] + for i in range(len(self.params)): + if with_output or i not in self.result_idx: + params.append(self.params[i]) + return params + + def assert_allclose( + self, + reference_program: Callable, + input_tensors: list[torch.Tensor] | None = None, + atol: float = 1e-2, + rtol: float = 1e-2, + max_mismatched_ratio=0.01, + ): + """Validates kernel output against a reference implementation. + + Args: + reference_program: Reference implementation to compare against + input_tensors: Optional pre-generated input tensors + atol: Absolute tolerance for comparison + rtol: Relative tolerance for comparison + max_mismatched_ratio: Maximum allowed ratio of mismatched elements + """ + ins = self._get_inputs() if input_tensors is None else input_tensors + ref_outs = reference_program(*ins) + torch.cuda.synchronize() + lib_outs = self.func(*ins) + torch.cuda.synchronize() + + if isinstance(lib_outs, torch.Tensor): + lib_outs = [lib_outs] + elif isinstance(lib_outs, tuple): + lib_outs = list(lib_outs) + elif lib_outs is None: + lib_outs = [] + + if isinstance(ref_outs, torch.Tensor): + ref_outs = [ref_outs] + elif isinstance(ref_outs, tuple): + ref_outs = list(ref_outs) + elif ref_outs is None: + ref_outs = [] + + ref_tensors = ins + ref_outs + lib_tensors = ins + lib_outs + + assert len(lib_tensors) == len( + ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !" + # torch.set_printoptions(edgeitems=torch.inf) + for lhs, rhs in zip(lib_tensors, ref_tensors): + # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) + # total_elements = lhs.numel() + # num_not_close = (~close_mask).sum().item() + # percentage_not_close = (num_not_close / total_elements) * 100 + # print(f"{percentage_not_close:.2f}% of the elements are not close.") + # print(f"Total elements: {total_elements}, Not close elements: {num_not_close}") + if lhs is not None and rhs is not None: + # in case of numsplit template, the ref output may be None + # which means the value is invalid, so we skip the comparison + def is_float8(tensor: torch.Tensor) -> bool: + return tensor.dtype in { + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + } + + torch_assert_close( + lhs if not is_float8(lhs) else lhs.to(torch.float32), + rhs if not is_float8(rhs) else rhs.to(torch.float32), + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio, + base_name="tilelang", + ref_name="ref", + ) + + def manual_assert_close( + self, + reference_program: Callable, + input_tensors: list[torch.Tensor] | None = None, + manual_check_prog: Callable = None, + ): + """Validates kernel output against a reference implementation. + + Args: + reference_program: Reference implementation to compare against + input_tensors: Optional pre-generated input tensors + atol: Absolute tolerance for comparison + rtol: Relative tolerance for comparison + max_mismatched_ratio: Maximum allowed ratio of mismatched elements + """ + ins = self._get_inputs() if input_tensors is None else input_tensors + ref_outs = reference_program(*ins) + torch.cuda.synchronize() + lib_outs = self.func(*ins) + torch.cuda.synchronize() + + if isinstance(lib_outs, torch.Tensor): + lib_outs = [lib_outs] + if isinstance(ref_outs, torch.Tensor): + ref_outs = [ref_outs] + elif ref_outs is None: + ref_outs = [] + assert len(lib_outs) == len(ref_outs), f"{len(lib_outs)=} not equals to {len(ref_outs)=} !" + torch.set_printoptions(edgeitems=torch.inf) + manual_check_prog(lib_outs, ref_outs) + + def assert_consistent(self, repeat=10): + """Checks for kernel consistency across multiple runs. + + Args: + repeat: Number of times to repeat the consistency check + """ + # Used to check no race condition inside the kernel + ins = self._get_inputs() + ref_outs = self.func(*ins) + + for _ in range(repeat): + lib_outs = self.func(*ins) + for lhs, rhs in zip(lib_outs, ref_outs): + assert torch.allclose(lhs, rhs), [ + "result is not consistent", + lhs, + rhs, + ] + + def run_once(self, func: Callable | None = None): + ins = self._get_inputs() + if not func: + func = self.__call__ + return func(*ins) + + def determine_profiler(self, func: Callable | None = None): + """Determines which profiler backend to use based on function type. + + Args: + func: Function to be profiled + profiler: Explicitly specified profiler type or "auto" for automatic detection + + Returns: + str: The determined profiler type ("torch" or "tvm") + """ + if isinstance(func, tvm.runtime.Module): + return "tvm" + else: + return "torch" + + def do_bench( + self, + func: Callable | None = None, + warmup: int = 25, + rep: int = 100, + n_warmup: int = 1, + n_repeat: int = 1, + input_tensors: list[torch.Tensor] = None, + backend: Literal["event", "cupti"] = "event", + quantiles: list[float] | None = None, + return_mode: Literal["min", "max", "mean", "median"] = "mean", + ) -> float: + """Benchmarks the execution time of a given function. + + Args: + func: Function to benchmark (uses adapter if None) + warmup: Warmup time in milliseconds + rep: Number of repetitions for timing + n_warmup: Number of warmup iterations + n_repeat: Number of timing iterations + profiler: Which profiling backend to use + input_tensors: Optional pre-generated input tensors + + Returns: + float: Average execution time in milliseconds + """ + profiler = self.determine_profiler(func) + if profiler == "torch": + if func is None: + assert self.adapter is not None, "benchmarking function should be provided" + func = self.adapter + ins = self._get_inputs() if input_tensors is None else input_tensors + bench_func = partial(func, *ins) + return do_bench( + bench_func, + warmup=warmup, + rep=rep, + _n_warmup=n_warmup, + _n_repeat=n_repeat, + quantiles=quantiles, + backend=backend, + return_mode=return_mode, + ) + elif profiler == "tvm": + assert func is not None, "func should not be None" + assert isinstance( + func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" + + ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors) + target = "cuda" + + with suppress(Exception): + target = self.mod.imported_modules[0].type_key + + assert target in ["cuda", "hip"], f"Unknown target: {target}" + + device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0) + time_evaluator = self.mod.time_evaluator( + self.mod.entry_name, device, number=rep, repeat=n_repeat) + # Transform Latency to ms + return time_evaluator(*ins).mean * 1e3 + else: + raise ValueError(f"Unknown profiler: {profiler}") + + @property + def func(self): + assert self.adapter is not None, "adapter should be provided" + return self.adapter + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.func(*args, **kwds) diff --git a/python/txl/tutorials/04-softmax.py b/python/txl/tutorials/04-softmax.py index efeafdc..53b4bf1 100644 --- a/python/txl/tutorials/04-softmax.py +++ b/python/txl/tutorials/04-softmax.py @@ -767,45 +767,45 @@ def test_softmax(dump_dir=None, M=32*1024, N=32*1024): ################################################ # Validate ################################################ - from triton import knobs - #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "axis-info,tritongpu-coalesce" - #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "tritongpu-coalesce" - #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "tritongpu-remove-layout-conversions" - #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "ttg-utility" - #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "txlgpu-pipeliner" - knobs.autotuning.print=True - knobs.compilation.always_compile=True - if dump_dir: - knobs.compilation.dump_ir=True - knobs.cache.dump_dir=dump_dir - - if HAS_QUACK: - quack_out = quack_fn().to(torch.float32).cpu().numpy() - if HAS_TXL: - txl_out = txl_fn(x, False).to(torch.float32).cpu().numpy() - - #exit() - - # NumPy reference - a = x32.cpu().numpy() - a_max = a.max(axis=1, keepdims=True) - ref = np.exp(a - a_max) - ref = ref / ref.sum(axis=1, keepdims=True) - - torch_out = compiled_func_ref(x, target).to(torch.float32).cpu().numpy() - - # relative error - outs = [] - outs.append(('torch', torch_out)) - if HAS_QUACK: - outs.append(('quack', quack_out)) - if HAS_TXL: - outs.append(('txl', txl_out)) - - for name, out in outs: - max_rel_err = np.max(np.abs(out - ref) / (ref + 1e-20)) - print("max relative error:", max_rel_err) - print("sum per row -1 (should be 0):", np.max(np.abs(out.sum(axis=1) - 1.0))) + # from triton import knobs + # #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "axis-info,tritongpu-coalesce" + # #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "tritongpu-coalesce" + # #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "tritongpu-remove-layout-conversions" + # #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "ttg-utility" + # #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "txlgpu-pipeliner" + # knobs.autotuning.print=True + # knobs.compilation.always_compile=True + # if dump_dir: + # knobs.compilation.dump_ir=True + # knobs.cache.dump_dir=dump_dir + + # if HAS_QUACK: + # quack_out = quack_fn().to(torch.float32).cpu().numpy() + # if HAS_TXL: + # txl_out = txl_fn(x, False).to(torch.float32).cpu().numpy() + + # #exit() + + # # NumPy reference + # a = x32.cpu().numpy() + # a_max = a.max(axis=1, keepdims=True) + # ref = np.exp(a - a_max) + # ref = ref / ref.sum(axis=1, keepdims=True) + + # torch_out = compiled_func_ref(x, target).to(torch.float32).cpu().numpy() + + # # relative error + # outs = [] + # outs.append(('torch', torch_out)) + # if HAS_QUACK: + # outs.append(('quack', quack_out)) + # if HAS_TXL: + # outs.append(('txl', txl_out)) + + # for name, out in outs: + # max_rel_err = np.max(np.abs(out - ref) / (ref + 1e-20)) + # print("max relative error:", max_rel_err) + # print("sum per row -1 (should be 0):", np.max(np.abs(out.sum(axis=1) - 1.0))) #import pdb;pdb.set_trace() #exit() From 26bff236676b54ca59df41e73ccbe80d57382000 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Thu, 20 Nov 2025 13:31:50 +0800 Subject: [PATCH 06/17] change color; delete nouse file --- docker/draw/draw_mla.py | 2 +- docker/draw/draw_softmax.py | 4 +- docker/tilelang/profiler/__init__.py | 287 --------------------------- 3 files changed, 3 insertions(+), 290 deletions(-) delete mode 100644 docker/tilelang/profiler/__init__.py diff --git a/docker/draw/draw_mla.py b/docker/draw/draw_mla.py index 9707245..7c80787 100644 --- a/docker/draw/draw_mla.py +++ b/docker/draw/draw_mla.py @@ -21,7 +21,7 @@ "Txl": "#e74c3c", "Triton": "#1abc9c", "TileLang": "#ff6fb3", - "Flashinfer": "#3498db", + "Flashinfer": "#e67e22", } data = { diff --git a/docker/draw/draw_softmax.py b/docker/draw/draw_softmax.py index e6f67b7..3421b34 100644 --- a/docker/draw/draw_softmax.py +++ b/docker/draw/draw_softmax.py @@ -13,9 +13,9 @@ methods = ["Quack", "Txl", "Torch"] colors = { - "Quack": "#f1c40f", + "Quack": "#2ecc71", "Txl": "#e74c3c", - "Torch": "#1abc9c", + "Torch": "#8e44ad", } data = { diff --git a/docker/tilelang/profiler/__init__.py b/docker/tilelang/profiler/__init__.py deleted file mode 100644 index 5af1fc2..0000000 --- a/docker/tilelang/profiler/__init__.py +++ /dev/null @@ -1,287 +0,0 @@ -"""The profiler and convert to torch utils""" -from __future__ import annotations -from typing import Callable, Any, Literal -from functools import partial -import torch -from contextlib import suppress -from dataclasses import dataclass -import tvm -from tilelang.utils.tensor import ( - get_tensor_supply, - TensorSupplyType, - torch_assert_close, -) -from tilelang.engine.param import KernelParam -from tilelang.jit.adapter import BaseKernelAdapter -from tilelang.profiler.bench import do_bench - - -@dataclass -class Profiler: - """A profiler class for benchmarking and validating kernel implementations. - - Attributes: - params: List of kernel parameters defining the input/output specifications - result_idx: Indices indicating which parameters are output tensors - supply_type: Type of tensor supply to use (e.g., random, zeros, etc.) - adapter: Optional kernel adapter for interfacing with different backends - """ - - params: list[KernelParam] - result_idx: list[int] - supply_type: TensorSupplyType - adapter: BaseKernelAdapter | None = None - - def __post_init__(self): - """Initialize tensor supply after dataclass initialization""" - self.result_idx = self._legalize_result_idx(self.result_idx) - self.supply = get_tensor_supply(self.supply_type) - - def _legalize_result_idx(self, result_idx: list[int] | None = None) -> list[int]: - params = self.params - # result_idx is a list of indices of the output tensors - if result_idx is None: - result_idx = [] - elif isinstance(result_idx, int): - if result_idx > len(params) or result_idx < -len(params): - raise ValueError( - f"result_idx should be an integer between {-len(params)} and {len(params) - 1}") - if result_idx < 0: - result_idx = len(params) + result_idx - result_idx = [result_idx] - elif not isinstance(result_idx, list): - raise ValueError("result_idx should be a list of integers") - - return result_idx - - def with_default_adapter(self, adapter: BaseKernelAdapter) -> Profiler: - self.adapter = adapter - return self - - def _get_inputs(self, with_output=False): - ins = [] - for i in range(len(self.params)): - if with_output or i not in self.result_idx: - ins.append(self.supply(self.params[i])) - return ins - - def _get_params(self, with_output=False): - params = [] - for i in range(len(self.params)): - if with_output or i not in self.result_idx: - params.append(self.params[i]) - return params - - def assert_allclose( - self, - reference_program: Callable, - input_tensors: list[torch.Tensor] | None = None, - atol: float = 1e-2, - rtol: float = 1e-2, - max_mismatched_ratio=0.01, - ): - """Validates kernel output against a reference implementation. - - Args: - reference_program: Reference implementation to compare against - input_tensors: Optional pre-generated input tensors - atol: Absolute tolerance for comparison - rtol: Relative tolerance for comparison - max_mismatched_ratio: Maximum allowed ratio of mismatched elements - """ - ins = self._get_inputs() if input_tensors is None else input_tensors - ref_outs = reference_program(*ins) - torch.cuda.synchronize() - lib_outs = self.func(*ins) - torch.cuda.synchronize() - - if isinstance(lib_outs, torch.Tensor): - lib_outs = [lib_outs] - elif isinstance(lib_outs, tuple): - lib_outs = list(lib_outs) - elif lib_outs is None: - lib_outs = [] - - if isinstance(ref_outs, torch.Tensor): - ref_outs = [ref_outs] - elif isinstance(ref_outs, tuple): - ref_outs = list(ref_outs) - elif ref_outs is None: - ref_outs = [] - - ref_tensors = ins + ref_outs - lib_tensors = ins + lib_outs - - assert len(lib_tensors) == len( - ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !" - # torch.set_printoptions(edgeitems=torch.inf) - for lhs, rhs in zip(lib_tensors, ref_tensors): - # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) - # total_elements = lhs.numel() - # num_not_close = (~close_mask).sum().item() - # percentage_not_close = (num_not_close / total_elements) * 100 - # print(f"{percentage_not_close:.2f}% of the elements are not close.") - # print(f"Total elements: {total_elements}, Not close elements: {num_not_close}") - if lhs is not None and rhs is not None: - # in case of numsplit template, the ref output may be None - # which means the value is invalid, so we skip the comparison - def is_float8(tensor: torch.Tensor) -> bool: - return tensor.dtype in { - torch.float8_e5m2, - torch.float8_e5m2fnuz, - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - } - - torch_assert_close( - lhs if not is_float8(lhs) else lhs.to(torch.float32), - rhs if not is_float8(rhs) else rhs.to(torch.float32), - rtol=rtol, - atol=atol, - max_mismatched_ratio=max_mismatched_ratio, - base_name="tilelang", - ref_name="ref", - ) - - def manual_assert_close( - self, - reference_program: Callable, - input_tensors: list[torch.Tensor] | None = None, - manual_check_prog: Callable = None, - ): - """Validates kernel output against a reference implementation. - - Args: - reference_program: Reference implementation to compare against - input_tensors: Optional pre-generated input tensors - atol: Absolute tolerance for comparison - rtol: Relative tolerance for comparison - max_mismatched_ratio: Maximum allowed ratio of mismatched elements - """ - ins = self._get_inputs() if input_tensors is None else input_tensors - ref_outs = reference_program(*ins) - torch.cuda.synchronize() - lib_outs = self.func(*ins) - torch.cuda.synchronize() - - if isinstance(lib_outs, torch.Tensor): - lib_outs = [lib_outs] - if isinstance(ref_outs, torch.Tensor): - ref_outs = [ref_outs] - elif ref_outs is None: - ref_outs = [] - assert len(lib_outs) == len(ref_outs), f"{len(lib_outs)=} not equals to {len(ref_outs)=} !" - torch.set_printoptions(edgeitems=torch.inf) - manual_check_prog(lib_outs, ref_outs) - - def assert_consistent(self, repeat=10): - """Checks for kernel consistency across multiple runs. - - Args: - repeat: Number of times to repeat the consistency check - """ - # Used to check no race condition inside the kernel - ins = self._get_inputs() - ref_outs = self.func(*ins) - - for _ in range(repeat): - lib_outs = self.func(*ins) - for lhs, rhs in zip(lib_outs, ref_outs): - assert torch.allclose(lhs, rhs), [ - "result is not consistent", - lhs, - rhs, - ] - - def run_once(self, func: Callable | None = None): - ins = self._get_inputs() - if not func: - func = self.__call__ - return func(*ins) - - def determine_profiler(self, func: Callable | None = None): - """Determines which profiler backend to use based on function type. - - Args: - func: Function to be profiled - profiler: Explicitly specified profiler type or "auto" for automatic detection - - Returns: - str: The determined profiler type ("torch" or "tvm") - """ - if isinstance(func, tvm.runtime.Module): - return "tvm" - else: - return "torch" - - def do_bench( - self, - func: Callable | None = None, - warmup: int = 25, - rep: int = 100, - n_warmup: int = 1, - n_repeat: int = 1, - input_tensors: list[torch.Tensor] = None, - backend: Literal["event", "cupti"] = "event", - quantiles: list[float] | None = None, - return_mode: Literal["min", "max", "mean", "median"] = "mean", - ) -> float: - """Benchmarks the execution time of a given function. - - Args: - func: Function to benchmark (uses adapter if None) - warmup: Warmup time in milliseconds - rep: Number of repetitions for timing - n_warmup: Number of warmup iterations - n_repeat: Number of timing iterations - profiler: Which profiling backend to use - input_tensors: Optional pre-generated input tensors - - Returns: - float: Average execution time in milliseconds - """ - profiler = self.determine_profiler(func) - if profiler == "torch": - if func is None: - assert self.adapter is not None, "benchmarking function should be provided" - func = self.adapter - ins = self._get_inputs() if input_tensors is None else input_tensors - bench_func = partial(func, *ins) - return do_bench( - bench_func, - warmup=warmup, - rep=rep, - _n_warmup=n_warmup, - _n_repeat=n_repeat, - quantiles=quantiles, - backend=backend, - return_mode=return_mode, - ) - elif profiler == "tvm": - assert func is not None, "func should not be None" - assert isinstance( - func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}" - - ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors) - target = "cuda" - - with suppress(Exception): - target = self.mod.imported_modules[0].type_key - - assert target in ["cuda", "hip"], f"Unknown target: {target}" - - device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0) - time_evaluator = self.mod.time_evaluator( - self.mod.entry_name, device, number=rep, repeat=n_repeat) - # Transform Latency to ms - return time_evaluator(*ins).mean * 1e3 - else: - raise ValueError(f"Unknown profiler: {profiler}") - - @property - def func(self): - assert self.adapter is not None, "adapter should be provided" - return self.adapter - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.func(*args, **kwds) From 8f4c00713ffaf7d18d9c322e0e70833dc46997f6 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Thu, 20 Nov 2025 16:28:09 +0800 Subject: [PATCH 07/17] fix mla tilelang --- docker/draw/draw_mla.py | 2 +- docker/draw/draw_softmax.py | 2 +- docker/tilelang/benchmark_mla.py | 12 +++++------- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/docker/draw/draw_mla.py b/docker/draw/draw_mla.py index 7c80787..2f37712 100644 --- a/docker/draw/draw_mla.py +++ b/docker/draw/draw_mla.py @@ -10,7 +10,7 @@ fp16_nc_fm = np.array([529, 587, 612, 634, 596, 621]) fp16_nc_txl = np.array([481, 527, 551, 550, 572, 581]) fp16_nc_triton= np.array([20, 78, 98, 112, 136, 150]) -fp16_nc_tile = np.array([310, 402, 423, 453, 487, 476]) +fp16_nc_tile = np.array([237, 430, 459, 465, 465, 463]) fp16_nc_fi = np.array([290, 320, 360, 387, 350, 362]) # =============================================================== diff --git a/docker/draw/draw_softmax.py b/docker/draw/draw_softmax.py index 3421b34..fd50d41 100644 --- a/docker/draw/draw_softmax.py +++ b/docker/draw/draw_softmax.py @@ -13,7 +13,7 @@ methods = ["Quack", "Txl", "Torch"] colors = { - "Quack": "#2ecc71", + "Quack": "#f1c40f", "Txl": "#e74c3c", "Torch": "#8e44ad", } diff --git a/docker/tilelang/benchmark_mla.py b/docker/tilelang/benchmark_mla.py index d8e40b5..992765d 100644 --- a/docker/tilelang/benchmark_mla.py +++ b/docker/tilelang/benchmark_mla.py @@ -98,7 +98,7 @@ def run_txl_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, c runner = make_mla_runner(q_value, k_value, q_pe, k_pe, 1 / math.sqrt(576), algo = 0,) - out_txl = runner() + out_txl = runner().permute(0, 2, 1, 3).contiguous() torch.cuda.synchronize() t = triton.testing.do_bench(runner) @@ -442,8 +442,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,dv:].contiguous() dpe = d - dv num_kv_splits = 1 @@ -452,8 +451,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) - kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, - num_kv_splits, block_size) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) def flash_mla_tilelang(): out = kernel( @@ -572,7 +570,7 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): available_targets = [ "torch", - # "tilelang", + "tilelang", "flash_mla", # "flashinfer", # "flash_mla_triton", @@ -583,7 +581,7 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): "b": batch, "s_q": - 2, + 1, # tilelang can run only when s_q=1 "cache_seqlens": torch.tensor([seqlen for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": From 9147c5148fa3c44c06048615d8ef3a5c217de50b Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Thu, 20 Nov 2025 17:35:28 +0800 Subject: [PATCH 08/17] fix gemm FP8 --- docker/draw/draw_gemm.py | 13 ++++++------- docker/draw/draw_mla.py | 4 ++-- python/txl/tutorials/01-matmul.py | 22 ++++++++++------------ 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/docker/draw/draw_gemm.py b/docker/draw/draw_gemm.py index 84a69c6..3227109 100644 --- a/docker/draw/draw_gemm.py +++ b/docker/draw/draw_gemm.py @@ -11,13 +11,12 @@ tile_fp16 = np.array([300, 420, 600, 690, 700, 720, 740]) tk_fp16 = np.array([400, 680, 680, 709, 780, 788, 798]) -cublas_fp8 = np.array([1300, 1400, 1500, 1550, 1500, 1500, 1400]) -tawa_fp8 = np.array([900, 1470, 1600, 1600, 1550, 1500, 1400]) -# fp8 运行不了 -txl_fp8 = np.array([900, 1470, 1600, 1600, 1550, 1500, 1400]) -triton_fp8 = np.array([600, 1000, 1450, 1500, 1500, 1500, 1400]) -tile_fp8 = np.array([700, 1400, 1500, 1600, 1550, 1500, 1400]) -tk_fp8 = np.array([600, 1200, 1600, 1500, 1500, 1450, 1400]) +cublas_fp8 = np.array([876, 1188, 1385, 1503, 1561, 1573, 1436]) +# tawa_fp8 = np.array([900, 1470, 1600, 1600, 1550, 1500, 1400]) +txl_fp8 = np.array([807, 1081, 1265, 1357, 1300, 1270, 1242]) +triton_fp8 = np.array([720, 1212, 1502, 1530, 1535, 1528, 1478]) +tile_fp8 = np.array([231, 312, 547, 712, 892, 930, 1003]) +tk_fp8 = np.array([579, 860, 1232, 1398, 1497, 1503, 1429]) # ------------------------------- theoretical_fp16 = 1000 diff --git a/docker/draw/draw_mla.py b/docker/draw/draw_mla.py index 2f37712..b82ceef 100644 --- a/docker/draw/draw_mla.py +++ b/docker/draw/draw_mla.py @@ -7,8 +7,8 @@ # ====================== 示例数据,按需替换 ====================== # FP16, causal = False -fp16_nc_fm = np.array([529, 587, 612, 634, 596, 621]) -fp16_nc_txl = np.array([481, 527, 551, 550, 572, 581]) +fp16_nc_fm = np.array([436, 510, 543, 560, 576, 579]) +fp16_nc_txl = np.array([393, 476, 518, 523, 525, 553]) fp16_nc_triton= np.array([20, 78, 98, 112, 136, 150]) fp16_nc_tile = np.array([237, 430, 459, 465, 465, 463]) fp16_nc_fi = np.array([290, 320, 360, 387, 350, 362]) diff --git a/python/txl/tutorials/01-matmul.py b/python/txl/tutorials/01-matmul.py index 3ca189a..beb56d2 100644 --- a/python/txl/tutorials/01-matmul.py +++ b/python/txl/tutorials/01-matmul.py @@ -1098,6 +1098,7 @@ def matmul_persistent_ws_tma_txl_kernel( WARP_SPECIALIZE: tl.constexpr, # ): dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 + byte_count: tl.constexpr = 2 if dtype == tl.float16 else 1 num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -1144,15 +1145,15 @@ def matmul_persistent_ws_tma_txl_kernel( b0_buf = txl.get_buffer(b0, bufIdx) txl.mbar_wait(mbar_c1, phase) - txl.mbar_expect(mbar_p_a0, BLOCK_SIZE_M//2*BLOCK_SIZE_K*2) + txl.mbar_expect(mbar_p_a0, BLOCK_SIZE_M//2*BLOCK_SIZE_K*byte_count) txl.tma_load(a0_buf, a_desc, [offs_am, offs_k], mbar_p_a0) txl.mbar_wait(mbar_c2, phase) - txl.mbar_expect(mbar_p_b0, BLOCK_SIZE_N*BLOCK_SIZE_K*2) + txl.mbar_expect(mbar_p_b0, BLOCK_SIZE_N*BLOCK_SIZE_K*byte_count) txl.tma_load(b0_buf, b_desc, [offs_bn, offs_k], mbar_p_b0) - txl.mbar_expect(mbar_p_a1, BLOCK_SIZE_M//2*BLOCK_SIZE_K*2) + txl.mbar_expect(mbar_p_a1, BLOCK_SIZE_M//2*BLOCK_SIZE_K*byte_count) txl.tma_load(a1_buf, a_desc, [offs_am + BLOCK_SIZE_M // 2, offs_k], mbar_p_a1) offs_k += BLOCK_SIZE_K @@ -1176,14 +1177,11 @@ def matmul_persistent_ws_tma_txl_kernel( offs_bn = pid_n * BLOCK_SIZE_N offs_k = 0 accumulator = tl.zeros((BLOCK_SIZE_M//2, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): mbar_p_b0 = txl.get_buffer(mbar_producer_b0, bufIdx) b0_buf = txl.get_buffer(b0, bufIdx) - txl.mbar_wait(mbar_p_b0, phase) - if txl.is_warpgroup([1]): mbar_p_a0 = txl.get_buffer(mbar_producer_a0, bufIdx) mbar_c1 = txl.get_buffer(mbar_consumer1, bufIdx) @@ -1899,7 +1897,7 @@ def run_test(expect, fn, a, b, label, enabled=True, log=False): print() print(expect) print(actual) - print((expect-actual).mean(dim=0)) + print((expect-actual.to(expect.dtype)).mean(dim=0)) passed = torch.allclose(expect, actual.to(expect.dtype), atol=1.0) icon = "✅" if passed else "❌" else: @@ -1927,9 +1925,9 @@ def validate(M, N, K, dtype, log=False): #run_test(naive_result, lambda a, b: matmul_naive_tma_txl(a, b), a, b, "TXL TMA Naive", log=log) #run_test(naive_result, lambda a, b: matmul_tma_persistent_txl(a, b), a, b, "TXL TMA Persistent", log=log) - #run_test(naive_result, lambda a, b: matmul_tma_ws_persistent_txl(a, b), a, b, "TXL TMA WS Persistent", log=True) + run_test(naive_result, lambda a, b: matmul_tma_ws_persistent_txl(a, b), a, b, "TXL TMA WS Persistent", log=True) #run_test(naive_result, lambda a, b: matmul_tma_ws_nn_persistent_txl(a, bn), a, bn, "TXL TMA WS NN Persistent", log=log) - run_test(naive_result, lambda a, b: matmul_separate_tma_txl(a, b), a, b, "TXL TMA split k", log=log) + # run_test(naive_result, lambda a, b: matmul_separate_tma_txl(a, b), a, b, "TXL TMA split k", log=log) return @@ -1979,7 +1977,7 @@ def profile(M, N, K, dtype, log=False): args = parser.parse_args() # dump_dir='dump/0930mm_split/' - dump_dir = None + dump_dir = '/workspace/dump' from triton import knobs #os.environ["TRITON_LLVM_DEBUG_ONLY"] = "tritongpu-remove-layout-conversions" @@ -2004,12 +2002,12 @@ def profile(M, N, K, dtype, log=False): torch.manual_seed(0) #validate(32, 32, 32, dtype) - #validate(128, 128, 512, dtype, log=True) + # validate(128, 128, 512, dtype, log=True) # validate(8192, 8192, args.K_range[0], dtype, log=True) #profile(8192, 8192, args.K_range[0], dtype) # exit() - print(dtype) + # print(dtype) proton.start("matmul", hook="triton") #proton.deactivate() From f653350a8947bd1aa5f61a5eb9600187bb7e989a Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Thu, 20 Nov 2025 17:41:45 +0800 Subject: [PATCH 09/17] add modal gemm --- docker/modal_gemm.py | 105 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 docker/modal_gemm.py diff --git a/docker/modal_gemm.py b/docker/modal_gemm.py new file mode 100644 index 0000000..9a2aa51 --- /dev/null +++ b/docker/modal_gemm.py @@ -0,0 +1,105 @@ +from modal import Image, App, Volume +import pathlib +local_dir = pathlib.Path(__file__).parent +root_dir = local_dir.parent +requirements_file = root_dir / "requirements.txt" +txl_wheel_file = local_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" + +test_file = root_dir / "python" / "txl" / "tutorials" / "01-matmul.py" + +app = App(name="txl") # Note: this is optional since Modal 0.57 +volume = Volume.from_name("txl-dump", create_if_missing=True) # create a cloud volume to store compiled dump files + +txl_image = ( + Image.debian_slim(python_version="3.12") + #Image.from_dockerfile(path="./Dockerfile") + .workdir("/workspace") + .add_local_file(txl_wheel_file, remote_path="/workspace/", copy=True) # copy the local code to the image + .run_commands( "ls .") + .pip_install_from_requirements(requirements_file) # local file not remote file + .run_commands( + "pip install /workspace/txl-3.4.0-cp312-cp312-linux_x86_64.whl", + ) + .add_local_file(test_file, remote_path="/workspace/test_txl.py", copy=False) # copy after image build, no need rebuild +) + +# Example function that uses the image +@app.function(gpu="H100", image=txl_image, timeout=60*60, + volumes={"/workspace/dump": volume}) +def run_demo(): + import subprocess, sys, os, torch, time + def get_gpu_type(): + + try: + # Execute nvidia-smi command to query GPU details + result = subprocess.run(['nvidia-smi', '-q'], capture_output=True, text=True, check=True) + output = result.stdout + + # Look for indicators of SXM or PCIe in the output + for line in output.split("\n"): + if "Product Name" in line: + print(line) + if 'H100' in line and 'HBM3' in line: + return True + except subprocess.CalledProcessError as e: + print(f"Error running nvidia-smi: {e}") + except FileNotFoundError: + print("nvidia-smi not found. Please ensure NVIDIA drivers are installed and in your PATH.") + return False + + def test_demo(): + os.makedirs("/workspace/dump", exist_ok=True) + logs_dir = pathlib.Path("/workspace/dump/logs") + logs_dir.mkdir(parents=True, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + log_path = logs_dir / f"mla-{ts}.log" + + env = os.environ.copy() + # env["TRITON_PRINT_AUTOTUNING"] = "0" + # env["TRITON_KERNEL_DUMP"] = "1" + # env["TRITON_DUMP_DIR"] = "/workspace/dump" + # env["TRITON_ALWAYS_COMPILE"] = "1" + # env["CUDA_LAUNCH_BLOCKING"] = "1" + + cmd = [sys.executable, "-u", "/workspace/test_txl.py"] + + with open(log_path, "w", buffering=1, encoding="utf-8", errors="replace") as f: + proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, env=env, bufsize=1 + ) + assert proc.stdout is not None + for line in proc.stdout: + print(line, end="") + f.write(line) + + rc = proc.wait() + + print(f"\n=== FULL LOG SAVED ===\n{log_path}\n") + if rc != 0: + raise SystemExit(rc) + + def import_cuBLAS_lib(): + import os, ctypes, pathlib + import nvidia.cublas, nvidia.cuda_runtime + + cublas_dir = (pathlib.Path(nvidia.cublas.__file__).parent / "lib").resolve() + cudart_dir = (pathlib.Path(nvidia.cuda_runtime.__file__).parent / "lib").resolve() + + os.environ["LD_LIBRARY_PATH"] = f"{cublas_dir}:{cudart_dir}:" + os.environ.get("LD_LIBRARY_PATH","") + + for name12, name in [("libcublas.so.12","libcublas.so"), + ("libcublasLt.so.12","libcublasLt.so")]: + src = cublas_dir / name12 + dst = cublas_dir / name + if src.exists() and not dst.exists(): + try: dst.symlink_to(name12) + except FileExistsError: pass + except PermissionError: pass + + ctypes.CDLL(str(cudart_dir / "libcudart.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublasLt.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublas.so.12"), mode=ctypes.RTLD_GLOBAL) + + import_cuBLAS_lib() + test_demo() \ No newline at end of file From c430fc2f93aede451bbeb8de72b4706718eaa1bc Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Thu, 20 Nov 2025 22:35:08 +0800 Subject: [PATCH 10/17] fix attn FP8 --- docker/draw/draw_attention.py | 2 +- docker/modal_attn.py | 107 +++++++++++++++++++++ python/txl/tutorials/02-flash-attention.py | 17 ++-- 3 files changed, 117 insertions(+), 9 deletions(-) create mode 100644 docker/modal_attn.py diff --git a/docker/draw/draw_attention.py b/docker/draw/draw_attention.py index 013b22e..3f391e5 100644 --- a/docker/draw/draw_attention.py +++ b/docker/draw/draw_attention.py @@ -8,7 +8,7 @@ # ====================== 示例数据,按需替换 ====================== # FP16, causal = False fp16_nc_fa3 = np.array([570, 600, 610, 630, 640]) -fp16_nc_txl = np.array([484, 530, 565, 588, 599]) +fp16_nc_txl = np.array([484, 544, 578, 597, 608]) fp16_nc_triton= np.array([390, 460, 500, 520, 540]) fp16_nc_tile = np.array([447, 590, 610, 570, 600]) fp16_nc_tk = np.array([453, 597, 599, 610, 590]) diff --git a/docker/modal_attn.py b/docker/modal_attn.py new file mode 100644 index 0000000..cca0cb3 --- /dev/null +++ b/docker/modal_attn.py @@ -0,0 +1,107 @@ +from modal import Image, App, Volume +import pathlib +local_dir = pathlib.Path(__file__).parent +root_dir = local_dir.parent +requirements_file = root_dir / "requirements.txt" +txl_wheel_file = local_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" + +test_file = root_dir / "python" / "txl" / "tutorials" / "02-flash-attention.py" + +app = App(name="txl") # Note: this is optional since Modal 0.57 +volume = Volume.from_name("txl-dump", create_if_missing=True) # create a cloud volume to store compiled dump files + +txl_image = ( + Image.debian_slim(python_version="3.12") + #Image.from_dockerfile(path="./Dockerfile") + .workdir("/workspace") + .add_local_file(txl_wheel_file, remote_path="/workspace/", copy=True) # copy the local code to the image + .run_commands( "ls .") + .pip_install_from_requirements(requirements_file) # local file not remote file + .run_commands( + "pip install /workspace/txl-3.4.0-cp312-cp312-linux_x86_64.whl", + ) + .add_local_file(test_file, remote_path="/workspace/test_txl.py", copy=False) # copy after image build, no need rebuild +) + +# Example function that uses the image +@app.function(gpu="H100", image=txl_image, timeout=60*60, + volumes={"/workspace/dump": volume}) +def run_demo(): + import subprocess, sys, os, torch, time + def get_gpu_type(): + + try: + # Execute nvidia-smi command to query GPU details + result = subprocess.run(['nvidia-smi', '-q'], capture_output=True, text=True, check=True) + output = result.stdout + + # Look for indicators of SXM or PCIe in the output + for line in output.split("\n"): + if "Product Name" in line: + print(line) + if 'H100' in line and 'HBM3' in line: + return True + except subprocess.CalledProcessError as e: + print(f"Error running nvidia-smi: {e}") + except FileNotFoundError: + print("nvidia-smi not found. Please ensure NVIDIA drivers are installed and in your PATH.") + return False + + def test_demo(): + os.makedirs("/workspace/dump", exist_ok=True) + logs_dir = pathlib.Path("/workspace/dump/logs") + logs_dir.mkdir(parents=True, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + log_path = logs_dir / f"mla-{ts}.log" + + env = os.environ.copy() + # env["TRITON_PRINT_AUTOTUNING"] = "0" + # env["TRITON_KERNEL_DUMP"] = "1" + # env["TRITON_DUMP_DIR"] = "/workspace/dump" + # env["TRITON_ALWAYS_COMPILE"] = "1" + # env["CUDA_LAUNCH_BLOCKING"] = "1" + + cmd = [sys.executable, "-u", "/workspace/test_txl.py"] + + with open(log_path, "w", buffering=1, encoding="utf-8", errors="replace") as f: + proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, env=env, bufsize=1 + ) + assert proc.stdout is not None + for line in proc.stdout: + print(line, end="") + f.write(line) + + rc = proc.wait() + + print(f"\n=== FULL LOG SAVED ===\n{log_path}\n") + if rc != 0: + raise SystemExit(rc) + + def import_cuBLAS_lib(): + import os, ctypes, pathlib + import nvidia.cublas, nvidia.cuda_runtime + + cublas_dir = (pathlib.Path(nvidia.cublas.__file__).parent / "lib").resolve() + cudart_dir = (pathlib.Path(nvidia.cuda_runtime.__file__).parent / "lib").resolve() + + os.environ["LD_LIBRARY_PATH"] = f"{cublas_dir}:{cudart_dir}:" + os.environ.get("LD_LIBRARY_PATH","") + + for name12, name in [("libcublas.so.12","libcublas.so"), + ("libcublasLt.so.12","libcublasLt.so")]: + src = cublas_dir / name12 + dst = cublas_dir / name + if src.exists() and not dst.exists(): + try: dst.symlink_to(name12) + except FileExistsError: pass + except PermissionError: pass + + ctypes.CDLL(str(cudart_dir / "libcudart.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublasLt.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublas.so.12"), mode=ctypes.RTLD_GLOBAL) + + if get_gpu_type(): + print("Running on H100 SXM") + import_cuBLAS_lib() + test_demo() \ No newline at end of file diff --git a/python/txl/tutorials/02-flash-attention.py b/python/txl/tutorials/02-flash-attention.py index dd8e5d5..fcc46c2 100644 --- a/python/txl/tutorials/02-flash-attention.py +++ b/python/txl/tutorials/02-flash-attention.py @@ -1540,11 +1540,10 @@ def _attn_fwd_ws_tma_txl3(sm_scale, M, # num_warps=4, num_warpgroups=3, pre_hook = _host_descriptor_pre_hook, - #ir_override='dump/LNZRDQJRUVQP3KVJM5NGKARBSO3YM73N4M6D4UPZNFKU34LAERNA/_attn_fwd_ws_tma_txl4.ttgir', ) ], key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"], - ) +) @txl.jit def _attn_fwd_ws_tma_txl4(sm_scale, M, # Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # @@ -1562,6 +1561,7 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # with pl.scope("kernel"): dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 + byte_count: tl.constexpr = 2 if dtype == tl.float16 else 1 tl.static_assert(BLOCK_N <= HEAD_DIM) start_m = tl.program_id(0) off_hz = tl.program_id(1) @@ -1619,10 +1619,10 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # pMbar_bQ1i = txl.get_buffer(pMbar_bQ1, 0) with pl.scope("waitQ"): - txl.mbar_expect(pMbar_bQ0i, BLOCK_M // 2 * HEAD_DIM * 2) + txl.mbar_expect(pMbar_bQ0i, BLOCK_M // 2 * HEAD_DIM * byte_count) txl.tma_load(bQ0i, desc_q, [qo_offset_y, 0], pMbar_bQ0i) txl.mbar_wait(pMbar_bQ0i, 0) - txl.mbar_expect(pMbar_bQ1i, BLOCK_M // 2 * HEAD_DIM * 2) + txl.mbar_expect(pMbar_bQ1i, BLOCK_M // 2 * HEAD_DIM * byte_count) txl.tma_load(bQ1i, desc_q, [qo_offset_y+BLOCK_M//2, 0], pMbar_bQ1i) txl.mbar_wait(pMbar_bQ1i, 0) @@ -1645,13 +1645,13 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # with pl.scope("waitQK"): txl.mbar_wait(cur_mbar_QK1, phase) txl.mbar_wait(cur_mbar_QK2, phase) - txl.mbar_expect(cur_mbar_bK, BLOCK_N * HEAD_DIM * 2) + txl.mbar_expect(cur_mbar_bK, BLOCK_N * HEAD_DIM * byte_count) txl.tma_load(cur_bK, desc_k, [offsetkv_y, 0], cur_mbar_bK) with pl.scope("waitPV"): txl.mbar_wait(cur_mbar_PV1, phase) txl.mbar_wait(cur_mbar_PV2, phase) - txl.mbar_expect(cur_mbar_bV, BLOCK_N * HEAD_DIM * 2) + txl.mbar_expect(cur_mbar_bV, BLOCK_N * HEAD_DIM * byte_count) txl.tma_load(cur_bV, desc_v, [offsetkv_y, 0], cur_mbar_bV) offsetkv_y += BLOCK_N @@ -1671,7 +1671,8 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # txl.bar_arrive(8, 256) else: offs_m = start_m * BLOCK_M + tl.arange(BLOCK_M//2, BLOCK_M) - + # if txl.tid(0) == 129: + # txl.print('here1') # initialize pointer to m and l # These are in regs m_i = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf") @@ -2847,7 +2848,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16, algo=0, no_tune= TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') BATCH, N_HEADS, HEAD_DIM = 4, 32, 128 -TORCH_HAS_FP8=False +TORCH_HAS_FP8 = True # vary seq length for fixed head and batch=4 configs = [] for mode in ["fwd"]: From aec22895206f8fb2cc641704c7565cb87fbddf3c Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Fri, 21 Nov 2025 10:36:16 +0800 Subject: [PATCH 11/17] verified mla performance --- docker/draw/draw_mla.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/draw/draw_mla.py b/docker/draw/draw_mla.py index b82ceef..eda6c54 100644 --- a/docker/draw/draw_mla.py +++ b/docker/draw/draw_mla.py @@ -7,10 +7,10 @@ # ====================== 示例数据,按需替换 ====================== # FP16, causal = False -fp16_nc_fm = np.array([436, 510, 543, 560, 576, 579]) -fp16_nc_txl = np.array([393, 476, 518, 523, 525, 553]) +fp16_nc_fm = np.array([436, 510, 543, 560, 592, 579]) +fp16_nc_txl = np.array([401, 490, 518, 523, 541, 570]) fp16_nc_triton= np.array([20, 78, 98, 112, 136, 150]) -fp16_nc_tile = np.array([237, 430, 459, 465, 465, 463]) +fp16_nc_tile = np.array([237, 412, 459, 465, 465, 463]) fp16_nc_fi = np.array([290, 320, 360, 387, 350, 362]) # =============================================================== From 64be2a62f82fa4fb1bb8994a2b4c3fef980f6d5a Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Fri, 21 Nov 2025 11:28:07 +0800 Subject: [PATCH 12/17] add flashinfer test --- docker/draw/draw_mla.py | 10 +- docker/modal_flashinfer.py | 116 +++++ docker/tilelang/benchmark_flashinfer.py | 648 ++++++++++++++++++++++++ 3 files changed, 769 insertions(+), 5 deletions(-) create mode 100644 docker/modal_flashinfer.py create mode 100644 docker/tilelang/benchmark_flashinfer.py diff --git a/docker/draw/draw_mla.py b/docker/draw/draw_mla.py index eda6c54..bad3f65 100644 --- a/docker/draw/draw_mla.py +++ b/docker/draw/draw_mla.py @@ -7,11 +7,11 @@ # ====================== 示例数据,按需替换 ====================== # FP16, causal = False -fp16_nc_fm = np.array([436, 510, 543, 560, 592, 579]) -fp16_nc_txl = np.array([401, 490, 518, 523, 541, 570]) -fp16_nc_triton= np.array([20, 78, 98, 112, 136, 150]) -fp16_nc_tile = np.array([237, 412, 459, 465, 465, 463]) -fp16_nc_fi = np.array([290, 320, 360, 387, 350, 362]) +fp16_nc_fm = np.array([436, 510, 543, 564, 582, 601]) +fp16_nc_txl = np.array([401, 490, 518, 535, 538, 561]) +fp16_nc_triton= np.array([19, 28, 34, 38, 44, 46]) +fp16_nc_tile = np.array([237, 412, 459, 473, 498, 477]) +fp16_nc_fi = np.array([406, 491, 527, 532, 528, 552]) # =============================================================== methods = ["FlashMLA", "Txl", "Triton", "TileLang", "Flashinfer"] diff --git a/docker/modal_flashinfer.py b/docker/modal_flashinfer.py new file mode 100644 index 0000000..b5b396a --- /dev/null +++ b/docker/modal_flashinfer.py @@ -0,0 +1,116 @@ +from modal import Image, App, Volume +import pathlib +local_dir = pathlib.Path(__file__).parent +root_dir = local_dir.parent +requirements_file = root_dir / "requirements.txt" +txl_wheel_file = local_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" + +flash_mla_dir = local_dir / "flash_mla" +test_file = local_dir / "tilelang" / "benchmark_flashinfer.py" +kernel_file = local_dir / "tilelang" / "example_mla_decode_paged.py" + +app = App(name="txl") # Note: this is optional since Modal 0.57 +volume = Volume.from_name("txl-dump", create_if_missing=True) # create a cloud volume to store compiled dump files + +txl_image = ( + Image.from_registry( + "nvidia/cuda:12.4.0-devel-ubuntu22.04", + add_python="3.12", + ) + #Image.from_dockerfile(path="./Dockerfile") + .workdir("/workspace") + .add_local_file(txl_wheel_file, remote_path="/workspace/", copy=True) # copy the local code to the image + .run_commands( "ls .") + .pip_install_from_requirements(requirements_file) # local file not remote file + .pip_install("triton") + .pip_install("flashinfer-python") + .add_local_file(test_file, remote_path="/workspace/test_txl.py", copy=False) # copy after image build, no need rebuild + .add_local_file(kernel_file, remote_path="/workspace/example_mla_decode_paged.py", copy=False) + .add_local_dir(flash_mla_dir, remote_path="/workspace/flash_mla", copy=False) +) + +# Example function that uses the image +@app.function(gpu="H100", image=txl_image, timeout=60*60, + volumes={"/workspace/dump": volume}) +def run_demo(): + import subprocess, sys, os, torch, time + def get_gpu_type(): + + try: + # Execute nvidia-smi command to query GPU details + result = subprocess.run(['nvidia-smi', '-q'], capture_output=True, text=True, check=True) + output = result.stdout + + # Look for indicators of SXM or PCIe in the output + for line in output.split("\n"): + if "Product Name" in line: + print(line) + if 'H100' in line and 'HBM3' in line: + return True + except subprocess.CalledProcessError as e: + print(f"Error running nvidia-smi: {e}") + except FileNotFoundError: + print("nvidia-smi not found. Please ensure NVIDIA drivers are installed and in your PATH.") + return False + + def test_demo(): + os.makedirs("/workspace/dump", exist_ok=True) + logs_dir = pathlib.Path("/workspace/dump/logs") + logs_dir.mkdir(parents=True, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + log_path = logs_dir / f"mla-{ts}.log" + + env = os.environ.copy() + # env["TRITON_PRINT_AUTOTUNING"] = "0" + # env["TRITON_KERNEL_DUMP"] = "1" + # env["TRITON_DUMP_DIR"] = "/workspace/dump" + # env["TRITON_ALWAYS_COMPILE"] = "1" + # env["CUDA_LAUNCH_BLOCKING"] = "1" + + cmd = [sys.executable, "-u", "/workspace/test_txl.py"] + + with open(log_path, "w", buffering=1, encoding="utf-8", errors="replace") as f: + proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, env=env, bufsize=1 + ) + assert proc.stdout is not None + for line in proc.stdout: + print(line, end="") + f.write(line) + + rc = proc.wait() + + print(f"\n=== FULL LOG SAVED ===\n{log_path}\n") + if rc != 0: + raise SystemExit(rc) + + def import_cuBLAS_lib(): + import os, ctypes, pathlib + import nvidia.cublas, nvidia.cuda_runtime + + cublas_dir = (pathlib.Path(nvidia.cublas.__file__).parent / "lib").resolve() + cudart_dir = (pathlib.Path(nvidia.cuda_runtime.__file__).parent / "lib").resolve() + + # cuda_home = cudart_dir.parent + # os.environ.setdefault("CUDA_HOME", str(cuda_home)) + + os.environ["LD_LIBRARY_PATH"] = f"{cublas_dir}:{cudart_dir}:" + os.environ.get("LD_LIBRARY_PATH","") + + for name12, name in [("libcublas.so.12","libcublas.so"), + ("libcublasLt.so.12","libcublasLt.so")]: + src = cublas_dir / name12 + dst = cublas_dir / name + if src.exists() and not dst.exists(): + try: dst.symlink_to(name12) + except FileExistsError: pass + except PermissionError: pass + + ctypes.CDLL(str(cudart_dir / "libcudart.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublasLt.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublas.so.12"), mode=ctypes.RTLD_GLOBAL) + + if get_gpu_type(): + print("Running on H100 GPU with HBM3 memory.") + import_cuBLAS_lib() + test_demo() \ No newline at end of file diff --git a/docker/tilelang/benchmark_flashinfer.py b/docker/tilelang/benchmark_flashinfer.py new file mode 100644 index 0000000..a61f77e --- /dev/null +++ b/docker/tilelang/benchmark_flashinfer.py @@ -0,0 +1,648 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch +import triton +import triton.language as tl + +# import tilelang +# from tilelang.profiler import do_bench +# from example_mla_decode_paged import mla_decode_tilelang + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, + h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@torch.inference_mode() +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, + h_kv, d, dv, causal, dtype): + from flash_mla import flash_mla_with_kvcache, get_mla_metadata + + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + ) + + out_flash, lse_flash = flash_mla() + t = triton.testing.do_bench(flash_mla) + return out_flash, lse_flash, t + +@torch.inference_mode() +def run_txl_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, + h_q, h_kv, d, dv, causal, dtype): + from flash_mla import make_mla_runner + + q_value = q[..., :dv].permute(0, 2, 1, 3).contiguous() + q_pe = q[..., dv:].permute(0, 2, 1, 3).contiguous() + blocked_k = blocked_k.view(b, max_seqlen_pad, h_kv, d) + k_value = blocked_k[..., :dv].permute(0, 2, 1, 3).contiguous() + k_pe = blocked_k[..., dv:].permute(0, 2, 1, 3).contiguous() + # print(f"{q_value.shape=}, {q_pe.shape=}, {k_value.shape=}, {k_pe.shape=}") + + runner = make_mla_runner(q_value, k_value, q_pe, k_pe, 1 / math.sqrt(576), algo = 0,) + + out_txl = runner().permute(0, 2, 1, 3).contiguous() + torch.cuda.synchronize() + t = triton.testing.do_bench(runner) + + return out_txl, None, t + +@torch.inference_mode() +def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, + h_q, h_kv, d, dv, causal, dtype): + # pip install flashinfer-python + import flashinfer + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., + dv:].contiguous() + + kv_indptr = [0] + kv_indices = [] + for i in range(b): + seq_len = cache_seqlens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_table[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + for seq_len in cache_seqlens[1:]: + kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) + + q_indptr = torch.arange(0, b + 1).int() * s_q + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3") + mla_wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + cache_seqlens, + h_q, + dv, + d - dv, + block_size, + causal, + 1 / math.sqrt(d), + q.dtype, + blocked_k.dtype, + ) + + def flashinfer(): + output, lse = mla_wrapper.run( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, d - dv), + blocked_k_nope, + blocked_k_pe, + return_lse=True) + return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) + + out_flash, lse_flash = flashinfer() + t = triton.testing.do_bench(flashinfer) + return out_flash, lse_flash, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[ + None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, + None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[ + None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + ) + + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + num_warps=4, + num_stages=2, + ) + + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, + cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., + dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton( + q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv), + blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits, + num_kv_splits, 1 / math.sqrt(d), block_size) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +@torch.inference_mode() +def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, + cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + BLOCK_N = 64 + BLOCK_H = 64 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) + + def flash_mla_tilelang(): + out = kernel( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, dpe), + blocked_k_nope.view(-1, h_kv, dv), + blocked_k_pe.view(-1, h_kv, dpe), + block_table, + cache_seqlens, + glse, + out_partial, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_tilelang() + t = do_bench(flash_mla_tilelang) + return out_flash, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "tilelang": run_flash_mla_tilelang, + "flash_mla": run_flash_mla, + "flashinfer": run_flashinfer, + "flash_mla_triton": run_flash_mla_triton, + "txl": run_txl_mla, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, + s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, + s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flashinfer", "flash_mla_triton", "tilelang" + ] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: + # flashinfer has a different lse return value + # flash_mla_triton and flash_mla_tilelang doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( + torch.finfo(dtype).bits // 8) + print( + f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s" + ) + print( + f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" + ) + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + mean_seqlens = cache_seqlens.float().mean().item() + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, + s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( + torch.finfo(dtype).bits // 8) + print( + f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s" + ) + return bytes / 10**6 / perf_b + + +available_targets = [ + # "torch", + # "tilelang", + # "flash_mla", + "flashinfer", + "flash_mla_triton", + # "txl", +] + +shape_configs = [{ + "b": + batch, + "s_q": + 1, # tilelang can run only when s_q=1 + "cache_seqlens": + torch.tensor([seqlen for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": + head, + "h_kv": + 1, + "d": + 512 + 64, + "dv": + 512, + "causal": + False, + "dtype": + torch.float16 +} for batch in [132] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="tilelang") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + # args = get_args() + # benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + # with open(f"{benchmark_type}_perf.csv", "w") as fout: + # fout.write("name,batch,seqlen,head,bw\n") + # for shape in shape_configs: + # if args.all: + # for target in available_targets: + # perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], + # shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], + # shape["causal"], shape["dtype"]) + # fout.write( + # f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + # ) + # elif args.compare: + # perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], + # shape["cache_seqlens"], shape["h_q"], shape["h_kv"], + # shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + # fout.write( + # f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n' + # ) + # fout.write( + # f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n' + # ) + # elif args.one: + # perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], + # shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], + # shape["causal"], shape["dtype"]) + # fout.write( + # f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n' + # ) + for shape in shape_configs: + for target in available_targets: + perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], + shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], + shape["causal"], shape["dtype"]) \ No newline at end of file From e2cc15a1ec37c169f79f19ed03ff5e0eda0ad2a6 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Fri, 21 Nov 2025 13:26:38 +0800 Subject: [PATCH 13/17] add mla sq2 draw --- docker/draw/draw_mla.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/docker/draw/draw_mla.py b/docker/draw/draw_mla.py index bad3f65..4feb549 100644 --- a/docker/draw/draw_mla.py +++ b/docker/draw/draw_mla.py @@ -6,15 +6,20 @@ x = np.arange(len(ctx)) # 真正用于画图的位置:0,1,2,... # ====================== 示例数据,按需替换 ====================== -# FP16, causal = False +# FP16, causal = False, sq=1 fp16_nc_fm = np.array([436, 510, 543, 564, 582, 601]) fp16_nc_txl = np.array([401, 490, 518, 535, 538, 561]) fp16_nc_triton= np.array([19, 28, 34, 38, 44, 46]) fp16_nc_tile = np.array([237, 412, 459, 473, 498, 477]) fp16_nc_fi = np.array([406, 491, 527, 532, 528, 552]) +# FP16, causal = False, sq=2 +fp16_nc2_fm = np.array([521, 591, 621, 628, 579, 626]) +fp16_nc2_txl = np.array([475, 531, 554, 557, 565, 587]) +fp16_nc2_fi = np.array([486, 534, 535, 545, 536, 539]) # =============================================================== -methods = ["FlashMLA", "Txl", "Triton", "TileLang", "Flashinfer"] +methods_q1 = ["FlashMLA", "Txl", "Triton", "TileLang", "Flashinfer"] +methods_q2 = ["FlashMLA", "Txl", "Flashinfer"] colors = { "FlashMLA": "#f1c40f", @@ -25,17 +30,22 @@ } data = { - ("FP16, causal=False"): { + ("FP16, causal=False, s_q=1"): { "FlashMLA": fp16_nc_fm, "Txl": fp16_nc_txl, "Triton": fp16_nc_triton, "TileLang": fp16_nc_tile, "Flashinfer": fp16_nc_fi, }, + ("FP16, causal=False, s_q=2"): { + "FlashMLA": fp16_nc2_fm, + "Txl": fp16_nc2_txl, + "Flashinfer": fp16_nc2_fi, + }, } # fig, axes = plt.subplots(2, 2, figsize=(12, 4), sharex=True) -fig, ax = plt.subplots(figsize=(6, 4)) +fig, ax = plt.subplots(1, 2, figsize=(12, 4), sharex=True) bar_width = 0.16 @@ -70,20 +80,28 @@ def plot_panel(ax, title, panel_data, ylim, yticks, methods): # 上排:FP16 plot_panel( - ax, - "FP16, causal=false", - data[("FP16, causal=False")], + ax[0], + "FP16, causal=false, s_q=1", + data[("FP16, causal=False, s_q=1")], ylim=(0, 800), yticks=[0, 200, 400, 600, 800], - methods=methods, + methods=methods_q1, ) +plot_panel( + ax[1], + "FP16, causal=false, s_q=2", + data[("FP16, causal=False, s_q=2")], + ylim=(0, 800), + yticks=[0, 200, 400, 600, 800], + methods=methods_q2, +) # axes[0, 0].set_ylabel("Throughput (TFLOPs/s)") # axes[1, 0].set_ylabel("Throughput (TFLOPs/s)") # axes[1, 0].set_xlabel("Context length") # axes[1, 1].set_xlabel("Context length") -handles, labels = ax.get_legend_handles_labels() +handles, labels = ax[0].get_legend_handles_labels() fig.legend( handles, labels, @@ -92,8 +110,9 @@ def plot_panel(ax, title, panel_data, ylim, yticks, methods): bbox_to_anchor=(0.5, 0.98), ) -ax.set_ylabel("Throughput (TFLOPs/s)") -ax.set_xlabel("Context length") +ax[0].set_ylabel("Throughput (TFLOPs/s)") +ax[0].set_xlabel("Context length") +ax[1].set_xlabel("Context length") plt.subplots_adjust(top=0.82) From 9de279de6511d3c02a159b890db9699debbc0080 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Fri, 21 Nov 2025 18:18:22 +0800 Subject: [PATCH 14/17] update fp8 gemm performance --- docker/draw/draw_gemm.py | 4 +- python/txl/tutorials/01-matmul.py | 212 ++++++++++++++++++++++++++++-- 2 files changed, 201 insertions(+), 15 deletions(-) diff --git a/docker/draw/draw_gemm.py b/docker/draw/draw_gemm.py index 3227109..474e0c4 100644 --- a/docker/draw/draw_gemm.py +++ b/docker/draw/draw_gemm.py @@ -11,9 +11,9 @@ tile_fp16 = np.array([300, 420, 600, 690, 700, 720, 740]) tk_fp16 = np.array([400, 680, 680, 709, 780, 788, 798]) -cublas_fp8 = np.array([876, 1188, 1385, 1503, 1561, 1573, 1436]) +cublas_fp8 = np.array([888, 1203, 1385, 1503, 1561, 1573, 1436]) # tawa_fp8 = np.array([900, 1470, 1600, 1600, 1550, 1500, 1400]) -txl_fp8 = np.array([807, 1081, 1265, 1357, 1300, 1270, 1242]) +txl_fp8 = np.array([1015, 1308, 1437, 1543, 1565, 1509, 1432]) triton_fp8 = np.array([720, 1212, 1502, 1530, 1535, 1528, 1478]) tile_fp8 = np.array([231, 312, 547, 712, 892, 930, 1003]) tk_fp8 = np.array([579, 860, 1232, 1398, 1497, 1503, 1429]) diff --git a/python/txl/tutorials/01-matmul.py b/python/txl/tutorials/01-matmul.py index beb56d2..ad0d321 100644 --- a/python/txl/tutorials/01-matmul.py +++ b/python/txl/tutorials/01-matmul.py @@ -1073,6 +1073,19 @@ def grid(META): key=["M", "N", "K"], use_cuda_graph=True, ) +# @txl.autotune( +# configs=[ +# txl.Config({"BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK, "GROUP_SIZE_M": 8, "NUM_CONSUMER_GROUPS": 2, "NUM_STAGES": s,}, +# num_stages=s, +# num_warps=4, +# num_warpgroups=3, +# pre_hook=matmul_tma_set_block_size_hook +# ) +# for BM in [128] for BN in [128, 256] for BK in [64, 128] for s in [3,4,5] +# ], +# key=["M", "N", "K"], +# use_cuda_graph=True, +# ) #@txl.jit(launch_metadata=_matmul_launch_metadata, diff_mode='llir', log_dir='dump') #@txl.jit(launch_metadata=_matmul_launch_metadata, diff_mode='llir') #@txl.jit(launch_metadata=_matmul_launch_metadata, src_file=filename) @@ -1115,6 +1128,160 @@ def matmul_persistent_ws_tma_txl_kernel( mbar_consumer2 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + if txl.is_warpgroup([0]): + + phase = 1 + bufIdx = 0 + for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)): + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + #pid_m = pid % num_pid_m + #pid_n = pid // num_pid_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + mbar_c1 = txl.get_buffer(mbar_consumer1, bufIdx) + mbar_c2 = txl.get_buffer(mbar_consumer2, bufIdx) + + mbar_p_a0 = txl.get_buffer(mbar_producer_a0, bufIdx) + mbar_p_a1 = txl.get_buffer(mbar_producer_a1, bufIdx) + mbar_p_b0 = txl.get_buffer(mbar_producer_b0, bufIdx) + + a0_buf = txl.get_buffer(a0, bufIdx) + a1_buf = txl.get_buffer(a1, bufIdx) + b0_buf = txl.get_buffer(b0, bufIdx) + + txl.mbar_wait(mbar_c1, phase) + txl.mbar_expect(mbar_p_a0, BLOCK_SIZE_M//2*BLOCK_SIZE_K*byte_count) + txl.tma_load(a0_buf, a_desc, [offs_am, offs_k], mbar_p_a0) + + txl.mbar_wait(mbar_c2, phase) + txl.mbar_expect(mbar_p_b0, BLOCK_SIZE_N*BLOCK_SIZE_K*byte_count) + txl.tma_load(b0_buf, b_desc, [offs_bn, offs_k], mbar_p_b0) + + + txl.mbar_expect(mbar_p_a1, BLOCK_SIZE_M//2*BLOCK_SIZE_K*byte_count) + txl.tma_load(a1_buf, a_desc, [offs_am + BLOCK_SIZE_M // 2, offs_k], mbar_p_a1) + + offs_k += BLOCK_SIZE_K + bufIdx = (bufIdx + 1) % NUM_STAGES + if bufIdx == 0: + phase = phase^1 + + if txl.is_warpgroup([1, 2]): # TODO: else + phase = 0 + bufIdx = 0 + for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)): + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + #pid_m = pid % num_pid_m + #pid_n = pid // num_pid_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + accumulator = tl.zeros((BLOCK_SIZE_M//2, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + mbar_p_b0 = txl.get_buffer(mbar_producer_b0, bufIdx) + + b0_buf = txl.get_buffer(b0, bufIdx) + txl.mbar_wait(mbar_p_b0, phase) + if txl.is_warpgroup([1]): + mbar_p_a0 = txl.get_buffer(mbar_producer_a0, bufIdx) + mbar_c1 = txl.get_buffer(mbar_consumer1, bufIdx) + a0_buf = txl.get_buffer(a0, bufIdx) + txl.mbar_wait(mbar_p_a0, phase) + accumulator = tl.dot(a0_buf, b0_buf.T, accumulator) # accumulator is reg, no contention among buffers + txl.dot_wait(0) + txl.mbar_arrive(mbar_c1) + if txl.is_warpgroup([2]): # TODO: else test + mbar_p_a1 = txl.get_buffer(mbar_producer_a1, bufIdx) + mbar_c2 = txl.get_buffer(mbar_consumer2, bufIdx) + a1_buf = txl.get_buffer(a1, bufIdx) + txl.mbar_wait(mbar_p_a1, phase) + accumulator = tl.dot(a1_buf, b0_buf.T, accumulator) + txl.dot_wait(0) + txl.mbar_arrive(mbar_c2) + + offs_k += BLOCK_SIZE_K + bufIdx = (bufIdx + 1) % NUM_STAGES + if bufIdx == 0: # TODO: pipelinestate + phase = phase^1 + + c = accumulator.to(dtype) + if txl.is_warpgroup([1]): + c_desc.store([offs_am, offs_bn], c) + if txl.is_warpgroup([2]): + c_desc.store([offs_am + BLOCK_SIZE_M//2, offs_bn], c) + +@txl.autotune( + configs=[ + txl.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "NUM_CONSUMER_GROUPS": 2, + "NUM_STAGES": 4, + }, + num_stages=4, + num_warps=4, + num_warpgroups=3, + pre_hook=matmul_tma_set_block_size_hook + ), + ], + key=["M", "N", "K"], + use_cuda_graph=True, +) +@txl.jit(launch_metadata=_matmul_launch_metadata) +def matmul_persistent_ws_tma_txl_kernel_fp8( + a_desc, + b_desc, + c_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + NUM_CONSUMER_GROUPS: tl.constexpr, + NUM_STAGES: tl.constexpr, + + # 3.4.x + #EPILOGUE_SUBTILE: tl.constexpr, # + NUM_SMS: tl.constexpr, # + WARP_SPECIALIZE: tl.constexpr, # +): + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 + byte_count: tl.constexpr = 2 if dtype == tl.float16 else 1 + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + a0 = txl.smem_alloc([BLOCK_SIZE_M//2, BLOCK_SIZE_K], dtype=dtype, num_stages=NUM_STAGES) + a1 = txl.smem_alloc([BLOCK_SIZE_M//2, BLOCK_SIZE_K], dtype=dtype, num_stages=NUM_STAGES) + b0 = txl.smem_alloc([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=dtype, num_stages=NUM_STAGES) + + mbar_producer_a0 = txl.mbar_alloc(1, num_stages=NUM_STAGES) + mbar_producer_a1 = txl.mbar_alloc(1, num_stages=NUM_STAGES) + mbar_producer_b0 = txl.mbar_alloc(1, num_stages=NUM_STAGES) + mbar_consumer1 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + mbar_consumer2 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + + if txl.is_warpgroup([0]): phase = 1 @@ -1237,14 +1404,22 @@ def grid(META): NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ), ) - - matmul_persistent_ws_tma_txl_kernel[grid]( + if dtype == torch.float8_e4m3fn: + matmul_persistent_ws_tma_txl_kernel_fp8[grid]( a_desc, b_desc, c_desc, # M, N, K, # FP8_OUTPUT=dtype == torch.float8_e4m3fn, # NUM_SMS=NUM_SMS, # WARP_SPECIALIZE=False, # ) + if dtype == torch.float16: + matmul_persistent_ws_tma_txl_kernel[grid]( + a_desc, b_desc, c_desc, # + M, N, K, # + FP8_OUTPUT=dtype == torch.float8_e4m3fn, # + NUM_SMS=NUM_SMS, # + WARP_SPECIALIZE=False, # + ) return c ########################################## @@ -1862,8 +2037,8 @@ def bench(K, dtype, reps=100, warmup_reps=25): b = bn.T.contiguous() - if cublas is not None: - bench_fn("cublas", reps, warmup_reps, cublas_matmul, a, b) + # if cublas is not None: + # bench_fn("cublas", reps, warmup_reps, cublas_matmul, a, b) #if dtype == torch.float16: # bench_fn("torch", reps, warmup_reps, torch_matmul, a, b) #bench_fn("naive", reps, warmup_reps, matmul, a, b.T) @@ -1947,13 +2122,22 @@ def validate(M, N, K, dtype, log=False): print() -def show_profile(precision, profile_name): +# def show_profile(precision, profile_name): +# import triton.profiler.viewer as proton_viewer +# metric_names = ["time/ms"] +# if precision == 'fp8': +# metric_names = ["tflop8/s"] + metric_names +# elif precision == 'fp16': +# metric_names = ["tflop16/s"] + metric_names +# file_name = f"{profile_name}.hatchet" +# tree, metrics = proton_viewer.parse(metric_names, file_name) +# proton_viewer.print_tree(tree, metrics) + +def show_profile(profile_name): import triton.profiler.viewer as proton_viewer metric_names = ["time/ms"] - if precision == 'fp8': - metric_names = ["tflop8/s"] + metric_names - elif precision == 'fp16': - metric_names = ["tflop16/s"] + metric_names + metric_names = ["tflop8/s"] + metric_names + metric_names = ["tflop16/s"] + metric_names file_name = f"{profile_name}.hatchet" tree, metrics = proton_viewer.parse(metric_names, file_name) proton_viewer.print_tree(tree, metrics) @@ -2010,10 +2194,12 @@ def profile(M, N, K, dtype, log=False): # print(dtype) proton.start("matmul", hook="triton") - #proton.deactivate() + proton.deactivate() # for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): # bench(K, dtype) - for k_seqlen in [256, 512, 1024, 2048, 4096, 8192, 16384]: - bench(k_seqlen, dtype) + for dtype in [torch.float8_e4m3fn, torch.float16]: + for k_seqlen in [256, 512, 1024, 2048, 4096, 8192, 16384]: + bench(k_seqlen, dtype) proton.finalize() - show_profile(args.prec, "matmul") + + show_profile("matmul") From 0732868a87629a47e7cca6308ea2e73db34ad99c Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Sat, 22 Nov 2025 20:57:11 +0800 Subject: [PATCH 15/17] update attn FP8 performance --- docker/draw/draw_attention.py | 14 ++- python/txl/tutorials/02-flash-attention.py | 105 +++++++++++++++------ 2 files changed, 83 insertions(+), 36 deletions(-) diff --git a/docker/draw/draw_attention.py b/docker/draw/draw_attention.py index 3f391e5..ff5b054 100644 --- a/docker/draw/draw_attention.py +++ b/docker/draw/draw_attention.py @@ -21,11 +21,9 @@ fp16_c_tk = np.array([330, 430, 510, 560, 550]) # FP8, causal = False Txl 还没有支持 FP8 -fp8_nc_fa3 = np.array([550, 750, 850, 900, 880]) -fp8_nc_tawa = np.array([500, 700, 780, 820, 800]) -fp8_nc_triton = np.array([420, 600, 700, 730, 710]) -fp8_nc_tile = np.array([440, 620, 720, 760, 740]) -fp8_nc_tk = np.array([430, 610, 710, 750, 730]) +fp8_nc_fa3 = np.array([587, 770, 850, 900, 980]) +fp8_nc_txl = np.array([547, 627, 676, 703, 706]) +fp8_nc_triton = np.array([520, 610, 690, 716, 723]) # FP8, causal = True Txl 还没有支持 FP8 fp8_c_fa3 = np.array([600, 800, 900, 950, 930]) @@ -62,7 +60,7 @@ }, ("FP8, causal=False"): { "FA3 (CUTLASS)": fp8_nc_fa3, - "Txl": fp8_nc_tawa, + "Txl": fp8_nc_txl, "Triton": fp8_nc_triton, }, ("FP8, causal=True"): { @@ -128,7 +126,7 @@ def plot_panel(ax, title, panel_data, ylim, yticks, methods): axes[1, 0], "FP8, causal=false", data[("FP8, causal=False")], - ylim=(0, 1000), + ylim=(0, 1250), yticks=[0, 250, 500, 750, 1000], methods=methods_down, ) @@ -136,7 +134,7 @@ def plot_panel(ax, title, panel_data, ylim, yticks, methods): axes[1, 1], "FP8, causal=true", data[("FP8, causal=True")], - ylim=(0, 1000), + ylim=(0, 1250), yticks=[0, 250, 500, 750, 1000], methods=methods_down, ) diff --git a/python/txl/tutorials/02-flash-attention.py b/python/txl/tutorials/02-flash-attention.py index fcc46c2..b5eecc2 100644 --- a/python/txl/tutorials/02-flash-attention.py +++ b/python/txl/tutorials/02-flash-attention.py @@ -88,10 +88,14 @@ def _host_descriptor_pre_hook(nargs): BLOCK_M = nargs["BLOCK_M"] // NUM_CONSUMER_GROUPS BLOCK_N = nargs["BLOCK_N"] HEAD_DIM = nargs["HEAD_DIM"] + FP8_OUTPUT = nargs.get("FP8_OUTPUT", False) if not isinstance(nargs["desc_q"], TensorDescriptor): return nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] - nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] + if FP8_OUTPUT: + nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] + else: + nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] @@ -1572,8 +1576,13 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # # If no host desc, then make device desc desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], block_shape=[BLOCK_M, HEAD_DIM]) - desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=[BLOCK_N, HEAD_DIM]) + if FP8_OUTPUT: #v_shape = (BATCH, H, HEAD_DIM, N_CTX) + y_dim_v = Z * H * HEAD_DIM + desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim_v, N_CTX], strides=[N_CTX, 1], + block_shape=[HEAD_DIM, BLOCK_N]) + else: + desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], block_shape=[BLOCK_N, HEAD_DIM]) desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], @@ -1591,7 +1600,10 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # pMbar_bQ1 = txl.mbar_alloc(1) bK = txl.smem_alloc([BLOCK_N, HEAD_DIM], dtype=dtype, num_stages=NUM_STAGES) - bV = txl.smem_alloc([BLOCK_N, HEAD_DIM], dtype=dtype, num_stages=NUM_STAGES) + if FP8_OUTPUT: + bV = txl.smem_alloc([HEAD_DIM, BLOCK_N], dtype=dtype, num_stages=NUM_STAGES) + else: + bV = txl.smem_alloc([BLOCK_N, HEAD_DIM], dtype=dtype, num_stages=NUM_STAGES) pMbar_bK = txl.mbar_alloc(1, num_stages=NUM_STAGES) pMbar_bV = txl.mbar_alloc(1, num_stages=NUM_STAGES) @@ -1652,7 +1664,10 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # txl.mbar_wait(cur_mbar_PV1, phase) txl.mbar_wait(cur_mbar_PV2, phase) txl.mbar_expect(cur_mbar_bV, BLOCK_N * HEAD_DIM * byte_count) - txl.tma_load(cur_bV, desc_v, [offsetkv_y, 0], cur_mbar_bV) + if FP8_OUTPUT: + txl.tma_load(cur_bV, desc_v, [off_hz * HEAD_DIM, start_n], cur_mbar_bV) + else: + txl.tma_load(cur_bV, desc_v, [offsetkv_y, 0], cur_mbar_bV) offsetkv_y += BLOCK_N bufIdxW = (bufIdxW + 1) % NUM_STAGES @@ -1797,7 +1812,10 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # # txl.bar_wait(WG2_BAR, WG_NUM_THREADS) # note that this non transposed v for FP8 is only supported on Blackwell - acc = tl.dot(p, cur_bV, acc) + if FP8_OUTPUT: + acc = tl.dot(p, cur_bV.T, acc) + else: + acc = tl.dot(p, cur_bV, acc) with pl.scope("dotQK"): txl.dot_wait(1) @@ -1864,7 +1882,10 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # txl.mbar_wait(cur_mbar_bV, phaseV) # note that this non transposed v for FP8 is only supported on Blackwell - acc = tl.dot(p, cur_bV, acc) + if FP8_OUTPUT: + acc = tl.dot(p, cur_bV.T, acc) + else: + acc = tl.dot(p, cur_bV, acc) txl.dot_wait(0) #txl.mbar_arrive(cur_mbar_PV) @@ -2615,7 +2636,7 @@ def forward(ctx, q, k, v, causal, sm_scale, algo=0, no_tune=False, profiling=Fal HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. HEAD_DIM_V = v.shape[-1] - assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_Q == HEAD_DIM_K assert HEAD_DIM_K in {16, 32, 64, 128, 256} o = torch.empty_like(q) stage = 3 if causal else 1 @@ -2630,10 +2651,16 @@ def forward(ctx, q, k, v, causal, sm_scale, algo=0, no_tune=False, profiling=Fal if supports_host_descriptor(): # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor y_dim = q.shape[0] * q.shape[1] * q.shape[2] + dtype = q.dtype dummy_block = [1, 1] desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) - desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + if dtype == torch.float8_e5m2: + n_ctx = v.shape[3] + y_dim_v = v.shape[0] * v.shape[1] * v.shape[2] + desc_v = TensorDescriptor(v, shape=[y_dim_v, n_ctx], strides=[n_ctx, 1], block_shape=dummy_block) + else: + desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) else: @@ -2781,19 +2808,20 @@ def backward(ctx, do): #@pytest.mark.parametrize("causal", [True]) # FIXME: Non-causal tests do not pass at the moment. #@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False]) def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16, algo=0, no_tune=False, profiling=False): - #torch.manual_seed(20) - #q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5)) - #k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5)) - #v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5)) - q = (torch.randn((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)) - k = (torch.randn((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)) - v = (torch.randn((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE)) - #sm_scale = 0.5 - #sm_scale = 1.0 + torch.manual_seed(20) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=torch.float16, device=DEVICE).normal_(mean=0.0, std=0.5)) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=torch.float16, device=DEVICE).normal_(mean=0.0, std=0.5)) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=torch.float16, device=DEVICE).normal_(mean=0.0, std=0.5)) q1 = q.permute(0,2,1,3).contiguous() k1 = k.permute(0,2,1,3).contiguous() v1 = v.permute(0,2,1,3).contiguous() - + if dtype == torch.float8_e5m2: + v = v.permute(0,1,3,2).contiguous().to(torch.float8_e5m2) + q = q.to(torch.float8_e5m2) + k = k.to(torch.float8_e5m2) + print(f"q sample: {q[0,0,0,:8]}") + print(f"k sample: {k[0,0,0,:8]}") + print(f"v sample: {v[0,0,0,:8]}") test_outs = [] # txl if HAS_FLASH: @@ -2834,8 +2862,13 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16, algo=0, no_tune= ## OLD for actual_out in test_outs: + if dtype == torch.float8_e5m2: + actual_out = actual_out.to(torch.float16) + print(actual_out.shape) print(f"Output max diff: {(actual_out - ref_out).abs().max().item()}") print(f"Output mean diff: {(actual_out - ref_out).abs().mean().item()}") + print(f"actual sample: {actual_out[0,0,0,:8]}") + print(f"ref sample: {ref_out[0,0,0,:8]}") assert torch.allclose(ref_out, actual_out, atol=1e-2, rtol=0) #rtol = 0.0 @@ -2849,6 +2882,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16, algo=0, no_tune= BATCH, N_HEADS, HEAD_DIM = 4, 32, 128 TORCH_HAS_FP8 = True +TORCH_HAS_FP16 = True # vary seq length for fixed head and batch=4 configs = [] for mode in ["fwd"]: @@ -2862,9 +2896,9 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16, algo=0, no_tune= x_vals=[2**i for i in range(10, 15)], # x_vals=[2**i for i in range(14, 15)], line_arg="provider", - line_vals=(["triton-fp16"] if Has_TXL else []) + (["triton-fp8"] if TORCH_HAS_FP8 else []) + + line_vals=(["triton-fp16"] if TORCH_HAS_FP16 else []) + (["triton-fp8"] if TORCH_HAS_FP8 else []) + (["flash"] if HAS_FLASH else []), - line_names=(["Triton [FP16]"] if Has_TXL else []) + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + + line_names=(["Triton [FP16]"] if TORCH_HAS_FP16 else []) + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + (["Flash-3"] if HAS_FLASH else []), styles=[("red", "-"), ("blue", "-"), ("green", "-")], ylabel="TFLOPS", @@ -2886,15 +2920,30 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev BATCH = int(16384 / N_CTX) assert mode in ["fwd", "bwd"] dtype = torch.float16 - q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device) - k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device) - v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device) + # q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device) + # k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device) + # v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device) if "triton" in provider: + if mode == "fwd" and "fp8" in provider: + v_shape = (BATCH, H, HEAD_DIM, N_CTX) + else: + v_shape = (BATCH, H, N_CTX, HEAD_DIM) + q = torch.randn( + (BATCH, H, N_CTX, HEAD_DIM), + dtype=dtype, + device=device, + requires_grad=True, + ) + k = torch.randn( + (BATCH, H, N_CTX, HEAD_DIM), + dtype=dtype, + device=device, + requires_grad=True, + ) + v = torch.randn(v_shape, dtype=dtype, device=device, requires_grad=True) if mode == "fwd" and "fp8" in provider: q = q.to(torch.float8_e5m2) k = k.to(torch.float8_e5m2) - v = v.permute(0, 1, 3, 2).contiguous() - v = v.permute(0, 1, 3, 2) v = v.to(torch.float8_e5m2) sm_scale = 1/math.sqrt(HEAD_DIM) fn = lambda: attention(q, k, v, causal, sm_scale, algo, no_tune) @@ -2940,14 +2989,14 @@ def run_test(algo=0, dump_dir=None): no_tune=True # has best config #no_tune=False # no best config - print("TEST...") + # print("TEST...") #test_op(1, 2, 1024, 128, False, dtype=torch.float16, no_tune=no_tune) PROFILING=False #test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=0, no_tune=no_tune, profiling=PROFILING) #test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=1, no_tune=no_tune, profiling=PROFILING) #test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=2, no_tune=no_tune, profiling=PROFILING) - #test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=4, no_tune=no_tune, profiling=PROFILING) + # test_op(16, 32, 1024, 128, False, dtype=torch.float8_e5m2, algo=4, no_tune=no_tune, profiling=PROFILING) #test_op(1, 2, 1536, 128, False, dtype=torch.float16, algo=4, no_tune=no_tune, profiling=PROFILING) # test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=algo, no_tune=no_tune, profiling=PROFILING) From 96cedca65dbabd9633810c73598bfc1281cdfc16 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Sat, 22 Nov 2025 22:17:50 +0800 Subject: [PATCH 16/17] update txl attn causal --- docker/draw/draw_attention.py | 30 +- python/txl/tutorials/02-flash-attention.py | 303 ++++++++++++++++++++- 2 files changed, 308 insertions(+), 25 deletions(-) diff --git a/docker/draw/draw_attention.py b/docker/draw/draw_attention.py index ff5b054..a7a2039 100644 --- a/docker/draw/draw_attention.py +++ b/docker/draw/draw_attention.py @@ -13,24 +13,22 @@ fp16_nc_tile = np.array([447, 590, 610, 570, 600]) fp16_nc_tk = np.array([453, 597, 599, 610, 590]) -# FP16, causal = True Txl 还没有支持 causal attention -fp16_c_fa3 = np.array([420, 520, 620, 680, 650]) -fp16_c_tawa = np.array([380, 480, 580, 630, 620]) -fp16_c_triton = np.array([320, 420, 500, 550, 540]) -fp16_c_tile = np.array([340, 440, 520, 570, 560]) -fp16_c_tk = np.array([330, 430, 510, 560, 550]) - -# FP8, causal = False Txl 还没有支持 FP8 +# FP16, causal = True +fp16_c_fa3 = np.array([387, 512, 589, 613,622]) +fp16_c_txl = np.array([335, 429, 497, 539, 573]) +fp16_c_triton = np.array([238, 376, 421, 473, 500]) +fp16_c_tile = np.array([346, 436, 511, 478, 509]) +fp16_c_tk = np.array([302, 421, 437, 457, 477]) + +# FP8, causal = False fp8_nc_fa3 = np.array([587, 770, 850, 900, 980]) fp8_nc_txl = np.array([547, 627, 676, 703, 706]) fp8_nc_triton = np.array([520, 610, 690, 716, 723]) -# FP8, causal = True Txl 还没有支持 FP8 -fp8_c_fa3 = np.array([600, 800, 900, 950, 930]) -fp8_c_tawa = np.array([540, 720, 800, 840, 820]) -fp8_c_triton = np.array([450, 630, 720, 760, 740]) -fp8_c_tile = np.array([470, 650, 740, 780, 760]) -fp8_c_tk = np.array([460, 640, 730, 770, 750]) +# FP8, causal = True +fp8_c_fa3 = np.array([379, 578, 738, 812, 864]) +fp8_c_txl = np.array([343, 485, 591, 639, 664]) +fp8_c_triton = np.array([333, 418, 558, 623, 679]) # =============================================================== methods_up = ["FA3 (CUTLASS)", "Txl", "Triton", "TileLang", "ThunderKittens"] @@ -53,7 +51,7 @@ }, ("FP16, causal=True"): { "FA3 (CUTLASS)": fp16_c_fa3, - "Txl": fp16_c_tawa, + "Txl": fp16_c_txl, "Triton": fp16_c_triton, "TileLang": fp16_c_tile, "ThunderKittens": fp16_c_tk, @@ -65,7 +63,7 @@ }, ("FP8, causal=True"): { "FA3 (CUTLASS)": fp8_c_fa3, - "Txl": fp8_c_tawa, + "Txl": fp8_c_txl, "Triton": fp8_c_triton, }, } diff --git a/python/txl/tutorials/02-flash-attention.py b/python/txl/tutorials/02-flash-attention.py index b5eecc2..b9c49ae 100644 --- a/python/txl/tutorials/02-flash-attention.py +++ b/python/txl/tutorials/02-flash-attention.py @@ -1900,6 +1900,280 @@ def _attn_fwd_ws_tma_txl4(sm_scale, M, # if txl.is_warpgroup([2]): desc_o.store([qo_offset_y+BLOCK_M//2, 0], acc.to(dtype)) +@txl.autotune( + configs=[ + txl.Config( + tma_ws_best_config, + num_stages=2, + num_warps=4, + num_warpgroups=3, + pre_hook = _host_descriptor_pre_hook, + ) + ], + key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"], +) +@txl.jit +def _attn_fwd_ws_tma_txl4_causal(sm_scale, M, # + Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + STAGE: tl.constexpr, # + warp_specialize: tl.constexpr, # + + # NOTE: txl + NUM_STAGES: tl.constexpr, # + NUM_CONSUMERS: tl.constexpr # + ): + + dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 + byte_count: tl.constexpr = 2 if dtype == tl.float16 else 1 + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + y_dim = Z * H * N_CTX + desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + if FP8_OUTPUT: + y_dim_v = Z * H * HEAD_DIM + desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim_v, N_CTX], strides=[N_CTX, 1], + block_shape=[HEAD_DIM, BLOCK_N]) + else: + desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + + offset_y = off_z * (N_CTX * H) + off_h * N_CTX + qo_offset_y = offset_y + start_m * BLOCK_M + + bQ0 = txl.smem_alloc([BLOCK_M//2, HEAD_DIM], dtype=dtype) + pMbar_bQ0 = txl.mbar_alloc(1) + bQ1 = txl.smem_alloc([BLOCK_M//2, HEAD_DIM], dtype=dtype) + pMbar_bQ1 = txl.mbar_alloc(1) + + bK = txl.smem_alloc([BLOCK_N, HEAD_DIM], dtype=dtype, num_stages=NUM_STAGES) + if FP8_OUTPUT: + bV = txl.smem_alloc([HEAD_DIM, BLOCK_N], dtype=dtype, num_stages=NUM_STAGES) + else: + bV = txl.smem_alloc([BLOCK_N, HEAD_DIM], dtype=dtype, num_stages=NUM_STAGES) + pMbar_bK = txl.mbar_alloc(1, num_stages=NUM_STAGES) + pMbar_bV = txl.mbar_alloc(1, num_stages=NUM_STAGES) + + cMbar_QK1 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + cMbar_PV1 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + cMbar_QK2 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + cMbar_PV2 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + + lo, hi = 0, (start_m + 1) * BLOCK_M + mask_begin = start_m * BLOCK_M + offsetkv_y = offset_y + lo + + if txl.is_warpgroup([0]): + bQ0i = txl.get_buffer(bQ0, 0) + pMbar_bQ0i = txl.get_buffer(pMbar_bQ0, 0) + bQ1i = txl.get_buffer(bQ1, 0) + pMbar_bQ1i = txl.get_buffer(pMbar_bQ1, 0) + + txl.mbar_expect(pMbar_bQ0i, BLOCK_M // 2 * HEAD_DIM * byte_count) + txl.tma_load(bQ0i, desc_q, [qo_offset_y, 0], pMbar_bQ0i) + txl.mbar_wait(pMbar_bQ0i, 0) + txl.mbar_expect(pMbar_bQ1i, BLOCK_M // 2 * HEAD_DIM * byte_count) + txl.tma_load(bQ1i, desc_q, [qo_offset_y+BLOCK_M//2, 0], pMbar_bQ1i) + txl.mbar_wait(pMbar_bQ1i, 0) + + bufIdxW = 0 + phase = 1 + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + cur_mbar_bK = txl.get_buffer(pMbar_bK, bufIdxW) + cur_mbar_bV = txl.get_buffer(pMbar_bV, bufIdxW) + cur_bK = txl.get_buffer(bK, bufIdxW) + cur_bV = txl.get_buffer(bV, bufIdxW) + + cur_mbar_QK1 = txl.get_buffer(cMbar_QK1, bufIdxW) # wait for the same buffer + cur_mbar_PV1 = txl.get_buffer(cMbar_PV1, bufIdxW) + cur_mbar_QK2 = txl.get_buffer(cMbar_QK2, bufIdxW) + cur_mbar_PV2 = txl.get_buffer(cMbar_PV2, bufIdxW) + + # TODO: tma_expect_and_load + txl.mbar_wait(cur_mbar_QK1, phase) + txl.mbar_wait(cur_mbar_QK2, phase) + txl.mbar_expect(cur_mbar_bK, BLOCK_N * HEAD_DIM * byte_count) + txl.tma_load(cur_bK, desc_k, [offsetkv_y, 0], cur_mbar_bK) + + txl.mbar_wait(cur_mbar_PV1, phase) + txl.mbar_wait(cur_mbar_PV2, phase) + txl.mbar_expect(cur_mbar_bV, BLOCK_N * HEAD_DIM * byte_count) + if FP8_OUTPUT: + txl.tma_load(cur_bV, desc_v, [off_hz * HEAD_DIM, start_n], cur_mbar_bV) + else: + txl.tma_load(cur_bV, desc_v, [offsetkv_y, 0], cur_mbar_bV) + + offsetkv_y += BLOCK_N + bufIdxW = (bufIdxW + 1) % NUM_STAGES + if bufIdxW == 0: + phase = phase^1 + + + if txl.is_warpgroup([1, 2]): + + if txl.is_warpgroup([1]): + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M//2) + txl.bar_arrive(8, 256) + else: + offs_m = start_m * BLOCK_M + tl.arange(BLOCK_M//2, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32) + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + + bQ0i = txl.get_buffer(bQ0, 0) + pMbar_bQ0i = txl.get_buffer(pMbar_bQ0, 0) + bQ1i = txl.get_buffer(bQ1, 0) + pMbar_bQ1i = txl.get_buffer(pMbar_bQ1, 0) + + if txl.is_warpgroup([1]): + txl.mbar_wait(pMbar_bQ0i, 0) + if txl.is_warpgroup([2]): + txl.mbar_wait(pMbar_bQ1i, 0) + + cur_mbar_bK = txl.get_buffer(pMbar_bK, 0) + cur_bK = txl.get_buffer(bK, 0) + txl.mbar_wait(cur_mbar_bK, 0) + + if txl.is_warpgroup([1]): + cur_mbar_QK = txl.get_buffer(cMbar_QK1, 0) + qk = tl.dot(bQ0i, cur_bK.T) + txl.dot_wait(0) + txl.mbar_arrive(cur_mbar_QK) + + else: # [2] + cur_mbar_QK = txl.get_buffer(cMbar_QK2, 0) + qk = tl.dot(bQ1i, cur_bK.T) + txl.dot_wait(0) + txl.mbar_arrive(cur_mbar_QK) + + if lo >= mask_begin: + mask = offs_m[:, None] >= (lo + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + m_i = m_ij + + p = p.to(dtype) + + bufIdxRK = 1 + bufIdxRV = 0 + phaseK = 0 + phaseV = 0 + + for start_n in range(lo+BLOCK_N, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + cur_mbar_bK = txl.get_buffer(pMbar_bK, bufIdxRK) + cur_bK = txl.get_buffer(bK, bufIdxRK) + txl.mbar_wait(cur_mbar_bK, phaseK) + + if txl.is_warpgroup([1]): + txl.bar_wait(8, 256) + if txl.is_warpgroup([2]): + txl.bar_wait(9, 256) + + if txl.is_warpgroup([1]): + cur_mbar_QK = txl.get_buffer(cMbar_QK1, bufIdxRK) + cur_mbar_PV = txl.get_buffer(cMbar_PV1, bufIdxRV) + qk = tl.dot(bQ0i, cur_bK.T) + + else: # [2] + cur_mbar_QK = txl.get_buffer(cMbar_QK2, bufIdxRK) + cur_mbar_PV = txl.get_buffer(cMbar_PV2, bufIdxRV) + qk = tl.dot(bQ1i, cur_bK.T) + + cur_mbar_bV = txl.get_buffer(pMbar_bV, bufIdxRV) + cur_bV = txl.get_buffer(bV, bufIdxRV) + txl.mbar_wait(cur_mbar_bV, phaseV) + + if FP8_OUTPUT: + acc = tl.dot(p, cur_bV.T, acc) + else: + acc = tl.dot(p, cur_bV, acc) + txl.dot_wait(1) + + if txl.is_warpgroup([1]): + txl.bar_arrive(9, 256) + else: + txl.bar_arrive(8, 256) + + txl.mbar_arrive(cur_mbar_QK) + + if start_n >= mask_begin: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + m_i = m_ij + + p = p.to(dtype) + + txl.dot_wait(0) + txl.mbar_arrive(cur_mbar_PV) + + acc = acc * alpha[:, None] + + bufIdxRK = (bufIdxRK + 1) % NUM_STAGES + if bufIdxRK == 0: + phaseK = phaseK ^ 1 + bufIdxRV = (bufIdxRV + 1) % NUM_STAGES + if bufIdxRV == 0: + phaseV = phaseV ^ 1 + + cur_mbar_bV = txl.get_buffer(pMbar_bV, bufIdxRV) + + cur_bV = txl.get_buffer(bV, bufIdxRV) + txl.mbar_wait(cur_mbar_bV, phaseV) + + if FP8_OUTPUT: + acc = tl.dot(p, cur_bV.T, acc) + else: + acc = tl.dot(p, cur_bV, acc) + txl.dot_wait(0) + + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + + if txl.is_warpgroup([1]): + desc_o.store([qo_offset_y, 0], acc.to(dtype)) + if txl.is_warpgroup([2]): + desc_o.store([qo_offset_y+BLOCK_M//2, 0], acc.to(dtype)) + ################################################### # TXL + TAWA ################################################### @@ -2690,9 +2964,10 @@ def grid(META): 2: _attn_fwd_ws_tma_txl2, 3: _attn_fwd_ws_tma_txl3, 4: _attn_fwd_ws_tma_txl4, - 5: _attn_fwd_ws_tma_txl_tawa, - 6: _attn_fwd_ws_tma_txl_tawa2, - 7: _attn_fwd_ws_tma_txl_test, + 5: _attn_fwd_ws_tma_txl4_causal, + # 5: _attn_fwd_ws_tma_txl_tawa, + # 6: _attn_fwd_ws_tma_txl_tawa2, + # 7: _attn_fwd_ws_tma_txl_test, } if profiling: @@ -2886,7 +3161,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16, algo=0, no_tune= # vary seq length for fixed head and batch=4 configs = [] for mode in ["fwd"]: - for causal in [False]: + for causal in [False, True]: for warp_specialize in [False, True] if is_blackwell() else [False]: if mode == "bwd" and not causal: continue @@ -2915,7 +3190,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16, algo=0, no_tune= @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device=DEVICE, algo=0, no_tune=False): +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device=DEVICE, no_tune=False): # follow fa3 paper BATCH = int(16384 / N_CTX) assert mode in ["fwd", "bwd"] @@ -2946,7 +3221,10 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev k = k.to(torch.float8_e5m2) v = v.to(torch.float8_e5m2) sm_scale = 1/math.sqrt(HEAD_DIM) - fn = lambda: attention(q, k, v, causal, sm_scale, algo, no_tune) + if causal: + fn = lambda: attention(q, k, v, causal, sm_scale, 5, no_tune) + else: + fn = lambda: attention(q, k, v, causal, sm_scale, 4, no_tune) if mode == "bwd": o = fn() do = torch.randn_like(o) @@ -2989,7 +3267,7 @@ def run_test(algo=0, dump_dir=None): no_tune=True # has best config #no_tune=False # no best config - # print("TEST...") + print("TEST...") #test_op(1, 2, 1024, 128, False, dtype=torch.float16, no_tune=no_tune) PROFILING=False @@ -2997,13 +3275,20 @@ def run_test(algo=0, dump_dir=None): #test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=1, no_tune=no_tune, profiling=PROFILING) #test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=2, no_tune=no_tune, profiling=PROFILING) # test_op(16, 32, 1024, 128, False, dtype=torch.float8_e5m2, algo=4, no_tune=no_tune, profiling=PROFILING) + + # torch.float8_e5m2 + causal may cause large numerical error, compare with tawa output the kernel is correct + # only algo5 is causal + # test_op(16, 32, 1024, 128, True, dtype=torch.float8_e5m2, algo=5, no_tune=no_tune, profiling=PROFILING) + # test_op(16, 32, 1024, 128, True, dtype=torch.float16, algo=5, no_tune=no_tune, profiling=PROFILING) + #test_op(1, 2, 1536, 128, False, dtype=torch.float16, algo=4, no_tune=no_tune, profiling=PROFILING) # test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=algo, no_tune=no_tune, profiling=PROFILING) print("BENCH...") - bench_flash_attention.run(save_path=".", print_data=True, algo=algo, no_tune=no_tune) + bench_flash_attention.run(save_path=".", print_data=True, no_tune=no_tune) if __name__ == "__main__": #run_test(6, dump_dir='dump/fa1113') #run_test(5, dump_dir='dump/fa1117') - run_test(4) + run_test() + From 1c0b84e263814034e38c3775b68cc65824b938d1 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Sun, 23 Nov 2025 17:56:52 +0800 Subject: [PATCH 17/17] update gemm benchmark --- docker/modal_gemm.py | 18 ++++- docker/tilelang/example_gemm.py | 63 +++++++++++++++ docker/tilelang/example_tilelang_gemm_fp8.py | 67 ++++++++++++++++ python/txl/tutorials/01-matmul.py | 84 +++++++++++++------- requirements.txt | 1 + 5 files changed, 202 insertions(+), 31 deletions(-) create mode 100644 docker/tilelang/example_gemm.py create mode 100644 docker/tilelang/example_tilelang_gemm_fp8.py diff --git a/docker/modal_gemm.py b/docker/modal_gemm.py index 9a2aa51..12ed7c6 100644 --- a/docker/modal_gemm.py +++ b/docker/modal_gemm.py @@ -6,21 +6,30 @@ txl_wheel_file = local_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" test_file = root_dir / "python" / "txl" / "tutorials" / "01-matmul.py" +# test_file = local_dir / "tilelang" / "benchmark_tilelang_matmul.py" +tilelang_gemm_fp8_file = local_dir / "tilelang" / "example_tilelang_gemm_fp8.py" +tilelang_gemm_file = local_dir / "tilelang" / "example_gemm.py" app = App(name="txl") # Note: this is optional since Modal 0.57 volume = Volume.from_name("txl-dump", create_if_missing=True) # create a cloud volume to store compiled dump files txl_image = ( - Image.debian_slim(python_version="3.12") + Image.from_registry( + "nvidia/cuda:12.4.0-devel-ubuntu22.04", + add_python="3.12", + ) #Image.from_dockerfile(path="./Dockerfile") .workdir("/workspace") .add_local_file(txl_wheel_file, remote_path="/workspace/", copy=True) # copy the local code to the image .run_commands( "ls .") .pip_install_from_requirements(requirements_file) # local file not remote file + .pip_install("tilelang") .run_commands( "pip install /workspace/txl-3.4.0-cp312-cp312-linux_x86_64.whl", ) .add_local_file(test_file, remote_path="/workspace/test_txl.py", copy=False) # copy after image build, no need rebuild + .add_local_file(tilelang_gemm_fp8_file, remote_path="/workspace/example_tilelang_gemm_fp8.py", copy=False) + .add_local_file(tilelang_gemm_file, remote_path="/workspace/example_gemm.py", copy=False) ) # Example function that uses the image @@ -100,6 +109,7 @@ def import_cuBLAS_lib(): ctypes.CDLL(str(cudart_dir / "libcudart.so.12"), mode=ctypes.RTLD_GLOBAL) ctypes.CDLL(str(cublas_dir / "libcublasLt.so.12"), mode=ctypes.RTLD_GLOBAL) ctypes.CDLL(str(cublas_dir / "libcublas.so.12"), mode=ctypes.RTLD_GLOBAL) - - import_cuBLAS_lib() - test_demo() \ No newline at end of file + if get_gpu_type(): + print("Running on H100 SXM") + import_cuBLAS_lib() + test_demo() \ No newline at end of file diff --git a/docker/tilelang/example_gemm.py b/docker/tilelang/example_gemm.py new file mode 100644 index 0000000..0affa91 --- /dev/null +++ b/docker/tilelang/example_gemm.py @@ -0,0 +1,63 @@ +import tilelang +import tilelang.language as T + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float", enable_rasteration=True): + + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local,transpose_B=True) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + print("c:") + print(c) + print("ref_c:") + print(ref_c) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # Get CUDA Source + print("CUDA Source:") + print(kernel.get_kernel_source()) + + # benchmark + profiler = kernel.get_profiler() + latency = profiler.do_bench(backend="cupti") + # latency = profiler.do_bench() + print(f"tilelang Latency: {latency}ms") + + +if __name__ == "__main__": + main() diff --git a/docker/tilelang/example_tilelang_gemm_fp8.py b/docker/tilelang/example_tilelang_gemm_fp8.py new file mode 100644 index 0000000..a47b213 --- /dev/null +++ b/docker/tilelang/example_tilelang_gemm_fp8.py @@ -0,0 +1,67 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float", enable_rasteration=True): + + @T.prim_func + def gemm_fp8( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=4): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_fp8 + + +def test_gemm_fp8(M, N, K, dtype): + torch_dtype = map_torch_type(dtype) + + kernel = matmul(M, N, K, 128, 128, 64, dtype) + + a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) + b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) + + c = kernel(a, b) + + ref_c = (a.half() @ b.half().T).to(dtype=torch_dtype) + + print(c) + print(ref_c) + + diff = calc_diff(c, ref_c) + print(f"diff: {diff}") + assert diff < 1e-3 + + +def main(): + test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3') + test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2') + + +if __name__ == "__main__": + main() diff --git a/python/txl/tutorials/01-matmul.py b/python/txl/tutorials/01-matmul.py index ad0d321..10a8794 100644 --- a/python/txl/tutorials/01-matmul.py +++ b/python/txl/tutorials/01-matmul.py @@ -34,6 +34,9 @@ from typing import Optional import txl +from example_gemm import matmul as tilelang_matmul +from example_tilelang_gemm_fp8 import matmul as tilelang_matmul_fp8 + if torch.cuda.is_available(): from triton._C.libtriton import nvidia cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) @@ -2009,27 +2012,49 @@ def torch_matmul(a, b): c = torch.matmul(a, b.T) return c +def matmul_tilelang(a, b): + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + if dtype == torch.float8_e4m3fn: + kernel = tilelang_matmul_fp8(M, N, K, 128, 256, 128, 'float8_e4m3') + else: + kernel = tilelang_matmul(M, N, K, 128, 256, 64) + + bytes_per_elem = a.element_size() + flops_str = f"flops{bytes_per_elem * 8}" + + dtype_str = str(dtype).split('.')[-1] + with proton.scope(f"matmul_tilelang_{dtype_str} [M={M}, N={N}, K={K}]", + {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): + c = kernel(a, b) + return c @contextmanager -def proton_context(): - proton.activate(0) +def proton_context(session): + proton.activate(session) try: yield finally: - proton.deactivate(0) + proton.deactivate(session) -def bench_fn(label, reps, warmup_reps, fn, *args): +def bench_fn(label, reps, warmup_reps, session, fn, *args): print(f"Benchmarking {label}: ...", end="") for _ in range(warmup_reps): fn(*args) - with proton_context(): + with proton_context(session): for _ in range(reps): fn(*args) print(f"\rBenchmarking {label}: done") -def bench(K, dtype, reps=100, warmup_reps=25): +def bench(K, dtype, session, reps=100, warmup_reps=25): M = 8192 N = 8192 a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) @@ -2041,15 +2066,18 @@ def bench(K, dtype, reps=100, warmup_reps=25): # bench_fn("cublas", reps, warmup_reps, cublas_matmul, a, b) #if dtype == torch.float16: # bench_fn("torch", reps, warmup_reps, torch_matmul, a, b) - #bench_fn("naive", reps, warmup_reps, matmul, a, b.T) + # bench_fn("naive", reps, warmup_reps, matmul, a, b.T) #bench_fn("persistent", reps, warmup_reps, matmul_persistent, a, b.T) - #bench_fn("tma", reps, warmup_reps, lambda a, b: matmul_tma(a, b, False), a, b) + # bench_fn("tma", reps, warmup_reps, lambda a, b: matmul_tma(a, b, False), a, b) #bench_fn("async_load", reps, warmup_reps, matmul, a, bn) #bench_fn("async_load_txl", reps, warmup_reps, matmul_async_load_txl, a, bn) #bench_fn("naive_tma_txl", reps, warmup_reps, matmul_naive_tma_txl, a, b) #0 #bench_fn("tma_persistent_txl", reps, warmup_reps, matmul_tma_persistent_txl, a, b) #1 - bench_fn("tma_ws_persistent_txl", reps, warmup_reps, matmul_tma_ws_persistent_txl, a, b) #2 + bench_fn("cublas", reps, warmup_reps, session, cublas_matmul, a, b) + bench_fn("tma_ws_persistent_txl", reps, warmup_reps, session, matmul_tma_ws_persistent_txl, a, b) #2 + bench_fn(f"tma_ws_persistent_triton", reps, warmup_reps, session, lambda a, b: matmul_tma_persistent(a, b, True), a, b) + bench_fn("tilelang", reps, warmup_reps, session, matmul_tilelang, a, b) #bench_fn("tma_ws_nn_persistent_txl", reps, warmup_reps, matmul_tma_ws_nn_persistent_txl, a, bn) return warp_specialize = [False, True] if HAS_WARP_SPECIALIZE else [False] @@ -2122,22 +2150,13 @@ def validate(M, N, K, dtype, log=False): print() -# def show_profile(precision, profile_name): -# import triton.profiler.viewer as proton_viewer -# metric_names = ["time/ms"] -# if precision == 'fp8': -# metric_names = ["tflop8/s"] + metric_names -# elif precision == 'fp16': -# metric_names = ["tflop16/s"] + metric_names -# file_name = f"{profile_name}.hatchet" -# tree, metrics = proton_viewer.parse(metric_names, file_name) -# proton_viewer.print_tree(tree, metrics) - -def show_profile(profile_name): +def show_profile(precision, profile_name): import triton.profiler.viewer as proton_viewer metric_names = ["time/ms"] - metric_names = ["tflop8/s"] + metric_names - metric_names = ["tflop16/s"] + metric_names + if precision == 'fp8': + metric_names = ["tflop8/s"] + metric_names + elif precision == 'fp16': + metric_names = ["tflop16/s"] + metric_names file_name = f"{profile_name}.hatchet" tree, metrics = proton_viewer.parse(metric_names, file_name) proton_viewer.print_tree(tree, metrics) @@ -2193,13 +2212,24 @@ def profile(M, N, K, dtype, log=False): # exit() # print(dtype) - proton.start("matmul", hook="triton") + session = proton.start("matmul_fp16", hook="triton") + proton.deactivate() + # for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + # bench(K, dtype) + for dtype in [torch.float16]: + for k_seqlen in [256, 512, 1024, 2048, 4096, 8192, 16384]: + bench(k_seqlen, dtype, session) + proton.finalize() + + show_profile("fp16","matmul_fp16") + + session = proton.start("matmul_fp8", hook="triton") proton.deactivate() # for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): # bench(K, dtype) - for dtype in [torch.float8_e4m3fn, torch.float16]: + for dtype in [torch.float8_e4m3fn]: for k_seqlen in [256, 512, 1024, 2048, 4096, 8192, 16384]: - bench(k_seqlen, dtype) + bench(k_seqlen, dtype, session) proton.finalize() - show_profile("matmul") + show_profile("fp8","matmul_fp8") diff --git a/requirements.txt b/requirements.txt index 07ec82a..9f1f00a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ pandas # for triton.testing llnl-hatchet # for proton profiling einops nvidia-cutlass-dsl +pytest