From cdf8fa290243e9d2f15c3699a905c106103d9e48 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 23 Jan 2026 13:22:33 +0000 Subject: [PATCH] issue/979 optimize paged attention --- .../ops/paged_attention/cuda/kernel_v2.cuh | 2085 +++++++++++++++ src/infiniop/ops/paged_attention/info.h | 149 +- .../nvidia/paged_attention_hd128.cu | 1024 +++++++ .../nvidia/paged_attention_hd64.cu | 524 ++++ .../nvidia/paged_attention_nvidia.cu | 425 ++- .../cuda/kernel_v2.cuh | 2361 +++++++++++++++++ .../ops/paged_attention_prefill/info.h | 166 +- .../nvidia/paged_attention_prefill_nvidia.cu | 1720 +++++++++++- test/infiniop/paged_attention.py | 3 +- test/infiniop/paged_attention_prefill.py | 3 +- 10 files changed, 8209 insertions(+), 251 deletions(-) create mode 100644 src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh create mode 100644 src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu create mode 100644 src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu create mode 100644 src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh diff --git a/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh b/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh new file mode 100644 index 000000000..e63dd68e2 --- /dev/null +++ b/src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh @@ -0,0 +1,2085 @@ +#ifndef __PAGED_ATTENTION_KERNEL_V2_CUH__ +#define __PAGED_ATTENTION_KERNEL_V2_CUH__ + +namespace op::paged_attention::cuda { + +struct OnlineSoftmaxState { + float m = -INFINITY; + float l = 0.0f; + + __device__ __forceinline__ void update(float x, float &alpha, float &beta) { + const float m_new = fmaxf(m, x); + alpha = expf(m - m_new); + beta = expf(x - m_new); + l = l * alpha + beta; + m = m_new; + } +}; +__device__ __forceinline__ float warpReduceSum(float x) { + for (int offset = 16; offset > 0; offset >>= 1) { + x += __shfl_down_sync(0xffffffff, x, offset); + } + return x; +} + +__device__ __forceinline__ float warpReduceMax(float x) { + for (int offset = 16; offset > 0; offset >>= 1) { + x = fmaxf(x, __shfl_down_sync(0xffffffff, x, offset)); + } + return x; +} + +__device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) { + return static_cast(__cvta_generic_to_shared(ptr)); +} + +__device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const void *src_global) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + const unsigned int dst = cvtaToShared(dst_shared); + asm volatile("cp.async.ca.shared.global [%0], [%1], 16;\n" ::"r"(dst), "l"(src_global)); +#else + auto *dst = reinterpret_cast(dst_shared); + const auto *src = reinterpret_cast(src_global); + *dst = *src; +#endif +} + +__device__ __forceinline__ void cpAsyncCommit() { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +template +__device__ __forceinline__ void cpAsyncWaitGroup() { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +// cp.async.wait_group requires a compile-time immediate, so for small fixed +// stage counts we provide a tiny runtime switch. +__device__ __forceinline__ void cpAsyncWaitGroupRt(int n) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + if (n <= 0) { + cpAsyncWaitGroup<0>(); + } else if (n == 1) { + cpAsyncWaitGroup<1>(); + } else { + // Clamp to 2 because v0.4 CTA kernel uses STAGES=3. + cpAsyncWaitGroup<2>(); + } +#else + (void)n; +#endif +} + +__device__ __forceinline__ void cpAsyncWaitAll() { + cpAsyncWaitGroup<0>(); +} + +template +__device__ void flashAttentionDecodeWarpKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int lane = threadIdx.x; + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0) { + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + + // q/out are [num_seqs, num_heads, head_size] + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __half22float2(q2[j]); + } + } + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __bfloat1622float2(q2[j]); + } + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + const int pbs = static_cast(page_block_size); + + // Iterate by blocks to avoid per-token division/mod and redundant block_table loads. + // Note: Per-token cp.async prefetching is generally too fine-grained for decode and can regress. + // We keep the warp kernel simple and reserve cp.async pipelining for CTA tile kernels. + int t_base = 0; + for (int logical_block = 0; t_base < seq_len; ++logical_block, t_base += pbs) { + int physical_block = 0; + if (lane == 0) { + physical_block = static_cast(block_table[logical_block]); + } + physical_block = __shfl_sync(0xffffffff, physical_block, 0); + + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int token_end = min(pbs, seq_len - t_base); + for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) { + const int t = t_base + token_in_block; + const Tdata *k_ptr = k_base + token_in_block * k_row_stride; + const Tdata *v_ptr = v_base + token_in_block * v_row_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + } + + qk = warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float o = acc[i] * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +// Split-KV decode (FA2-style): each split scans a shard of KV and writes partial (m, l, acc) +// to workspace, then a combine kernel merges splits into final out. +template +__device__ void flashAttentionDecodeSplitKvWarpKernel( + float *partial_acc, // [num_splits, num_seqs, num_heads, head_size] + float *partial_m, // [num_splits, num_seqs, num_heads] + float *partial_l, // [num_splits, num_seqs, num_heads] + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int split_idx = static_cast(blockIdx.z); + const int lane = threadIdx.x; + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0 || num_splits <= 0) { + return; + } + + // Split the [0, seq_len) range into num_splits contiguous shards. + const int shard = (seq_len + num_splits - 1) / num_splits; + const int start = split_idx * shard; + const int end = min(seq_len, start + shard); + if (start >= end) { + // Empty shard => write neutral element. + const int n = gridDim.y * gridDim.x; + const int idx = (split_idx * n + seq_idx * gridDim.x + head_idx); + if (lane == 0) { + partial_m[idx] = -INFINITY; + partial_l[idx] = 0.0f; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = 0.0f; + } + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __half22float2(q2[j]); + } + } + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __bfloat1622float2(q2[j]); + } + } +#endif + + float m = -INFINITY; + float l = 0.0f; + const int pbs = static_cast(page_block_size); + + // Scan only [start, end). + int t = start; + int logical_block = t / pbs; + int token_in_block = t - logical_block * pbs; + for (; t < end; ++logical_block) { + int physical_block = 0; + if (lane == 0) { + physical_block = static_cast(block_table[logical_block]); + } + physical_block = __shfl_sync(0xffffffff, physical_block, 0); + + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int token_end = min(pbs, end - logical_block * pbs); + for (; token_in_block < token_end && t < end; ++token_in_block, ++t) { + const Tdata *k_ptr = k_base + token_in_block * k_row_stride; + const Tdata *v_ptr = v_base + token_in_block * v_row_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + } + + qk = warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + token_in_block = 0; + } + + const int n = gridDim.y * gridDim.x; + const int idx = (split_idx * n + seq_idx * gridDim.x + head_idx); + if (lane == 0) { + partial_m[idx] = m; + partial_l[idx] = l; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = acc[i]; + } +} + +template +__device__ void flashAttentionDecodeSplitKvCombineWarpKernel( + Tdata *out_, + const float *partial_acc, // [num_splits, num_seqs, num_heads, head_size] + const float *partial_m, // [num_splits, num_seqs, num_heads] + const float *partial_l, // [num_splits, num_seqs, num_heads] + int num_splits, + ptrdiff_t o_stride) { + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int lane = threadIdx.x; + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int n = gridDim.y * gridDim.x; + const int base = (seq_idx * gridDim.x + head_idx); + + float m = -INFINITY; + if (lane == 0) { + for (int s = 0; s < num_splits; ++s) { + m = fmaxf(m, partial_m[s * n + base]); + } + } + m = __shfl_sync(0xffffffff, m, 0); + + float l = 0.0f; + if (lane == 0) { + for (int s = 0; s < num_splits; ++s) { + const float ms = partial_m[s * n + base]; + const float ls = partial_l[s * n + base]; + if (ls > 0.0f) { + l += ls * exp2f(ms - m); + } + } + } + l = __shfl_sync(0xffffffff, l, 0); + const float inv_l = 1.0f / (l + 1e-6f); + + // Combine acc for each dim. + Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + float acc = 0.0f; + for (int s = 0; s < num_splits; ++s) { + const float ms = partial_m[s * n + base]; + const float w = exp2f(ms - m); + acc += partial_acc[(s * n + base) * HEAD_SIZE + dim] * w; + } + const float o = acc * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +// Split-KV decode with a CTA tile kernel (FA2-style): each CTA scans a shard of KV, +// writes partial (m, l, acc) to workspace, then a combine kernel merges splits. +template +__device__ void flashAttentionDecodeSplitKvCtaKernel( + float *partial_acc, // [num_splits, num_seqs, num_heads, head_size] + float *partial_m, // [num_splits, num_seqs, num_heads] + float *partial_l, // [num_splits, num_seqs, num_heads] + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + + constexpr int kWarpSize = 32; + static_assert(CTA_THREADS % kWarpSize == 0, "CTA_THREADS must be a multiple of 32."); + static_assert(TOKENS_PER_TILE > 0 && TOKENS_PER_TILE <= 16, "TOKENS_PER_TILE should stay small."); + constexpr int NUM_WARPS = CTA_THREADS / kWarpSize; + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % CTA_THREADS == 0, "HEAD_SIZE must be divisible by CTA_THREADS."); + constexpr int kPack = HEAD_SIZE / CTA_THREADS; // 2 (64@32t, 128@64t) or 4 (128@32t) + static_assert(kPack == 2 || kPack == 4, "v0.4 split-kv CTA kernel supports kPack=2/4 only."); + constexpr int kPackedDims = CTA_THREADS; + constexpr int kComputeWarps = (kPackedDims + kWarpSize - 1) / kWarpSize; + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int split_idx = static_cast(blockIdx.z); + const int tid = threadIdx.x; + const int lane = tid % kWarpSize; + const int warp_id = tid / kWarpSize; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0 || num_splits <= 0) { + return; + } + + // Split the [0, seq_len) range into num_splits contiguous shards. + const int shard = (seq_len + num_splits - 1) / num_splits; + const int start = split_idx * shard; + const int end = min(seq_len, start + shard); + + const int n = gridDim.y * gridDim.x; + const int idx = (split_idx * n + seq_idx * gridDim.x + head_idx); + + if (start >= end) { + // Empty shard => write neutral element. + if (tid == 0) { + partial_m[idx] = -INFINITY; + partial_l[idx] = 0.0f; + } + const int dim = tid * kPack; + if constexpr (kPack == 2) { + partial_acc[idx * HEAD_SIZE + dim + 0] = 0.0f; + partial_acc[idx * HEAD_SIZE + dim + 1] = 0.0f; + } else { + partial_acc[idx * HEAD_SIZE + dim + 0] = 0.0f; + partial_acc[idx * HEAD_SIZE + dim + 1] = 0.0f; + partial_acc[idx * HEAD_SIZE + dim + 2] = 0.0f; + partial_acc[idx * HEAD_SIZE + dim + 3] = 0.0f; + } + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + + const int dim = tid * kPack; + float q0 = 0.0f, q1 = 0.0f, q2 = 0.0f, q3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 qh2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __half22float2(qh2); + q0 = qf.x; + q1 = qf.y; + } else { + const half2 qh2_0 = *reinterpret_cast(q_ptr + dim + 0); + const half2 qh2_1 = *reinterpret_cast(q_ptr + dim + 2); + const float2 qf0 = __half22float2(qh2_0); + const float2 qf1 = __half22float2(qh2_1); + q0 = qf0.x; + q1 = qf0.y; + q2 = qf1.x; + q3 = qf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 qb2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __bfloat1622float2(qb2); + q0 = qf.x; + q1 = qf.y; + } else { + const __nv_bfloat162 qb2_0 = *reinterpret_cast(q_ptr + dim + 0); + const __nv_bfloat162 qb2_1 = *reinterpret_cast(q_ptr + dim + 2); + const float2 qf0 = __bfloat1622float2(qb2_0); + const float2 qf1 = __bfloat1622float2(qb2_1); + q0 = qf0.x; + q1 = qf0.y; + q2 = qf1.x; + q3 = qf1.y; + } + } else +#endif + { + q0 = static_cast(q_ptr[dim + 0]); + q1 = static_cast(q_ptr[dim + 1]); + if constexpr (kPack == 4) { + q2 = static_cast(q_ptr[dim + 2]); + q3 = static_cast(q_ptr[dim + 3]); + } + } + + float acc0 = 0.0f, acc1 = 0.0f, acc2 = 0.0f, acc3 = 0.0f; + + float m = -INFINITY; + float l = 0.0f; + + __shared__ float warp_sums[TOKENS_PER_TILE][kComputeWarps]; + __shared__ float alpha_shared; + __shared__ float weights_shared[TOKENS_PER_TILE]; + + const int pbs = static_cast(page_block_size); + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + static_assert(sizeof(Tdata) == 2, "CTA split-kv kernel assumes fp16/bf16."); + constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes. + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + constexpr int STAGES = 3; + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + + const int first_block = start / pbs; + const int last_block = (end - 1) / pbs; + + for (int logical_block = first_block; logical_block <= last_block; ++logical_block) { + const int physical_block = static_cast(block_table[logical_block]); + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int t_base = logical_block * pbs; + const int token_begin = (logical_block == first_block) ? (start - t_base) : 0; + const int token_end = (logical_block == last_block) ? (end - t_base) : pbs; + const int token_count = token_end - token_begin; + if (token_count <= 0) { + continue; + } + + const int num_tiles = (token_count + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = token_begin + ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + (token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_in_block + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = token_begin + tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + + float partial[TOKENS_PER_TILE]; +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + if (j < tile_n) { + float k0 = 0.0f, k1 = 0.0f, k2 = 0.0f, k3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 kh2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __half22float2(kh2); + k0 = kf.x; + k1 = kf.y; + } else { + const half2 kh2_0 = *reinterpret_cast(&sh_k[buf][j][dim + 0]); + const half2 kh2_1 = *reinterpret_cast(&sh_k[buf][j][dim + 2]); + const float2 kf0 = __half22float2(kh2_0); + const float2 kf1 = __half22float2(kh2_1); + k0 = kf0.x; + k1 = kf0.y; + k2 = kf1.x; + k3 = kf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 kb2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __bfloat1622float2(kb2); + k0 = kf.x; + k1 = kf.y; + } else { + const __nv_bfloat162 kb2_0 = *reinterpret_cast(&sh_k[buf][j][dim + 0]); + const __nv_bfloat162 kb2_1 = *reinterpret_cast(&sh_k[buf][j][dim + 2]); + const float2 kf0 = __bfloat1622float2(kb2_0); + const float2 kf1 = __bfloat1622float2(kb2_1); + k0 = kf0.x; + k1 = kf0.y; + k2 = kf1.x; + k3 = kf1.y; + } + } else +#endif + { + k0 = static_cast(sh_k[buf][j][dim + 0]); + k1 = static_cast(sh_k[buf][j][dim + 1]); + if constexpr (kPack == 4) { + k2 = static_cast(sh_k[buf][j][dim + 2]); + k3 = static_cast(sh_k[buf][j][dim + 3]); + } + } + if constexpr (kPack == 2) { + partial[j] = fmaf(q0, k0, q1 * k1); + } else { + partial[j] = fmaf(q0, k0, fmaf(q1, k1, fmaf(q2, k2, q3 * k3))); + } + } else { + partial[j] = 0.0f; + } + } + +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + const float sum = warpReduceSum(partial[j]); + if (lane == 0 && warp_id < kComputeWarps) { + warp_sums[j][warp_id] = sum; + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + if (warp_id == 0) { + float score = -INFINITY; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + float qk = 0.0f; +#pragma unroll + for (int w = 0; w < kComputeWarps; ++w) { + qk += warp_sums[lane][w]; + } + const int t = t_base + token_in_block + lane; + score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + } + + float tile_max = warpReduceMax(score); + tile_max = __shfl_sync(0xffffffff, tile_max, 0); + + float m_new = 0.0f; + if (lane == 0) { + m_new = fmaxf(m, tile_max); + } + m_new = __shfl_sync(0xffffffff, m_new, 0); + + float w = 0.0f; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + w = exp2f(score - m_new); + } + if (lane < TOKENS_PER_TILE) { + weights_shared[lane] = (lane < tile_n) ? w : 0.0f; + } + + const float tile_sum = warpReduceSum(w); + if (lane == 0) { + const float alpha = exp2f(m - m_new); + alpha_shared = alpha; + l = l * alpha + tile_sum; + m = m_new; + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + const float alpha = alpha_shared; + float sum_wv0 = 0.0f, sum_wv1 = 0.0f, sum_wv2 = 0.0f, sum_wv3 = 0.0f; +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + const float w = weights_shared[j]; + float v0 = 0.0f, v1 = 0.0f, v2 = 0.0f, v3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 vh2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __half22float2(vh2); + v0 = vf.x; + v1 = vf.y; + } else { + const half2 vh2_0 = *reinterpret_cast(&sh_v[buf][j][dim + 0]); + const half2 vh2_1 = *reinterpret_cast(&sh_v[buf][j][dim + 2]); + const float2 vf0 = __half22float2(vh2_0); + const float2 vf1 = __half22float2(vh2_1); + v0 = vf0.x; + v1 = vf0.y; + v2 = vf1.x; + v3 = vf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 vb2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __bfloat1622float2(vb2); + v0 = vf.x; + v1 = vf.y; + } else { + const __nv_bfloat162 vb2_0 = *reinterpret_cast(&sh_v[buf][j][dim + 0]); + const __nv_bfloat162 vb2_1 = *reinterpret_cast(&sh_v[buf][j][dim + 2]); + const float2 vf0 = __bfloat1622float2(vb2_0); + const float2 vf1 = __bfloat1622float2(vb2_1); + v0 = vf0.x; + v1 = vf0.y; + v2 = vf1.x; + v3 = vf1.y; + } + } else +#endif + { + v0 = static_cast(sh_v[buf][j][dim + 0]); + v1 = static_cast(sh_v[buf][j][dim + 1]); + if constexpr (kPack == 4) { + v2 = static_cast(sh_v[buf][j][dim + 2]); + v3 = static_cast(sh_v[buf][j][dim + 3]); + } + } + sum_wv0 = fmaf(w, v0, sum_wv0); + sum_wv1 = fmaf(w, v1, sum_wv1); + if constexpr (kPack == 4) { + sum_wv2 = fmaf(w, v2, sum_wv2); + sum_wv3 = fmaf(w, v3, sum_wv3); + } + } + acc0 = acc0 * alpha + sum_wv0; + acc1 = acc1 * alpha + sum_wv1; + if constexpr (kPack == 4) { + acc2 = acc2 * alpha + sum_wv2; + acc3 = acc3 * alpha + sum_wv3; + } + + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = token_begin + prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + (token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_prefetch + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + int desired_pending2 = pending_groups - 1; + if (desired_pending2 < 0) { + desired_pending2 = 0; + } + if (desired_pending2 > (STAGES - 1)) { + desired_pending2 = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending2); + pending_groups = desired_pending2; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + } + + cpAsyncWaitAll(); + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + + if (tid == 0) { + partial_m[idx] = m; + partial_l[idx] = l; + } + if constexpr (kPack == 2) { + partial_acc[idx * HEAD_SIZE + dim + 0] = acc0; + partial_acc[idx * HEAD_SIZE + dim + 1] = acc1; + } else { + partial_acc[idx * HEAD_SIZE + dim + 0] = acc0; + partial_acc[idx * HEAD_SIZE + dim + 1] = acc1; + partial_acc[idx * HEAD_SIZE + dim + 2] = acc2; + partial_acc[idx * HEAD_SIZE + dim + 3] = acc3; + } +} + +template +__device__ void flashAttentionDecodeCtaPipelinedKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int NUM_WARPS = HEAD_SIZE / kWarpSize; + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int tid = threadIdx.x; + const int lane = tid % kWarpSize; + const int warp_id = tid / kWarpSize; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0) { + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; + + const float q_val = static_cast(q_ptr[tid]); + float acc = 0.0f; + + float m = -INFINITY; + float l = 0.0f; + + __shared__ Tdata sh_k[2][HEAD_SIZE]; + __shared__ Tdata sh_v[2][HEAD_SIZE]; + __shared__ float warp_sums[NUM_WARPS]; + __shared__ float alpha_s; + __shared__ float beta_s; + __shared__ int physical_block_s; + constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes. + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + + const int pbs = static_cast(page_block_size); + + // Prefetch the very first token. + int buf = 0; + int t_base = 0; + int token_in_block = 0; + int logical_block = 0; + { + if (tid == 0) { + physical_block_s = static_cast(block_table[0]); + } + __syncthreads(); + const Tdata *k_base = k_cache_ + physical_block_s * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block_s * v_batch_stride + kv_head_idx * v_head_stride; + if (tid < CHUNKS) { + const int off = tid * CHUNK_ELEMS; + cpAsyncCaSharedGlobal16(&sh_k[buf][off], (k_base + 0 * k_row_stride) + off); + cpAsyncCaSharedGlobal16(&sh_v[buf][off], (v_base + 0 * v_row_stride) + off); + } + cpAsyncCommit(); + cpAsyncWaitAll(); + __syncthreads(); + } + + for (int t = 0; t < seq_len; ++t) { + // Compute current token location within paged KV. + const int next_t = t + 1; + const bool has_next = next_t < seq_len; + + if (has_next) { + const int next_block = next_t / pbs; + const int next_in_block = next_t - next_block * pbs; + if (next_block != logical_block) { + logical_block = next_block; + if (tid == 0) { + physical_block_s = static_cast(block_table[logical_block]); + } + __syncthreads(); + } + + const Tdata *k_base = k_cache_ + physical_block_s * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block_s * v_batch_stride + kv_head_idx * v_head_stride; + const Tdata *k_src = k_base + next_in_block * k_row_stride; + const Tdata *v_src = v_base + next_in_block * v_row_stride; + if (tid < CHUNKS) { + const int off = tid * CHUNK_ELEMS; + cpAsyncCaSharedGlobal16(&sh_k[buf ^ 1][off], k_src + off); + cpAsyncCaSharedGlobal16(&sh_v[buf ^ 1][off], v_src + off); + } + cpAsyncCommit(); + } + + // Dot: each thread handles one dim, reduce across head dim. + const float k_val = static_cast(sh_k[buf][tid]); + float partial = q_val * k_val; + float warp_sum = warpReduceSum(partial); + if (lane == 0) { + warp_sums[warp_id] = warp_sum; + } + __syncthreads(); + + float qk = 0.0f; + if (warp_id == 0) { + float v = (lane < NUM_WARPS) ? warp_sums[lane] : 0.0f; + v = warpReduceSum(v); + if (lane == 0) { + qk = v; + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + const float alpha = exp2f(m - m_new); + const float beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + alpha_s = alpha; + beta_s = beta; + } + } + __syncthreads(); + + const float alpha = alpha_s; + const float beta = beta_s; + const float v_val = static_cast(sh_v[buf][tid]); + acc = acc * alpha + beta * v_val; + + if (has_next) { + cpAsyncWaitAll(); + __syncthreads(); + buf ^= 1; + } + } + + __shared__ float inv_l_s; + if (tid == 0) { + inv_l_s = 1.0f / (l + 1e-6f); + } + __syncthreads(); + out_ptr[tid] = static_cast(acc * inv_l_s); +} + +template +__device__ void flashAttentionDecodeCtaKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + + constexpr int kWarpSize = 32; + static_assert(CTA_THREADS % kWarpSize == 0, "CTA_THREADS must be a multiple of 32."); + static_assert(TOKENS_PER_TILE > 0 && TOKENS_PER_TILE <= 16, "TOKENS_PER_TILE should stay small."); + constexpr int NUM_WARPS = CTA_THREADS / kWarpSize; + + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int tid = threadIdx.x; + const int lane = tid % kWarpSize; + const int warp_id = tid / kWarpSize; + + // Each thread owns a small packed vector of head dims. This lets us shrink the + // CTA to 1-2 warps and reduce block-wide synchronization overhead. + static_assert(HEAD_SIZE % CTA_THREADS == 0, "HEAD_SIZE must be divisible by CTA_THREADS."); + constexpr int kPack = HEAD_SIZE / CTA_THREADS; // 2 (64@32t, 128@64t) or 4 (128@32t) + static_assert(kPack == 2 || kPack == 4, "v0.4 CTA tile kernel supports kPack=2/4 only."); + constexpr int kPackedDims = CTA_THREADS; + constexpr int kComputeWarps = (kPackedDims + kWarpSize - 1) / kWarpSize; + const int dim = tid * kPack; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0) { + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + + // q/out are [num_seqs, num_heads, head_size] + const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; + + float q0 = 0.0f; + float q1 = 0.0f; + float q2 = 0.0f; + float q3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 qh2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __half22float2(qh2); + q0 = qf.x; + q1 = qf.y; + } else { + const half2 qh2_0 = *reinterpret_cast(q_ptr + dim + 0); + const half2 qh2_1 = *reinterpret_cast(q_ptr + dim + 2); + const float2 qf0 = __half22float2(qh2_0); + const float2 qf1 = __half22float2(qh2_1); + q0 = qf0.x; + q1 = qf0.y; + q2 = qf1.x; + q3 = qf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 qb2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __bfloat1622float2(qb2); + q0 = qf.x; + q1 = qf.y; + } else { + const __nv_bfloat162 qb2_0 = *reinterpret_cast(q_ptr + dim + 0); + const __nv_bfloat162 qb2_1 = *reinterpret_cast(q_ptr + dim + 2); + const float2 qf0 = __bfloat1622float2(qb2_0); + const float2 qf1 = __bfloat1622float2(qb2_1); + q0 = qf0.x; + q1 = qf0.y; + q2 = qf1.x; + q3 = qf1.y; + } + } else +#endif + { + q0 = static_cast(q_ptr[dim + 0]); + q1 = static_cast(q_ptr[dim + 1]); + if constexpr (kPack == 4) { + q2 = static_cast(q_ptr[dim + 2]); + q3 = static_cast(q_ptr[dim + 3]); + } + } + + float acc0 = 0.0f; + float acc1 = 0.0f; + float acc2 = 0.0f; + float acc3 = 0.0f; + + float m = -INFINITY; + float l = 0.0f; + + // Only the compute warps contribute QK partial sums. Keeping this array + // compact reduces shared-memory traffic and bank pressure. + __shared__ float warp_sums[TOKENS_PER_TILE][kComputeWarps]; + __shared__ float alpha_shared; + __shared__ float weights_shared[TOKENS_PER_TILE]; + + const int pbs = static_cast(page_block_size); + + static_assert(sizeof(Tdata) == 2, "CTA tile kernel assumes 16B chunks map to 8 elements for fp16/bf16."); + constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes. + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + // Multi-stage cp.async pipeline. Using >= 3 stages allows us to keep + // multiple groups in-flight and overlap global->shared copies with compute. + constexpr int STAGES = 3; + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int t_base = 0; + for (int logical_block = 0; t_base < seq_len; ++logical_block, t_base += pbs) { + const int physical_block = static_cast(block_table[logical_block]); + + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int token_end = min(pbs, seq_len - t_base); + const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + if (num_tiles <= 0) { + continue; + } + + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + (token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_in_block + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + // Ensure tile 0 is ready. We want to keep up to (STAGES - 1) groups + // in flight for overlap, but still make forward progress in the tail + // when we stop issuing new prefetch groups. + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + + float partial[TOKENS_PER_TILE]; +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + if (j < tile_n) { + float k0 = 0.0f; + float k1 = 0.0f; + float k2 = 0.0f; + float k3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 kh2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __half22float2(kh2); + k0 = kf.x; + k1 = kf.y; + } else { + const half2 kh2_0 = *reinterpret_cast(&sh_k[buf][j][dim + 0]); + const half2 kh2_1 = *reinterpret_cast(&sh_k[buf][j][dim + 2]); + const float2 kf0 = __half22float2(kh2_0); + const float2 kf1 = __half22float2(kh2_1); + k0 = kf0.x; + k1 = kf0.y; + k2 = kf1.x; + k3 = kf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 kb2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __bfloat1622float2(kb2); + k0 = kf.x; + k1 = kf.y; + } else { + const __nv_bfloat162 kb2_0 = *reinterpret_cast(&sh_k[buf][j][dim + 0]); + const __nv_bfloat162 kb2_1 = *reinterpret_cast(&sh_k[buf][j][dim + 2]); + const float2 kf0 = __bfloat1622float2(kb2_0); + const float2 kf1 = __bfloat1622float2(kb2_1); + k0 = kf0.x; + k1 = kf0.y; + k2 = kf1.x; + k3 = kf1.y; + } + } else +#endif + { + k0 = static_cast(sh_k[buf][j][dim + 0]); + k1 = static_cast(sh_k[buf][j][dim + 1]); + if constexpr (kPack == 4) { + k2 = static_cast(sh_k[buf][j][dim + 2]); + k3 = static_cast(sh_k[buf][j][dim + 3]); + } + } + if constexpr (kPack == 2) { + partial[j] = fmaf(q0, k0, q1 * k1); + } else { + partial[j] = fmaf(q0, k0, fmaf(q1, k1, fmaf(q2, k2, q3 * k3))); + } + } else { + partial[j] = 0.0f; + } + } + +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + float sum = warpReduceSum(partial[j]); + // Only compute warps contribute to qk; load-only warps would + // otherwise write zeros and increase reduction overhead. + if (lane == 0 && warp_id < kComputeWarps) { + warp_sums[j][warp_id] = sum; + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + if (warp_id == 0) { + // Distribute token-wise score computation across lanes to avoid + // serial loops in lane0. TOKENS_PER_TILE <= 16 by construction. + float score = -INFINITY; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + float qk = 0.0f; +#pragma unroll + for (int w = 0; w < kComputeWarps; ++w) { + qk += warp_sums[lane][w]; + } + const int t = t_base + token_in_block + lane; + score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(t - (seq_len - 1))) * kLog2e; + } + } + + float tile_max = warpReduceMax(score); + tile_max = __shfl_sync(0xffffffff, tile_max, 0); + + float m_new = 0.0f; + if (lane == 0) { + m_new = fmaxf(m, tile_max); + } + m_new = __shfl_sync(0xffffffff, m_new, 0); + + float w = 0.0f; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + w = exp2f(score - m_new); + } + + if (lane < TOKENS_PER_TILE) { + weights_shared[lane] = (lane < tile_n) ? w : 0.0f; + } + + float tile_sum = warpReduceSum(w); + if (lane == 0) { + const float alpha = exp2f(m - m_new); + alpha_shared = alpha; + l = l * alpha + tile_sum; + m = m_new; + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + const float alpha = alpha_shared; + float sum_wv0 = 0.0f; + float sum_wv1 = 0.0f; + float sum_wv2 = 0.0f; + float sum_wv3 = 0.0f; +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + const float w = weights_shared[j]; + float v0 = 0.0f; + float v1 = 0.0f; + float v2 = 0.0f; + float v3 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const half2 vh2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __half22float2(vh2); + v0 = vf.x; + v1 = vf.y; + } else { + const half2 vh2_0 = *reinterpret_cast(&sh_v[buf][j][dim + 0]); + const half2 vh2_1 = *reinterpret_cast(&sh_v[buf][j][dim + 2]); + const float2 vf0 = __half22float2(vh2_0); + const float2 vf1 = __half22float2(vh2_1); + v0 = vf0.x; + v1 = vf0.y; + v2 = vf1.x; + v3 = vf1.y; + } + } else if constexpr (std::is_same_v) { + if constexpr (kPack == 2) { + const __nv_bfloat162 vb2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __bfloat1622float2(vb2); + v0 = vf.x; + v1 = vf.y; + } else { + const __nv_bfloat162 vb2_0 = *reinterpret_cast(&sh_v[buf][j][dim + 0]); + const __nv_bfloat162 vb2_1 = *reinterpret_cast(&sh_v[buf][j][dim + 2]); + const float2 vf0 = __bfloat1622float2(vb2_0); + const float2 vf1 = __bfloat1622float2(vb2_1); + v0 = vf0.x; + v1 = vf0.y; + v2 = vf1.x; + v3 = vf1.y; + } + } else +#endif + { + v0 = static_cast(sh_v[buf][j][dim + 0]); + v1 = static_cast(sh_v[buf][j][dim + 1]); + if constexpr (kPack == 4) { + v2 = static_cast(sh_v[buf][j][dim + 2]); + v3 = static_cast(sh_v[buf][j][dim + 3]); + } + } + sum_wv0 = fmaf(w, v0, sum_wv0); + sum_wv1 = fmaf(w, v1, sum_wv1); + if constexpr (kPack == 4) { + sum_wv2 = fmaf(w, v2, sum_wv2); + sum_wv3 = fmaf(w, v3, sum_wv3); + } + } + acc0 = acc0 * alpha + sum_wv0; + acc1 = acc1 * alpha + sum_wv1; + if constexpr (kPack == 4) { + acc2 = acc2 * alpha + sum_wv2; + acc3 = acc3 * alpha + sum_wv3; + } + + // Prefetch the tile that will reuse this buffer (STAGES steps ahead). + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + (token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_prefetch + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + // Before consuming the next tile, ensure at least one group + // completes. In steady state we keep (STAGES - 1) in flight; in + // the tail (no more prefetches) we gradually drain. + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + } + + // Drain any in-flight async copies before moving to the next paged block. + cpAsyncWaitAll(); + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + + __shared__ float inv_l_shared; + if (tid == 0) { + inv_l_shared = 1.0f / (l + 1e-6f); + } + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + { + const float s = inv_l_shared; + const float o0 = acc0 * s; + const float o1 = acc1 * s; + const float o2 = acc2 * s; + const float o3 = acc3 * s; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + out_ptr[dim + 0] = __float2half_rn(o0); + out_ptr[dim + 1] = __float2half_rn(o1); + if constexpr (kPack == 4) { + out_ptr[dim + 2] = __float2half_rn(o2); + out_ptr[dim + 3] = __float2half_rn(o3); + } + } else if constexpr (std::is_same_v) { + out_ptr[dim + 0] = __float2bfloat16_rn(o0); + out_ptr[dim + 1] = __float2bfloat16_rn(o1); + if constexpr (kPack == 4) { + out_ptr[dim + 2] = __float2bfloat16_rn(o2); + out_ptr[dim + 3] = __float2bfloat16_rn(o3); + } + } else +#endif + { + out_ptr[dim + 0] = static_cast(o0); + out_ptr[dim + 1] = static_cast(o1); + if constexpr (kPack == 4) { + out_ptr[dim + 2] = static_cast(o2); + out_ptr[dim + 3] = static_cast(o3); + } + } + } +} + +// GQA/MQA fused decode kernel: one CTA computes outputs for NGROUPS query heads that +// share the same KV head. This reduces redundant K/V reads when num_heads > num_kv_heads. +// +// v0.4: implemented for head_dim=128 and NGROUPS=4 (common case: 32 Q heads / 8 KV heads). +template +__device__ void flashAttentionDecodeCtaGqaKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const Tindex *cache_lens_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 128, "v0.4 GQA fused CTA kernel is implemented for head_size=128 only."); + static_assert(NGROUPS == 4, "v0.4 GQA fused CTA kernel is implemented for NGROUPS=4 only."); + static_assert(CTA_THREADS % kWarpSize == 0, "CTA_THREADS must be a multiple of 32."); + static_assert(TOKENS_PER_TILE > 0 && TOKENS_PER_TILE <= 16, "TOKENS_PER_TILE should stay small."); + constexpr int NUM_WARPS = CTA_THREADS / kWarpSize; + + // Pack dims per thread. For head_dim=128 and CTA_THREADS=64, kPack=2. + static_assert(HEAD_SIZE % CTA_THREADS == 0, "HEAD_SIZE must be divisible by CTA_THREADS."); + constexpr int kPack = HEAD_SIZE / CTA_THREADS; + static_assert(kPack == 2, "v0.4 GQA fused CTA kernel expects kPack=2."); + constexpr int kPackedDims = CTA_THREADS; + constexpr int kComputeWarps = (kPackedDims + kWarpSize - 1) / kWarpSize; + + const int seq_idx = blockIdx.y; + const int kv_head_idx = blockIdx.x; + const int tid = threadIdx.x; + const int lane = tid % kWarpSize; + const int warp_id = tid / kWarpSize; + const int dim = tid * kPack; + + const int seq_len = static_cast(cache_lens_[seq_idx]); + if (seq_len <= 0) { + return; + } + + // v0.4 limitation: alibi slopes are per query head; support can be added later. + if (alibi_slopes_ != nullptr) { + return; + } + + const Tindex *block_table = block_tables_ + seq_idx * static_cast(max_num_blocks_per_seq); + + // q/out are [num_seqs, num_heads, head_size]. For a KV head, we handle NGROUPS query heads: + // q_head = kv_head * NGROUPS + g + float q0[NGROUPS]; + float q1[NGROUPS]; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const int q_head = kv_head_idx * NGROUPS + g; + const Tdata *q_ptr = q_ + seq_idx * q_stride + q_head * HEAD_SIZE; + const half2 qh2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __half22float2(qh2); + q0[g] = qf.x; + q1[g] = qf.y; + } + } else if constexpr (std::is_same_v) { +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const int q_head = kv_head_idx * NGROUPS + g; + const Tdata *q_ptr = q_ + seq_idx * q_stride + q_head * HEAD_SIZE; + const __nv_bfloat162 qb2 = *reinterpret_cast(q_ptr + dim); + const float2 qf = __bfloat1622float2(qb2); + q0[g] = qf.x; + q1[g] = qf.y; + } + } else +#endif + { +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const int q_head = kv_head_idx * NGROUPS + g; + const Tdata *q_ptr = q_ + seq_idx * q_stride + q_head * HEAD_SIZE; + q0[g] = static_cast(q_ptr[dim + 0]); + q1[g] = static_cast(q_ptr[dim + 1]); + } + } + + float acc0[NGROUPS]; + float acc1[NGROUPS]; + float m[NGROUPS]; + float l[NGROUPS]; +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + acc0[g] = 0.0f; + acc1[g] = 0.0f; + m[g] = -INFINITY; + l[g] = 0.0f; + } + + __shared__ float warp_sums[NGROUPS][TOKENS_PER_TILE][kComputeWarps]; + __shared__ float alpha_shared[NGROUPS]; + __shared__ float weights_shared[NGROUPS][TOKENS_PER_TILE]; + + const int pbs = static_cast(page_block_size); + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + static_assert(sizeof(Tdata) == 2, "CTA GQA kernel assumes fp16/bf16."); + constexpr int CHUNK_ELEMS = 8; // 8 * 2 bytes = 16 bytes. + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + constexpr int STAGES = 3; + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + + int t_base = 0; + for (int logical_block = 0; t_base < seq_len; ++logical_block, t_base += pbs) { + const int physical_block = static_cast(block_table[logical_block]); + + const Tdata *k_base = k_cache_ + physical_block * k_batch_stride + kv_head_idx * k_head_stride; + const Tdata *v_base = v_cache_ + physical_block * v_batch_stride + kv_head_idx * v_head_stride; + + const int token_end = min(pbs, seq_len - t_base); + const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + if (num_tiles <= 0) { + continue; + } + + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + (token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_in_block + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + + // Compute QK partial sums for each group and each token in the tile. + float partial_qk[NGROUPS][TOKENS_PER_TILE]; +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + if (j < tile_n) { + float k0 = 0.0f; + float k1 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const half2 kh2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __half22float2(kh2); + k0 = kf.x; + k1 = kf.y; + } else if constexpr (std::is_same_v) { + const __nv_bfloat162 kb2 = *reinterpret_cast(&sh_k[buf][j][dim]); + const float2 kf = __bfloat1622float2(kb2); + k0 = kf.x; + k1 = kf.y; + } else +#endif + { + k0 = static_cast(sh_k[buf][j][dim + 0]); + k1 = static_cast(sh_k[buf][j][dim + 1]); + } + partial_qk[g][j] = fmaf(q0[g], k0, q1[g] * k1); + } else { + partial_qk[g][j] = 0.0f; + } + } + } + +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + const float sum = warpReduceSum(partial_qk[g][j]); + if (lane == 0 && warp_id < kComputeWarps) { + warp_sums[g][j][warp_id] = sum; + } + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + if (warp_id == 0) { +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + float score = -INFINITY; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + float qk = 0.0f; +#pragma unroll + for (int w = 0; w < kComputeWarps; ++w) { + qk += warp_sums[g][lane][w]; + } + score = qk * scale_log2; + } + + float tile_max = warpReduceMax(score); + tile_max = __shfl_sync(0xffffffff, tile_max, 0); + + float m_new = 0.0f; + if (lane == 0) { + m_new = fmaxf(m[g], tile_max); + } + m_new = __shfl_sync(0xffffffff, m_new, 0); + + float w = 0.0f; + if (lane < TOKENS_PER_TILE && lane < tile_n) { + w = exp2f(score - m_new); + } + if (lane < TOKENS_PER_TILE) { + weights_shared[g][lane] = (lane < tile_n) ? w : 0.0f; + } + + const float tile_sum = warpReduceSum(w); + if (lane == 0) { + const float alpha = exp2f(m[g] - m_new); + alpha_shared[g] = alpha; + l[g] = l[g] * alpha + tile_sum; + m[g] = m_new; + } + } + } + + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + + float alpha[NGROUPS]; + float sum_wv0[NGROUPS]; + float sum_wv1[NGROUPS]; +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + alpha[g] = alpha_shared[g]; + sum_wv0[g] = 0.0f; + sum_wv1[g] = 0.0f; + } + +#pragma unroll + for (int j = 0; j < TOKENS_PER_TILE; ++j) { + float v0 = 0.0f; + float v1 = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const half2 vh2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __half22float2(vh2); + v0 = vf.x; + v1 = vf.y; + } else if constexpr (std::is_same_v) { + const __nv_bfloat162 vb2 = *reinterpret_cast(&sh_v[buf][j][dim]); + const float2 vf = __bfloat1622float2(vb2); + v0 = vf.x; + v1 = vf.y; + } else +#endif + { + v0 = static_cast(sh_v[buf][j][dim + 0]); + v1 = static_cast(sh_v[buf][j][dim + 1]); + } + +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const float w = weights_shared[g][j]; + sum_wv0[g] = fmaf(w, v0, sum_wv0[g]); + sum_wv1[g] = fmaf(w, v1, sum_wv1[g]); + } + } + +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + acc0[g] = acc0[g] * alpha[g] + sum_wv0[g]; + acc1[g] = acc1[g] * alpha[g] + sum_wv1[g]; + } + + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += CTA_THREADS) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + (token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + (token_prefetch + tok) * v_row_stride + off; + cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + int desired_pending2 = pending_groups - 1; + if (desired_pending2 < 0) { + desired_pending2 = 0; + } + if (desired_pending2 > (STAGES - 1)) { + desired_pending2 = (STAGES - 1); + } + cpAsyncWaitGroupRt(desired_pending2); + pending_groups = desired_pending2; + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + } + + cpAsyncWaitAll(); + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + } + + // Write outputs for each group. + __shared__ float inv_l_shared[NGROUPS]; + if (tid < NGROUPS) { + inv_l_shared[tid] = 1.0f / (l[tid] + 1e-6f); + } + if constexpr (NUM_WARPS == 1) { + __syncwarp(); + } else { + __syncthreads(); + } + +#pragma unroll + for (int g = 0; g < NGROUPS; ++g) { + const int q_head = kv_head_idx * NGROUPS + g; + Tdata *out_ptr = out_ + seq_idx * o_stride + q_head * HEAD_SIZE; + const float s = inv_l_shared[g]; + const float o0 = acc0[g] * s; + const float o1 = acc1[g] * s; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + out_ptr[dim + 0] = __float2half_rn(o0); + out_ptr[dim + 1] = __float2half_rn(o1); + } else if constexpr (std::is_same_v) { + out_ptr[dim + 0] = __float2bfloat16_rn(o0); + out_ptr[dim + 1] = __float2bfloat16_rn(o1); + } else +#endif + { + out_ptr[dim + 0] = static_cast(o0); + out_ptr[dim + 1] = static_cast(o1); + } + } +} +} // namespace op::paged_attention::cuda + +#endif // __PAGED_ATTENTION_KERNEL_V2_CUH__ diff --git a/src/infiniop/ops/paged_attention/info.h b/src/infiniop/ops/paged_attention/info.h index 216bb2360..4b840af69 100644 --- a/src/infiniop/ops/paged_attention/info.h +++ b/src/infiniop/ops/paged_attention/info.h @@ -13,92 +13,171 @@ class PagedAttentionInfo { PagedAttentionInfo() = default; public: - // --- Data Types and Scale --- infiniDtype_t dtype; + infiniDtype_t index_dtype; float scale; - // --- Shape Dimensions --- size_t num_seqs; size_t num_heads; size_t num_kv_heads; size_t head_size; - size_t block_size; + size_t page_block_size; size_t max_num_blocks_per_seq; - // --- Strides for Memory Layout --- ptrdiff_t q_stride; - ptrdiff_t kv_block_stride; - ptrdiff_t kv_head_stride; + ptrdiff_t k_batch_stride; + ptrdiff_t k_row_stride; + ptrdiff_t k_head_stride; + ptrdiff_t v_batch_stride; + ptrdiff_t v_row_stride; + ptrdiff_t v_head_stride; ptrdiff_t o_stride; + ptrdiff_t block_table_batch_stride; + ptrdiff_t cache_lens_stride; + static utils::Result create( infiniopTensorDescriptor_t out_desc, infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t block_tables_desc, - infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t cache_lens_desc, const std::optional &alibi_slopes_desc, float scale) { auto dtype = q_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - if (q_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || block_tables_desc->ndim() != 2 || seq_lens_desc->ndim() != 1) { + if (q_desc->ndim() != 3 || out_desc->ndim() != 3) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (block_tables_desc->ndim() != 2 || cache_lens_desc->ndim() != 1) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (block_tables_desc->dtype() != INFINI_DTYPE_I64) { + CHECK_OR_RETURN(q_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(out_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + const auto block_tables_dt = block_tables_desc->dtype(); + const auto cache_lens_dt = cache_lens_desc->dtype(); + const bool debug_dtype = (std::getenv("INFINIOP_FLASH_DEBUG_DTYPE") != nullptr); + const bool block_tables_ok = (block_tables_dt == INFINI_DTYPE_I64) || (block_tables_dt == INFINI_DTYPE_I32) || (block_tables_dt == INFINI_DTYPE_U32); + const bool cache_lens_ok = (cache_lens_dt == INFINI_DTYPE_I64) || (cache_lens_dt == INFINI_DTYPE_I32) || (cache_lens_dt == INFINI_DTYPE_U32); + if (!(block_tables_ok && cache_lens_ok)) { + if (debug_dtype) { + std::fprintf(stderr, + "[flash_attention] Bad index dtype: block_tables=%d cache_lens=%d (expected I32/I64/U32)\n", + static_cast(block_tables_dt), static_cast(cache_lens_dt)); + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } - - if (seq_lens_desc->dtype() != INFINI_DTYPE_I64) { + if (block_tables_dt != cache_lens_dt) { + // Keep them consistent to simplify backend dispatch. + if (debug_dtype) { + std::fprintf(stderr, + "[flash_attention] Mismatched index dtype: block_tables=%d cache_lens=%d\n", + static_cast(block_tables_dt), static_cast(cache_lens_dt)); + } return INFINI_STATUS_BAD_TENSOR_DTYPE; } - // --- Extract shape dimensions --- + CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(cache_lens_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) { + if (alibi_slopes_desc.value()->dtype() != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (alibi_slopes_desc.value()->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + CHECK_OR_RETURN(alibi_slopes_desc.value()->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + } + + // Shapes auto q_shape = q_desc->shape(); - auto k_cache_shape = k_cache_desc->shape(); + auto k_shape = k_cache_desc->shape(); + + const size_t num_seqs = q_shape[0]; + const size_t num_heads = q_shape[1]; + const size_t head_size = q_shape[2]; + + const size_t num_blocks = k_shape[0]; + (void)num_blocks; + const size_t page_block_size = k_shape[2]; + const size_t num_kv_heads = k_shape[1]; + + // if (page_block_size % 256 != 0) { + // printf("paged block size %zu\n", page_block_size); + // return INFINI_STATUS_BAD_TENSOR_SHAPE; + // } + if (head_size != 64 && head_size != 128) { + // First build only targets common FA2 head dims (expand later). + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (num_heads % num_kv_heads != 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (v_cache_desc->shape()[0] != k_shape[0] || v_cache_desc->shape()[1] != k_shape[1] || v_cache_desc->shape()[2] != k_shape[2] || v_cache_desc->shape()[3] != k_shape[3]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } - size_t num_seqs = q_shape[0]; - size_t num_heads = q_shape[1]; - size_t head_size = q_shape[2]; + if (out_desc->shape()[0] != q_shape[0] || out_desc->shape()[1] != q_shape[1] || out_desc->shape()[2] != q_shape[2]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } - if (head_size != 16 && head_size != 32 && head_size != 64 && head_size != 128 && head_size != 256) { - std::cerr << "[Error] Now only supports head_size = 16/32/64/128/256, but got " - << head_size << "." << std::endl; + if (cache_lens_desc->shape()[0] != num_seqs) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - size_t num_kv_heads = k_cache_shape[1]; - size_t block_size = v_cache_desc->shape()[2]; // 使用V cache的block size维度更可靠 - size_t max_num_blocks_per_seq = block_tables_desc->shape()[1]; + const size_t max_num_blocks_per_seq = block_tables_desc->shape()[1]; + + // Strides (in elements) + const ptrdiff_t q_stride = q_desc->stride(0); + const ptrdiff_t o_stride = out_desc->stride(0); + + const ptrdiff_t k_batch_stride = k_cache_desc->stride(0); + const ptrdiff_t k_row_stride = k_cache_desc->stride(2); + const ptrdiff_t k_head_stride = k_cache_desc->stride(1); + + const ptrdiff_t v_batch_stride = v_cache_desc->stride(0); + const ptrdiff_t v_row_stride = v_cache_desc->stride(2); + const ptrdiff_t v_head_stride = v_cache_desc->stride(1); - // --- Calculate max_seq_len for shared memory allocation --- - // This is a safe upper bound. - // info.max_seq_len = info.max_num_blocks_per_seq * info.block_size; - // --- Extract strides for memory access --- - ptrdiff_t q_stride = q_desc->stride(0); - ptrdiff_t kv_block_stride = k_cache_desc->stride(0); - ptrdiff_t kv_head_stride = k_cache_desc->stride(1); - ptrdiff_t o_stride = out_desc->stride(0); + const ptrdiff_t block_table_batch_stride = block_tables_desc->stride(0); + const ptrdiff_t cache_lens_stride = cache_lens_desc->stride(0); return utils::Result(PagedAttentionInfo{ dtype, + block_tables_dt, scale, num_seqs, num_heads, num_kv_heads, head_size, - block_size, + page_block_size, max_num_blocks_per_seq, q_stride, - kv_block_stride, - kv_head_stride, - o_stride}); + k_batch_stride, + k_row_stride, + k_head_stride, + v_batch_stride, + v_row_stride, + v_head_stride, + o_stride, + block_table_batch_stride, + cache_lens_stride, + }); } }; diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu new file mode 100644 index 000000000..c16b48e48 --- /dev/null +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd128.cu @@ -0,0 +1,1024 @@ +#include + +#include +#include +#include +#include + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel_v2.cuh" + +namespace op::paged_attention::nvidia { + +namespace { +constexpr int kMaxSplits = 8; + +constexpr size_t ceilDiv(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline int getSmCount() { + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { + return 0; + } + int sm_count = 0; + if (cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device) != cudaSuccess) { + return 0; + } + return sm_count; +} + +// A lightweight FA2-style "waves" heuristic. +// +// Important: our split-kv kernel shards the KV sequence length, so the main "work" +// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k +// (max pages * page size), which matches common decode microbench where all seqs +// share the same cache length. +inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) { + if (sm_count <= 0) { + return 1; + } + if (num_heads == 0 || num_seqs == 0) { + return 1; + } + if (seqlen_k <= 256) { + return 1; + } + + const size_t base_blocks = num_heads * num_seqs; + int best_splits = 1; + // Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens. + size_t best_score = (ceilDiv(base_blocks, static_cast(sm_count)) * seqlen_k); + + size_t prev_work_per_block = seqlen_k; + for (int s = 2; s <= kMaxSplits; ++s) { + const size_t blocks = base_blocks * static_cast(s); + const size_t waves_split = ceilDiv(blocks, static_cast(sm_count)); + const size_t work_per_block = ceilDiv(seqlen_k, static_cast(s)); + // If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant. + if (work_per_block == prev_work_per_block) { + continue; + } + prev_work_per_block = work_per_block; + // Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit. + const size_t waves_combine = ceilDiv(base_blocks, static_cast(sm_count)); + const size_t score = waves_split * work_per_block + waves_combine; + if (score < best_score) { + best_score = score; + best_splits = s; + } + } + return best_splits; +} +} // namespace + +inline bool envBool(const char *name) { + if (const char *env = std::getenv(name)) { + return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + return false; +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeWarpKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Cta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // Default CTA variant (lower overhead). + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128CtaTile16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Cta32( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // Experimental 1-warp CTA variant for head_dim=128 (kPack=4). + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128Cta32Tile16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128CtaGqa4( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // GQA fused kernel: CTA computes 4 query heads for one KV head (head_dim=128). + op::paged_attention::cuda::flashAttentionDecodeCtaGqaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCta( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCtaTile16( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCta32( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCta32Tile16( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCtaKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd128SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, o_stride); +} + +template +infiniStatus_t launch_decode_hd128_impl( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + + // Default decode config (2026-01-22): + // decode_flash_cta8_64_gqa_splitkv_4 + // Users can override any knob via the corresponding INFINIOP_FLASH_* env vars. + bool use_cta = true; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) { + // Backward-compatible: any non-"cta" value means "warp". + use_cta = (std::strcmp(env, "cta") == 0); + } + bool use_gqa_fused = true; + if (const char *env = std::getenv("INFINIOP_FLASH_GQA_FUSED")) { + if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) { + use_gqa_fused = false; + } else { + use_gqa_fused = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + } + int cta_tile = 8; + if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) { + const int v = std::atoi(env); + if (v == 8 || v == 16) { + cta_tile = v; + } + } + int cta_threads = 64; + if (const char *env = std::getenv("INFINIOP_FLASH_CTA_THREADS")) { + const int v = std::atoi(env); + if (v == 32 || v == 64) { + cta_threads = v; + } + } + dim3 block(use_cta ? static_cast(cta_threads) : 32); + + bool use_split = true; + bool use_split_auto = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) { + if (std::strcmp(env, "auto") == 0) { + use_split_auto = true; + use_split = false; + } else { + if (std::strcmp(env, "0") == 0 || std::strcmp(env, "false") == 0) { + use_split = false; + } else { + use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + } + } + int num_splits = 4; + bool fixed_num_splits = true; + if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) { + if (std::strcmp(env, "auto") == 0) { + fixed_num_splits = false; + } else { + num_splits = std::atoi(env); + fixed_num_splits = (num_splits > 0); + } + } + if (num_splits < 1) { + num_splits = 1; + } + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + + const bool debug_dispatch = envBool("INFINIOP_FLASH_DEBUG_DISPATCH"); + auto dump_dispatch = [&](const char *path) { + if (!debug_dispatch) { + return; + } + // Avoid spamming: only print when the key dispatch signature changes. + struct Sig { + const char *path; + int dtype; + size_t heads; + size_t kv_heads; + size_t seqs; + size_t pbs; + size_t max_blocks; + int cta_tile; + int cta_threads; + int split; + int split_auto; + int num_splits; + int fixed; + int gqa_fused; + }; + static Sig last{}; + static bool has_last = false; + + Sig cur{ + path, + static_cast(dtype), + num_heads, + num_kv_heads, + num_seqs, + page_block_size, + max_num_blocks_per_seq, + cta_tile, + cta_threads, + static_cast(use_split), + static_cast(use_split_auto), + num_splits, + static_cast(fixed_num_splits), + static_cast(use_gqa_fused), + }; + + if (has_last && cur.path == last.path && cur.dtype == last.dtype && cur.heads == last.heads && cur.kv_heads == last.kv_heads && cur.seqs == last.seqs && cur.pbs == last.pbs && cur.max_blocks == last.max_blocks && cur.cta_tile == last.cta_tile && cur.cta_threads == last.cta_threads && cur.split == last.split && cur.split_auto == last.split_auto && cur.num_splits == last.num_splits && cur.fixed == last.fixed && cur.gqa_fused == last.gqa_fused) { + return; + } + last = cur; + has_last = true; + + fprintf(stderr, + "[INFINIOP][paged_attention][hd128] dispatch: path=%s dtype=%d heads=%zu kv_heads=%zu seqs=%zu " + "pbs=%zu max_blocks=%zu cta_tile=%d cta_threads=%d split=%d split_auto=%d num_splits=%d fixed=%d gqa_fused=%d\n", + path, static_cast(dtype), num_heads, num_kv_heads, num_seqs, + page_block_size, max_num_blocks_per_seq, cta_tile, cta_threads, + static_cast(use_split), static_cast(use_split_auto), num_splits, static_cast(fixed_num_splits), + static_cast(use_gqa_fused)); + }; + + // Split-kv auto mode: decide whether to split based on a heuristic. + if (use_split_auto) { + // Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound). + const size_t seqlen_k = max_num_blocks_per_seq * page_block_size; + const int sm_count = getSmCount(); + num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count); + if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) { + if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) { + static size_t last_seqlen_k = 0; + if (last_seqlen_k != seqlen_k) { + last_seqlen_k = seqlen_k; + fprintf(stderr, + "[INFINIOP][paged_attention] splitkv auto(mode): sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n", + sm_count, num_heads, num_seqs, seqlen_k, num_splits); + } + } + } + // If auto picks 1, fall back to non-split to avoid extra workspace and kernel overhead. + use_split = (num_splits > 1); + } + + // const bool debug_dispatch = [] { + // if (const char *env = std::getenv("INFINIOP_FLASH_DEBUG_DISPATCH")) { + // return (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + // } + // return false; + // }(); + + // const char *selected_path = "unknown"; + + // Optional: fuse GQA groups (4) when seqlen_q=1 decode and alibi is disabled. + // This reuses K/V loads across query heads that share the same KV head. + // Controlled by INFINIOP_FLASH_GQA_FUSED (default: enabled). + if (use_gqa_fused && use_cta && !use_split && alibi_slopes == nullptr && num_kv_heads > 0 && num_heads == num_kv_heads * 4) { + dump_dispatch("cta_gqa_fused"); + dim3 grid_gqa(static_cast(num_kv_heads), static_cast(num_seqs), 1); + if (dtype == INFINI_DTYPE_F16) { + flashAttentionDecodeHd128CtaGqa4<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, nullptr, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + flashAttentionDecodeHd128CtaGqa4<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, nullptr, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + dim3 grid(static_cast(num_heads), static_cast(num_seqs), 1); + if (use_split) { + dump_dispatch(use_cta ? "splitkv_cta" : "splitkv_warp"); + // } + if (!fixed_num_splits) { + // Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound). + const size_t seqlen_k = max_num_blocks_per_seq * page_block_size; + const int sm_count = getSmCount(); + num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count); + if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) { + if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) { + static size_t last_seqlen_k = 0; + if (last_seqlen_k != seqlen_k) { + last_seqlen_k = seqlen_k; + fprintf(stderr, + "[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n", + sm_count, num_heads, num_seqs, seqlen_k, num_splits); + } + } + } + } + + const size_t n = num_seqs * num_heads; + const size_t acc_elems = static_cast(kMaxSplits) * n * 128; + const size_t m_elems = static_cast(kMaxSplits) * n; + const size_t l_elems = static_cast(kMaxSplits) * n; + const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float); + if (workspace == nullptr || workspace_size < needed_bytes) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *ws = static_cast(workspace); + float *partial_acc = ws; + float *partial_m = partial_acc + acc_elems; + float *partial_l = partial_m + m_elems; + + dim3 grid_split(static_cast(num_heads), static_cast(num_seqs), static_cast(num_splits)); + dim3 block_split(use_cta ? static_cast(cta_threads) : 32); + + if (dtype == INFINI_DTYPE_F16) { + if (use_cta) { + if (cta_threads == 32) { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCta32Tile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta32<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } else { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCtaTile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } + } else { + flashAttentionDecodeHd128SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + flashAttentionDecodeHd128SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + if (use_cta) { + if (cta_threads == 32) { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCta32Tile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta32<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } else { + if (cta_tile == 16) { + flashAttentionDecodeHd128SplitKvCtaTile16<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } else { + flashAttentionDecodeHd128SplitKvCta<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + } + } else { + flashAttentionDecodeHd128SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + } + flashAttentionDecodeHd128SplitKvCombine<__nv_bfloat16><<>>( + static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + dump_dispatch(use_cta ? "cta_nosplit" : "warp_nosplit"); + + if (dtype == INFINI_DTYPE_F16) { + if (use_cta) { + if (cta_tile == 16) { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32Tile16<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128CtaTile16<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128Cta<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } + } else { + flashAttentionDecodeHd128Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + if (use_cta) { + if (cta_tile == 16) { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32Tile16<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128CtaTile16<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + if (cta_threads == 32) { + flashAttentionDecodeHd128Cta32<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd128Cta<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } + } else { + flashAttentionDecodeHd128Warp<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + return INFINI_STATUS_SUCCESS; + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +infiniStatus_t launch_decode_hd128_i64( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int64_t *block_tables, + const int64_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd128_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd128_i32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int32_t *block_tables, + const int32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd128_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd128_u32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const uint32_t *block_tables, + const uint32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd128_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +} // namespace op::paged_attention::nvidia diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu new file mode 100644 index 000000000..421fd22ef --- /dev/null +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_hd64.cu @@ -0,0 +1,524 @@ +#include + +#include +#include +#include +#include + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel_v2.cuh" + +namespace op::paged_attention::nvidia { + +namespace { +constexpr int kMaxSplits = 8; + +constexpr size_t ceilDiv(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline int getSmCount() { + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { + return 0; + } + int sm_count = 0; + if (cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device) != cudaSuccess) { + return 0; + } + return sm_count; +} + +// A lightweight FA2-style "waves" heuristic. +// +// Important: our split-kv kernel shards the KV sequence length, so the main "work" +// dimension is tokens, not the number of pages. We use an upper bound for seqlen_k +// (max pages * page size), which matches common decode microbench where all seqs +// share the same cache length. +inline int chooseNumSplitsHeuristic(size_t num_heads, size_t num_seqs, size_t seqlen_k, int sm_count) { + if (sm_count <= 0) { + return 1; + } + if (num_heads == 0 || num_seqs == 0) { + return 1; + } + if (seqlen_k <= 256) { + return 1; + } + + const size_t base_blocks = num_heads * num_seqs; + int best_splits = 1; + // Baseline: one kernel, base_blocks CTAs, each scanning seqlen_k tokens. + size_t best_score = (ceilDiv(base_blocks, static_cast(sm_count)) * seqlen_k); + + size_t prev_work_per_block = seqlen_k; + for (int s = 2; s <= kMaxSplits; ++s) { + const size_t blocks = base_blocks * static_cast(s); + const size_t waves_split = ceilDiv(blocks, static_cast(sm_count)); + const size_t work_per_block = ceilDiv(seqlen_k, static_cast(s)); + // If this split count doesn't reduce per-block work vs the previous split, it's effectively redundant. + if (work_per_block == prev_work_per_block) { + continue; + } + prev_work_per_block = work_per_block; + // Combine is one extra kernel with base_blocks blocks; approximate as one more wave unit. + const size_t waves_combine = ceilDiv(base_blocks, static_cast(sm_count)); + const size_t score = waves_split * work_per_block + waves_combine; + if (score < best_score) { + best_score = score; + best_splits = s; + } + } + return best_splits; +} +} // namespace + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeWarpKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64Cta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + // Default CTA variant (lower overhead). + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64CtaTile16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeCtaKernel( + out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, o_stride); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64SplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + int num_splits) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvWarpKernel( + partial_acc, partial_m, partial_l, + q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, + k_batch_stride, k_row_stride, k_head_stride, v_batch_stride, v_row_stride, + v_head_stride, num_splits); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionDecodeHd64SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + ptrdiff_t o_stride) { + op::paged_attention::cuda::flashAttentionDecodeSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, o_stride); +} + +template +infiniStatus_t launch_decode_hd64_impl( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const Tindex *block_tables, + const Tindex *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + + dim3 grid(static_cast(num_heads), static_cast(num_seqs), 1); + bool use_cta = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_KERNEL")) { + use_cta = (std::strcmp(env, "cta") == 0); + } + int cta_tile = 8; + if (const char *env = std::getenv("INFINIOP_FLASH_CTA_TILE")) { + const int v = std::atoi(env); + if (v == 8 || v == 16) { + cta_tile = v; + } + } + // For head_dim=64 we use a 1-warp CTA (32 threads) with packed loads. + dim3 block(32); + + bool use_split = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) { + use_split = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + int num_splits = 4; + bool fixed_num_splits = false; + if (const char *env = std::getenv("INFINIOP_FLASH_NUM_SPLITS")) { + if (std::strcmp(env, "auto") == 0) { + fixed_num_splits = false; + } else { + num_splits = std::atoi(env); + fixed_num_splits = (num_splits > 0); + } + } + if (num_splits < 1) { + num_splits = 1; + } + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + + if (use_split) { + if (use_cta) { + // We currently only implement the split-kv path with warp kernels. + // The CTA kernel is a separate non-split implementation. + static bool warned = false; + if (!warned) { + warned = true; + fprintf(stderr, + "[INFINIOP][paged_attention] split-kv is enabled; ignoring INFINIOP_FLASH_DECODE_KERNEL=cta " + "(CTA split-kv not implemented yet)\n"); + } + } + + if (!fixed_num_splits) { + // Approximate seqlen_k by the per-seq KV capacity (paged KV upper bound). + const size_t seqlen_k = max_num_blocks_per_seq * page_block_size; + const int sm_count = getSmCount(); + num_splits = chooseNumSplitsHeuristic(num_heads, num_seqs, seqlen_k, sm_count); + if (const char *dbg = std::getenv("INFINIOP_FLASH_DEBUG_SPLITS")) { + if (std::strcmp(dbg, "1") == 0 || std::strcmp(dbg, "true") == 0) { + static size_t last_seqlen_k = 0; + if (last_seqlen_k != seqlen_k) { + last_seqlen_k = seqlen_k; + fprintf(stderr, + "[INFINIOP][paged_attention] splitkv auto: sm=%d heads=%zu seqs=%zu seqlen_k~%zu -> num_splits=%d\n", + sm_count, num_heads, num_seqs, seqlen_k, num_splits); + } + } + } + } + + const size_t n = num_seqs * num_heads; + const size_t acc_elems = static_cast(kMaxSplits) * n * 64; + const size_t m_elems = static_cast(kMaxSplits) * n; + const size_t l_elems = static_cast(kMaxSplits) * n; + const size_t needed_bytes = (acc_elems + m_elems + l_elems) * sizeof(float); + if (workspace == nullptr || workspace_size < needed_bytes) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *ws = static_cast(workspace); + float *partial_acc = ws; + float *partial_m = partial_acc + acc_elems; + float *partial_l = partial_m + m_elems; + + dim3 grid_split(static_cast(num_heads), static_cast(num_seqs), static_cast(num_splits)); + dim3 block_split(32); + + if (dtype == INFINI_DTYPE_F16) { + flashAttentionDecodeHd64SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd64SplitKvCombine<<>>( + static_cast(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + flashAttentionDecodeHd64SplitKv<<>>( + partial_acc, partial_m, partial_l, + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, num_splits); + flashAttentionDecodeHd64SplitKvCombine<__nv_bfloat16><<>>( + static_cast<__nv_bfloat16 *>(out), partial_acc, partial_m, partial_l, num_splits, o_stride); + return INFINI_STATUS_SUCCESS; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (dtype == INFINI_DTYPE_F16) { + if (use_cta) { + if (cta_tile == 16) { + flashAttentionDecodeHd64CtaTile16<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd64Cta<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + flashAttentionDecodeHd64Warp<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + return INFINI_STATUS_SUCCESS; + } + if (dtype == INFINI_DTYPE_BF16) { + if (use_cta) { + if (cta_tile == 16) { + flashAttentionDecodeHd64CtaTile16<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } else { + flashAttentionDecodeHd64Cta<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + } else { + flashAttentionDecodeHd64Warp<<>>( + static_cast<__nv_bfloat16 *>(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, cache_lens, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + q_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride); + } + + return INFINI_STATUS_SUCCESS; + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +infiniStatus_t launch_decode_hd64_i64( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int64_t *block_tables, + const int64_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd64_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd64_i32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const int32_t *block_tables, + const int32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd64_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +infiniStatus_t launch_decode_hd64_u32( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + infiniDtype_t dtype, + const uint32_t *block_tables, + const uint32_t *cache_lens, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t q_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + cudaStream_t stream) { + return launch_decode_hd64_impl( + workspace, workspace_size, + out, q, k_cache, v_cache, dtype, block_tables, cache_lens, alibi_slopes, num_heads, num_seqs, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, q_stride, k_batch_stride, k_row_stride, + k_head_stride, v_batch_stride, v_row_stride, v_head_stride, o_stride, stream); +} + +} // namespace op::paged_attention::nvidia diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu index d544fd34a..18b6ef073 100644 --- a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu @@ -1,29 +1,68 @@ -#include +#include -#include "../../../devices/nvidia/nvidia_common.cuh" -#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include +#include +#include -#include "../../../reduce/cuda/reduce.cuh" -#include "../cuda/kernel.cuh" +#include "../../../devices/nvidia/nvidia_common.cuh" #include "paged_attention_nvidia.cuh" -template -INFINIOP_CUDA_KERNEL pagedAttention( - Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, - const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes, - const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq, - const size_t block_size, - const ptrdiff_t q_stride, - const ptrdiff_t kv_block_stride, - const ptrdiff_t kv_head_stride, - const ptrdiff_t o_stride) { - op::paged_attention::cuda::pagedAttentionKernel( - out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale, - max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride); -} - namespace op::paged_attention::nvidia { +infiniStatus_t launch_decode_hd64_i64( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd64_i32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd64_u32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd128_i64( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int64_t *block_tables, const int64_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd128_i32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const int32_t *block_tables, const int32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + +infiniStatus_t launch_decode_hd128_u32( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, const uint32_t *block_tables, const uint32_t *cache_lens, const float *alibi_slopes, + size_t num_heads, size_t num_seqs, size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t page_block_size, + ptrdiff_t q_stride, ptrdiff_t k_batch_stride, ptrdiff_t k_row_stride, ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, ptrdiff_t v_row_stride, ptrdiff_t v_head_stride, ptrdiff_t o_stride, + cudaStream_t stream); + struct Descriptor::Opaque { std::shared_ptr internal; }; @@ -40,108 +79,284 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t block_tables_desc, - infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t cache_lens_desc, const std::optional &alibi_slopes_desc, float scale) { - auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale); - CHECK_RESULT(info); + + auto info_res = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, alibi_slopes_desc, scale); + CHECK_RESULT(info_res); + auto info = info_res.take(); + // Reserve workspace for optional split-kv decode (partial acc + m/l). + // Workspace is independent of runtime env toggles; kernels will clamp num_splits <= kMaxSplits. + constexpr size_t kMaxSplits = 8; + const size_t per_split = info.num_seqs * info.num_heads * (info.head_size + 2) * sizeof(float); + const size_t workspace_bytes = kMaxSplits * per_split; + *desc_ptr = new Descriptor( new Opaque{reinterpret_cast(handle)->internal()}, - info.take(), 0, handle->device, handle->device_id); + info, workspace_bytes, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } -template -infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache, - infiniDtype_t dtype, - const void *block_tables, const void *seq_lens, const void *alibi_slopes, - size_t num_heads, size_t num_seqs, - size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size, - ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride, - cudaStream_t stream) { - dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1); - dim3 block(NUM_THREADS); - size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); - - if (dtype == INFINI_DTYPE_F16) { - pagedAttention - <<>>( - (half *)out, - (const half *)q, (const half *)k_cache, (const half *)v_cache, - (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, - scale, max_num_blocks_per_seq, block_size, - q_stride, kv_block_stride, kv_head_stride, o_stride); - } else if (dtype == INFINI_DTYPE_BF16) { - pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS> - <<>>( - (__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k_cache, (const __nv_bfloat16 *)v_cache, - (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, - scale, max_num_blocks_per_seq, block_size, - q_stride, kv_block_stride, kv_head_stride, o_stride); - } else if (dtype == INFINI_DTYPE_F32) { - pagedAttention - <<>>( - (float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache, - (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, - scale, max_num_blocks_per_seq, block_size, - q_stride, kv_block_stride, kv_head_stride, o_stride); - } else { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - return INFINI_STATUS_SUCCESS; -} - infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, void *out, const void *q, const void *k_cache, const void *v_cache, - const void *block_tables, const void *seq_lens, const void *alibi_slopes, + const void *block_tables, const void *cache_lens, const void *alibi_slopes, void *stream_) const { - cudaStream_t stream = (cudaStream_t)stream_; - -#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \ - launchKernel<__H_SIZE, __B_SIZE>( \ - out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \ - _info.num_heads, _info.num_seqs, \ - _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \ - _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \ - stream); - -#define SWITCH_HEAD_SIZE(__B_SIZE) \ - switch (_info.head_size) { \ - case 16: \ - LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \ - break; \ - case 32: \ - LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \ - break; \ - case 64: \ - LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \ - break; \ - case 128: \ - LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \ - break; \ - case 256: \ - LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \ - break; \ - default: \ - return INFINI_STATUS_BAD_TENSOR_SHAPE; \ - } - if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { - SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024) - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { - SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512) - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { - SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096) + bool need_workspace = false; + if (const char *env = std::getenv("INFINIOP_FLASH_DECODE_SPLITKV")) { + // "auto" may enable split-kv depending on the runtime heuristic. + need_workspace = (std::strcmp(env, "auto") == 0) || (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); } else { - return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + // Keep hd64 behavior unchanged, but for hd128 we default to split-kv decode, which needs workspace. + need_workspace = (_info.head_size == 128); + } + if (need_workspace && workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } -#undef LAUNCH_HEADSIZE_BLOCKSIZE -#undef SWITCH_HEAD_SIZE + auto stream = static_cast(stream_); - return INFINI_STATUS_SUCCESS; + const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast(alibi_slopes); + + if (_info.index_dtype == INFINI_DTYPE_I64) { + const auto *block_table_i64 = static_cast(block_tables); + const auto *cache_lens_i64 = static_cast(cache_lens); + switch (_info.head_size) { + case 64: + return launch_decode_hd64_i64( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i64, cache_lens_i64, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + case 128: + return launch_decode_hd128_i64( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i64, cache_lens_i64, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (_info.index_dtype == INFINI_DTYPE_I32) { + const auto *block_table_i32 = static_cast(block_tables); + const auto *cache_lens_i32 = static_cast(cache_lens); + switch (_info.head_size) { + case 64: + return launch_decode_hd64_i32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i32, cache_lens_i32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + case 128: + return launch_decode_hd128_i32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_i32, cache_lens_i32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (_info.index_dtype == INFINI_DTYPE_U32) { + const auto *block_table_u32 = static_cast(block_tables); + const auto *cache_lens_u32 = static_cast(cache_lens); + switch (_info.head_size) { + case 64: + return launch_decode_hd64_u32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_u32, cache_lens_u32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + case 128: + return launch_decode_hd128_u32( + workspace, workspace_size, + out, q, k_cache, v_cache, _info.dtype, + block_table_u32, cache_lens_u32, alibi_ptr, + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.scale, + _info.max_num_blocks_per_seq, _info.page_block_size, + _info.q_stride, _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, + _info.o_stride, stream); + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; } } // namespace op::paged_attention::nvidia + +// #include + +// #include "../../../devices/nvidia/nvidia_common.cuh" +// #include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +// #include "../../../reduce/cuda/reduce.cuh" +// #include "../cuda/kernel.cuh" +// #include "paged_attention_nvidia.cuh" + +// template +// INFINIOP_CUDA_KERNEL pagedAttention( +// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, +// const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes, +// const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq, +// const size_t block_size, +// const ptrdiff_t q_stride, +// const ptrdiff_t kv_block_stride, +// const ptrdiff_t kv_head_stride, +// const ptrdiff_t o_stride) { +// op::paged_attention::cuda::pagedAttentionKernel( +// out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale, +// max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride); +// } + +// namespace op::paged_attention::nvidia { + +// struct Descriptor::Opaque { +// std::shared_ptr internal; +// }; + +// Descriptor::~Descriptor() { +// delete _opaque; +// } + +// infiniStatus_t Descriptor::create( +// infiniopHandle_t handle, +// Descriptor **desc_ptr, +// infiniopTensorDescriptor_t out_desc, +// infiniopTensorDescriptor_t q_desc, +// infiniopTensorDescriptor_t k_cache_desc, +// infiniopTensorDescriptor_t v_cache_desc, +// infiniopTensorDescriptor_t block_tables_desc, +// infiniopTensorDescriptor_t seq_lens_desc, +// const std::optional &alibi_slopes_desc, +// float scale) { +// auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale); +// CHECK_RESULT(info); +// *desc_ptr = new Descriptor( +// new Opaque{reinterpret_cast(handle)->internal()}, +// info.take(), 0, handle->device, handle->device_id); + +// return INFINI_STATUS_SUCCESS; +// } + +// template +// infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache, +// infiniDtype_t dtype, +// const void *block_tables, const void *seq_lens, const void *alibi_slopes, +// size_t num_heads, size_t num_seqs, +// size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size, +// ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride, +// cudaStream_t stream) { +// dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1); +// dim3 block(NUM_THREADS); +// size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); + +// if (dtype == INFINI_DTYPE_F16) { +// pagedAttention +// <<>>( +// (half *)out, +// (const half *)q, (const half *)k_cache, (const half *)v_cache, +// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, +// scale, max_num_blocks_per_seq, block_size, +// q_stride, kv_block_stride, kv_head_stride, o_stride); +// } else if (dtype == INFINI_DTYPE_BF16) { +// pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS> +// <<>>( +// (__nv_bfloat16 *)out, (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k_cache, (const __nv_bfloat16 *)v_cache, +// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, +// scale, max_num_blocks_per_seq, block_size, +// q_stride, kv_block_stride, kv_head_stride, o_stride); +// } else if (dtype == INFINI_DTYPE_F32) { +// pagedAttention +// <<>>( +// (float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache, +// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, +// scale, max_num_blocks_per_seq, block_size, +// q_stride, kv_block_stride, kv_head_stride, o_stride); +// } else { +// return INFINI_STATUS_BAD_TENSOR_DTYPE; +// } +// return INFINI_STATUS_SUCCESS; +// } + +// infiniStatus_t Descriptor::calculate( +// void *workspace, size_t workspace_size, +// void *out, const void *q, const void *k_cache, const void *v_cache, +// const void *block_tables, const void *seq_lens, const void *alibi_slopes, +// void *stream_) const { +// cudaStream_t stream = (cudaStream_t)stream_; + +// #define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \ +// launchKernel<__H_SIZE, __B_SIZE>( \ +// out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \ +// _info.num_heads, _info.num_seqs, \ +// _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \ +// _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \ +// stream); + +// #define SWITCH_HEAD_SIZE(__B_SIZE) \ +// switch (_info.head_size) { \ +// case 16: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \ +// break; \ +// case 32: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \ +// break; \ +// case 64: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \ +// break; \ +// case 128: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \ +// break; \ +// case 256: \ +// LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \ +// break; \ +// default: \ +// return INFINI_STATUS_BAD_TENSOR_SHAPE; \ +// } + +// if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { +// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024) +// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { +// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512) +// } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { +// SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096) +// } else { +// return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; +// } + +// #undef LAUNCH_HEADSIZE_BLOCKSIZE +// #undef SWITCH_HEAD_SIZE + +// return INFINI_STATUS_SUCCESS; +// } + +// } // namespace op::paged_attention::nvidia diff --git a/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh new file mode 100644 index 000000000..6790f12d8 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh @@ -0,0 +1,2361 @@ +#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ +#define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ + +#include +#include +#include +#include + +#include +#include + +// Reuse warp-level primitives and math helpers from decode flash_attention kernels. +#include "../../paged_attention/cuda/kernel_v2.cuh" + +namespace op::paged_attention_prefill::cuda { + +__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cu_seqlens_q, size_t num_seqs) { + size_t low = 0, high = (num_seqs == 0) ? 0 : (num_seqs - 1); + while (low <= high) { + size_t mid = (low + high) >> 1; + const size_t start = static_cast(cu_seqlens_q[mid]); + const size_t end = static_cast(cu_seqlens_q[mid + 1]); + if (token_idx >= start && token_idx < end) { + return mid; + } else if (token_idx < start) { + if (mid == 0) { + break; + } + high = mid - 1; + } else { + low = mid + 1; + } + } + return 0; +} + +template +__device__ void PagedAttentionPrefillWarpKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int lane = threadIdx.x; + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int q_token_local = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_token_local >= q_len) { + return; + } + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = history_len + q_token_local + 1; + if (allowed_k_len <= 0) { + return; + } + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const int64_t q_token = q_start + static_cast(q_token_local); + const Tdata *q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + Tdata *out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __half22float2(q2[j]); + } + } + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __bfloat1622float2(q2[j]); + } + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + const int pbs = static_cast(page_block_size); + int t_base = 0; + for (int logical_block = 0; t_base < allowed_k_len; ++logical_block, t_base += pbs) { + int physical_block = 0; + if (lane == 0) { + physical_block = static_cast(block_table[logical_block]); + } + physical_block = __shfl_sync(0xffffffff, physical_block, 0); + + const Tdata *k_base = k_cache_ + static_cast(physical_block) * k_batch_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base = v_cache_ + static_cast(physical_block) * v_batch_stride + static_cast(kv_head_idx) * v_head_stride; + + const int token_end = min(pbs, allowed_k_len - t_base); + for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) { + const int t = t_base + token_in_block; + const Tdata *k_ptr = k_base + static_cast(token_in_block) * k_row_stride; + const Tdata *v_ptr = v_base + static_cast(token_in_block) * v_row_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + qk = op::paged_attention::cuda::warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + const int causal_limit = allowed_k_len - 1; + score += (alibi_slope * static_cast(t - causal_limit)) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float o = acc[i] * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +template +__global__ void PagedAttentionPrefillWarpGlobalKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int lane = threadIdx.x; + const size_t head_idx = static_cast(blockIdx.x); + const size_t global_token_idx = static_cast(blockIdx.y); + + if (lane >= kWarpSize || head_idx >= num_heads || global_token_idx >= total_q_tokens) { + return; + } + + const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs); + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + + const int q_token_local = static_cast(global_token_idx - static_cast(q_start)); + if (q_token_local < 0 || q_token_local >= q_len) { + return; + } + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = history_len + q_token_local + 1; + if (allowed_k_len <= 0) { + return; + } + + const int num_queries_per_kv = static_cast(num_heads / num_kv_heads); + const int kv_head_idx = static_cast(head_idx) / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + const Tdata *q_ptr = q_ + static_cast(global_token_idx) * q_stride + static_cast(head_idx) * q_head_stride; + Tdata *out_ptr = out_ + static_cast(global_token_idx) * o_stride + static_cast(head_idx) * o_head_stride; + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + const int pbs = static_cast(page_block_size); + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = static_cast(q_ptr[dim]); + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __half22float2(q2[j]); + } + } + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = __bfloat1622float2(q2[j]); + } + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + // Iterate by pages to avoid per-token division/mod and redundant block_table loads. + int t_base = 0; + for (int logical_block = 0; t_base < allowed_k_len; ++logical_block, t_base += pbs) { + const int32_t phys = static_cast(block_table[logical_block]); + const Tdata *k_base = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(kv_head_idx) * v_head_stride; + + const int token_end = min(pbs, allowed_k_len - t_base); + for (int token_in_block = 0; token_in_block < token_end; ++token_in_block) { + const int t = t_base + token_in_block; + const Tdata *k_ptr = k_base + static_cast(token_in_block) * k_row_stride; + const Tdata *v_ptr = v_base + static_cast(token_in_block) * v_row_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + } + qk = op::paged_attention::cuda::warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + const int causal_limit = allowed_k_len - 1; + score += (alibi_slope * static_cast(t - causal_limit)) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float o = acc[i] * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +template +__global__ void PagedAttentionPrefillReferenceKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_heads, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + size_t num_seqs) { + + const size_t global_token_idx = static_cast(blockIdx.x); + const size_t head_idx = static_cast(blockIdx.y); + const size_t dim_idx = static_cast(threadIdx.x); + + if (dim_idx >= HEAD_SIZE || head_idx >= num_heads) { + return; + } + + const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs); + const size_t q_token_idx = global_token_idx - static_cast(cu_seqlens_q_[seq_idx]); + const size_t q_len = static_cast(cu_seqlens_q_[seq_idx + 1] - cu_seqlens_q_[seq_idx]); + + const size_t total_kv_len = static_cast(total_kv_lens_[seq_idx]); + const size_t history_len = total_kv_len - q_len; + const size_t causal_limit = history_len + q_token_idx; + + const size_t num_queries_per_kv = num_heads / num_kv_heads; + const size_t kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + + const Tdata *q_vec = q_ + static_cast(global_token_idx) * q_stride + static_cast(head_idx) * q_head_stride; + Tdata *out_ptr = out_ + static_cast(global_token_idx) * o_stride + static_cast(head_idx) * o_head_stride; + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + const size_t pbs = page_block_size; + + Tcompute max_score = -INFINITY; + for (size_t t = 0; t <= causal_limit; ++t) { + const size_t page = t / pbs; + const size_t off = t - page * pbs; + const ptrdiff_t phys = static_cast(block_table[page]); + const Tdata *k_vec = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + + Tcompute score = 0; + for (size_t d = 0; d < HEAD_SIZE; ++d) { + score += static_cast(q_vec[d]) * static_cast(k_vec[d]); + } + score *= static_cast(scale); + if (alibi_slope != 0.0f) { + score += static_cast(alibi_slope * static_cast(t - causal_limit)); + } + if (score > max_score) { + max_score = score; + } + } + + Tcompute sum_exp = 0; + for (size_t t = 0; t <= causal_limit; ++t) { + const size_t page = t / pbs; + const size_t off = t - page * pbs; + const ptrdiff_t phys = static_cast(block_table[page]); + const Tdata *k_vec = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + + Tcompute score = 0; + for (size_t d = 0; d < HEAD_SIZE; ++d) { + score += static_cast(q_vec[d]) * static_cast(k_vec[d]); + } + score *= static_cast(scale); + if (alibi_slope != 0.0f) { + score += static_cast(alibi_slope * static_cast(t - causal_limit)); + } + sum_exp += static_cast(expf(static_cast(score - max_score))); + } + + const Tcompute inv_sum = static_cast(1.0f) / (sum_exp + static_cast(1e-6f)); + Tcompute acc = 0; + for (size_t t = 0; t <= causal_limit; ++t) { + const size_t page = t / pbs; + const size_t off = t - page * pbs; + const ptrdiff_t phys = static_cast(block_table[page]); + const Tdata *k_vec = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + + Tcompute score = 0; + for (size_t d = 0; d < HEAD_SIZE; ++d) { + score += static_cast(q_vec[d]) * static_cast(k_vec[d]); + } + score *= static_cast(scale); + if (alibi_slope != 0.0f) { + score += static_cast(alibi_slope * static_cast(t - causal_limit)); + } + const Tcompute prob = static_cast(expf(static_cast(score - max_score))) * inv_sum; + + const Tdata *v_vec = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(off) * v_row_stride + static_cast(kv_head_idx) * v_head_stride; + acc += prob * static_cast(v_vec[dim_idx]); + } + + out_ptr[dim_idx] = static_cast(acc); +} + +template +__device__ void PagedAttentionPrefillWarpCtaKernel( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be small (warp-per-query design)."); + static_assert(BLOCK_N == 64 || BLOCK_N == 128, "BLOCK_N must be 64/128 in v0.4."); + + constexpr int kWarpSize = 32; + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= BLOCK_M) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int m_block = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * BLOCK_M; + const int q_token_local = m_start + warp_id; + // IMPORTANT: do not early-return for a subset of warps in this CTA because we use __syncthreads() + // later. Tail tiles are handled by masking inactive warps. + if (m_start >= q_len) { + return; // uniform across the CTA + } + const bool is_active = (q_token_local < q_len); + + const int64_t kv_len_total_i64 = total_kv_lens_[seq_idx]; + const int kv_len_total = static_cast(kv_len_total_i64); + // history_len = total_kv_len - q_len (KV already includes current q tokens). + const int history_len = kv_len_total - q_len; + const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0; + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int64_t q_token = q_start; + if (is_active) { + q_token += static_cast(q_token_local); + } + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + const Tdata *q_ptr = nullptr; + Tdata *out_ptr = nullptr; + if (is_active) { + q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; + } + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = is_active ? static_cast(q_ptr[dim]) : 0.0f; + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = make_float2(q_reg[j * 2 + 0], q_reg[j * 2 + 1]); + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + // For this CTA, we only need to scan up to the max allowed k among active warps. + const int max_q_in_tile = min(m_start + BLOCK_M, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + + __shared__ int32_t s_phys[BLOCK_N]; + __shared__ int32_t s_off[BLOCK_N]; + // Ensure shared-memory tiles are aligned for half2/bfloat162 vector loads. + __shared__ __align__(16) Tdata s_k[BLOCK_N * HEAD_SIZE]; + __shared__ __align__(16) Tdata s_v[BLOCK_N * HEAD_SIZE]; + + const int pbs = static_cast(page_block_size); + + for (int k_base = 0; k_base < max_allowed_k_len; k_base += BLOCK_N) { + const int tile_n = min(BLOCK_N, max_allowed_k_len - k_base); + + // Precompute page mapping once per token in the tile. + for (int t = threadIdx.x; t < tile_n; t += blockDim.x) { + const int kpos = k_base + t; + const int page = (pbs == 256) ? (kpos >> 8) : (kpos / pbs); + const int off = (pbs == 256) ? (kpos & 255) : (kpos - page * pbs); + const int32_t phys = static_cast(block_table[page]); + s_phys[t] = phys; + s_off[t] = off; + } + __syncthreads(); + + // Load K/V tile into shared memory (contiguous in head_dim). + const int tile_elems = tile_n * HEAD_SIZE; + for (int idx = threadIdx.x; idx < tile_elems; idx += blockDim.x) { + const int t = idx / HEAD_SIZE; + const int dim = idx - t * HEAD_SIZE; + const int32_t phys = s_phys[t]; + const int32_t off = s_off[t]; + const Tdata *k_base_ptr = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base_ptr = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(off) * v_row_stride + static_cast(kv_head_idx) * v_head_stride; + s_k[t * HEAD_SIZE + dim] = k_base_ptr[dim]; + s_v[t * HEAD_SIZE + dim] = v_base_ptr[dim]; + } + __syncthreads(); + + // Each warp processes one query token and scans the K/V tile. + for (int t = 0; t < tile_n; ++t) { + const int kpos = k_base + t; + if (kpos >= allowed_k_len) { + break; + } + const Tdata *k_ptr = s_k + t * HEAD_SIZE; + const Tdata *v_ptr = s_v + t * HEAD_SIZE; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + + qk = op::paged_attention::cuda::warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + // Causal prefill: last position is (allowed_k_len - 1) for this query. + score += (alibi_slope * static_cast(kpos - (allowed_k_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + + __syncthreads(); + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float out_val = acc[i] * inv_l; + if (!is_active) { + continue; + } + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(out_val); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(out_val); + } else { + out_ptr[dim] = static_cast(out_val); + } + } +} + +// Pipelined CTA kernel (FA2-style): stage K/V loads with cp.async and overlap global->shared +// copies with compute. +// +// Design notes: +// - Keep shared memory <= 48KB for compatibility with multi-arch builds that include SM75. +// - Iterate by paged blocks (logical pages) so each tile stays within one physical block and +// avoids per-token (page, off) mapping arrays in shared memory. +// - One warp computes one query token (same as warpcta kernels). Warps with shorter causal +// limits simply mask the tail tokens but still participate in CTA-wide barriers. +template +__device__ void PagedAttentionPrefillWarpCtaKernelPipelined( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be <= 16."); + static_assert(TOKENS_PER_TILE == 32, "Pipelined CTA kernel currently assumes TOKENS_PER_TILE == 32."); + static_assert(STAGES >= 2 && STAGES <= 3, "STAGES must be 2 or 3."); + static_assert(sizeof(Tdata) == 2, "Pipelined CTA kernel supports only fp16/bf16."); + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= BLOCK_M) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int m_block = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * BLOCK_M; + const int q_token_local = m_start + warp_id; + // Uniform return for empty tail CTAs (avoid deadlock with __syncthreads). + if (m_start >= q_len) { + return; + } + const bool is_active = (q_token_local < q_len); + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0; + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int64_t q_token = q_start; + if (is_active) { + q_token += static_cast(q_token_local); + } + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + const Tdata *q_ptr = nullptr; + Tdata *out_ptr = nullptr; + if (is_active) { + q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; + } + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = is_active ? static_cast(q_ptr[dim]) : 0.0f; + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = make_float2(q_reg[j * 2 + 0], q_reg[j * 2 + 1]); + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + // For this CTA, scan KV up to the max causal limit among active warps. + const int max_q_in_tile = min(m_start + BLOCK_M, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + if (max_allowed_k_len <= 0) { + // Nothing to attend to (should be rare). Produce zeros. + if (is_active) { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + out_ptr[dim] = Tdata{}; + } + } + return; + } + + // cp.async uses 16B chunks; for fp16/bf16 that's 8 elements. + constexpr int CHUNK_ELEMS = 8; + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + // Multi-stage pipeline buffers. + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + // Per-warp scratch for tile-wise softmax (scores over TOKENS_PER_TILE). + // We keep scores in shared so each lane can load its token score (lane -> token index), + // then weights are broadcast via warp shuffles to avoid extra shared-memory traffic. + __shared__ float sh_scores[BLOCK_M][TOKENS_PER_TILE]; + // Store Q in shared (per warp). This enables more tile-level parallelism in score + // computation without expensive cross-lane shuffles of Q registers. + __shared__ __align__(16) Tdata sh_q[BLOCK_M][HEAD_SIZE]; + + const int pbs = static_cast(page_block_size); + const int tid = threadIdx.x; + + // Populate per-warp Q shared tile once. +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + sh_q[warp_id][dim] = is_active ? q_ptr[dim] : Tdata{}; + } + __syncwarp(); + + int t_base = 0; + for (int logical_block = 0; t_base < max_allowed_k_len; ++logical_block, t_base += pbs) { + const int physical_block = static_cast(block_table[logical_block]); + + const Tdata *k_base = k_cache_ + static_cast(physical_block) * k_batch_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base = v_cache_ + static_cast(physical_block) * v_batch_stride + static_cast(kv_head_idx) * v_head_stride; + + const int token_end = min(pbs, max_allowed_k_len - t_base); + const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + if (num_tiles <= 0) { + continue; + } + + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + static_cast(token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + static_cast(token_in_block + tok) * v_row_stride + off; + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + op::paged_attention::cuda::cpAsyncCommit(); + ++pending_groups; + } + + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + __syncthreads(); + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + + const int global_k_base = t_base + token_in_block; + // Tile-wise online softmax (more FA2-like than per-token update): + // 1) Compute scores for this tile (masked to each warp's causal limit). + // 2) Compute tile max + sumexp. + // 3) Accumulate weighted V for the tile. + // 4) Merge into running (m, l, acc) in a numerically stable way. + // + // NOTE: this does not yet implement MMA / full tile-level GEMM; it mainly reduces + // the serial (lane0) online-softmax update frequency from per-token to per-tile. + float alpha = 1.0f; + float beta = 0.0f; + float tile_sumexp = 0.0f; + float tile_m = -INFINITY; + + if (allowed_k_len > 0) { + // 1) scores + // Increase tile-level parallelism vs the previous per-token loop: + // split the warp into 4 groups of 8 lanes; each group computes one token score in parallel. + constexpr int LANES_PER_GROUP = 8; + constexpr int GROUPS_PER_WARP = 4; + constexpr int DIMS_PER_GROUP_LANE = HEAD_SIZE / LANES_PER_GROUP; + static_assert(HEAD_SIZE % LANES_PER_GROUP == 0, "HEAD_SIZE must be divisible by 8."); + + const int group_id = lane / LANES_PER_GROUP; // [0..3] + const int lane_g = lane & (LANES_PER_GROUP - 1); // [0..7] + const unsigned int group_mask = 0xFFu << (group_id * LANES_PER_GROUP); + + for (int j_base = 0; j_base < TOKENS_PER_TILE; j_base += GROUPS_PER_WARP) { + const int j = j_base + group_id; // token index in [0..31] + const int kpos = global_k_base + j; + + const bool token_in_tile = (j < tile_n); + const bool token_unmasked = token_in_tile && (kpos < allowed_k_len); + + float qk_part = 0.0f; + if (token_unmasked) { + const Tdata *k_ptr = &sh_k[buf][j][0]; + const int dim_base = lane_g * DIMS_PER_GROUP_LANE; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const half2 *q2 = reinterpret_cast(&sh_q[warp_id][dim_base]); + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int t = 0; t < DIMS_PER_GROUP_LANE / 2; ++t) { + const float2 qf = __half22float2(q2[t]); + const float2 kf = __half22float2(k2[t]); + qk_part += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const __nv_bfloat162 *q2 = reinterpret_cast(&sh_q[warp_id][dim_base]); + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int t = 0; t < DIMS_PER_GROUP_LANE / 2; ++t) { + const float2 qf = __bfloat1622float2(q2[t]); + const float2 kf = __bfloat1622float2(k2[t]); + qk_part += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int t = 0; t < DIMS_PER_GROUP_LANE; ++t) { + qk_part += static_cast(sh_q[warp_id][dim_base + t]) * static_cast(k_ptr[dim_base + t]); + } + } + } + + // Reduce within 8-lane group. + for (int offset = LANES_PER_GROUP / 2; offset > 0; offset >>= 1) { + qk_part += __shfl_down_sync(group_mask, qk_part, offset, LANES_PER_GROUP); + } + + if (lane_g == 0) { + float score = -INFINITY; + if (token_unmasked) { + score = qk_part * scale_log2; + if (alibi_slope != 0.0f) { + const int causal_limit = allowed_k_len - 1; + score += (alibi_slope * static_cast(kpos - causal_limit)) * kLog2e; + } + } + sh_scores[warp_id][j] = score; + } + } + __syncwarp(); + + // 2) tile max + sumexp (lane t corresponds to token t within the tile) + const float score_lane = (lane < tile_n) ? sh_scores[warp_id][lane] : -INFINITY; + float tile_m_tmp = op::paged_attention::cuda::warpReduceMax(score_lane); + tile_m_tmp = __shfl_sync(0xffffffff, tile_m_tmp, 0); + tile_m = tile_m_tmp; + + float w_lane = 0.0f; + if (lane < tile_n && tile_m != -INFINITY) { + w_lane = exp2f(score_lane - tile_m); + } + float sumexp_tmp = op::paged_attention::cuda::warpReduceSum(w_lane); + sumexp_tmp = __shfl_sync(0xffffffff, sumexp_tmp, 0); + tile_sumexp = sumexp_tmp; + + // 3) weighted V for this tile (per lane owns HEAD_SIZE/32 dims) + float acc_tile[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + acc_tile[i] = 0.0f; + } + + if (tile_sumexp > 0.0f) { + for (int j = 0; j < tile_n; ++j) { + // Broadcast weight for token j from lane j. + const float wj = __shfl_sync(0xffffffff, w_lane, j); + const Tdata *v_ptr = &sh_v[buf][j][0]; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) { + const float2 vf = __half22float2(v2[jj]); + acc_tile[jj * 2 + 0] += wj * vf.x; + acc_tile[jj * 2 + 1] += wj * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) { + const float2 vf = __bfloat1622float2(v2[jj]); + acc_tile[jj * 2 + 0] += wj * vf.x; + acc_tile[jj * 2 + 1] += wj * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + acc_tile[i] += wj * static_cast(v_ptr[dim]); + } + } + } + } + + // 4) merge tile into running (m, l, acc) + if (lane == 0) { + if (tile_sumexp > 0.0f && tile_m != -INFINITY) { + const float m_new = fmaxf(m, tile_m); + alpha = exp2f(m - m_new); + beta = exp2f(tile_m - m_new); + l = l * alpha + tile_sumexp * beta; + m = m_new; + } else { + alpha = 1.0f; + beta = 0.0f; + } + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + acc[i] = acc[i] * alpha + beta * acc_tile[i]; + } + } + + // IMPORTANT: warps in this CTA can have different allowed_k_len (due to causal mask + history), + // so they may finish the token loop at different times. We must not start prefetching into + // the circular shared-memory buffer until all warps finish consuming the current tile. + __syncthreads(); + + // Prefetch the tile that will reuse this buffer (STAGES steps ahead). + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + static_cast(token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + static_cast(token_prefetch + tok) * v_row_stride + off; + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + op::paged_attention::cuda::cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + int desired_pending2 = pending_groups - 1; + if (desired_pending2 < 0) { + desired_pending2 = 0; + } + if (desired_pending2 > (STAGES - 1)) { + desired_pending2 = (STAGES - 1); + } + op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending2); + pending_groups = desired_pending2; + __syncthreads(); + } + } + + op::paged_attention::cuda::cpAsyncWaitAll(); + __syncthreads(); + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float out_val = acc[i] * inv_l; + if (!is_active) { + continue; + } + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(out_val); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(out_val); + } else { + out_ptr[dim] = static_cast(out_val); + } + } +} + +// Split-KV prefill (FA2-style): each split scans a shard of KV and writes partial (m, l, acc) +// to workspace. A separate combine kernel merges splits into the final output. +// +// Notes: +// - Implemented for the pipelined CTA kernel family (warpcta8pipe). We split by logical paged blocks. +// - Each warp still applies its own causal limit (allowed_k_len) so correctness is preserved. +template +__device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( + float *partial_acc, // [num_splits, total_q_tokens, num_heads, head_size] + float *partial_m, // [num_splits, total_q_tokens, num_heads] + float *partial_l, // [num_splits, total_q_tokens, num_heads] + int split_idx, + int num_splits, + int m_block, + size_t total_q_tokens, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride) { + + (void)max_num_blocks_per_seq; + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be <= 16."); + static_assert(TOKENS_PER_TILE == 32, "Split-KV prefill assumes TOKENS_PER_TILE == 32."); + static_assert(STAGES >= 2 && STAGES <= 3, "STAGES must be 2 or 3."); + static_assert(sizeof(Tdata) == 2, "Split-KV prefill supports only fp16/bf16."); + + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= BLOCK_M) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * BLOCK_M; + const int q_token_local = m_start + warp_id; + if (m_start >= q_len) { + return; // uniform + } + const bool is_active = (q_token_local < q_len); + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0; + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int64_t q_token = q_start; + if (is_active) { + q_token += static_cast(q_token_local); + } + + const size_t n = total_q_tokens * static_cast(num_heads); + size_t base = 0; + if (is_active) { + base = static_cast(q_token) * static_cast(num_heads) + static_cast(head_idx); + } + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + const Tdata *q_ptr = nullptr; + if (is_active) { + q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + } + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = is_active ? static_cast(q_ptr[dim]) : 0.0f; + acc[i] = 0.0f; + } + + float m = -INFINITY; + float l = 0.0f; + + const int max_q_in_tile = min(m_start + BLOCK_M, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + if (max_allowed_k_len <= 0) { + if (is_active) { + const size_t idx = static_cast(split_idx) * n + base; + if (lane == 0) { + partial_m[idx] = -INFINITY; + partial_l[idx] = 0.0f; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = 0.0f; + } + } + return; + } + + const int pbs = static_cast(page_block_size); + const int num_blocks_total = (max_allowed_k_len + pbs - 1) / pbs; + const int blocks_per_split = (num_blocks_total + num_splits - 1) / num_splits; + const int start_block = split_idx * blocks_per_split; + const int end_block = min(num_blocks_total, start_block + blocks_per_split); + if (start_block >= end_block) { + if (is_active) { + const size_t idx = static_cast(split_idx) * n + base; + if (lane == 0) { + partial_m[idx] = -INFINITY; + partial_l[idx] = 0.0f; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = 0.0f; + } + } + return; + } + + const int max_allowed_k_len_split = min(max_allowed_k_len, end_block * pbs); + + constexpr int CHUNK_ELEMS = 8; + constexpr int CHUNKS = HEAD_SIZE / CHUNK_ELEMS; + constexpr int LOADS_PER_TILE = CHUNKS * TOKENS_PER_TILE; + + __shared__ __align__(16) Tdata sh_k[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ __align__(16) Tdata sh_v[STAGES][TOKENS_PER_TILE][HEAD_SIZE]; + __shared__ float sh_scores[BLOCK_M][TOKENS_PER_TILE]; + + const int tid = threadIdx.x; + + int t_base = start_block * pbs; + for (int logical_block = start_block; t_base < max_allowed_k_len_split; ++logical_block, t_base += pbs) { + const int physical_block = static_cast(block_table[logical_block]); + + const Tdata *k_base = k_cache_ + static_cast(physical_block) * k_batch_stride + static_cast(kv_head_idx) * k_head_stride; + const Tdata *v_base = v_cache_ + static_cast(physical_block) * v_batch_stride + static_cast(kv_head_idx) * v_head_stride; + + const int token_end = min(pbs, max_allowed_k_len_split - t_base); + const int num_tiles = (token_end + TOKENS_PER_TILE - 1) / TOKENS_PER_TILE; + if (num_tiles <= 0) { + continue; + } + + int pending_groups = 0; + const int preload = min(STAGES, num_tiles); + for (int ti = 0; ti < preload; ++ti) { + const int token_in_block = ti * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < tile_n) { + const Tdata *k_src = k_base + static_cast(token_in_block + tok) * k_row_stride + off; + const Tdata *v_src = v_base + static_cast(token_in_block + tok) * v_row_stride + off; + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[ti][tok][off], k_src); + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[ti][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[ti][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + op::paged_attention::cuda::cpAsyncCommit(); + ++pending_groups; + } + + int desired_pending = pending_groups - 1; + if (desired_pending < 0) { + desired_pending = 0; + } + if (desired_pending > (STAGES - 1)) { + desired_pending = (STAGES - 1); + } + op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending); + pending_groups = desired_pending; + __syncthreads(); + + for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + const int buf = tile_idx % STAGES; + const int token_in_block = tile_idx * TOKENS_PER_TILE; + const int tile_n = min(TOKENS_PER_TILE, token_end - token_in_block); + const int global_k_base = t_base + token_in_block; + + float alpha = 1.0f; + float beta = 0.0f; + float tile_sumexp = 0.0f; + float tile_m = -INFINITY; + float w_lane = 0.0f; + + if (allowed_k_len > 0) { + // 1) scores + for (int j = 0; j < tile_n; ++j) { + const int kpos = global_k_base + j; + const bool token_unmasked = (kpos < allowed_k_len); + float qk = 0.0f; + if (token_unmasked) { + const Tdata *k_ptr = &sh_k[buf][j][0]; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *q2 = reinterpret_cast(q_ptr + dim_base); + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int ii = 0; ii < DIMS_PER_THREAD / 2; ++ii) { + const float2 qf = __half22float2(q2[ii]); + const float2 kf = __half22float2(k2[ii]); + qk = fmaf(qf.x, kf.x, qk); + qk = fmaf(qf.y, kf.y, qk); + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *q2 = reinterpret_cast(q_ptr + dim_base); + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int ii = 0; ii < DIMS_PER_THREAD / 2; ++ii) { + const float2 qf = __bfloat1622float2(q2[ii]); + const float2 kf = __bfloat1622float2(k2[ii]); + qk = fmaf(qf.x, kf.x, qk); + qk = fmaf(qf.y, kf.y, qk); + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk = fmaf(q_reg[i], static_cast(k_ptr[dim]), qk); + } + } + } + qk = op::paged_attention::cuda::warpReduceSum(qk); + if (lane == 0) { + float score = token_unmasked ? (qk * scale_log2) : -INFINITY; + if (token_unmasked && alibi_slope != 0.0f) { + const int causal_limit = allowed_k_len - 1; + score += (alibi_slope * static_cast(kpos - causal_limit)) * kLog2e; + } + sh_scores[warp_id][j] = score; + } + } + __syncwarp(); + + // 2) tile max / sumexp + float max_tmp = -INFINITY; + if (lane < tile_n) { + max_tmp = sh_scores[warp_id][lane]; + } + max_tmp = op::paged_attention::cuda::warpReduceMax(max_tmp); + max_tmp = __shfl_sync(0xffffffff, max_tmp, 0); + tile_m = max_tmp; + + if (lane < tile_n) { + const float s = sh_scores[warp_id][lane]; + w_lane = (s == -INFINITY) ? 0.0f : exp2f(s - tile_m); + } else { + w_lane = 0.0f; + } + float sumexp_tmp = op::paged_attention::cuda::warpReduceSum(w_lane); + sumexp_tmp = __shfl_sync(0xffffffff, sumexp_tmp, 0); + tile_sumexp = sumexp_tmp; + + // 3) weighted V for this tile + float acc_tile[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + acc_tile[i] = 0.0f; + } + if (tile_sumexp > 0.0f) { + for (int j = 0; j < tile_n; ++j) { + const float wj = __shfl_sync(0xffffffff, w_lane, j); + const Tdata *v_ptr = &sh_v[buf][j][0]; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) { + const float2 vf = __half22float2(v2[jj]); + acc_tile[jj * 2 + 0] += wj * vf.x; + acc_tile[jj * 2 + 1] += wj * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int jj = 0; jj < DIMS_PER_THREAD / 2; ++jj) { + const float2 vf = __bfloat1622float2(v2[jj]); + acc_tile[jj * 2 + 0] += wj * vf.x; + acc_tile[jj * 2 + 1] += wj * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + acc_tile[i] += wj * static_cast(v_ptr[dim]); + } + } + } + } + + // 4) merge tile into running (m, l, acc) + if (lane == 0) { + if (tile_sumexp > 0.0f && tile_m != -INFINITY) { + const float m_new = fmaxf(m, tile_m); + alpha = exp2f(m - m_new); + beta = exp2f(tile_m - m_new); + l = l * alpha + tile_sumexp * beta; + m = m_new; + } else { + alpha = 1.0f; + beta = 0.0f; + } + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + acc[i] = acc[i] * alpha + beta * acc_tile[i]; + } + } + + __syncthreads(); + + const int prefetch_tile = tile_idx + STAGES; + if (prefetch_tile < num_tiles) { + const int token_prefetch = prefetch_tile * TOKENS_PER_TILE; + const int prefetch_n = min(TOKENS_PER_TILE, token_end - token_prefetch); + for (int li = tid; li < LOADS_PER_TILE; li += blockDim.x) { + const int tok = li / CHUNKS; + const int chunk = li - tok * CHUNKS; + const int off = chunk * CHUNK_ELEMS; + if (tok < prefetch_n) { + const Tdata *k_src = k_base + static_cast(token_prefetch + tok) * k_row_stride + off; + const Tdata *v_src = v_base + static_cast(token_prefetch + tok) * v_row_stride + off; + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_k[buf][tok][off], k_src); + op::paged_attention::cuda::cpAsyncCaSharedGlobal16(&sh_v[buf][tok][off], v_src); + } else { + reinterpret_cast(&sh_k[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + reinterpret_cast(&sh_v[buf][tok][off])[0] = make_uint4(0, 0, 0, 0); + } + } + op::paged_attention::cuda::cpAsyncCommit(); + ++pending_groups; + } + + if (tile_idx + 1 < num_tiles) { + int desired_pending2 = pending_groups - 1; + if (desired_pending2 < 0) { + desired_pending2 = 0; + } + if (desired_pending2 > (STAGES - 1)) { + desired_pending2 = (STAGES - 1); + } + op::paged_attention::cuda::cpAsyncWaitGroupRt(desired_pending2); + pending_groups = desired_pending2; + __syncthreads(); + } + } + + op::paged_attention::cuda::cpAsyncWaitAll(); + __syncthreads(); + } + + if (is_active) { + const size_t idx = static_cast(split_idx) * n + base; + if (lane == 0) { + partial_m[idx] = m; + partial_l[idx] = l; + } +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + partial_acc[idx * HEAD_SIZE + dim] = acc[i]; + } + } +} + +template +__device__ void PagedAttentionPrefillSplitKvCombineWarpKernel( + Tdata *out_, + const float *partial_acc, // [num_splits, total_q_tokens, num_heads, head_size] + const float *partial_m, // [num_splits, total_q_tokens, num_heads] + const float *partial_l, // [num_splits, total_q_tokens, num_heads] + int num_splits, + size_t total_q_tokens, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + const int head_idx = static_cast(blockIdx.x); + const int token_idx = static_cast(blockIdx.y); + const int lane = threadIdx.x; + constexpr int kWarpSize = 32; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + + const int num_heads = gridDim.x; + const size_t n = total_q_tokens * static_cast(num_heads); + const size_t base = static_cast(token_idx) * static_cast(num_heads) + static_cast(head_idx); + + float m = -INFINITY; + if (lane == 0) { + for (int s = 0; s < num_splits; ++s) { + m = fmaxf(m, partial_m[static_cast(s) * n + base]); + } + } + m = __shfl_sync(0xffffffff, m, 0); + + float l = 0.0f; + if (lane == 0) { + for (int s = 0; s < num_splits; ++s) { + const float ms = partial_m[static_cast(s) * n + base]; + const float ls = partial_l[static_cast(s) * n + base]; + if (ls > 0.0f) { + l += ls * exp2f(ms - m); + } + } + } + l = __shfl_sync(0xffffffff, l, 0); + const float inv_l = 1.0f / (l + 1e-6f); + + Tdata *out_ptr = out_ + static_cast(token_idx) * o_stride + static_cast(head_idx) * o_head_stride; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + float acc = 0.0f; + for (int s = 0; s < num_splits; ++s) { + const float ms = partial_m[static_cast(s) * n + base]; + const float w = exp2f(ms - m); + acc += partial_acc[(static_cast(s) * n + base) * HEAD_SIZE + dim] * w; + } + const float o = acc * inv_l; + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(o); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(o); + } else { + out_ptr[dim] = static_cast(o); + } + } +} + +// Variant for large K tile where (K+V) shared memory would exceed the per-block limit on some GPUs. +// We keep K in shared memory for reuse across warps, but load V directly from global memory. +template +__device__ void PagedAttentionPrefillWarpCtaKernelKOnly( + Tdata *out_, + const Tdata *q_, + const Tdata *k_cache_, + const Tdata *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + static_assert(HEAD_SIZE == 64 || HEAD_SIZE == 128, "Only head_size 64/128 supported in v0.4."); + static_assert(BLOCK_M > 0 && BLOCK_M <= 16, "BLOCK_M must be <=16."); + static_assert(BLOCK_N > 0 && BLOCK_N <= 128, "BLOCK_N must be <=128."); + + constexpr int kWarpSize = 32; + constexpr int DIMS_PER_THREAD = HEAD_SIZE / kWarpSize; + static_assert(HEAD_SIZE % kWarpSize == 0, "HEAD_SIZE must be divisible by 32."); + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= BLOCK_M) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int m_block = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * BLOCK_M; + const int q_token_local = m_start + warp_id; + // IMPORTANT: do not early-return for a subset of warps in this CTA because we use __syncthreads() + // later. Tail tiles are handled by masking inactive warps. + if (m_start >= q_len) { + return; // uniform across the CTA + } + const bool is_active = (q_token_local < q_len); + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + const int allowed_k_len = is_active ? (history_len + q_token_local + 1) : 0; + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + + int64_t q_token = q_start; + if (is_active) { + q_token += static_cast(q_token_local); + } + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + const Tdata *q_ptr = nullptr; + Tdata *out_ptr = nullptr; + if (is_active) { + q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; + } + + float q_reg[DIMS_PER_THREAD]; + float acc[DIMS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + q_reg[i] = is_active ? static_cast(q_ptr[dim]) : 0.0f; + acc[i] = 0.0f; + } + +#if defined(__CUDA_ARCH__) + float2 q_reg2[DIMS_PER_THREAD / 2]; +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + q_reg2[j] = make_float2(q_reg[j * 2 + 0], q_reg[j * 2 + 1]); + } +#endif + + float m = -INFINITY; + float l = 0.0f; + + const int max_q_in_tile = min(m_start + BLOCK_M, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + + __shared__ int32_t s_phys[BLOCK_N]; + __shared__ int32_t s_off[BLOCK_N]; + __shared__ __align__(16) Tdata s_k[BLOCK_N * HEAD_SIZE]; + + const int pbs = static_cast(page_block_size); + + for (int k_base = 0; k_base < max_allowed_k_len; k_base += BLOCK_N) { + const int tile_n = min(BLOCK_N, max_allowed_k_len - k_base); + + for (int t = threadIdx.x; t < tile_n; t += blockDim.x) { + const int kpos = k_base + t; + const int page = (pbs == 256) ? (kpos >> 8) : (kpos / pbs); + const int off = (pbs == 256) ? (kpos & 255) : (kpos - page * pbs); + const int32_t phys = static_cast(block_table[page]); + s_phys[t] = phys; + s_off[t] = off; + } + __syncthreads(); + + const int tile_elems = tile_n * HEAD_SIZE; + for (int idx = threadIdx.x; idx < tile_elems; idx += blockDim.x) { + const int t = idx / HEAD_SIZE; + const int dim = idx - t * HEAD_SIZE; + const int32_t phys = s_phys[t]; + const int32_t off = s_off[t]; + const Tdata *k_base_ptr = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + s_k[t * HEAD_SIZE + dim] = k_base_ptr[dim]; + } + __syncthreads(); + + for (int t = 0; t < tile_n; ++t) { + const int kpos = k_base + t; + if (kpos >= allowed_k_len) { + break; + } + const Tdata *k_ptr = s_k + t * HEAD_SIZE; + const int32_t phys = s_phys[t]; + const int32_t off = s_off[t]; + const Tdata *v_ptr = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(off) * v_row_stride + static_cast(kv_head_idx) * v_head_stride; + + float qk = 0.0f; +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __half22float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *k2 = reinterpret_cast(k_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 qf = q_reg2[j]; + const float2 kf = __bfloat1622float2(k2[j]); + qk += qf.x * kf.x + qf.y * kf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + qk += q_reg[i] * static_cast(k_ptr[dim]); + } + } + + qk = op::paged_attention::cuda::warpReduceSum(qk); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + float score = qk * scale_log2; + if (alibi_slope != 0.0f) { + score += (alibi_slope * static_cast(kpos - (allowed_k_len - 1))) * kLog2e; + } + const float m_new = fmaxf(m, score); + alpha = exp2f(m - m_new); + beta = exp2f(score - m_new); + l = l * alpha + beta; + m = m_new; + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); + +#if defined(__CUDA_ARCH__) + if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const half2 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else if constexpr (std::is_same_v) { + const int dim_base = lane * DIMS_PER_THREAD; + const __nv_bfloat162 *v2 = reinterpret_cast(v_ptr + dim_base); +#pragma unroll + for (int j = 0; j < DIMS_PER_THREAD / 2; ++j) { + const float2 vf = __bfloat1622float2(v2[j]); + acc[j * 2 + 0] = acc[j * 2 + 0] * alpha + beta * vf.x; + acc[j * 2 + 1] = acc[j * 2 + 1] * alpha + beta * vf.y; + } + } else +#endif + { +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float v_val = static_cast(v_ptr[dim]); + acc[i] = acc[i] * alpha + beta * v_val; + } + } + } + + __syncthreads(); + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + +#pragma unroll + for (int i = 0; i < DIMS_PER_THREAD; ++i) { + const int dim = lane * DIMS_PER_THREAD + i; + const float out_val = acc[i] * inv_l; + if (!is_active) { + continue; + } + if constexpr (std::is_same_v) { + out_ptr[dim] = __float2half_rn(out_val); + } else if constexpr (std::is_same_v) { + out_ptr[dim] = __float2bfloat16_rn(out_val); + } else { + out_ptr[dim] = static_cast(out_val); + } + } +} + +// TensorCore (WMMA) score kernel (v0.4 experimental): +// - Target shape: head_dim=128, page_block_size=256, fp16. +// - Compute QK^T with WMMA into shared memory, then reuse the existing online-softmax + V accumulation +// pattern (SIMT) per query row. +// +// Notes: +// - This is a correctness-first kernel. It doesn't yet use MMA for PV (P * V) update. +// - We keep the same grid mapping as other prefill kernels: blockIdx = (head, seq, m_block). +template +__device__ __forceinline__ void PagedAttentionPrefillMmaScoreUpdateRow( + int lane, + int k_base, + int allowed_k_len, + const float *scores_row, // [kBlockN] + const half *v_tile, // [kBlockN, kHeadDim] + float scale_log2, + float alibi_slope_log2, + float &m, + float &l, + float *acc) { // [kDimsPerThread] + + // Max over keys in this tile. + float local_max = -INFINITY; + for (int t = lane; t < kBlockN; t += kWarpSize) { + const int kpos = k_base + t; + if (kpos >= allowed_k_len) { + continue; + } + float score = scores_row[t] * scale_log2; + if (alibi_slope_log2 != 0.0f) { + score += alibi_slope_log2 * static_cast(kpos - (allowed_k_len - 1)); + } + local_max = fmaxf(local_max, score); + } + float tile_m = op::paged_attention::cuda::warpReduceMax(local_max); + tile_m = __shfl_sync(0xffffffff, tile_m, 0); + + // Sumexp + weighted V over keys in this tile, partitioned by lanes. + float sumexp_lane = 0.0f; + float acc_tile[kDimsPerThread] = {0.0f, 0.0f, 0.0f, 0.0f}; + const int dim_base = lane * kDimsPerThread; + if (tile_m != -INFINITY) { + for (int t = lane; t < kBlockN; t += kWarpSize) { + const int kpos = k_base + t; + if (kpos >= allowed_k_len) { + continue; + } + float score = scores_row[t] * scale_log2; + if (alibi_slope_log2 != 0.0f) { + score += alibi_slope_log2 * static_cast(kpos - (allowed_k_len - 1)); + } + const float w = exp2f(score - tile_m); + sumexp_lane += w; + + const half *v_ptr = v_tile + t * kHeadDim + dim_base; + const half2 *v2 = reinterpret_cast(v_ptr); +#pragma unroll + for (int j = 0; j < kDimsPerThread / 2; ++j) { + const float2 vf = __half22float2(v2[j]); + acc_tile[j * 2 + 0] += w * vf.x; + acc_tile[j * 2 + 1] += w * vf.y; + } + } + } + + float tile_sumexp = op::paged_attention::cuda::warpReduceSum(sumexp_lane); + tile_sumexp = __shfl_sync(0xffffffff, tile_sumexp, 0); + + float alpha = 1.0f; + float beta = 0.0f; + if (lane == 0) { + if (tile_sumexp > 0.0f && tile_m != -INFINITY) { + const float m_new = fmaxf(m, tile_m); + alpha = exp2f(m - m_new); + beta = exp2f(tile_m - m_new); + l = l * alpha + tile_sumexp * beta; + m = m_new; + } else { + alpha = 1.0f; + beta = 0.0f; + } + } + alpha = __shfl_sync(0xffffffff, alpha, 0); + beta = __shfl_sync(0xffffffff, beta, 0); +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + acc[i] = acc[i] * alpha + beta * acc_tile[i]; + } +} + +template +__device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow( + int lane, + bool active, + int q_token_local, + int64_t q_start, + int head_idx, + half *out_, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + float l, + const float *acc) { // [kDimsPerThread] + if (!active) { + return; + } + + float inv_l = 0.0f; + if (lane == 0) { + inv_l = 1.0f / (l + 1e-6f); + } + inv_l = __shfl_sync(0xffffffff, inv_l, 0); + + const int64_t q_token = q_start + static_cast(q_token_local); + half *out_ptr = out_ + q_token * o_stride + static_cast(head_idx) * o_head_stride; +#pragma unroll + for (int i = 0; i < kDimsPerThread; ++i) { + const int dim = lane * kDimsPerThread + i; + out_ptr[dim] = __float2half_rn(acc[i] * inv_l); + } +} + +template +__device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel( + half *out_, + const half *q_, + const half *k_cache_, + const half *v_cache_, + const Tindex *block_tables_, + const int64_t *total_kv_lens_, + const int64_t *cu_seqlens_q_, + const float *alibi_slopes_, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + + (void)max_num_blocks_per_seq; + + constexpr int kWarpSize = 32; + constexpr int kWarps = 8; + constexpr int kHeadDim = 128; + // Extra padding in the K dimension to reduce shared-memory bank conflicts for ldmatrix / wmma loads. + // NOTE: FA2 uses a swizzled smem layout; padding is a smaller step that keeps our code simple. + constexpr int kHeadDimSmem = 136; // must be a multiple of 8 for wmma::load_matrix_sync + constexpr int kBlockM = 16; // 2 rows per warp + // Keep static shared memory <= 48KB for compatibility with build targets that cap SMEM at 0xC000. + // kBlockN=64 brings s_q+s_k+s_v+s_scores+s_phys/s_off down to ~41KB. + constexpr int kBlockN = 64; + constexpr int kDimsPerThread = kHeadDim / kWarpSize; + + static_assert(kHeadDim % kWarpSize == 0, "head_dim must be divisible by 32."); + + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_id = threadIdx.x / kWarpSize; + if (warp_id >= kWarps) { + return; + } + + const int head_idx = static_cast(blockIdx.x); + const int seq_idx = static_cast(blockIdx.y); + const int m_block = static_cast(blockIdx.z); + + const int64_t q_start = cu_seqlens_q_[seq_idx]; + const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const int q_len = static_cast(q_end - q_start); + if (q_len <= 0) { + return; + } + + const int m_start = m_block * kBlockM; + // Uniform early return for empty tail tiles (avoid deadlock with __syncthreads()). + if (m_start >= q_len) { + return; + } + + const int kv_len_total = static_cast(total_kv_lens_[seq_idx]); + const int history_len = kv_len_total - q_len; + + // Clamp max k length for this CTA based on the last active query row in the tile. + const int max_q_in_tile = min(m_start + kBlockM, q_len); + const int max_allowed_k_len = min(history_len + max_q_in_tile, kv_len_total); + + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / static_cast(num_kv_heads); + const int kv_head_idx = head_idx / num_queries_per_kv; + + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + constexpr float kLog2e = 1.4426950408889634f; + const float scale_log2 = scale * kLog2e; + const float alibi_slope_log2 = alibi_slope * kLog2e; + + const int pbs = static_cast(page_block_size); + + const Tindex *block_table = block_tables_ + static_cast(seq_idx) * static_cast(block_table_batch_stride); + + // Shared memory: + // - s_q: [kBlockM, kHeadDimSmem] (padded) + // - s_k/s_v: [kBlockN, kHeadDim] + // - s_scores: [kBlockM, kBlockN] raw dot products (no scale / alibi) + __shared__ __align__(16) half s_q[kBlockM * kHeadDimSmem]; + __shared__ int32_t s_phys[kBlockN]; + __shared__ int32_t s_off[kBlockN]; + __shared__ __align__(16) half s_k[kBlockN * kHeadDimSmem]; + __shared__ __align__(16) half s_v[kBlockN * kHeadDimSmem]; + __shared__ __align__(16) float s_scores[kBlockM * kBlockN]; + + // Load Q tile (pad inactive rows with 0). + for (int idx = threadIdx.x; idx < kBlockM * kHeadDim; idx += blockDim.x) { + const int r = idx / kHeadDim; + const int d = idx - r * kHeadDim; + const int q_token_local = m_start + r; + if (q_token_local < q_len) { + const int64_t q_token = q_start + static_cast(q_token_local); + const half *q_ptr = q_ + q_token * q_stride + static_cast(head_idx) * q_head_stride; + s_q[r * kHeadDimSmem + d] = q_ptr[d]; + } else { + s_q[r * kHeadDimSmem + d] = __float2half_rn(0.0f); + } + } + __syncthreads(); + + // Two rows per warp: row0=warp_id, row1=warp_id+kWarps. + const int row0 = warp_id; + const int row1 = warp_id + kWarps; + const bool active0 = (row0 < kBlockM) && ((m_start + row0) < q_len); + const bool active1 = (row1 < kBlockM) && ((m_start + row1) < q_len); + const int allowed0 = active0 ? min(history_len + (m_start + row0) + 1, kv_len_total) : 0; + const int allowed1 = active1 ? min(history_len + (m_start + row1) + 1, kv_len_total) : 0; + + float m0 = -INFINITY, l0 = 0.0f; + float m1 = -INFINITY, l1 = 0.0f; + float acc0[kDimsPerThread] = {0.0f, 0.0f, 0.0f, 0.0f}; + float acc1[kDimsPerThread] = {0.0f, 0.0f, 0.0f, 0.0f}; + + // Iterate over K/V tiles. + for (int k_base = 0; k_base < max_allowed_k_len; k_base += kBlockN) { + // Map logical k positions to physical blocks for this tile (pad the tail with -1). + for (int t = threadIdx.x; t < kBlockN; t += blockDim.x) { + const int kpos = k_base + t; + if (kpos < max_allowed_k_len) { + const int page = (pbs == 256) ? (kpos >> 8) : (kpos / pbs); + const int off = (pbs == 256) ? (kpos & 255) : (kpos - page * pbs); + s_phys[t] = static_cast(block_table[page]); + s_off[t] = off; + } else { + s_phys[t] = -1; + s_off[t] = 0; + } + } + __syncthreads(); + + // Load K/V tile into shared memory (pad with 0 for inactive tokens). + for (int idx = threadIdx.x; idx < kBlockN * kHeadDim; idx += blockDim.x) { + const int t = idx / kHeadDim; + const int d = idx - t * kHeadDim; + const int32_t phys = s_phys[t]; + if (phys >= 0) { + const int32_t off = s_off[t]; + const half *k_ptr = k_cache_ + static_cast(phys) * k_batch_stride + static_cast(off) * k_row_stride + static_cast(kv_head_idx) * k_head_stride; + const half *v_ptr = v_cache_ + static_cast(phys) * v_batch_stride + static_cast(off) * v_row_stride + static_cast(kv_head_idx) * v_head_stride; + s_k[t * kHeadDimSmem + d] = k_ptr[d]; + s_v[t * kHeadDimSmem + d] = v_ptr[d]; + } else { + s_k[t * kHeadDimSmem + d] = __float2half_rn(0.0f); + s_v[t * kHeadDimSmem + d] = __float2half_rn(0.0f); + } + } + __syncthreads(); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) + // WMMA: each warp computes scores for 16 keys (one 16-column slice of the K tile) across all 16 rows. + // For kBlockN=64, only the first 4 warps participate in WMMA score computation. + namespace wmma = nvcuda::wmma; + constexpr int kNSub = kBlockN / 16; + if (warp_id < kNSub) { + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + const int n_sub = warp_id; // [0, kNSub) + const half *q_tile = s_q; + const half *k_tile = s_k + (n_sub * 16) * kHeadDimSmem; + // K loop (head_dim=128). +#pragma unroll + for (int kk = 0; kk < (kHeadDim / 16); ++kk) { + wmma::load_matrix_sync(a_frag, q_tile + kk * 16, kHeadDimSmem); + wmma::load_matrix_sync(b_frag, k_tile + kk * 16, kHeadDimSmem); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + float *scores_tile = s_scores + n_sub * 16; + wmma::store_matrix_sync(scores_tile, c_frag, kBlockN, wmma::mem_row_major); + } +#else + // No WMMA support on this architecture: fall back to scalar dot in the existing kernels. + // (We keep scores as 0 so this kernel is effectively incorrect; host dispatch must avoid selecting it.) + if (threadIdx.x == 0) { + // Intentionally empty. + } +#endif + __syncthreads(); + + // Online softmax + V update per row handled by the same warp across tiles. + if (row0 < kBlockM) { + PagedAttentionPrefillMmaScoreUpdateRow( + lane, k_base, allowed0, s_scores + row0 * kBlockN, s_v, scale_log2, alibi_slope_log2, m0, l0, acc0); + } + if (row1 < kBlockM) { + PagedAttentionPrefillMmaScoreUpdateRow( + lane, k_base, allowed1, s_scores + row1 * kBlockN, s_v, scale_log2, alibi_slope_log2, m1, l1, acc1); + } + __syncthreads(); + } + + // Write outputs. + if (row0 < kBlockM) { + PagedAttentionPrefillMmaScoreWriteRow( + lane, active0, m_start + row0, q_start, head_idx, out_, o_stride, o_head_stride, l0, acc0); + } + if (row1 < kBlockM) { + PagedAttentionPrefillMmaScoreWriteRow( + lane, active1, m_start + row1, q_start, head_idx, out_, o_stride, o_head_stride, l1, acc1); + } +} + +} // namespace op::paged_attention_prefill::cuda + +#endif diff --git a/src/infiniop/ops/paged_attention_prefill/info.h b/src/infiniop/ops/paged_attention_prefill/info.h index 6f1809f06..a40f4ceaf 100644 --- a/src/infiniop/ops/paged_attention_prefill/info.h +++ b/src/infiniop/ops/paged_attention_prefill/info.h @@ -3,6 +3,7 @@ #include "../../../utils.h" #include "../../tensor.h" +#include #include #include #include @@ -14,21 +15,30 @@ class PagedAttentionPrefillInfo { public: infiniDtype_t dtype; + infiniDtype_t index_dtype; float scale; size_t num_seqs; + size_t total_q_tokens; size_t num_heads; size_t num_kv_heads; size_t head_size; - size_t block_size; + size_t page_block_size; size_t max_num_blocks_per_seq; - size_t total_q_tokens; + size_t num_blocks; ptrdiff_t q_stride; ptrdiff_t q_head_stride; - ptrdiff_t kv_block_stride; - ptrdiff_t kv_head_stride; + ptrdiff_t k_batch_stride; + ptrdiff_t k_row_stride; + ptrdiff_t k_head_stride; + ptrdiff_t v_batch_stride; + ptrdiff_t v_row_stride; + ptrdiff_t v_head_stride; ptrdiff_t o_stride; + ptrdiff_t o_head_stride; + + ptrdiff_t block_table_batch_stride; static utils::Result create( infiniopTensorDescriptor_t out_desc, @@ -36,89 +46,161 @@ class PagedAttentionPrefillInfo { infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t block_tables_desc, - infiniopTensorDescriptor_t seq_lens_desc, - infiniopTensorDescriptor_t cum_seq_lens_q_desc, + infiniopTensorDescriptor_t total_kv_lens_desc, + infiniopTensorDescriptor_t cum_seqlens_q_desc, const std::optional &alibi_slopes_desc, float scale) { auto dtype = q_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); - + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) { + // q/out: [total_q, heads, head_dim] + if (q_desc->ndim() != 3 || out_desc->ndim() != 3) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + // FA2 paged KV layout: [num_blocks, page_block_size, kv_heads, head_dim] + if (k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (block_tables_desc->ndim() != 2 || total_kv_lens_desc->ndim() != 1 || cum_seqlens_q_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + CHECK_OR_RETURN(q_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(out_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + // Index dtypes: allow I32/I64/U32 (v0.4 roadmap allows internal conversion to I32). + const auto block_tables_dt = block_tables_desc->dtype(); + if (!((block_tables_dt == INFINI_DTYPE_I64) || (block_tables_dt == INFINI_DTYPE_I32) || (block_tables_dt == INFINI_DTYPE_U32))) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } + // Keep it simple: require total_kv_lens + cum_seqlens_q to be int64 for now + // (matches current paged_attention_prefill signature). We will convert to int32 internally later. + if (total_kv_lens_desc->dtype() != INFINI_DTYPE_I64 || cum_seqlens_q_desc->dtype() != INFINI_DTYPE_I64) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) { + if (alibi_slopes_desc.value()->dtype() != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (alibi_slopes_desc.value()->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + CHECK_OR_RETURN(alibi_slopes_desc.value()->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); } - auto k_shape = k_cache_desc->shape(); - auto v_shape = v_cache_desc->shape(); - auto block_tables_shape = block_tables_desc->shape(); - auto seq_lens_shape = seq_lens_desc->shape(); - auto cum_seq_lens_q_shape = cum_seq_lens_q_desc->shape(); + const auto q_shape = q_desc->shape(); + const auto k_shape = k_cache_desc->shape(); + + const size_t total_q_tokens = q_shape[0]; + const size_t num_heads = q_shape[1]; + const size_t head_size = q_shape[2]; + + const size_t num_blocks = k_shape[0]; + const size_t page_block_size = k_shape[2]; + const size_t num_kv_heads = k_shape[1]; - if (k_shape.size() != 4 || v_shape.size() != 4) { + if (head_size != 64 && head_size != 128) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (num_heads % num_kv_heads != 0) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (block_tables_shape.size() != 2) { + // v_cache must match the inferred K layout. + const auto v_shape = v_cache_desc->shape(); + if (v_shape[0] != num_blocks || v_shape[3] != head_size) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (seq_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) { + if (v_shape[1] != num_kv_heads || v_shape[2] != page_block_size) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (cum_seq_lens_q_shape[0] != seq_lens_shape[0] + 1) { - return INFINI_STATUS_BAD_PARAM; + if (v_cache_desc->shape()[0] != k_shape[0] || v_cache_desc->shape()[3] != k_shape[3]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; } - // Q shape: [total_tokens, heads, dim] - auto q_shape = q_desc->shape(); - if (q_shape.size() != 3) { + if (out_desc->shape()[0] != q_shape[0] || out_desc->shape()[1] != q_shape[1] || out_desc->shape()[2] != q_shape[2]) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - size_t total_q_tokens = q_shape[0]; - size_t num_heads = q_shape[1]; - size_t head_size = q_shape[2]; - if (head_size > 1024) { + const size_t num_seqs = total_kv_lens_desc->shape()[0]; + if (cum_seqlens_q_desc->shape()[0] != num_seqs + 1) { return INFINI_STATUS_BAD_PARAM; } - size_t num_seqs = seq_lens_shape[0]; - size_t num_kv_heads = k_shape[1]; - size_t block_size = k_shape[2]; - size_t max_num_blocks_per_seq = block_tables_shape[1]; - - ptrdiff_t q_stride = q_desc->stride(0); - ptrdiff_t q_head_stride = q_desc->stride(1); - ptrdiff_t kv_block_stride = k_cache_desc->stride(0); - ptrdiff_t kv_head_stride = k_cache_desc->stride(1); - ptrdiff_t o_stride = out_desc->stride(0); + const size_t max_num_blocks_per_seq = block_tables_desc->shape()[1]; + + // Strides (in elements) + const ptrdiff_t q_stride = q_desc->stride(0); + const ptrdiff_t q_head_stride = q_desc->stride(1); + const ptrdiff_t o_stride = out_desc->stride(0); + const ptrdiff_t o_head_stride = out_desc->stride(1); + + const ptrdiff_t k_batch_stride = k_cache_desc->stride(0); + const ptrdiff_t k_row_stride = k_cache_desc->stride(2); + const ptrdiff_t k_head_stride = k_cache_desc->stride(1); + + const ptrdiff_t v_batch_stride = v_cache_desc->stride(0); + const ptrdiff_t v_row_stride = v_cache_desc->stride(2); + const ptrdiff_t v_head_stride = v_cache_desc->stride(1); + + const ptrdiff_t block_table_batch_stride = block_tables_desc->stride(0); + + if (const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_INFO")) { + static bool printed = false; + if (!printed && std::strcmp(dbg, "1") == 0) { + const auto bt_shape = block_tables_desc->shape(); + std::fprintf(stderr, + "[infiniop][flash_attention_prefill][info] k_shape=[%zu,%zu,%zu,%zu] k_strides=[%td,%td,%td,%td] (row_stride=%td head_stride=%td)\n", + static_cast(k_shape[0]), static_cast(k_shape[1]), + static_cast(k_shape[2]), static_cast(k_shape[3]), + k_cache_desc->stride(0), k_cache_desc->stride(1), k_cache_desc->stride(2), k_cache_desc->stride(3), + k_row_stride, k_head_stride); + std::fprintf(stderr, + "[infiniop][flash_attention_prefill][info] block_tables shape=[%zu,%zu] strides=[%td,%td]\n", + static_cast(bt_shape[0]), static_cast(bt_shape[1]), + block_tables_desc->stride(0), block_tables_desc->stride(1)); + printed = true; + } + } return utils::Result(PagedAttentionPrefillInfo{ dtype, + block_tables_dt, scale, num_seqs, + total_q_tokens, num_heads, num_kv_heads, head_size, - block_size, + page_block_size, max_num_blocks_per_seq, - total_q_tokens, + num_blocks, q_stride, q_head_stride, - kv_block_stride, - kv_head_stride, - o_stride}); + k_batch_stride, + k_row_stride, + k_head_stride, + v_batch_stride, + v_row_stride, + v_head_stride, + o_stride, + o_head_stride, + block_table_batch_stride, + }); } }; - } // namespace op::paged_attention_prefill #endif diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu index 90c4c94fc..e95268a84 100644 --- a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu @@ -1,56 +1,1237 @@ -#include -#include -#include -#include +#include + +#include +#include +#include +#include #include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh" -#include "../cuda/kernel.cuh" + +// #include "paged_attention_prefill_fa2.cuh" #include "paged_attention_prefill_nvidia.cuh" -template -infiniStatus_t launchPagedAttentionPrefill( - Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, - const int64_t *block_tables, - const int64_t *seq_lens, - const int64_t *cum_seq_lens_q, - const float *alibi_slopes, - const size_t num_heads, - const size_t num_seqs, - const size_t num_kv_heads, - const float scale, - const size_t max_num_blocks_per_seq, - const size_t block_size, - const size_t total_q_tokens, - const size_t head_size, - const ptrdiff_t kv_block_stride, - const ptrdiff_t kv_head_stride, - const ptrdiff_t q_stride, - const ptrdiff_t q_head_stride, +#include "../cuda/kernel_v2.cuh" + +namespace op::paged_attention_prefill::nvidia { + +namespace { +constexpr size_t ceilDiv(size_t a, size_t b) { + return (a + b - 1) / b; +} + +inline const char *default_prefill_kernel(const PagedAttentionPrefillInfo &info) { + // Heuristic auto-dispatch (v0.4): + // - Prefer the pipelined + tile-wise softmax kernel on FA2-compatible block_size=256. + // - Keep a conservative fallback for other shapes / older GPUs (cp.async is a no-op below SM80). + // + // Users can always override via INFINIOP_FLASH_PREFILL_KERNEL. + if (info.page_block_size == 256 && (info.dtype == INFINI_DTYPE_F16 || info.dtype == INFINI_DTYPE_BF16)) { + if (info.head_size == 128) { + return "warpcta8pipe"; + } + // For head_size=64 we keep the previous default until we have broader perf coverage. + } + return "warpcta8"; +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel). + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride, + q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64Warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // Legacy per-seq launch (kept only as a wrapper; current "warp" impl uses a global-token kernel). + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, block_table_batch_stride, + q_stride, q_head_stride, k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 4 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 4 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8N128( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token, tile_n=128 for fewer K stages. + // Note: we keep K in shared memory but load V from global to stay within the per-block shared limit. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelKOnly( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token, with cp.async pipelining. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Mma( + half *out, + const half *q, + const half *k_cache, + const half *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCta8MmaHd128Kernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 8 warps per CTA, one warp per query token, with cp.async pipelining. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelined( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + int num_splits, + size_t total_q_tokens, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride) { + // Encode (split_idx, m_block) into blockIdx.z to allow a single kernel launch: + // blockIdx.z in [0, num_splits * num_m_blocks). + const int num_m_blocks = static_cast((total_q_tokens + 8 - 1) / 8); + const int bz = static_cast(blockIdx.z); + const int split_idx = bz / num_m_blocks; + const int m_block = bz - split_idx * num_m_blocks; + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( + partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv( + float *partial_acc, + float *partial_m, + float *partial_l, + int num_splits, + size_t total_q_tokens, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride) { + const int num_m_blocks = static_cast((total_q_tokens + 8 - 1) / 8); + const int bz = static_cast(blockIdx.z); + const int split_idx = bz / num_m_blocks; + const int m_block = bz - split_idx * num_m_blocks; + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( + partial_acc, partial_m, partial_l, split_idx, num_splits, m_block, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + size_t total_q_tokens, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64SplitKvCombine( + Tdata *out, + const float *partial_acc, + const float *partial_m, + const float *partial_l, + int num_splits, + size_t total_q_tokens, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillSplitKvCombineWarpKernel( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 16 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride) { + // 16 warps per CTA, one warp per query token. + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpCtaKernel( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); +} + +template +infiniStatus_t launch_prefill_ref( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, cudaStream_t stream) { - if (total_q_tokens == 0 || num_heads == 0) { + const dim3 grid(static_cast(total_q_tokens), static_cast(num_heads), 1); + const dim3 block(static_cast(head_size), 1, 1); + + if (head_size == 64) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, num_seqs); + return INFINI_STATUS_SUCCESS; + } + + if (head_size == 128) { + op::paged_attention_prefill::cuda::PagedAttentionPrefillReferenceKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, num_seqs); + return INFINI_STATUS_SUCCESS; + } + + return INFINI_STATUS_BAD_TENSOR_SHAPE; +} + +template +infiniStatus_t launch_prefill_warp( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + const dim3 block(32, 1, 1); + // Global-token launch: + // - dramatically reduces grid size vs the legacy (num_seqs * total_q_tokens) launch + // - matches PagedAttention varlen (cu_seqlens) mental model better + const dim3 grid(static_cast(num_heads), + static_cast(total_q_tokens), + 1); + + switch (head_size) { + case 64: + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq, + page_block_size, block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + op::paged_attention_prefill::cuda::PagedAttentionPrefillWarpGlobalKernel + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, scale, max_num_blocks_per_seq, + page_block_size, block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: return INFINI_STATUS_BAD_TENSOR_SHAPE; } +} + +template +infiniStatus_t launch_prefill( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 4; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); - dim3 grid(total_q_tokens, num_heads); - dim3 block(head_size); + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta8 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta8 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} - op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel +template +infiniStatus_t launch_prefill_warpcta8pipe( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta8Pipe + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta8Pipe + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8mma( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + // Current WMMA kernel only supports fp16 + head_dim=128. + if constexpr (!std::is_same_v) { + return launch_prefill_warpcta8pipe( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale, + max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, stream); + } + + if (head_size != 128) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Guardrail: the current WMMA-score kernel is correctness-first and can be extremely slow on long prompts. + // Allow power users to force it via INFINIOP_FLASH_PREFILL_MMA_FORCE=1. + const char *force_env = std::getenv("INFINIOP_FLASH_PREFILL_MMA_FORCE"); + const bool force_mma = (force_env != nullptr) && (std::strcmp(force_env, "1") == 0); + const size_t seqlen_k_est = max_num_blocks_per_seq * page_block_size; + if (!force_mma && seqlen_k_est > 4096) { + static bool warned = false; + if (!warned) { + std::fprintf(stderr, + "[infiniop][paged_attention_prefill] warpcta8mma is experimental and very slow for long seqlen_k (est=%zu). " + "Falling back to warpcta8pipe. Set INFINIOP_FLASH_PREFILL_MMA_FORCE=1 to override.\n", + seqlen_k_est); + warned = true; + } + return launch_prefill_warpcta8pipe( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale, + max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, stream); + } + + // WMMA requires SM70+. If not supported (or if we can't query), fall back to the pipelined SIMT kernel. + int device = 0; + cudaDeviceProp prop{}; + if (cudaGetDevice(&device) == cudaSuccess && cudaGetDeviceProperties(&prop, device) == cudaSuccess) { + if (prop.major < 7) { + return launch_prefill_warpcta8pipe( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_heads, num_seqs, num_kv_heads, total_q_tokens, head_size, scale, + max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride, stream); + } + } + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(16)))); + + PagedAttentionPrefillHd128WarpCta8Mma <<>>( - out, q, k_cache, v_cache, - block_tables, seq_lens, cum_seq_lens_q, alibi_slopes, - num_heads, num_kv_heads, scale, - max_num_blocks_per_seq, block_size, - kv_block_stride, kv_head_stride, + static_cast(out), + static_cast(q), + static_cast(k_cache), + static_cast(v_cache), + block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, q_stride, q_head_stride, - head_size, - num_seqs); + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; +} +template +infiniStatus_t launch_prefill_warpcta8pipe_splitkv( + float *partial_acc, + float *partial_m, + float *partial_l, + int num_splits, + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kMaxSplits = 8; + if (num_splits < 1) { + num_splits = 1; + } + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const size_t num_m_blocks = ceilDiv(total_q_tokens, static_cast(kWarps)); + // Single kernel launch with split_idx encoded in grid.z: + // blockIdx.z in [0, num_splits * num_m_blocks). + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(num_m_blocks * static_cast(num_splits))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta8PipeSplitKv + <<>>( + partial_acc, partial_m, partial_l, num_splits, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); + break; + case 128: + PagedAttentionPrefillHd128WarpCta8PipeSplitKv + <<>>( + partial_acc, partial_m, partial_l, num_splits, total_q_tokens, + q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride); + break; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Combine: one warp per (token, head). + const dim3 block2(32); + const dim3 grid2(static_cast(num_heads), static_cast(total_q_tokens), 1); + switch (head_size) { + case 64: + PagedAttentionPrefillHd64SplitKvCombine + <<>>( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128SplitKvCombine + <<>>( + out, partial_acc, partial_m, partial_l, num_splits, total_q_tokens, o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} + +template +infiniStatus_t launch_prefill_warpcta8n128( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 8; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + // Only meaningful for head_dim=128. + if (head_size != 128) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + PagedAttentionPrefillHd128WarpCta8N128 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); return INFINI_STATUS_SUCCESS; } -namespace op::paged_attention_prefill::nvidia { +template +infiniStatus_t launch_prefill_warpcta16( + Tdata *out, + const Tdata *q, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const int64_t *total_kv_lens, + const int64_t *cu_seqlens_q, + const float *alibi_slopes, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t total_q_tokens, + size_t head_size, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + ptrdiff_t block_table_batch_stride, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_row_stride, + ptrdiff_t k_head_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_row_stride, + ptrdiff_t v_head_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + cudaStream_t stream) { + + constexpr int kWarps = 16; + constexpr int kThreads = kWarps * 32; + const dim3 block(kThreads); + const dim3 grid(static_cast(num_heads), + static_cast(num_seqs), + static_cast(ceilDiv(total_q_tokens, static_cast(kWarps)))); + + switch (head_size) { + case 64: + PagedAttentionPrefillHd64WarpCta16 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + case 128: + PagedAttentionPrefillHd128WarpCta16 + <<>>( + out, q, k_cache, v_cache, block_tables, total_kv_lens, cu_seqlens_q, alibi_slopes, + num_kv_heads, scale, max_num_blocks_per_seq, page_block_size, + block_table_batch_stride, + q_stride, q_head_stride, + k_batch_stride, k_row_stride, k_head_stride, + v_batch_stride, v_row_stride, v_head_stride, + o_stride, o_head_stride); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } +} +} // namespace struct Descriptor::Opaque { std::shared_ptr internal; @@ -68,22 +1249,87 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t block_tables_desc, - infiniopTensorDescriptor_t seq_lens_desc, - infiniopTensorDescriptor_t cum_seq_lens_q_desc, + infiniopTensorDescriptor_t total_kv_lens_desc, + infiniopTensorDescriptor_t cum_seqlens_q_desc, const std::optional &alibi_slopes_desc, float scale) { auto info = PagedAttentionPrefillInfo::create( out_desc, q_desc, k_cache_desc, v_cache_desc, - block_tables_desc, seq_lens_desc, - cum_seq_lens_q_desc, + block_tables_desc, total_kv_lens_desc, cum_seqlens_q_desc, alibi_slopes_desc, scale); - CHECK_RESULT(info); + // Optional split-kv prefill requires workspace for partial (m, l, acc). + // IMPORTANT: Unlike decode, prefill's total_q_tokens can be very large, so we must NOT reserve + // a huge workspace unless the user explicitly enables split-kv. + bool use_splitkv = false; + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) { + use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + int num_splits = 1; + if (use_splitkv) { + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) { + const int v = std::atoi(env); + if (v > 0) { + num_splits = v; + } + } else { + num_splits = 4; + } + constexpr int kMaxSplits = 8; + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + } + const size_t n = info->total_q_tokens * info->num_heads; + const size_t splitkv_workspace_bytes = use_splitkv ? (static_cast(num_splits) * n * (info->head_size + 2) * sizeof(float)) : 0; + + // FA2-style kernel needs a workspace scratch for: + // - converting block_tables + total_kv_lens to int32 + // - storing softmax LSE (only required to satisfy the upstream kernel contract) + // bool want_fa2 = false; + // if (const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL")) { + // want_fa2 = (std::strcmp(k_env, "fa2") == 0); + // } + // bool fa2_materialize_kv = false; + // if (const char *env = std::getenv("INFINIOP_FA2_MATERIALIZE_PAGED_KV")) { + // fa2_materialize_kv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + // } + // size_t fa2_workspace_bytes = 0; + // // FA2 prefill supports both fp16 and bf16 inputs (head_dim=128, block_size=256). + // // Workspace sizing is identical since both are 16-bit element types. + // if (want_fa2 && (info->dtype == INFINI_DTYPE_F16 || info->dtype == INFINI_DTYPE_BF16) && + // info->head_size == 128 && info->page_block_size == 256) { + // const size_t bt_bytes = info->num_seqs * info->max_num_blocks_per_seq * sizeof(int); + // const size_t len_bytes = info->num_seqs * sizeof(int); + // const size_t cuq_bytes = (info->num_seqs + 1) * sizeof(int); + // const size_t cuk_bytes = (info->num_seqs + 1) * sizeof(int); + // const size_t lse_bytes = info->num_heads * info->total_q_tokens * sizeof(float); + // // Add a small alignment slack since we sub-allocate with alignment. + // fa2_workspace_bytes = bt_bytes + len_bytes + cuq_bytes + cuk_bytes + lse_bytes + 64; + + // // Optional: materialize paged KV into the FA2-friendly physical layout + // // [num_blocks, page_block_size, kv_heads, head_dim] (token-major) to avoid + // // extremely strided loads when the framework stores KV as + // // [num_blocks, kv_heads, page_block_size, head_dim] (head-major). + // if (fa2_materialize_kv) { + // // Materialize per-seq contiguous KV in *sequence order*: + // // [num_seqs, max_num_blocks_per_seq * page_block_size, kv_heads, head_dim]. + // const size_t kv_elems = + // info->num_seqs * info->max_num_blocks_per_seq * info->page_block_size * info->num_kv_heads * info->head_size; + // const size_t kv_bytes = kv_elems * sizeof(uint16_t); // 16-bit (fp16/bf16) + // // K + V + alignment slack + // fa2_workspace_bytes += 2 * kv_bytes + 64; + // } + // } + + const size_t workspace_bytes = splitkv_workspace_bytes; + // const size_t workspace_bytes = splitkv_workspace_bytes + fa2_workspace_bytes; + *desc_ptr = new Descriptor( new Opaque{reinterpret_cast(handle)->internal()}, - info.take(), 0, handle->device, handle->device_id); + info.take(), workspace_bytes, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } @@ -92,35 +1338,379 @@ infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, void *out, const void *q, const void *k_cache, const void *v_cache, const void *block_tables, - const void *seq_lens, - const void *cum_seq_lens_q, + const void *total_kv_lens, + const void *cum_seqlens_q, const void *alibi_slopes, void *stream_) const { + auto stream = static_cast(stream_); + + const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast(alibi_slopes); + const auto *total_kv_lens_i64 = static_cast(total_kv_lens); + const auto *cu_seqlens_q_i64 = static_cast(cum_seqlens_q); + + bool use_splitkv = false; + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) { + use_splitkv = (std::strcmp(env, "1") == 0) || (std::strcmp(env, "true") == 0); + } + int num_splits = 1; + if (use_splitkv) { + if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_NUM_SPLITS")) { + const int v = std::atoi(env); + if (v > 0) { + num_splits = v; + } + } else { + // Conservative default; users can override. + num_splits = 4; + } + constexpr int kMaxSplits = 8; + if (num_splits > kMaxSplits) { + num_splits = kMaxSplits; + } + const size_t n = _info.total_q_tokens * _info.num_heads; + const size_t required = static_cast(num_splits) * n * (_info.head_size + 2) * sizeof(float); + if (workspace_size < required) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + } + + if (use_splitkv) { + const size_t n = _info.total_q_tokens * _info.num_heads; + float *partial_acc = static_cast(workspace); + float *partial_m = partial_acc + static_cast(num_splits) * n * _info.head_size; + float *partial_l = partial_m + static_cast(num_splits) * n; + + // Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64. +#define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \ + return launch_prefill_warpcta8pipe_splitkv( \ + partial_acc, partial_m, partial_l, num_splits, \ + static_cast(out), \ + static_cast(q), \ + static_cast(k_cache), \ + static_cast(v_cache), \ + static_cast(BT_PTR), \ + total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream) + + if (_info.dtype == INFINI_DTYPE_F16) { + if (_info.index_dtype == INFINI_DTYPE_I64) { + DISPATCH_SPLITKV(int64_t, half, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + DISPATCH_SPLITKV(int32_t, half, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_U32) { + DISPATCH_SPLITKV(uint32_t, half, block_tables); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (_info.dtype == INFINI_DTYPE_BF16) { + if (_info.index_dtype == INFINI_DTYPE_I64) { + DISPATCH_SPLITKV(int64_t, __nv_bfloat16, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + DISPATCH_SPLITKV(int32_t, __nv_bfloat16, block_tables); + } + if (_info.index_dtype == INFINI_DTYPE_U32) { + DISPATCH_SPLITKV(uint32_t, __nv_bfloat16, block_tables); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; + +#undef DISPATCH_SPLITKV + } + +// Default to the fastest validated kernel for supported shapes. +// "ref" is still available for debugging/correctness bisecting. +#define DISPATCH_KERNEL(Tindex, Tdata, Tcompute) \ + do { \ + const char *k_env = std::getenv("INFINIOP_FLASH_PREFILL_KERNEL"); \ + const char *k = (k_env == nullptr) ? default_prefill_kernel(_info) : k_env; \ + if (k_env != nullptr) { \ + const bool known = (std::strcmp(k, "warp") == 0) || (std::strcmp(k, "warpcta") == 0) || (std::strcmp(k, "warpcta8") == 0) || (std::strcmp(k, "warpcta8pipe") == 0) || (std::strcmp(k, "warpcta8mma") == 0) || (std::strcmp(k, "warpcta8n128") == 0) || (std::strcmp(k, "warpcta16") == 0) || (std::strcmp(k, "ref") == 0); \ + if (!known) { \ + const char *fallback = default_prefill_kernel(_info); \ + std::fprintf(stderr, \ + "[infiniop][paged_attention_prefill] WARNING: unknown kernel '%s', falling back to '%s'\n", \ + k, fallback); \ + k = fallback; \ + } \ + } \ + const char *dbg = std::getenv("INFINIOP_DEBUG_PREFILL_DISPATCH"); \ + static bool printed_dispatch = false; \ + if (!printed_dispatch && dbg != nullptr && std::strcmp(dbg, "1") == 0) { \ + std::fprintf(stderr, \ + "[infiniop][paged_attention_prefill] kernel=%s (override=%s head_size=%zu block=%zu dtype=%zu)\n", \ + k, \ + (k_env == nullptr ? "auto" : "env"), \ + static_cast(_info.head_size), \ + static_cast(_info.page_block_size), \ + static_cast(_info.dtype)); \ + printed_dispatch = true; \ + } \ + if (std::strcmp(k, "warp") == 0) { \ + return launch_prefill_warp( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta") == 0) { \ + return launch_prefill( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta8") == 0) { \ + return launch_prefill_warpcta8( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta8pipe") == 0) { \ + return launch_prefill_warpcta8pipe( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if constexpr (std::is_same_v) { \ + if (std::strcmp(k, "warpcta8mma") == 0) { \ + return launch_prefill_warpcta8mma( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + } \ + if (std::strcmp(k, "warpcta8n128") == 0) { \ + return launch_prefill_warpcta8n128( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "warpcta16") == 0) { \ + return launch_prefill_warpcta16( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + if (std::strcmp(k, "ref") == 0) { \ + return launch_prefill_ref( \ + static_cast(out), static_cast(q), \ + static_cast(k_cache), static_cast(v_cache), \ + static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ + _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ + _info.block_table_batch_stride, \ + _info.q_stride, _info.q_head_stride, \ + _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ + _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ + _info.o_stride, _info.o_head_stride, stream); \ + } \ + return INFINI_STATUS_BAD_PARAM; \ + } while (false) - cudaStream_t stream = (cudaStream_t)stream_; - -#define LAUNCH_KERNEL(Tdata, Tcompute) \ - launchPagedAttentionPrefill( \ - (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ - (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \ - (const float *)alibi_slopes, \ - _info.num_heads, _info.num_seqs, _info.num_kv_heads, \ - _info.scale, _info.max_num_blocks_per_seq, \ - _info.block_size, _info.total_q_tokens, \ - _info.head_size, \ - _info.kv_block_stride, _info.kv_head_stride, \ - _info.q_stride, _info.q_head_stride, \ - stream) - - if (_info.dtype == INFINI_DTYPE_F16) { - return LAUNCH_KERNEL(half, float); - } else if (_info.dtype == INFINI_DTYPE_BF16) { - return LAUNCH_KERNEL(__nv_bfloat16, float); - } else if (_info.dtype == INFINI_DTYPE_F32) { - return LAUNCH_KERNEL(float, float); +#define DISPATCH_INDEX(Tindex) \ + do { \ + if (_info.dtype == INFINI_DTYPE_F16) { \ + DISPATCH_KERNEL(Tindex, half, float); \ + } \ + if (_info.dtype == INFINI_DTYPE_BF16) { \ + DISPATCH_KERNEL(Tindex, __nv_bfloat16, float); \ + } \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } while (false) + + if (_info.index_dtype == INFINI_DTYPE_I64) { + DISPATCH_INDEX(int64_t); + } else if (_info.index_dtype == INFINI_DTYPE_I32) { + DISPATCH_INDEX(int32_t); + } else if (_info.index_dtype == INFINI_DTYPE_U32) { + DISPATCH_INDEX(uint32_t); } return INFINI_STATUS_BAD_TENSOR_DTYPE; } } // namespace op::paged_attention_prefill::nvidia + +// #include +// #include +// #include +// #include + +// #include "../../../devices/nvidia/nvidia_common.cuh" +// #include "../../../devices/nvidia/nvidia_kernel_common.cuh" +// #include "../cuda/kernel.cuh" +// #include "paged_attention_prefill_nvidia.cuh" + +// template +// infiniStatus_t launchPagedAttentionPrefill( +// Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, +// const int64_t *block_tables, +// const int64_t *seq_lens, +// const int64_t *cum_seq_lens_q, +// const float *alibi_slopes, +// const size_t num_heads, +// const size_t num_seqs, +// const size_t num_kv_heads, +// const float scale, +// const size_t max_num_blocks_per_seq, +// const size_t block_size, +// const size_t total_q_tokens, +// const size_t head_size, +// const ptrdiff_t kv_block_stride, +// const ptrdiff_t kv_head_stride, +// const ptrdiff_t q_stride, +// const ptrdiff_t q_head_stride, +// cudaStream_t stream) { + +// if (total_q_tokens == 0 || num_heads == 0) { +// return INFINI_STATUS_BAD_TENSOR_SHAPE; +// } + +// dim3 grid(total_q_tokens, num_heads); +// dim3 block(head_size); + +// op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel +// <<>>( +// out, q, k_cache, v_cache, +// block_tables, seq_lens, cum_seq_lens_q, alibi_slopes, +// num_heads, num_kv_heads, scale, +// max_num_blocks_per_seq, block_size, +// kv_block_stride, kv_head_stride, +// q_stride, q_head_stride, +// head_size, +// num_seqs); + +// return INFINI_STATUS_SUCCESS; +// } + +// namespace op::paged_attention_prefill::nvidia { + +// struct Descriptor::Opaque { +// std::shared_ptr internal; +// }; + +// Descriptor::~Descriptor() { +// delete _opaque; +// } + +// infiniStatus_t Descriptor::create( +// infiniopHandle_t handle, +// Descriptor **desc_ptr, +// infiniopTensorDescriptor_t out_desc, +// infiniopTensorDescriptor_t q_desc, +// infiniopTensorDescriptor_t k_cache_desc, +// infiniopTensorDescriptor_t v_cache_desc, +// infiniopTensorDescriptor_t block_tables_desc, +// infiniopTensorDescriptor_t seq_lens_desc, +// infiniopTensorDescriptor_t cum_seq_lens_q_desc, +// const std::optional &alibi_slopes_desc, +// float scale) { + +// auto info = PagedAttentionPrefillInfo::create( +// out_desc, q_desc, k_cache_desc, v_cache_desc, +// block_tables_desc, seq_lens_desc, +// cum_seq_lens_q_desc, +// alibi_slopes_desc, scale); + +// CHECK_RESULT(info); + +// *desc_ptr = new Descriptor( +// new Opaque{reinterpret_cast(handle)->internal()}, +// info.take(), 0, handle->device, handle->device_id); + +// return INFINI_STATUS_SUCCESS; +// } + +// infiniStatus_t Descriptor::calculate( +// void *workspace, size_t workspace_size, +// void *out, const void *q, const void *k_cache, const void *v_cache, +// const void *block_tables, +// const void *seq_lens, +// const void *cum_seq_lens_q, +// const void *alibi_slopes, +// void *stream_) const { + +// cudaStream_t stream = (cudaStream_t)stream_; + +// #define LAUNCH_KERNEL(Tdata, Tcompute) \ +// launchPagedAttentionPrefill( \ +// (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ +// (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \ +// (const float *)alibi_slopes, \ +// _info.num_heads, _info.num_seqs, _info.num_kv_heads, \ +// _info.scale, _info.max_num_blocks_per_seq, \ +// _info.block_size, _info.total_q_tokens, \ +// _info.head_size, \ +// _info.kv_block_stride, _info.kv_head_stride, \ +// _info.q_stride, _info.q_head_stride, \ +// stream) + +// if (_info.dtype == INFINI_DTYPE_F16) { +// return LAUNCH_KERNEL(half, float); +// } else if (_info.dtype == INFINI_DTYPE_BF16) { +// return LAUNCH_KERNEL(__nv_bfloat16, float); +// } else if (_info.dtype == INFINI_DTYPE_F32) { +// return LAUNCH_KERNEL(float, float); +// } + +// return INFINI_STATUS_BAD_TENSOR_DTYPE; +// } + +// } // namespace op::paged_attention_prefill::nvidia diff --git a/test/infiniop/paged_attention.py b/test/infiniop/paged_attention.py index 882e9cfee..c1f10f9b7 100644 --- a/test/infiniop/paged_attention.py +++ b/test/infiniop/paged_attention.py @@ -100,13 +100,12 @@ def ref_single_query_cached_kv_attention( ] # Data types for testing -_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32] +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] # Tolerance map for different data types _TOLERANCE_MAP = { InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, - InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, } # Global flags for controlling test behavior diff --git a/test/infiniop/paged_attention_prefill.py b/test/infiniop/paged_attention_prefill.py index 4bbe762a8..65d843fae 100644 --- a/test/infiniop/paged_attention_prefill.py +++ b/test/infiniop/paged_attention_prefill.py @@ -32,10 +32,9 @@ (16, 128, 128, 128, 8, 16, 4), ] -_TENSOR_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] _TOLERANCE_MAP = { - InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 2e-2, "rtol": 2e-2}, }