From c18b742641bd584be7a79e537934fb4b2f18f130 Mon Sep 17 00:00:00 2001 From: MoringLotus Date: Sun, 21 Dec 2025 15:17:57 +0000 Subject: [PATCH 1/4] softmax --- .../ops/softmax/nvidia/softmax_nvidia.cu | 26 +++++++++++++++++++ src/infiniop/ops/softmax/operator.cc | 3 ++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/infiniop/ops/softmax/nvidia/softmax_nvidia.cu b/src/infiniop/ops/softmax/nvidia/softmax_nvidia.cu index d87fe8167..892096246 100644 --- a/src/infiniop/ops/softmax/nvidia/softmax_nvidia.cu +++ b/src/infiniop/ops/softmax/nvidia/softmax_nvidia.cu @@ -107,6 +107,32 @@ infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype, <<>>((float *)y, (const float *)x, othersize, dimsize, stride); } + }else if (dtype == INFINI_DTYPE_BF16){ + if (dimsize > 1024) { + blockSoftmax<__nv_bfloat16, BLOCK_SIZE> + <<>>((__nv_bfloat16 *)y, (const __nv_bfloat16 *)x, + dimsize, stride); + } else if (dimsize > 31) { + constexpr unsigned int BLOCK_SIZE_x = 32; + constexpr unsigned int BLOCK_SIZE_y = 32; + constexpr int numPerThreadx = 32; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpSoftmax<__nv_bfloat16, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx> + <<>>((__nv_bfloat16 *)y, (const __nv_bfloat16 *)x, + othersize, dimsize, stride); + } else { + constexpr unsigned int BLOCK_SIZE_x = 16; + constexpr unsigned int BLOCK_SIZE_y = 32; + constexpr int numPerThreadx = 2; + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpSoftmax<__nv_bfloat16, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx> + <<>>((__nv_bfloat16 *)y, (const __nv_bfloat16 *)x, + othersize, dimsize, stride); + } } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/softmax/operator.cc b/src/infiniop/ops/softmax/operator.cc index 0a922888d..b22d874a5 100644 --- a/src/infiniop/ops/softmax/operator.cc +++ b/src/infiniop/ops/softmax/operator.cc @@ -20,7 +20,7 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor( reinterpret_cast(desc_ptr), \ y_desc, \ x_desc, axis); - + std::cout << "handle device " << handle->device << std::endl; switch (handle->device) { #ifdef ENABLE_NVIDIA_API CREATE(INFINI_DEVICE_NVIDIA, nvidia) @@ -35,6 +35,7 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor( CREATE(INFINI_DEVICE_HYGON, nvidia); #endif } + std::cout << "Error In CREATE NVIDIA API" << std::endl; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } From b587d4d91c54425eadd5c184dc49818773265704 Mon Sep 17 00:00:00 2001 From: MoringLotus Date: Sun, 21 Dec 2025 16:01:13 +0000 Subject: [PATCH 2/4] finish topk router --- .../ops/topkrouter/cpu/topkrouter_cpu.cc | 22 +++-- src/infiniop/ops/topkrouter/cuda/kernel.cuh | 32 +++++--- .../topkrouter/nvidia/topkrouter_nvidia.cu | 8 +- test/infiniop/topkrouter.py | 80 +++++++++++++++++-- 4 files changed, 118 insertions(+), 24 deletions(-) diff --git a/src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc b/src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc index 2d2b36d6b..27890a6e2 100644 --- a/src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc +++ b/src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc @@ -187,13 +187,23 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, flo size_t N = _info.N; size_t width = _info.width; - // 下面是 deepseek的config.json的超参数 - const size_t n_routed_experts = 256; - const size_t n_group = 8; - const size_t topk_group = 4; - const bool norm_topk_prob = true; + // 支持不同的专家数量 + size_t n_routed_experts, n_group, topk_group; + bool norm_topk_prob = true; + + if (width == 256) { + n_routed_experts = 256; + n_group = 8; + topk_group = 4; + } else if (width == 64) { + n_routed_experts = 64; + n_group = 8; + topk_group = 4; + } else { + return INFINI_STATUS_BAD_PARAM; + } - if ((width != n_routed_experts) || (width % n_group != 0) || (256 != width)) { + if (width % n_group != 0) { return INFINI_STATUS_BAD_PARAM; } diff --git a/src/infiniop/ops/topkrouter/cuda/kernel.cuh b/src/infiniop/ops/topkrouter/cuda/kernel.cuh index 0832c5b93..23f288105 100644 --- a/src/infiniop/ops/topkrouter/cuda/kernel.cuh +++ b/src/infiniop/ops/topkrouter/cuda/kernel.cuh @@ -73,12 +73,16 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出数 // ------------------------------------------------------ // // 对输入数据做 sigmoid // // ------------------------------------------------------ // - float value = sigmoid_func(data_input[tid]); - - // ------------------------------------------------------ // - // 对输入数据加偏执 // - // ------------------------------------------------------ // - value += d_correction_bias[tid]; + float value; + if (tid < width) { + value = sigmoid_func(data_input[tid]); + // ------------------------------------------------------ // + // 对输入数据加偏执 // + // ------------------------------------------------------ // + value += d_correction_bias[tid]; + } else { + value = -FLT_MAX; // 对于越界的线程,设为最小值,这样不会被选到topk + } // ----------------------------------------------------------- // // 每个warp为一组,一共8组,找出每组的最大的前两个数据 // @@ -91,14 +95,20 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出数 WarpMergeSortT(temp_storage[warp_id]).Sort(thread_values, thread_indices, CustomLess()); } __syncthreads(); - share_data[tid] = thread_values[0]; + if (tid < width) { + share_data[tid] = thread_values[0]; + } // ----------------------------------------------------------- // // 每个组中,前两个数据的和 // // ----------------------------------------------------------- // __syncthreads(); if (0 == lane_id) { - share_data_group[warp_id] = share_data[warp_id * warp_threads] + share_data[warp_id * warp_threads + 1]; + int base_idx = warp_id * warp_threads; + // 确保不越界 + float val1 = (base_idx < width) ? share_data[base_idx] : -FLT_MAX; + float val2 = (base_idx + 1 < width) ? share_data[base_idx + 1] : -FLT_MAX; + share_data_group[warp_id] = val1 + val2; } __syncthreads(); // ----------------------------------------------------------- // @@ -143,7 +153,11 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出数 value = 0.0f; if (tid < 8) { int index = thread_indices[0]; - value = sigmoid_func(data_input[index]); + if (index < width) { + value = sigmoid_func(data_input[index]); + } else { + value = 0.0f; + } } { typedef cub::WarpReduce WarpReduce; diff --git a/src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu b/src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu index 495c3914b..3f0c2c8cd 100644 --- a/src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu +++ b/src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu @@ -43,10 +43,10 @@ template infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias, const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype, cudaStream_t stream) { - const int block_threads = BLOCK_SIZE; + const int block_threads = width; // 使用width作为线程数,避免越界 dim3 blocks(N); dim3 threads(block_threads); - + std::cout << "Launch Nvidia topk router" << std::endl; if (xtype == INFINI_DTYPE_F32) { topkrouter_kernel<<>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); } else if (xtype == INFINI_DTYPE_F16) { @@ -86,7 +86,9 @@ infiniStatus_t Descriptor::calculate( if (256 == width) { launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream); - } else { + } else if (64 == width){ + launch_topkrouter<64>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream); + }else { return INFINI_STATUS_BAD_PARAM; } diff --git a/test/infiniop/topkrouter.py b/test/infiniop/topkrouter.py index ce8eda767..115a5cf44 100644 --- a/test/infiniop/topkrouter.py +++ b/test/infiniop/topkrouter.py @@ -29,12 +29,12 @@ _TEST_CASES_ = [ # x_shape, x_stride, topk, routed_scaling_factor ((1, 256), None, 8, 2.5), + ((24, 64), None, 8, 1.0), # 添加64专家的测试用例,匹配router_logits_buf: [24, 64] ] # w (weight) types # Note: 'None' means the same as input dtype -# _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] -_X_DTYPES = [] # CPU CI +_X_DTYPES = [InfiniDtype.F32, InfiniDtype.F16] # 启用测试 # x types used for testing _VALUE_DTYPES = [InfiniDtype.F32] @@ -121,8 +121,76 @@ def forward(self, router_logits): return topk_indices, topk_weights -def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk): - lable_indices, lable_values = DeepseekV3TopkRouter(correction_bias, routed_scaling_factor, topk)(router_logits) +class GeneralTopkRouter(nn.Module): + def __init__(self, correction_bias, routed_scaling_factor: float = 1.0, topk: int = 8, n_routed_experts: int = 256): + super().__init__() + self.top_k = topk + self.n_routed_experts = n_routed_experts + self.routed_scaling_factor = routed_scaling_factor + + # 根据专家数量计算分组参数 + if n_routed_experts == 256: + self.n_group = 8 # 256/8 = 32 per group + self.topk_group = 4 + elif n_routed_experts == 64: + self.n_group = 8 # 64/8 = 8 per group + self.topk_group = 4 + else: + # 默认配置 + self.n_group = 8 + self.topk_group = 4 + + self.norm_topk_prob = True + + self.e_score_correction_bias = torch.zeros(n_routed_experts, device=correction_bias.device) + self.e_score_correction_bias[:] = correction_bias[:] + + @torch.no_grad() + def get_topk_indices(self, scores): + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[1] + + return topk_indices + + def forward(self, router_logits): + scores = router_logits.sigmoid() + scores = scores.to(torch.float32) + + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + + return topk_indices, topk_weights + + +def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk, n_routed_experts=256): + if n_routed_experts == 256: + router = DeepseekV3TopkRouter(correction_bias, routed_scaling_factor, topk) + else: + router = GeneralTopkRouter(correction_bias, routed_scaling_factor, topk, n_routed_experts) + + lable_indices, lable_values = router(router_logits) lable_indices = lable_indices.to(torch.int32) return lable_values, lable_indices @@ -196,7 +264,7 @@ def lib_topkrouter(): lib_topkrouter() - lable_values, lable_indices = torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk) + lable_values, lable_indices = torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk, width) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: debug(lable_values, values, atol=atol, rtol=rtol) @@ -208,7 +276,7 @@ def lib_topkrouter(): # Profiling workflow if PROFILE: # fmt: off - profile_operation("PyTorch", lambda: torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation("PyTorch", lambda: torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk, width), device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_topkrouter(), device, NUM_PRERUN, NUM_ITERATIONS) # fmt: on check_error(LIBINFINIOP.infiniopDestroyTopkrouterDescriptor(descriptor)) From d91e539ae8c1c07b1a72188d12e5f53698b090b5 Mon Sep 17 00:00:00 2001 From: MoringLotus Date: Wed, 24 Dec 2025 08:18:02 +0000 Subject: [PATCH 3/4] Softmax & TopkRouter --- src/infiniop/ops/softmax/operator.cc | 22 ++-- .../ops/topkrouter/cpu/topkrouter_cpu.cc | 22 ++-- src/infiniop/ops/topkrouter/cuda/kernel.cuh | 43 ++------ .../ops/topkrouter/metax/topkrouter_metax.h | 8 ++ .../topkrouter/metax/topkrouter_metax.maca | 101 ++++++++++++++++++ .../topkrouter/nvidia/topkrouter_nvidia.cu | 18 ++-- src/infiniop/ops/topkrouter/operator.cc | 15 +++ 7 files changed, 163 insertions(+), 66 deletions(-) create mode 100644 src/infiniop/ops/topkrouter/metax/topkrouter_metax.h create mode 100644 src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca diff --git a/src/infiniop/ops/softmax/operator.cc b/src/infiniop/ops/softmax/operator.cc index b22d874a5..68d60e94c 100644 --- a/src/infiniop/ops/softmax/operator.cc +++ b/src/infiniop/ops/softmax/operator.cc @@ -22,17 +22,17 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor( x_desc, axis); std::cout << "handle device " << handle->device << std::endl; switch (handle->device) { -#ifdef ENABLE_NVIDIA_API - CREATE(INFINI_DEVICE_NVIDIA, nvidia) -#endif -#ifdef ENABLE_ILUVATAR_API - CREATE(INFINI_DEVICE_ILUVATAR, nvidia); -#endif -#ifdef ENABLE_QY_API - CREATE(INFINI_DEVICE_QY, nvidia); -#endif -#ifdef ENABLE_HYGON_API - CREATE(INFINI_DEVICE_HYGON, nvidia); + #ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) + #endif + #ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); + #endif + #ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia); + #endif + #ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); #endif } std::cout << "Error In CREATE NVIDIA API" << std::endl; diff --git a/src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc b/src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc index 27890a6e2..2d2b36d6b 100644 --- a/src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc +++ b/src/infiniop/ops/topkrouter/cpu/topkrouter_cpu.cc @@ -187,23 +187,13 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, flo size_t N = _info.N; size_t width = _info.width; - // 支持不同的专家数量 - size_t n_routed_experts, n_group, topk_group; - bool norm_topk_prob = true; - - if (width == 256) { - n_routed_experts = 256; - n_group = 8; - topk_group = 4; - } else if (width == 64) { - n_routed_experts = 64; - n_group = 8; - topk_group = 4; - } else { - return INFINI_STATUS_BAD_PARAM; - } + // 下面是 deepseek的config.json的超参数 + const size_t n_routed_experts = 256; + const size_t n_group = 8; + const size_t topk_group = 4; + const bool norm_topk_prob = true; - if (width % n_group != 0) { + if ((width != n_routed_experts) || (width % n_group != 0) || (256 != width)) { return INFINI_STATUS_BAD_PARAM; } diff --git a/src/infiniop/ops/topkrouter/cuda/kernel.cuh b/src/infiniop/ops/topkrouter/cuda/kernel.cuh index 23f288105..266fc61eb 100644 --- a/src/infiniop/ops/topkrouter/cuda/kernel.cuh +++ b/src/infiniop/ops/topkrouter/cuda/kernel.cuh @@ -1,21 +1,12 @@ #ifndef _TOPKROUTER_KERNEL_CUH__ #define _TOPKROUTER_KERNEL_CUH__ -#include -#include -#include -#include -#include -#include -#include -#include -#include template inline __device__ float exp_func(T x) { float data; if constexpr (std::is_same_v) { data = x; - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { data = __bfloat162float(x); } else if constexpr (std::is_same_v) { data = __half2float(x); @@ -73,16 +64,12 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出数 // ------------------------------------------------------ // // 对输入数据做 sigmoid // // ------------------------------------------------------ // - float value; - if (tid < width) { - value = sigmoid_func(data_input[tid]); - // ------------------------------------------------------ // - // 对输入数据加偏执 // - // ------------------------------------------------------ // - value += d_correction_bias[tid]; - } else { - value = -FLT_MAX; // 对于越界的线程,设为最小值,这样不会被选到topk - } + float value = sigmoid_func(data_input[tid]); + + // ------------------------------------------------------ // + // 对输入数据加偏执 // + // ------------------------------------------------------ // + value += d_correction_bias[tid]; // ----------------------------------------------------------- // // 每个warp为一组,一共8组,找出每组的最大的前两个数据 // @@ -95,20 +82,14 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出数 WarpMergeSortT(temp_storage[warp_id]).Sort(thread_values, thread_indices, CustomLess()); } __syncthreads(); - if (tid < width) { - share_data[tid] = thread_values[0]; - } + share_data[tid] = thread_values[0]; // ----------------------------------------------------------- // // 每个组中,前两个数据的和 // // ----------------------------------------------------------- // __syncthreads(); if (0 == lane_id) { - int base_idx = warp_id * warp_threads; - // 确保不越界 - float val1 = (base_idx < width) ? share_data[base_idx] : -FLT_MAX; - float val2 = (base_idx + 1 < width) ? share_data[base_idx + 1] : -FLT_MAX; - share_data_group[warp_id] = val1 + val2; + share_data_group[warp_id] = share_data[warp_id * warp_threads] + share_data[warp_id * warp_threads + 1]; } __syncthreads(); // ----------------------------------------------------------- // @@ -153,11 +134,7 @@ __global__ void topkrouter_kernel(float *values_topk, // 输出数 value = 0.0f; if (tid < 8) { int index = thread_indices[0]; - if (index < width) { - value = sigmoid_func(data_input[index]); - } else { - value = 0.0f; - } + value = sigmoid_func(data_input[index]); } { typedef cub::WarpReduce WarpReduce; diff --git a/src/infiniop/ops/topkrouter/metax/topkrouter_metax.h b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.h new file mode 100644 index 000000000..62f17dc6c --- /dev/null +++ b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.h @@ -0,0 +1,8 @@ +#ifndef __TOPKROUTER_METAX_H__ +#define __TOPKROUTER_METAX_H__ + +#include "../topkrouter.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca new file mode 100644 index 000000000..45823ecea --- /dev/null +++ b/src/infiniop/ops/topkrouter/metax/topkrouter_metax.maca @@ -0,0 +1,101 @@ +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" + +#include "topkrouter_metax.h" + +#include +#include +#include +#include +#include +#include + +#include "../cuda/kernel.cuh" + +namespace op::topkrouter::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t correction_bias_desc) { + auto result = TopkrouterInfo::create(x_desc); + CHECK_RESULT(result); + auto info = result.take(); + + if (info.x_strides[1] != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +namespace { + +template +infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias, + const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype, + hcStream_t stream) { + const int block_threads = BLOCK_SIZE; + dim3 blocks(N); + dim3 threads(block_threads); + + if (xtype == INFINI_DTYPE_F32) { + topkrouter_kernel<<>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); + } else if (xtype == INFINI_DTYPE_F16) { + topkrouter_kernel<<>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); + } else if (xtype == INFINI_DTYPE_BF16) { + topkrouter_kernel<<>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +}; // namespace + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + float *values, + int *indices, + const void *x, + const float *correction_bias, + const float routed_scaling_factor, + const size_t topk, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + size_t N = _info.N; + size_t width = _info.width; // 256 + + // size_t n_routed_experts = 256; + // size_t n_group = 8; + // size_t topk_group = 4; + auto cuda_stream = reinterpret_cast(stream); + + if (256 == width) { + launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream); + } else { + return INFINI_STATUS_BAD_PARAM; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::topkrouter::metax diff --git a/src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu b/src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu index 3f0c2c8cd..02ae569f3 100644 --- a/src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu +++ b/src/infiniop/ops/topkrouter/nvidia/topkrouter_nvidia.cu @@ -2,9 +2,17 @@ #include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh" -#include "../cuda/kernel.cuh" + #include "topkrouter_nvidia.cuh" + +#include +#include +#include #include +#include +#include + +#include "../cuda/kernel.cuh" namespace op::topkrouter::nvidia { @@ -43,10 +51,10 @@ template infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias, const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype, cudaStream_t stream) { - const int block_threads = width; // 使用width作为线程数,避免越界 + const int block_threads = BLOCK_SIZE; dim3 blocks(N); dim3 threads(block_threads); - std::cout << "Launch Nvidia topk router" << std::endl; + if (xtype == INFINI_DTYPE_F32) { topkrouter_kernel<<>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk); } else if (xtype == INFINI_DTYPE_F16) { @@ -86,9 +94,7 @@ infiniStatus_t Descriptor::calculate( if (256 == width) { launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream); - } else if (64 == width){ - launch_topkrouter<64>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream); - }else { + } else { return INFINI_STATUS_BAD_PARAM; } diff --git a/src/infiniop/ops/topkrouter/operator.cc b/src/infiniop/ops/topkrouter/operator.cc index 89555e9f9..73b6e9bcf 100644 --- a/src/infiniop/ops/topkrouter/operator.cc +++ b/src/infiniop/ops/topkrouter/operator.cc @@ -8,6 +8,9 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) #include "nvidia/topkrouter_nvidia.cuh" #endif +#ifdef ENABLE_METAX_API +#include "metax/topkrouter_metax.h" +#endif #ifdef ENABLE_KUNLUN_API #include "kunlun/topkrouter_kunlun.h" #endif @@ -30,6 +33,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API CREATE(INFINI_DEVICE_KUNLUN, kunlun); #endif @@ -56,6 +62,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API GET(INFINI_DEVICE_KUNLUN, kunlun); #endif @@ -85,6 +94,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif @@ -111,6 +123,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip #ifdef ENABLE_QY_API DESTROY(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_KUNLUN_API DESTROY(INFINI_DEVICE_KUNLUN, kunlun); #endif From ddc52ab2649b0f195730d54bc7957127037c20ee Mon Sep 17 00:00:00 2001 From: MoringLotus Date: Thu, 25 Dec 2025 06:01:26 +0000 Subject: [PATCH 4/4] finish bi attention impl & test --- include/infinicore/ops/bi_attention.hpp | 16 + include/infiniop.h | 1 + include/infiniop/ops/bi_attention.h | 34 ++ python/infinicore/ops/bi_attention.py | 28 ++ .../ops/bi_attention/bi_attention.cc | 31 ++ .../ops/bi_attention/bi_attention_infiniop.cc | 52 ++++ src/infinicore/ops/src/infiniop/README.md | 0 src/infiniop/ops/attention/attention.h | 6 +- src/infiniop/ops/bi_attention/bi_attention.h | 37 +++ src/infiniop/ops/bi_attention/operator.cc | 291 ++++++++++++++++++ test/infiniop/bi_attention.py | 271 ++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 39 +++ 12 files changed, 803 insertions(+), 3 deletions(-) create mode 100644 include/infinicore/ops/bi_attention.hpp create mode 100644 include/infiniop/ops/bi_attention.h create mode 100644 python/infinicore/ops/bi_attention.py create mode 100644 src/infinicore/ops/bi_attention/bi_attention.cc create mode 100644 src/infinicore/ops/bi_attention/bi_attention_infiniop.cc create mode 100644 src/infinicore/ops/src/infiniop/README.md create mode 100644 src/infiniop/ops/bi_attention/bi_attention.h create mode 100644 src/infiniop/ops/bi_attention/operator.cc create mode 100644 test/infiniop/bi_attention.py diff --git a/include/infinicore/ops/bi_attention.hpp b/include/infinicore/ops/bi_attention.hpp new file mode 100644 index 000000000..a02d13051 --- /dev/null +++ b/include/infinicore/ops/bi_attention.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class BiAttention { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, size_t); + static void execute(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos); + static common::OpDispatcher &dispatcher(); +}; + +Tensor bi_attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos); +void bi_attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos); +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index 92e6f5963..5532cd19e 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -30,6 +30,7 @@ #include "infiniop/ops/topkrouter.h" #include "infiniop/ops/topksoftmax.h" #include "infiniop/ops/zeros.h" +#include "infiniop/ops/bi_attention.h" #include "infiniop/tensor_descriptor.h" #endif // __INFINIOP_API_H__ diff --git a/include/infiniop/ops/bi_attention.h b/include/infiniop/ops/bi_attention.h new file mode 100644 index 000000000..05c21e4d2 --- /dev/null +++ b/include/infiniop/ops/bi_attention.h @@ -0,0 +1,34 @@ +#ifndef __INFINIOP_BI_ATTENTION_API_H__ +#define __INFINIOP_BI_ATTENTION_API_H__ + +#include "../operator_descriptor.h" +#include "gemm.h" +#include "swiglu.h" + +typedef struct InfiniopDescriptor *infiniopBiAttentionDescriptor_t; + +__C __export infiniStatus_t infiniopCreateBiAttentionDescriptor(infiniopHandle_t handle, + infiniopBiAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + size_t pos); + +__C __export infiniStatus_t infiniopGetBiAttentionWorkspaceSize(infiniopBiAttentionDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopBiAttention(infiniopBiAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + void *k_cache, + void *v_cache, + void *stream); + +__C __export infiniStatus_t infiniopDestroyBiAttentionDescriptor(infiniopBiAttentionDescriptor_t desc); +#endif diff --git a/python/infinicore/ops/bi_attention.py b/python/infinicore/ops/bi_attention.py new file mode 100644 index 000000000..ca082d086 --- /dev/null +++ b/python/infinicore/ops/bi_attention.py @@ -0,0 +1,28 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def bi_attention(q, k, v, k_cache, v_cache, pos, *, out=None): + if out is None: + return Tensor( + _infinicore.attention( + q._underlying, + k._underlying, + v._underlying, + k_cache._underlying, + v_cache._underlying, + pos, + ) + ) + + _infinicore.bi_attention_( + out._underlying, + q._underlying, + k._underlying, + v._underlying, + k_cache._underlying, + v_cache._underlying, + pos, + ) + + return out diff --git a/src/infinicore/ops/bi_attention/bi_attention.cc b/src/infinicore/ops/bi_attention/bi_attention.cc new file mode 100644 index 000000000..53851905e --- /dev/null +++ b/src/infinicore/ops/bi_attention/bi_attention.cc @@ -0,0 +1,31 @@ +#include "infinicore/ops/bi_attention.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &BiAttention::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void BiAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, k_cache, v_cache); + infinicore::context::setDevice(out->device()); + dispatcher().lookup(out->device().getType())(out, q, k, v, k_cache, v_cache, pos); +} + +Tensor bi_attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) { + size_t n_q_head = q->shape()[0]; + size_t seq_len = q->shape()[1]; + size_t head_dim = q->shape()[2]; + Shape shape = {seq_len, n_q_head, head_dim}; + auto out = Tensor::empty(shape, q->dtype(), q->device()); + bi_attention_(out, q, k, v, k_cache, v_cache, pos); + return out; +} + +void bi_attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) { + BiAttention::execute(out, q, k, v, k_cache, v_cache, pos); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/bi_attention/bi_attention_infiniop.cc b/src/infinicore/ops/bi_attention/bi_attention_infiniop.cc new file mode 100644 index 000000000..f8ac8a8d5 --- /dev/null +++ b/src/infinicore/ops/bi_attention/bi_attention_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/bi_attention.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::bi_attention_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopBiAttentionDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyBiAttentionDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) { + size_t seed = hash_combine(out, q, k, v, k_cache, v_cache, pos); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopBiAttentionDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateBiAttentionDescriptor( + context::getInfiniopHandle(device), &desc, + out->desc(), q->desc(), k->desc(), v->desc(), + k_cache->desc(), v_cache->desc(), pos)); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetBiAttentionWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopBiAttention( + desc, workspace->data(), workspace_size, + out->data(), q->data(), k->data(), v->data(), + k_cache->data(), v_cache->data(), context::getStream())); +} + +static bool registered = []() { + BiAttention::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::attention_impl::infiniop diff --git a/src/infinicore/ops/src/infiniop/README.md b/src/infinicore/ops/src/infiniop/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/src/infiniop/ops/attention/attention.h b/src/infiniop/ops/attention/attention.h index d4740d25e..28f258897 100644 --- a/src/infiniop/ops/attention/attention.h +++ b/src/infiniop/ops/attention/attention.h @@ -1,12 +1,12 @@ -#ifndef ATTENTION_H -#define ATTENTION_H +#ifndef BI_ATTENTION_H +#define BI_ATTENTION_H #include "../../operator.h" #include "info.h" #define DESCRIPTOR(NAMESPACE) \ \ - namespace op::attention::NAMESPACE { \ + namespace op::bi_attention::NAMESPACE { \ class Descriptor final : public InfiniopDescriptor { \ struct Opaque; \ Opaque *_opaque; \ diff --git a/src/infiniop/ops/bi_attention/bi_attention.h b/src/infiniop/ops/bi_attention/bi_attention.h new file mode 100644 index 000000000..d4740d25e --- /dev/null +++ b/src/infiniop/ops/bi_attention/bi_attention.h @@ -0,0 +1,37 @@ +#ifndef ATTENTION_H +#define ATTENTION_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::attention::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t x_desc); \ + }; \ + } + +#endif // ATTENTION_H diff --git a/src/infiniop/ops/bi_attention/operator.cc b/src/infiniop/ops/bi_attention/operator.cc new file mode 100644 index 000000000..22d2c14ed --- /dev/null +++ b/src/infiniop/ops/bi_attention/operator.cc @@ -0,0 +1,291 @@ +#include "../../operator.h" +#include "../../../utils.h" +#include "../../../utils/check.h" +#include "../../handle.h" +#include "../../tensor.h" +#include "infiniop/ops/bi_attention.h" +#include "infiniop/ops/softmax.h" +#include "infiniop/ops/gemm.h" +#include "infiniop/ops/rearrange.h" + +#include +#include + +struct InfiniopBiAttentionDescriptor { + InfiniopDescriptor _super; + infiniopRearrangeDescriptor_t rearrange_desc_k; + infiniopRearrangeDescriptor_t rearrange_desc_v; + infiniopRearrangeDescriptor_t rearrange_desc_q; + infiniopRearrangeDescriptor_t rearrange_desc_out; + infiniopGemmDescriptor_t matmul_desc1; + infiniopGemmDescriptor_t matmul_desc2; + infiniopSoftmaxDescriptor_t softmax_desc; + size_t workspace_size; + size_t op_workspace_offset; + size_t op_workspace_size; + size_t q_cont_offset; + size_t att_score_offset; + size_t att_val_offset; + size_t k_cache_offset; + size_t v_cache_offset; + float qk_alpha; +}; + +__C __export infiniStatus_t infiniopCreateBiAttentionDescriptor(infiniopHandle_t handle, + infiniopBiAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + size_t pos) { + if (out_desc->ndim() != 3 || q_desc->ndim() != 3 || k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() != 3 || v_cache_desc->ndim() != 3) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (!out_desc->isContiguous()) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + if (q_desc->strides()[2] != 1 || k_desc->strides()[2] != 1 || v_desc->strides()[2] != 1 || k_cache_desc->strides()[2] != 1 || v_cache_desc->strides()[2] != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + size_t n_q_head = q_desc->shape()[0]; + size_t seq_len = q_desc->shape()[1]; + size_t head_dim = q_desc->shape()[2]; + size_t hidden_size = n_q_head * head_dim; + size_t n_kv_head = k_desc->shape()[0]; + size_t total_seq_len = seq_len + pos; + size_t n_group = n_q_head / n_kv_head; + size_t alignment = 256; + + if (out_desc->shape()[0] != seq_len || out_desc->shape()[1] != n_q_head || out_desc->shape()[2] != head_dim) { + return INFINI_STATUS_BAD_PARAM; + } + + // k: [n_kv_head, seq_len, head_dim] + if (k_desc->shape()[0] != n_kv_head || k_desc->shape()[1] != seq_len || k_desc->shape()[2] != head_dim) { + return INFINI_STATUS_BAD_PARAM; + } + + // v: [n_kv_head, seq_len, head_dim] + if (v_desc->shape()[0] != n_kv_head || v_desc->shape()[1] != seq_len || v_desc->shape()[2] != head_dim) { + return INFINI_STATUS_BAD_PARAM; + } + + // k_cache: [n_kv_head, _, head_dim] + if (k_cache_desc->shape()[0] != n_kv_head || k_cache_desc->shape()[1] < total_seq_len || k_cache_desc->shape()[2] != head_dim) { + return INFINI_STATUS_BAD_PARAM; + } + + // v_cache: [n_kv_head, _, head_dim] + if (v_cache_desc->shape()[0] != n_kv_head || v_cache_desc->shape()[1] < total_seq_len || v_cache_desc->shape()[2] != head_dim) { + return INFINI_STATUS_BAD_PARAM; + } + + // Rearrange k into k_cache + infiniopTensorDescriptor_t dst_k_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&dst_k_desc, 3, k_desc->shape().data(), k_cache_desc->strides().data(), k_cache_desc->dtype())); + infiniopRearrangeDescriptor_t rearrange_desc_k; + CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_k, dst_k_desc, k_desc)); + + // Rearrange v into v_cache + infiniopTensorDescriptor_t dst_v_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&dst_v_desc, 3, v_desc->shape().data(), v_cache_desc->strides().data(), v_cache_desc->dtype())); + infiniopRearrangeDescriptor_t rearrange_desc_v; + CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_v, dst_v_desc, v_desc)); + + infiniopRearrangeDescriptor_t rearrange_desc_q = nullptr; + size_t q_cont_size = 0; + infiniopTensorDescriptor_t rearranged_q_desc; + // Rearrange q into contiguous + if (!q_desc->isContiguous(0, 1)) { + CHECK_STATUS(infiniopCreateTensorDescriptor(&rearranged_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype())); + q_cont_size = utils::align(rearranged_q_desc->numel() * infiniSizeOf(rearranged_q_desc->dtype()), alignment); + rearrange_desc_q = new InfiniopDescriptor; + CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_q, rearranged_q_desc, q_desc)); + } + + // Matmul1: q * full_k + // q: [n_q_head, seq_len, head_dim] -> [n_kv_head, n_group *seq_len, head_dim] + infiniopTensorDescriptor_t reshaped_q_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&reshaped_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype())); + TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimSplit(0, {n_kv_head, n_group})); + TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimMerge(1, 2)); + // full_k: [n_kv_head, head_dim, total_seq_len] + infiniopTensorDescriptor_t full_k_desc; + size_t full_k_shape[3] = {n_kv_head, total_seq_len, head_dim}; + CHECK_STATUS(infiniopCreateTensorDescriptor(&full_k_desc, 3, full_k_shape, k_cache_desc->strides().data(), k_cache_desc->dtype())); + TRANSFORM_TENSOR_DESC(full_k_desc, dimPermute({0, 2, 1})); + // qk: [n_kv_head, n_group * seq_len, total_seq_len] + infiniopTensorDescriptor_t qk_desc; + size_t qk_shape[3] = {n_kv_head, n_group * seq_len, total_seq_len}; + CHECK_STATUS(infiniopCreateTensorDescriptor(&qk_desc, 3, qk_shape, nullptr, q_desc->dtype())); + // matmul1_desc + // qk_alpha + float qk_alpha = 1 / sqrt(head_dim); + infiniopGemmDescriptor_t matmul1_desc; + CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul1_desc, qk_desc, reshaped_q_desc, full_k_desc)); + // matmul1 workspace size + size_t matmul1_workspace_size; + CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul1_desc, &matmul1_workspace_size)); + matmul1_workspace_size = utils::align(matmul1_workspace_size, alignment); + // attention score tensor size + size_t attn_score_size = utils::align(qk_desc->numel() * infiniSizeOf(qk_desc->dtype()), alignment); + + // CausalSoftmax: softmax(qk) + // qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len] + TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(1, {n_group, seq_len})); + TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(0, 1)); + infiniopSoftmaxDescriptor_t softmax_desc; + CHECK_STATUS(infiniopCreateSoftmaxDescriptor(handle, &softmax_desc, qk_desc, qk_desc, 2)); + // softmax workspace size + size_t softmax_workspace_size; + CHECK_STATUS(infiniopGetSoftmaxWorkspaceSize(softmax_desc, &softmax_workspace_size)); + softmax_workspace_size = utils::align(softmax_workspace_size, alignment); + + // Matmul2: softmax(qk) * full_v + // softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len] + // full_v: [n_kv_head, total_seq_len, head_dim] + TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(0, {n_kv_head, n_group})); + TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(1, 2)); + infiniopTensorDescriptor_t full_v_desc; + size_t full_v_shape[3] = {n_kv_head, total_seq_len, head_dim}; + CHECK_STATUS(infiniopCreateTensorDescriptor(&full_v_desc, 3, full_v_shape, v_cache_desc->strides().data(), v_cache_desc->dtype())); + // temp_out: [n_kv_head, n_group * seq_len, head_dim] + infiniopTensorDescriptor_t att_val_desc; + size_t temp_out_shape[3] = {n_kv_head, n_group * seq_len, head_dim}; + CHECK_STATUS(infiniopCreateTensorDescriptor(&att_val_desc, 3, temp_out_shape, nullptr, q_desc->dtype())); + // matmul2_desc + infiniopGemmDescriptor_t matmul2_desc; + CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, att_val_desc, qk_desc, full_v_desc)); + // matmul2 workspace size + size_t matmul2_workspace_size; + CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul2_desc, &matmul2_workspace_size)); + matmul2_workspace_size = utils::align(matmul2_workspace_size, alignment); + // attention value tensor size + size_t att_val_size = utils::align(att_val_desc->numel() * infiniSizeOf(att_val_desc->dtype()), alignment); + + // Rearrange temp_out into out + // out: [seq_len, n_q_head, head_dim] + // temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim] + TRANSFORM_TENSOR_DESC(att_val_desc, dimSplit(1, {n_group, seq_len})); + TRANSFORM_TENSOR_DESC(att_val_desc, dimMerge(0, 1)); + TRANSFORM_TENSOR_DESC(att_val_desc, dimPermute({1, 0, 2})); + infiniopRearrangeDescriptor_t rearrange_desc_out; + CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_out, out_desc, att_val_desc)); + + // workspace size + size_t op_workspace_size = utils::align(std::max(std::max(matmul1_workspace_size, matmul2_workspace_size), softmax_workspace_size), alignment); + size_t temp_tensors_size = attn_score_size + std::max(q_cont_size, att_val_size); + size_t workspace_size = temp_tensors_size + op_workspace_size; + + // k_cache_offset + size_t k_cache_offset = 0; + if (pos > 0) { + k_cache_offset = pos * k_cache_desc->getByteStrides()[1]; + } + + // v_cache_offset + size_t v_cache_offset = 0; + if (pos > 0) { + v_cache_offset = pos * v_cache_desc->getByteStrides()[1]; + } + + // create attention descriptor + *(InfiniopBiAttentionDescriptor **)desc_ptr = new InfiniopBiAttentionDescriptor{ + {handle->device, handle->device_id}, + rearrange_desc_k, + rearrange_desc_v, + rearrange_desc_q, + rearrange_desc_out, + matmul1_desc, + matmul2_desc, + softmax_desc, + workspace_size, + temp_tensors_size, + op_workspace_size, + attn_score_size, + 0, + attn_score_size, + k_cache_offset, + v_cache_offset, + 1.f / std::sqrt(float(head_dim)), + }; + + return INFINI_STATUS_SUCCESS; +} + +__C __export infiniStatus_t infiniopGetBiAttentionWorkspaceSize(infiniopBiAttentionDescriptor_t desc, size_t *size) { + *size = ((InfiniopBiAttentionDescriptor *)desc)->workspace_size; + return INFINI_STATUS_SUCCESS; +} + +__C __export infiniStatus_t infiniopBiAttention(infiniopBiAttentionDescriptor_t desc_, + void *workspace_, + size_t workspace_size_, + void *out, + void const *q, + void const *k, + void const *v, + void *k_cache, + void *v_cache, + void *stream) { + auto desc = (InfiniopBiAttentionDescriptor *)desc_; + if (workspace_size_ < desc->workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; // STATUS_MEMORY_NOT_ALLOCATED + } + void *workspace = (char *)workspace_ + desc->op_workspace_offset; + size_t workspace_size = desc->op_workspace_size; + void *att_score = (char *)workspace_ + desc->att_score_offset; + void *att_val = (char *)workspace_ + desc->att_val_offset; + void const *q_ = q; + // concat k and v to k_cache and v_cache + CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_k, + (char *)k_cache + desc->k_cache_offset, k, stream)); + + CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_v, + (char *)v_cache + desc->v_cache_offset, v, stream)); + + // rearrange q into contiguous + if (desc->rearrange_desc_q) { + void *q_cont = (char *)workspace_ + desc->q_cont_offset; + CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_q, q_cont, q, stream)); + q_ = q_cont; + } + + // matmul1: q * full_k + CHECK_STATUS(infiniopGemm(desc->matmul_desc1, + workspace, workspace_size, + att_score, q_, k_cache, desc->qk_alpha, 0.0, stream)); + // softmax(qk) + CHECK_STATUS(infiniopSoftmax(desc->softmax_desc, + workspace, workspace_size, + att_score, att_score, stream)); + // matmul2: softmax(qk) * full_v + CHECK_STATUS(infiniopGemm(desc->matmul_desc2, + workspace, workspace_size, + att_val, att_score, v_cache, 1.0, 0.0, stream)); + // rearrange out + CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_out, out, att_val, stream)); + + return INFINI_STATUS_SUCCESS; +} + +__C __export infiniStatus_t infiniopDestroyBiAttentionDescriptor(infiniopBiAttentionDescriptor_t desc_) { + auto desc = (InfiniopBiAttentionDescriptor *)desc_; + if (desc->rearrange_desc_q) { + CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_q)); + } + CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_k)); + CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_v)); + CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_out)); + CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul_desc1)); + CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul_desc2)); + CHECK_STATUS(infiniopDestroySoftmaxDescriptor(desc->softmax_desc)); + delete desc; + + return INFINI_STATUS_SUCCESS; +} diff --git a/test/infiniop/bi_attention.py b/test/infiniop/bi_attention.py new file mode 100644 index 000000000..1f1146366 --- /dev/null +++ b/test/infiniop/bi_attention.py @@ -0,0 +1,271 @@ +from ctypes import c_uint64 +import ctypes +import sys +import os +import torch + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + + +def softmax(x): + type = x.dtype + return torch.nn.functional.softmax(x.to(torch.float32), dim=-1).to(type) + + +def bi_attention(q, k, v, k_cache, v_cache, pos): + type = q.dtype + + n_q_head = q.shape[0] + n_kv_head = k.shape[0] + + # Concatenate key and value caches + k_cache = k_cache[:, :pos, :] # (n_kv_head, pos, head_dim) + v_cache = v_cache[:, :pos, :] # (n_kv_head, pos, head_dim) + k = torch.cat([k_cache, k], dim=1) # (n_kv_head, total_seq_len, head_dim) + v = torch.cat([v_cache, v], dim=1) # (n_kv_head, total_seq_len, head_dim) + + total_seq_len = k.shape[1] + + head_dim = v.shape[-1] + + if n_q_head != n_kv_head: + q = q.reshape( + n_kv_head, -1, head_dim + ) # (n_kv_head, n_group * seq_len, head_dim) + + # Scaled dot-product bi_attention + attn_scores = ( + torch.einsum("hqd,hkd->hqk", q.to(torch.float32), k.to(torch.float32)) + .to(type) + .reshape(n_q_head, -1, total_seq_len) + ) # (n_q_head, seq_len, total_seq_len) + attn_scores = attn_scores / (head_dim**0.5) + + attn_weights = softmax(attn_scores).reshape( + n_kv_head, -1, total_seq_len + ) # (n_kv_head, seq_len, total_seq_len) + + # Weighted sum of values + attn_output = ( + torch.einsum( + "hqk,hkd->hqd", attn_weights.to(torch.float32), v.to(torch.float32) + ) + .to(type) + .reshape(n_q_head, -1, head_dim) + .permute(1, 0, 2) + ) # ([seq_len, n_q_head, head_dim]) + + return attn_output + + +def test( + handle, + device, + n_q_head, + n_kv_head, + seq_len, + head_dim, + pos, + k_cache_buf_len, + v_cache_buf_len, + q_stride=None, + k_stride=None, + v_stride=None, + k_cache_stride=None, + v_cache_stride=None, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing BiAttention on {InfiniDeviceNames[device]} with n_q_head:{n_q_head} n_kv_head:{n_kv_head} seq_len:{seq_len} head_dim:{head_dim} pos:{pos} " + f"dtype:{InfiniDtypeNames[dtype]} q_stride:{q_stride} k_stride:{k_stride} v_stride:{v_stride} k_cache_stride:{k_cache_stride} v_cache_stride:{v_cache_stride}" + ) + + out = TestTensor([seq_len, n_q_head, head_dim], None, dtype, device, mode="zeros") + q = TestTensor([n_q_head, seq_len, head_dim], q_stride, dtype, device, scale=0.1) + k = TestTensor([n_kv_head, seq_len, head_dim], k_stride, dtype, device, scale=0.1) + v = TestTensor([n_kv_head, seq_len, head_dim], v_stride, dtype, device, scale=0.1) + k_cache = TestTensor( + [n_kv_head, k_cache_buf_len, head_dim], k_cache_stride, dtype, device, scale=0.1 + ) + v_cache = TestTensor( + [n_kv_head, v_cache_buf_len, head_dim], v_cache_stride, dtype, device, scale=0.1 + ) + + def torch_bi_attention(): + return bi_attention( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + k_cache.torch_tensor(), + v_cache.torch_tensor(), + pos, + ) + + ans = torch_bi_attention() + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateBiAttentionDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + q.descriptor, + k.descriptor, + v.descriptor, + k_cache.descriptor, + v_cache.descriptor, + pos, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [out, q, k, v, k_cache, v_cache]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetBiAttentionWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, out.device) + + def lib_bi_attention(): + check_error( + LIBINFINIOP.infiniopBiAttention( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + q.data(), + k.data(), + v.data(), + k_cache.data(), + v_cache.data(), + None, + ) + ) + + lib_bi_attention() + + # Validate results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch_bi_attention(), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_bi_attention(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroyBiAttentionDescriptor(descriptor)) + + +if __name__ == "__main__": + _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, ] + + # Tolerance map for different data types + _TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-4, "rtol": 1e-2}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-3}, + } + + DEBUG = False + PROFILE = False + NUM_PRERUN = 10 + NUM_ITERATIONS = 1000 + test_cases = [ + # prefill + ( + 32, # n_q_head + 4, # n_kv_head + 5, # seq_len + 64, # head_dim + 0, # pos + 2048, # k_cache_buf_len + 2048, # v_cache_buf_len + [64, 2560, 1], # q_stride + [64, 2560, 1], # k_stride + [64, 2560, 1], # v_stride + [64, 11264, 1], # k_cache_stride + [64, 11264, 1], # v_cache_stride + ), + # decode + ( + 32, # n_q_head + 4, # n_kv_head + 1, # seq_len + 64, # head_dim + 3, # pos + 2048, # k_cache_buf_len + 2048, # v_cache_buf_len + [64, 2560, 1], # q_stride + [64, 2560, 1], # k_stride + [64, 2560, 1], # v_stride + [64, 11264, 1], # k_cache_stride + [64, 11264, 1], # v_cache_stride + ), + # for test + ( + 8, # n_q_head + 4, # n_kv_head + 2, # seq_len + 16, # head_dim + 1, # pos + 8, # k_cache_buf_len + 8, # v_cache_buf_len + None, # q_stride + None, # k_stride + None, # v_stride + None, # k_cache_stride + None, # v_cache_stride + ), + ( + 28, # n_q_head + 28, # n_kv_head + 15, # seq_len + 128, # head_dim + 0, # pos + 2048, # k_cache_buf_len + 2048, # v_cache_buf_len + [128, 10752, 1], # q_stride + [128, 10752, 1], # k_stride + [128, 10752, 1], # v_stride + [128, 3584, 1], # k_cache_stride + [128, 3584, 1], # v_cache_stride + ), + ] + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(device, test, test_cases, _TENSOR_DTYPES) + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 5b2974111..3d1f8acaf 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -95,6 +95,45 @@ def attention_(lib): infiniopOperatorDescriptor_t, ] +@OpRegister.operator +def bi_attention_(lib): + lib.infiniopCreateBiAttentionDescriptor.restype = c_int32 + lib.infiniopCreateBiAttentionDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_size_t, + ] + + lib.infiniopGetBiAttentionWorkspaceSize.restype = c_int32 + lib.infiniopGetBiAttentionWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopBiAttention.restype = c_int32 + lib.infiniopBiAttention.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyBiAttentionDescriptor.restype = c_int32 + lib.infiniopDestroyBiAttentionDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] @OpRegister.operator def causal_softmax_(lib):