From 01a27728e664377097d39b1ad9fc68cd6f570fc5 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Mon, 28 Apr 2025 18:53:04 +0800 Subject: [PATCH 01/24] Fix synchronization issues --- .gitignore | 1 + csrc/kernels/splitkv_mla.cu | 10 +++++++--- csrc/kernels/traits.h | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 982daef..9b500a0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist/ *perf.csv *.png /.vscode +compile_commands.json diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/kernels/splitkv_mla.cu index ff29305..5e1fded 100644 --- a/csrc/kernels/splitkv_mla.cu +++ b/csrc/kernels/splitkv_mla.cu @@ -1017,13 +1017,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params cudaGridDependencySynchronize(); int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + // We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race. + int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; int begin_seqlen = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int end_seqlen = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; - int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); // Copy the first Q launch_q_copy(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q); @@ -1123,6 +1124,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params // Issue P0 = Q @ K0^T, wait warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); + // We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0 + NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized); cute::warpgroup_wait<0>(); #define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \ @@ -1238,7 +1241,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params cute::tma_store_wait<0>(); } else { - int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + // Don't use __ldg because of PDL and instruction reordering + int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx; float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< diff --git a/csrc/kernels/traits.h b/csrc/kernels/traits.h index 31c1388..5f915a6 100644 --- a/csrc/kernels/traits.h +++ b/csrc/kernels/traits.h @@ -102,5 +102,6 @@ enum NamedBarriers : int { sScale0Ready = 0, sScale1Ready = 1, sP0Ready = 2, - rO1sP0sV0RIssued = 3 + rO1sP0sV0RIssued = 3, + sMInitialized = 4, }; From 9c5dfab6d1746b4a27af14f440e7afd5c01ece68 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 29 Apr 2025 12:02:57 +0800 Subject: [PATCH 02/24] update to cutlass 3.9 --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index afa1772..e94e888 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit e94e888df3551224738bfa505787b515eae8352f From 9edee0c022cd0938148a18e334203b0aab43aa19 Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 29 Apr 2025 12:03:15 +0800 Subject: [PATCH 03/24] update .gitignore --- .gitignore | 1 + setup.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 9b500a0..4535280 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.png /.vscode compile_commands.json +.cache diff --git a/setup.py b/setup.py index 131ceff..217f540 100644 --- a/setup.py +++ b/setup.py @@ -11,10 +11,12 @@ IS_WINDOWS, ) + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return nvcc_extra_args + ["--threads", nvcc_threads] + def get_features_args(): features_args = [] DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"] From 41b611f7d7561790a2f5040ff89212e08c7b0011 Mon Sep 17 00:00:00 2001 From: Zeyu WANG Date: Fri, 1 Aug 2025 17:21:27 +0800 Subject: [PATCH 04/24] Add more GPU architctures support (#76) * Add more GPU architctures support * Merge fmha and mla runner * add varlen & non varlen support, and add incontiguous tensor support * update readme * add varlen api --------- Co-authored-by: dianzhangc --- README.md | 12 +- csrc/sm100/collective/fmha_common.hpp | 127 ++ csrc/sm100/collective/fmha_fusion.hpp | 396 ++++ ..._fmha_fwd_epilogue_tma_warpspecialized.hpp | 234 +++ ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 1218 +++++++++++ .../sm100_fmha_load_tma_warpspecialized.hpp | 316 +++ ...a_mla_fwd_mainloop_tma_warpspecialized.hpp | 1225 +++++++++++ ...m100_fmha_mla_load_tma_warpspecialized.hpp | 340 +++ csrc/sm100/common/gather_tensor.hpp | 215 ++ csrc/sm100/common/helper.h | 72 + csrc/sm100/common/mask.cuh | 8 + csrc/sm100/common/pipeline_mla.hpp | 250 +++ csrc/sm100/common/pow_2.hpp | 92 + csrc/sm100/common/utils.hpp | 83 + csrc/sm100/device/fmha.hpp | 276 +++ csrc/sm100/device/fmha_device_bwd.hpp | 340 +++ csrc/sm100/fmha_cutlass_bwd_sm100.cu | 83 + csrc/sm100/fmha_cutlass_bwd_sm100.cuh | 200 ++ csrc/sm100/fmha_cutlass_fwd_sm100.cu | 81 + csrc/sm100/fmha_cutlass_fwd_sm100.cuh | 334 +++ .../kernel/fmha_causal_tile_scheduler.hpp | 197 ++ csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp | 153 ++ csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp | 161 ++ csrc/sm100/kernel/fmha_options.hpp | 85 + csrc/sm100/kernel/fmha_tile_scheduler.hpp | 162 ++ ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 1841 +++++++++++++++++ ...mha_bwd_mla_kernel_tma_warpspecialized.hpp | 1834 ++++++++++++++++ ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 619 ++++++ csrc/sm100/pybind.cu | 17 + csrc/{ => sm90}/flash_api.cpp | 0 csrc/{ => sm90}/kernels/config.h | 0 csrc/{ => sm90}/kernels/get_mla_metadata.cu | 0 csrc/{ => sm90}/kernels/get_mla_metadata.h | 0 csrc/{ => sm90}/kernels/mla_combine.cu | 0 csrc/{ => sm90}/kernels/mla_combine.h | 0 csrc/{ => sm90}/kernels/params.h | 0 csrc/{ => sm90}/kernels/splitkv_mla.cu | 0 csrc/{ => sm90}/kernels/splitkv_mla.h | 0 csrc/{ => sm90}/kernels/traits.h | 0 csrc/{ => sm90}/kernels/utils.h | 0 flash_mla/__init__.py | 3 + flash_mla/flash_mla_interface.py | 271 ++- setup.py | 61 +- ...st_flash_mla.py => test_flash_mla_sm90.py} | 0 tests/test_fmha_sm100.py | 199 ++ 45 files changed, 11489 insertions(+), 16 deletions(-) create mode 100644 csrc/sm100/collective/fmha_common.hpp create mode 100644 csrc/sm100/collective/fmha_fusion.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp create mode 100644 csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp create mode 100644 csrc/sm100/common/gather_tensor.hpp create mode 100644 csrc/sm100/common/helper.h create mode 100644 csrc/sm100/common/mask.cuh create mode 100644 csrc/sm100/common/pipeline_mla.hpp create mode 100644 csrc/sm100/common/pow_2.hpp create mode 100644 csrc/sm100/common/utils.hpp create mode 100644 csrc/sm100/device/fmha.hpp create mode 100644 csrc/sm100/device/fmha_device_bwd.hpp create mode 100644 csrc/sm100/fmha_cutlass_bwd_sm100.cu create mode 100644 csrc/sm100/fmha_cutlass_bwd_sm100.cuh create mode 100644 csrc/sm100/fmha_cutlass_fwd_sm100.cu create mode 100644 csrc/sm100/fmha_cutlass_fwd_sm100.cuh create mode 100644 csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp create mode 100644 csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp create mode 100644 csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp create mode 100644 csrc/sm100/kernel/fmha_options.hpp create mode 100644 csrc/sm100/kernel/fmha_tile_scheduler.hpp create mode 100644 csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp create mode 100644 csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp create mode 100644 csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp create mode 100644 csrc/sm100/pybind.cu rename csrc/{ => sm90}/flash_api.cpp (100%) rename csrc/{ => sm90}/kernels/config.h (100%) rename csrc/{ => sm90}/kernels/get_mla_metadata.cu (100%) rename csrc/{ => sm90}/kernels/get_mla_metadata.h (100%) rename csrc/{ => sm90}/kernels/mla_combine.cu (100%) rename csrc/{ => sm90}/kernels/mla_combine.h (100%) rename csrc/{ => sm90}/kernels/params.h (100%) rename csrc/{ => sm90}/kernels/splitkv_mla.cu (100%) rename csrc/{ => sm90}/kernels/splitkv_mla.h (100%) rename csrc/{ => sm90}/kernels/traits.h (100%) rename csrc/{ => sm90}/kernels/utils.h (100%) rename tests/{test_flash_mla.py => test_flash_mla_sm90.py} (100%) create mode 100644 tests/test_fmha_sm100.py diff --git a/README.md b/README.md index 5d66f55..07e021a 100644 --- a/README.md +++ b/README.md @@ -28,13 +28,21 @@ Currently released: ### Install ```bash -python setup.py install +pip install -v . ``` ### Benchmark +#### Testing MLA Decoding + +```bash +python tests/test_flash_mla_sm90.py +``` + +#### Testing MLA Forward/Backward + ```bash -python tests/test_flash_mla.py +python tests/test_fmha_sm100.py ``` It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. diff --git a/csrc/sm100/collective/fmha_common.hpp b/csrc/sm100/collective/fmha_common.hpp new file mode 100644 index 0000000..c60d9e9 --- /dev/null +++ b/csrc/sm100/collective/fmha_common.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = decltype(atom.accumulate_)::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, prepend(make_layout(stages), _)); +} + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} + +template +CUTLASS_DEVICE +void warpgroup_reg_set() { + if constexpr (RegCount < 128) { + cutlass::arch::warpgroup_reg_dealloc(); + } + else { + cutlass::arch::warpgroup_reg_alloc(); + } +} + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/fmha_fusion.hpp b/csrc/sm100/collective/fmha_fusion.hpp new file mode 100644 index 0000000..1486767 --- /dev/null +++ b/csrc/sm100/collective/fmha_fusion.hpp @@ -0,0 +1,396 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct NoMask { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return ceil_div(get<1>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + return; + } +}; + +struct ResidualMask : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<1>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct ResidualMaskForBackward : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (! elem_less(pos, select<0,1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +// There are two ways to do causal if N_Q != N_K +// (1) The Q is at the beginning of the matrix +// (2) The Q is at the end of the matrix +template +struct CausalMask : NoMask { + + using Base = NoMask; + + static constexpr bool IsQBegin = kIsQBegin; + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } else { + const int offset_q = get<1>(problem_size) - get<0>(problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); + } else { + const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); + return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); + } + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is the default setting. + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to set kIsQBegin=false + + if constexpr (IsQBegin) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } else { + const auto offset_q = get<1>(problem_size) - get<0>(problem_size); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } + } +}; + +template +struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { + + using Base = CausalMask; + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + int offset_q = 0; + if constexpr (!kIsQBegin) { + offset_q = get<1>(problem_size) - get<0>(problem_size); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size); + if (masked) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + +struct VariableLength { + int max_length; + int* cumulative_length = nullptr; + int total_length = -1; + + CUTE_HOST_DEVICE operator int() const { + return max_length; + } +}; + +template struct is_variable_length_impl : std::false_type {}; +template<> struct is_variable_length_impl : std::true_type {}; +template constexpr bool is_variable_length_v = is_variable_length_impl>::value; + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Idx const& idx) { + return transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { + auto new_shape = apply_variable_length(shape, idx); + auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { + if constexpr (is_variable_length_v) { + return cute::make_tuple(c, s.cumulative_length[idx]); + } + else { + return c; + } + }); + return cute::make_tuple(new_shape, new_coord); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length_offset(Shape const& shape, Coord const& coord) { + auto idx = back(back(coord)); + auto result_shape = transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); + auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx]; + } + else { + return _0{}; + } + }); + return cute::make_tuple(result_shape, result_offset); +} + +} // namespace cutlass::fmha::collective + +namespace cute { + +template<> +struct is_integral : true_type {}; + +CUTE_HOST_DEVICE +void print(cutlass::fmha::collective::VariableLength a) { + printf("Varlen<%d, %p>", a.max_length, a.cumulative_length); +} + +} diff --git a/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000..616357c --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::fmha::collective { + +template< + class Element, + class ElementAcc, + class TileShape, // Q, D, _ + class StrideO, // Q, D, B + class StrideLSE_, // Q, B + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdEpilogueTmaWarpspecialized { + + using Pipeline = cutlass::PipelineAsync<2>; + +// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{}))); + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>()); +// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); + using SmemLayoutO_ = SmemLayoutO; + using StrideLSE = StrideLSE_; + using ElementOut = Element; + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorage { + + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + + }; + + struct Arguments { + Element* ptr_O; + StrideO dO; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + using TMA_O = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}), + SmemLayoutO{}(_,_,_0{}) + )); + + + struct Params { + TMA_O tma_store_o; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + // FMHA and MLA have different input ProblemShapes; + // get problem_shape_O according to the input ProblemShape. + template + CUTLASS_DEVICE static constexpr + auto get_problem_shape_O ( + ProblemShape const& problem_shape) { + if constexpr (rank_v(ProblemShape{}))> == 2) { + return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape)); + } else { + return select<0,2,3>(problem_shape); + } + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace = nullptr) { + + auto ptr_O = args.ptr_O; + StrideO dO = args.dO; + + auto problem_shape_O = get_problem_shape_O(problem_shape); + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dO) = get<0>(dO); + get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O)); + // offset ptr by the amount we add back in later + ptr_O -= max_length_q * get<0>(dO); + } + } + + auto tma_store_o = make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(ptr_O, problem_shape_O, dO), + SmemLayoutO{}(_,_,_0{}) + ); + + return { + tma_store_o, + args.ptr_LSE, + args.dLSE + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + const Params& params; + + CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {} + + template + CUTLASS_DEVICE auto + store( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, + Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { + + BlkCoord blk_coord = blk_coord_in; + uint32_t lane_predicate = cute::elect_one_sync(); + + using X = Underscore; + + int o0_index = 2 * get<0>(blk_coord); + int o1_index = 2 * get<0>(blk_coord) + 1; + + Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape)); + // offset mode 0 by (max_length - real_length) + // offset mode 3,1 by cumulative_length + real_length + // the ptr is already offset by - max_length + // so in total this achieves + int offs_0 = 0; + int offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + offs_0 = max_length_q - get<0>(problem_shape); + offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape); + get<2,1>(blk_coord) = 0; + } + } + + Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p); + + Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord)); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto block_tma = params.tma_store_o.get_slice(0); + Tensor tOsO = block_tma.partition_S(sO); + Tensor tOgO = block_tma.partition_D(gO); + + auto pipeline_release_state = pipeline_consumer_state; + + // O1 O2 + // one pipeline: O + // wait from corr, issue tma store on smem + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index)); + } + tma_store_arrive(); + + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index)); + } + tma_store_arrive(); + + tma_store_wait<1>(); + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + tma_store_wait<0>(); + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000..f39fd75 --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1218 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_tma_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class TileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + // Since shared memory is sufficient for FMHA, there is no need to reuse shared memory. + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using TileShape = TileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + // Reuse shared memory for V and O. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaUmmaAsync< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + + static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp new file mode 100644 index 0000000..1951056 --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -0,0 +1,316 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape +> +struct Sm100FmhaLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + auto problem_shape_qk = problem_shape; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dQ) = get<0>(dQ); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_Q -= max_length_q * get<0>(dQ); + } + } + + if constexpr (is_variable_length_v>) { + auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; + if (cumulative_length_kv != nullptr) { + int max_length_kv = get<1>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dK) = get<0>(dK); + get<2,1>(dV) = get<0>(dV); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_K -= max_length_kv * get<0>(dK); + ptr_V -= max_length_kv * get<0>(dV); + } + } + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); + + int q_offs_0 = 0; + int q_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + q_offs_0 = max_length_q - get<0>(problem_shape); + q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); + + int kv_offs_0 = 0; + int kv_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + int max_length = get<1>(params_problem_shape).max_length; + kv_offs_0 = max_length - get<1>(problem_shape); + kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // V1 + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000..bf41af9 --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1225 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" +#include "common/pipeline_mla.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class ComposedTileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using ComposedTileShape = ComposedTileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountK = 1; + static constexpr int StageCountV = 1; + static constexpr int StageCountKV = StageCountK + StageCountV; + // Support StageCountKV > 2 in the future. + static_assert(StageCountK == 1 && StageCountV == 1, "Only support StageCountK = StageCountV = 1!"); + static_assert(std::is_same_v>, "Only support ThreadShape = Shape<_2, _1, _1>"); + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + static constexpr auto HeadDimLatent = size<2, 0>(ComposedTileShape{}); + static constexpr auto HeadDimRope = size<2, 1>(ComposedTileShape{}); + static constexpr auto HeadDimQK = HeadDimLatent + HeadDimRope; + static constexpr auto HeadDimPV = HeadDimLatent; + + using TileShapeQK = decltype(shape_div(replace<2>(ComposedTileShape{}, HeadDimQK), ThreadShape{})); + using TileShapePV = decltype(select<0,2,1>(shape_div(replace<2>(ComposedTileShape{}, HeadDimPV), ThreadShape{}))); + using TileShape = decltype(replace<2>(ComposedTileShape{}, HeadDimLatent)); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{}))); + + // Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory, + // we reuse shared memory for V and O to address this problem, + // and a barrier has been added to coordinate access to shared memory. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorageQKVO { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_o; // use as O0 + cute::array_aligned> smem_v; // use as V0 and O1 + }; + + struct TensorStorageQKV { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + + using TensorStorage = std::conditional_t; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaAsyncMla< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static constexpr int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + using Load = Sm100MlaFwdLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape, OrderLoadEpilogue + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + bool need_apply_mask, + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_mask) { + if(need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + constexpr int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + const int mask_trip_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + const int total_trip_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + int trip_idx = total_trip_count; + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + constexpr bool NeedMask = !std::is_same_v; + + CUTLASS_PRAGMA_NO_UNROLL + for (; trip_idx > 0; trip_idx -= 1) { + softmax_step( + trip_idx <= mask_trip_count, + row_max, row_sum, stage, + trip_idx == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + constexpr int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp new file mode 100644 index 0000000..c2d3e2b --- /dev/null +++ b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dQ) = get<0>(dQ); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_Q -= max_length_q * get<0>(dQ); + } + } + + if constexpr (is_variable_length_v>) { + auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; + if (cumulative_length_kv != nullptr) { + int max_length_kv = get<1>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dK) = get<0>(dK); + get<2,1>(dV) = get<0>(dV); + get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); + // offset ptr by the amount we add back in later + ptr_K -= max_length_kv * get<0>(dK); + ptr_V -= max_length_kv * get<0>(dV); + } + } + + auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape)); + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); + + int q_offs_0 = 0; + int q_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + q_offs_0 = max_length_q - get<0>(problem_shape); + q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); + + int kv_offs_0 = 0; + int kv_offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + int max_length = get<1>(params_problem_shape).max_length; + kv_offs_0 = max_length - get<1>(problem_shape); + kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + // V1 + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch vi + cute::prefetch(params.tma_load_v, tVgV(_, k_index)); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch ki+1 + if(mask_tile_count > 1) { + cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1)); + } + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/csrc/sm100/common/gather_tensor.hpp b/csrc/sm100/common/gather_tensor.hpp new file mode 100644 index 0000000..46fb640 --- /dev/null +++ b/csrc/sm100/common/gather_tensor.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +namespace example { + +using namespace cute; + +// Empty type used to disable gather/scatter for a GEMM argument +struct NoGather +{ + template + NoGather(Ts...) {}; +}; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr + IndexedGather(Index const *indices = {}): indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr + Index + operator()(I i) const { return indices_[i]; } + + CUTE_HOST_DEVICE friend + void + print(IndexedGather const &s) { + cute::print("Indexed"); + } + + Index const *indices_; +}; + +/// Function object that applies a stride to its argument +/// Example: StridedFunc gathers every other row/column +template +struct StridedGather +{ + CUTE_HOST_DEVICE constexpr + StridedGather(Stride stride = {}): stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(I i) const { return i * stride_; } + + CUTE_HOST_DEVICE friend + void + print(StridedGather const &s) { + cute::print("Strided{"); + print(s.stride_); + cute::print("}"); + } + + Stride stride_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr + CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; } + + CUTE_HOST_DEVICE friend + void + print(CustomStride const & s) { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend + auto + safe_div(CustomStride const &s, Div const &div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend + auto + make_layout(Shape const &shape, CustomStride const &stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE +auto +make_custom_stride_layout(Stride const &stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit stride with a custom gather stride + auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout(repeat_like(stride, _1{}), + replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE +auto +make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) +{ + if constexpr (not cutlass::platform::is_same, NoGather>::value) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); + } else { + return make_tensor(iter, shape, stride); + } +} + +} // namespace example + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(ceil_div(shape, Int{}), ceil_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Offset,Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace example diff --git a/csrc/sm100/common/helper.h b/csrc/sm100/common/helper.h new file mode 100644 index 0000000..e957c4e --- /dev/null +++ b/csrc/sm100/common/helper.h @@ -0,0 +1,72 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once + + #include "cuda_runtime.h" + #include + + /** + * Panic wrapper for unwinding CUTLASS errors + */ + #define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + + /** + * Panic wrapper for unwinding CUDA runtime errors + */ + #define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +#define FLASH_MLA_ASSERT(cond) \ +do { \ + if (!(cond)) { \ + std::cerr << "FLASH_MLA_ASSERT: " << #cond << " failed at " << __FILE__ << ":" << __LINE__ << std::endl; \ + std::abort(); \ + } \ +} while (0) + + \ No newline at end of file diff --git a/csrc/sm100/common/mask.cuh b/csrc/sm100/common/mask.cuh new file mode 100644 index 0000000..d118aab --- /dev/null +++ b/csrc/sm100/common/mask.cuh @@ -0,0 +1,8 @@ +#pragma once + +enum class MaskMode { + kNone = 0U, // No mask + kCausal = 1U, // Causal mask + kCustom = 2U, // Custom mask +}; + diff --git a/csrc/sm100/common/pipeline_mla.hpp b/csrc/sm100/common/pipeline_mla.hpp new file mode 100644 index 0000000..5bbeed9 --- /dev/null +++ b/csrc/sm100/common/pipeline_mla.hpp @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Support the producer to acquire specific bytes of data. +*/ + +#pragma once + +#include "cutlass/pipeline/sm100_pipeline.hpp" + +namespace cutlass { + +using namespace cute; + +template < + int Stages_, + class ClusterShape = Shape, + class AtomThrShape_MNK_ = Shape<_1,_1,_1> +> +class PipelineTmaAsyncMla { + +public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; + +private: + using Impl = PipelineTmaUmmaAsync; + +public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + + using McastDirection = McastDirection; + + // Helper function to initialize barriers + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + auto atom_thr_shape = AtomThrShape_MNK{}; + uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { + auto atom_thr_shape = AtomThrShape_MNK{}; + + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? + cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas + cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + auto cluster_layout = make_layout(cluster_shape); + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { + // Calculate consumer mask + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + auto cluster_layout = make_layout(cluster_shape); + if (mcast_direction == McastDirection::kRow) { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + else { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + +public: + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape, mcast_direction); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape, mcast_direction); + } + } + + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + CUTLASS_DEVICE + void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); + if (barrier_token != BarrierStatus::WaitDone) { + empty_barrier_ptr_[stage].wait(phase); + } + + if (params_.is_leader) { + full_barrier_ptr_[stage].arrive_and_expect_tx(bytes); + } + #ifndef NDEBUG + if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { + asm volatile ("brkpt;\n" ::); + } + + // Most likely you have elected more than one leader + if (params_.is_leader && (threadIdx.x % 32 != 0)) { + asm volatile ("brkpt;\n" ::); + } + #endif + } + + CUTLASS_DEVICE + void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index(), false); + } + +private: + Impl impl_; + Params params_; + EmptyBarrier *empty_barrier_ptr_; + FullBarrier *full_barrier_ptr_; + uint16_t block_id_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 + if (!skip) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); + } + } + else { + if (!skip) { + if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { + cutlass::arch::umma_arrive(smem_ptr); + } + else { + cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); + } + } + } + } +}; + +} diff --git a/csrc/sm100/common/pow_2.hpp b/csrc/sm100/common/pow_2.hpp new file mode 100644 index 0000000..eca9325 --- /dev/null +++ b/csrc/sm100/common/pow_2.hpp @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace cutlass::fmha { + +struct Pow2 { + int n; + int log2_n; + + explicit CUTE_DEVICE Pow2(int n) : n(n) { +#ifdef __CUDA_ARCH__ + log2_n = __ffs(n) - 1; +#endif + } + + template + CUTE_HOST_DEVICE T operator *(T const& b) const { + return n * b; + } + + template + CUTE_HOST_DEVICE auto operator *(Int const&) const { + if constexpr (N & (N - 1) == 0) { + return Pow2{n * N}; + } + return n * N; + } + +}; + +template +CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) { + return a >> b.log2_n; +} + +template +CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) { + return a & (b.n - 1); +} + +template +CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) { + return a < b.n; +} + +CUTE_HOST_DEVICE void print(Pow2 const& a) { + printf("2^%d", a.log2_n); +} + +} // end namespace cutlass::fmha + +namespace cute { + +template <> +struct is_integral : true_type {}; + +} // end namespace cute diff --git a/csrc/sm100/common/utils.hpp b/csrc/sm100/common/utils.hpp new file mode 100644 index 0000000..f43770d --- /dev/null +++ b/csrc/sm100/common/utils.hpp @@ -0,0 +1,83 @@ +#pragma once + +#include "cutlass/numeric_types.h" +#include "helper.h" + +template +struct cutlass_dtype { + using type = T; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::half_t; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::bfloat16_t; +}; + +template <> +struct cutlass_dtype<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; +}; + +template <> +struct cutlass_dtype<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; +}; + +template +using cutlass_dtype_t = typename cutlass_dtype::type; + +template +struct DeviceAllocation { + T* ptr_ = nullptr; + size_t offset_ = 0; + size_t size_ = 0; + + DeviceAllocation(DeviceAllocation const&) = delete; + DeviceAllocation& operator=(DeviceAllocation const&) = delete; + + DeviceAllocation() = default; + DeviceAllocation(size_t size) { reset(size); } + ~DeviceAllocation() { reset(); } + + void reset(size_t size, size_t offset=0) { + reset(); + auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset)); + assert(ret == cudaSuccess); + size_ = size; + offset_ = offset; + } + + T* get() { + return ptr_ + offset_; + } + + const T* get() const { + return ptr_ + offset_; + } + + void reset() { + if (ptr_ != nullptr) { + auto ret = cudaFree(ptr_); + assert(ret == cudaSuccess); + } + } + + size_t size() const { return size_; } + + size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } + + void copy_from_host(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } + + void copy_from_device(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } +}; \ No newline at end of file diff --git a/csrc/sm100/device/fmha.hpp b/csrc/sm100/device/fmha.hpp new file mode 100644 index 0000000..f8406d3 --- /dev/null +++ b/csrc/sm100/device/fmha.hpp @@ -0,0 +1,276 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FMHA { +public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/sm100/device/fmha_device_bwd.hpp b/csrc/sm100/device/fmha_device_bwd.hpp new file mode 100644 index 0000000..d2463ac --- /dev/null +++ b/csrc/sm100/device/fmha_device_bwd.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/tensor.hpp" + +#include "../device/fmha.hpp" +#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp" +#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class Element, + class ElementAccumulator, + class TileShape, + bool IsMla, + class Mask +> +class Sm100FmhaBwd { +public: + /// Argument structure: User API + struct Arguments { + // Q K D D_VO HB + ProblemShape problem_shape; + + const Element* ptr_Q; + cute::tuple> stride_Q; + const Element* ptr_K; + cute::tuple> stride_K; + const Element* ptr_V; + cute::tuple> stride_V; + + const Element* ptr_O; + cute::tuple> stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple> stride_LSE; + + const Element* ptr_dO; + cute::tuple> stride_dO; + + Element* ptr_dQ; + cute::tuple> stride_dQ; + Element* ptr_dK; + cute::tuple> stride_dK; + Element* ptr_dV; + cute::tuple> stride_dV; + + ElementAccumulator softmax_scale; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdSumOdO + >; + using OperationConvert = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdConvert + >; + + using OperationMha= cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > + >; + + using OperationMla = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > + >; + + using Operation = std::conditional_t; + + using Kernel = typename Operation::Kernel; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments( + Arguments const& args, + ElementAccumulator* sum_odo = nullptr, + ElementAccumulator* scaled_lse = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_sum_OdO = make_stride(_1{}, make_stride(Q, Q*H)); + auto stride_scaled_lse = make_stride(_1{}, make_stride(Q, Q*H)); + auto log2_e = log2f(expf(1.0f)); + return typename OperationSumOdO::Arguments { + args.problem_shape, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + sum_odo, stride_sum_OdO, + args.ptr_LSE, args.stride_LSE, + scaled_lse, stride_scaled_lse, + -1.0f, -log2_e + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_src_dQ = make_stride(D, _1{}, make_stride(D*Q, D*Q*H)); + return typename OperationConvert::Arguments { + args.problem_shape, + src, stride_src_dQ, + nullptr, stride_src_dQ, + nullptr, stride_src_dQ, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, + args.softmax_scale + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple> const& stride_sum_OdO = {}, + ElementAccumulator* scaled_lse = nullptr, cute::tuple> const& stride_scaled_lse = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple> const& stride_dQ = {}) { + + return typename Operation::Arguments{ + args.problem_shape, + { args.ptr_Q, args.stride_Q, + args.ptr_K, args.stride_K, + args.ptr_V, args.stride_V, + args.ptr_dO, args.stride_dO, + scaled_lse, stride_scaled_lse, + sum_OdO, stride_sum_OdO, + dQ_acc, stride_dQ, + args.softmax_scale }, + { args.ptr_dK, args.stride_dK, + args.ptr_dV, args.stride_dV }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // scaled LSE vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments( + args, sum_OdO, args_sum_OdO.stride_sum_OdO, + scaled_lse, args_sum_OdO.stride_scaled_lse, + dQ_acc, args_convert.stride_src_dQ + ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cu b/csrc/sm100/fmha_cutlass_bwd_sm100.cu new file mode 100644 index 0000000..4ff745d --- /dev/null +++ b/csrc/sm100/fmha_cutlass_bwd_sm100.cu @@ -0,0 +1,83 @@ +#include +#include +#include +#include +#include +#include "common/mask.cuh" +#include "common/utils.hpp" + +#include "fmha_cutlass_bwd_sm100.cuh" + +template +void call_run_fmha_bwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, + [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, [[maybe_unused]] Mla mla, + at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + float softmax_scale, int max_seqlen_q, int total_seqlen_kv) { + static constexpr bool IsVarlen = std::is_same_v; + static constexpr bool IsMla = std::is_same_v; + using TileShape = std::conditional_t, Shape<_128, _128, _128, _128>>; + run_fmha_bwd(workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, total_seqlen_kv); +} + + +void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen) { + + const c10::cuda::OptionalCUDAGuard device_guard(q.device()); + + int head_dim_qk = q.size(-1); + int head_dim_vo = v.size(-1); + MaskMode mask_mode = static_cast(mask_mode_code); + auto scalar_type_in = q.scalar_type(); + auto scalar_type_out = o.scalar_type(); + + if(scalar_type_in == at::ScalarType::BFloat16 && scalar_type_out == at::ScalarType::BFloat16) { + using Element = cutlass::bfloat16_t; + using ElementOut = cutlass::bfloat16_t; + + auto apply_config = [&](auto fn) { + if (mask_mode == MaskMode::kCausal) { + if(is_varlen) { + fn(CausalForBackwardMask{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(CausalForBackwardMask{}, cute::false_type{}, Element{}, ElementOut{}); + } + } + else { + if(is_varlen) { + fn(ResidualMaskForBackward{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(ResidualMaskForBackward{}, cute::false_type{}, Element{}, ElementOut{}); + } + } + }; + + apply_config([&](auto mask, auto varlen, auto in, auto out) { + if (head_dim_qk == 192 && head_dim_vo == 128) { + call_run_fmha_bwd(mask, varlen, in, out, true_type{}, workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, max_seqlen_kv); + } else if (head_dim_qk == 128 && head_dim_vo == 128) { + call_run_fmha_bwd(mask, varlen, in, out, false_type{}, workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, max_seqlen_kv); } + else { + std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk << " head_dim_vo=" << head_dim_vo << std::endl; + } + }); + + } else { + FLASH_MLA_ASSERT(false); + } +} diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cuh b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh new file mode 100644 index 0000000..2b19be2 --- /dev/null +++ b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh @@ -0,0 +1,200 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +#include "common/utils.hpp" +#include "collective/fmha_fusion.hpp" +#include "device/fmha_device_bwd.hpp" + +using namespace cute; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha; +using namespace cutlass; + + +template< + class DType, + bool kIsVarlen, + bool kIsMla, + class TileShape, + class ActiveMask +> +struct BwdRunner { + + using Element = DType; + using ElementAccumulator = float; + + // Q K D D_VO (H B) + using ProblemShape = std::conditional_t< + kIsVarlen, + cute::tuple>, + cute::tuple> + >; + + using Operation = cutlass::fmha::device::Sm100FmhaBwd; + + using TensorStride = Stride>; + using StrideQ = TensorStride; // Seq DQK (H B) + using StrideK = TensorStride; // Seq DQK (H B) + using StrideV = TensorStride; // Seq DVO (H B) + using StrideO = TensorStride; // Seq DVO (H B) + using StrideLSE = Stride<_1, Stride>; // Seq (H B) + + // Backwards specific + using StrideDQ = TensorStride; + using StrideDK = TensorStride; // Seq DQK (H B) + using StrideDV = TensorStride; // Seq DVO (H B) + using StrideDO = TensorStride; + + static void run(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + float softmax_scale, int max_seqlen_q, int max_seqlen_kv) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + ProblemShape problem_shape; + cute::tuple> tensor_shape; + + + int d = q.size(-1); + int d_vo = v.size(-1); + int batch_size = cumulative_seqlen_q.size(0) - 1; + int num_qo_heads = q.size(1); + int total_seqlen_q = q.size(0); + int total_seqlen_kv = k.size(0); + + //varlen: q: [Q, H, D] + //fixedlen: q: [B, H, Q, D] + if constexpr (kIsVarlen) { + problem_shape = cute::make_tuple( + VariableLength{max_seqlen_q, static_cast(cumulative_seqlen_q.data_ptr()), total_seqlen_q}, + VariableLength{max_seqlen_kv, static_cast(cumulative_seqlen_kv.data_ptr()), total_seqlen_kv}, + d, d_vo, cute::make_tuple(num_qo_heads, batch_size)); + tensor_shape = make_shape(total_seqlen_q, total_seqlen_kv, d, d_vo, make_shape(num_qo_heads, 1)); + } else { + int q_len = total_seqlen_q / batch_size; + int kv_len = total_seqlen_kv / batch_size; + problem_shape = cute::make_tuple(q_len, kv_len, d, d_vo, cute::make_tuple(num_qo_heads, batch_size)); + tensor_shape = problem_shape; + } + + auto [Q, K, D, D_VO, HB] = tensor_shape; + auto [H, B] = HB; + + int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2); + int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2); + int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2); + int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2); + int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1); + int dq_stride0 = dq.stride(0), dq_stride1 = dq.stride(1), dq_stride2 = dq.stride(2); + int dk_stride0 = dk.stride(0), dk_stride1 = dk.stride(1), dk_stride2 = dk.stride(2); + int dv_stride0 = dv.stride(0), dv_stride1 = dv.stride(1), dv_stride2 = dv.stride(2); + int do_stride0 = d_o.stride(0), do_stride1 = d_o.stride(1), do_stride2 = d_o.stride(2); + TORCH_CHECK(q_stride2 == 1); + TORCH_CHECK(k_stride2 == 1); + TORCH_CHECK(v_stride2 == 1); + TORCH_CHECK(o_stride2 == 1); + TORCH_CHECK(lse_stride0 == 1); + TORCH_CHECK(dq_stride2 == 1); + TORCH_CHECK(dk_stride2 == 1); + TORCH_CHECK(dv_stride2 == 1); + TORCH_CHECK(do_stride2 == 1); + + StrideQ stride_Q = make_stride(q_stride0, _1{}, make_stride(q_stride1, B == 1 ? 0 : q_stride0*Q)); + StrideK stride_K = make_stride(k_stride0, _1{}, make_stride(k_stride1, B == 1 ? 0 : k_stride0*K)); + StrideV stride_V = make_stride(v_stride0, _1{}, make_stride(v_stride1, B == 1 ? 0 : v_stride0*K)); + StrideO stride_O = make_stride(o_stride0, _1{}, make_stride(o_stride1, B == 1 ? 0 : o_stride0*Q)); + StrideLSE stride_LSE = make_stride(_1{}, make_stride(lse_stride1, B == 1 ? 0 : Q)); + + StrideDQ stride_dQ = make_stride(dq_stride0, _1{}, make_stride(dq_stride1, B == 1 ? 0 : dq_stride0*Q)); + StrideDK stride_dK = make_stride(dk_stride0, _1{}, make_stride(dk_stride1, B == 1 ? 0 : dk_stride0*K)); + StrideDV stride_dV = make_stride(dv_stride0, _1{}, make_stride(dv_stride1, B == 1 ? 0 : dv_stride0*K)); + StrideDO stride_dO = make_stride(do_stride0, _1{}, make_stride(do_stride1, B == 1 ? 0 : do_stride0*Q)); + + typename Operation::Arguments arguments{ + problem_shape, + (static_cast(q.data_ptr())), stride_Q, + (static_cast(k.data_ptr())), stride_K, + (static_cast(v.data_ptr())), stride_V, + (static_cast(o.data_ptr())), stride_O, + (static_cast(lse.data_ptr())), stride_LSE, + (static_cast(d_o.data_ptr())), stride_dO, + (static_cast(dq.data_ptr())), stride_dQ, + (static_cast(dk.data_ptr())), stride_dK, + (static_cast(dv.data_ptr())), stride_dV, + static_cast(softmax_scale), + hw_info + }; + + Operation op; + + size_t workspace_size = 0; + workspace_size = Operation::get_workspace_size(arguments); + DeviceAllocation workspace(workspace_size); + uint8_t* workspace_ptr = workspace.get(); + + CUTLASS_CHECK(op.can_implement(arguments)); + CUTLASS_CHECK(op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); + } + +}; + + +template +void run_fmha_bwd(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + float softmax_scale, int max_seqlen_q, int total_seqlen_kv) { + BwdRunner::run(workspace_buffer, d_o, q, k, v, o, lse, + cumulative_seqlen_q, cumulative_seqlen_kv, + dq, dk, dv, + softmax_scale, max_seqlen_q, total_seqlen_kv); +} diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cu b/csrc/sm100/fmha_cutlass_fwd_sm100.cu new file mode 100644 index 0000000..e322709 --- /dev/null +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cu @@ -0,0 +1,81 @@ +#include "common/mask.cuh" +#include "common/utils.hpp" +#include "fmha_cutlass_fwd_sm100.cuh" + +#include +#include +#include +#include +#include + +template +void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, + [[maybe_unused]] Element in, [[maybe_unused]] ElementOut out, + [[maybe_unused]] Mla mla, at::Tensor workspace_buffer, at::Tensor q, + at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, + float softmax_scale, int max_seqlen_q, int max_seqlen_kv) { + static constexpr bool IsVarlen = std::is_same_v; + static constexpr bool IsMla = std::is_same_v; + static constexpr bool IsCausalMask = std::is_same_v>; + using Option = std::conditional_t, + Option>; + + run_fmha_fwd( + workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, + softmax_scale, max_seqlen_q, max_seqlen_kv); +} + +void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, + int mask_mode_code, float sm_scale, int max_seqlen_q, + int max_seqlen_kv, bool is_varlen) { + const c10::cuda::OptionalCUDAGuard device_guard(q.device()); + CHECK(q.scalar_type() == k.scalar_type()); + auto scalar_type_in = q.scalar_type(); + auto scalar_type_out = o.scalar_type(); + int head_dim_qk = q.size(-1); + int head_dim_vo = v.size(-1); + MaskMode mask_mode = static_cast(mask_mode_code); + + if (scalar_type_in == at::ScalarType::BFloat16 && + scalar_type_out == at::ScalarType::BFloat16) { + using Element = cutlass::bfloat16_t; + using ElementOut = cutlass::bfloat16_t; + + auto apply_config = [&](auto fn) { + if (mask_mode == MaskMode::kCausal) { + if (is_varlen) { + fn(CausalMask{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(CausalMask{}, cute::false_type{}, Element{}, ElementOut{}); + } + } else { + if (is_varlen) { + fn(ResidualMask{}, cute::true_type{}, Element{}, ElementOut{}); + } else { + fn(ResidualMask{}, cute::false_type{}, Element{}, ElementOut{}); + } + } + }; + + apply_config([&](auto mask, auto varlen, auto in, auto out) { + if (head_dim_qk == 192 && head_dim_vo == 128) { + call_run_fmha_fwd(mask, varlen, in, out, true_type{}, workspace_buffer, q, k, v, + cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, + max_seqlen_q, max_seqlen_kv); + } else if (head_dim_qk == 128 && head_dim_vo == 128) { + call_run_fmha_fwd(mask, varlen, in, out, false_type{}, workspace_buffer, q, k, v, + cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, sm_scale, + max_seqlen_q, max_seqlen_kv); + } else { + std::cout << "No kernel instantiated for head_dim_qk=" << head_dim_qk + << " head_dim_vo=" << head_dim_vo << std::endl; + } + }); + + } else { + FLASH_MLA_ASSERT(false); + } +} diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cuh b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh new file mode 100644 index 0000000..71831bb --- /dev/null +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh @@ -0,0 +1,334 @@ +#pragma once + +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp" +#include "collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp" +#include "collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.h" +#include "device/fmha.hpp" +#include "kernel/fmha_causal_tile_scheduler.hpp" +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp" + +#include +#include + +using namespace cute; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::device; + +struct FmhaOptions { + int b = 1; + int h = 1; + int h_k = 1; + int q = 256; + int k = 256; + int d = 128; +}; + +struct MlaOptions { + int b = 1; + int h = 1; + int h_k = 1; + int q = 256; + int k = 256; + int dl = 128; // headdim latent + int dr = 64; // headdim rope +}; + +template +struct FwdRunner { + + using Element = Element_; + using ElementAccumulatorQK = float; + using ElementAccumulatorPV = float; + using ElementOut = ElementOut_; + + using HeadDimLatent = _128; + using HeadDim = Shape; + using TileShapeMla = Shape<_256, _128, HeadDim>; + using TileShapeFmha = Shape<_256, _128, _128>; + using TileShape = std::conditional_t; + + using ProblemShapeRegular = std::conditional_t< + kIsMla, + cute::tuple, cute::tuple, int>>, + cute::tuple, int>>>; + + using ProblemShapeVarlen = + std::conditional_t, + cute::tuple, int>>, + cute::tuple, int>>>; + + using ProblemShapeType = + std::conditional_t; + + using StrideQ = cute::tuple, int>>; + using StrideK = cute::tuple, int>>; + using StrideV = StrideK; + using StrideO = StrideQ; + using StrideLSE = cute::tuple<_1, cute::tuple, int>>; + + static constexpr bool kIsPersistent = + find_option_t::value; + + using TileScheduler = std::conditional_t< + kIsPersistent, + std::conditional_t> || + std::is_same_v>, + cutlass::fmha::kernel::CausalPersistentTileScheduler, + cutlass::fmha::kernel::PersistentTileScheduler>, + std::conditional_t>; + + static constexpr bool IsOrderLoadEpilogue = + kIsPersistent && (sizeof(Element) == sizeof(ElementOut)); + using OrderLoadEpilogue = std::conditional_t; + + using MainloopMla = cutlass::fmha::collective::Sm100MlaFwdMainloopTmaWarpspecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeMla, StrideQ, StrideK, + StrideV, ActiveMask, Shape<_2, _1, _1>, OrderLoadEpilogue>; + + using OperationMla = + cutlass::fmha::device::FMHA, + TileScheduler, cutlass::fmha::kernel::Sm100MlaFwdCtxKernelWarpspecializedSchedule>>; + + using MainloopFmha = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeFmha, StrideQ, StrideK, + StrideV, ActiveMask>; + + using OperationFmha = + cutlass::fmha::device::FMHA, + TileScheduler>>; + + using Mainloop = std::conditional_t; + using Operation = std::conditional_t; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + + template + auto initialize_varlen(const ProblemShape &problem_size, int max_seqlen_q, int max_seqlen_kv, + int total_seqlen_q, int total_seqlen_kv) { + + int num_batches = get<3, 1>(problem_size); + + ProblemShape problem_size_for_init = problem_size; + get<3, 1>(problem_size_for_init) = 1; + get<0>(problem_size_for_init) = total_seqlen_q; + get<1>(problem_size_for_init) = total_seqlen_kv; + + ProblemShapeType problem_size_for_launch; + + get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q}; + get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv}; + get<2>(problem_size_for_launch) = get<2>(problem_size); + get<3>(problem_size_for_launch) = get<3>(problem_size); + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + + template + static constexpr auto get_problem_shape(const Options &options) { + int h_r = options.h / options.h_k; + if constexpr (std::is_same_v) { + return cute::make_tuple(options.q, options.k, cute::make_tuple(options.dl, options.dr), + cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); + } else { + return cute::make_tuple(options.q, options.k, options.d, + cute::make_tuple(cute::make_tuple(h_r, options.h_k), options.b)); + } + } + + template + ProblemShapeType initialize(const Options &options, int max_seqlen_q, int max_seqlen_kv, + int total_seqlen_q, int total_seqlen_kv, + void *cumulative_length_q, void *cumulative_length_kv) { + assert(options.h % options.h_k == 0); + auto problem_shape_in = get_problem_shape(options); + + ProblemShapeType problem_shape; + decltype(problem_shape_in) problem_size; + + if constexpr (kIsVarlen) { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen( + problem_shape_in, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv); + problem_shape = problem_shape_launch; + problem_size = problem_shape_init; + } else { + problem_size = problem_shape_in; + problem_shape = problem_shape_in; + } + + auto get_head_dimension = [&]() { + if constexpr (rank_v(problem_shape))> == 2) { + return cute::make_tuple(size<2, 0>(problem_shape) + size<2, 1>(problem_shape), + size<2, 0>(problem_shape)); + } else { + return cute::make_tuple(size<2>(problem_size), size<2>(problem_size)); + } + }; + + + if constexpr (kIsVarlen) { + get<0>(problem_shape).cumulative_length = static_cast(cumulative_length_q); + get<1>(problem_shape).cumulative_length = static_cast(cumulative_length_kv); + } + + return problem_shape; + } + + auto get_arguments(const ProblemShapeType &problem_shape, + const cutlass::KernelHardwareInfo &hw_info, float scale_softmax, + void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr, + void *cumulative_length_q, void *cumulative_length_kv) { + auto problem_shape_ = problem_shape; + if constexpr (kIsVarlen) { + get<0>(problem_shape_).cumulative_length = static_cast(cumulative_length_q); + get<1>(problem_shape_).cumulative_length = static_cast(cumulative_length_kv); + } + + typename Operation::Arguments arguments{ + problem_shape_, + {static_cast(q_ptr), stride_Q, static_cast(k_ptr), stride_K, + static_cast(v_ptr), stride_V, scale_softmax}, + {static_cast(o_ptr), stride_O, + static_cast(lse_ptr), stride_LSE}, + hw_info}; + + return arguments; + } + + template + void run(const Options &options, const cutlass::KernelHardwareInfo &hw_info, at::Tensor q, + at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, float scale_softmax, + at::Tensor workspace, at::Tensor cumulative_seqlen_q, + at::Tensor cumulative_seqlen_kv, int max_seqlen_q, int max_seqlen_kv) { + + int total_seqlen_q = q.size(0); + int total_seqlen_kv = k.size(0); + ProblemShapeType problem_shape = + initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv, + cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); + + int SQ = size<0>(problem_shape); + int SK = size<1>(problem_shape); + int B = size<3, 1>(problem_shape); + int H = size<3, 0>(problem_shape); + int H_K = size<3, 0, 1>(problem_shape); + int H_Q = size<3, 0, 0>(problem_shape); + + int q_stride0 = q.stride(0), q_stride1 = q.stride(1), q_stride2 = q.stride(2); + int k_stride0 = k.stride(0), k_stride1 = k.stride(1), k_stride2 = k.stride(2); + int v_stride0 = v.stride(0), v_stride1 = v.stride(1), v_stride2 = v.stride(2); + int o_stride0 = o.stride(0), o_stride1 = o.stride(1), o_stride2 = o.stride(2); + int lse_stride0 = lse.stride(0), lse_stride1 = lse.stride(1); + TORCH_CHECK(q_stride2 == 1); + TORCH_CHECK(k_stride2 == 1); + TORCH_CHECK(v_stride2 == 1); + TORCH_CHECK(o_stride2 == 1); + TORCH_CHECK(lse_stride0 == 1); + + stride_Q = make_stride(q_stride0, _1{}, make_stride(make_stride(q_stride1, H_Q * q_stride1), SQ * q_stride0)); + stride_O = make_stride(o_stride0, _1{}, make_stride(make_stride(o_stride1, H_Q * o_stride1), SQ * o_stride0)); + stride_K = make_stride(k_stride0, _1{}, make_stride(make_stride(_0{}, k_stride1), SK * k_stride0)); + stride_V = make_stride(v_stride0, _1{}, make_stride(make_stride(_0{}, v_stride1), SK * v_stride0)); + stride_LSE = make_stride(_1{}, make_stride(make_stride(lse_stride1, lse_stride1 * H_Q), SQ)); + + if constexpr (kIsVarlen) { + get<2, 1>(stride_Q) = 0; + get<2, 1>(stride_K) = 0; + get<2, 1>(stride_V) = 0; + get<2, 1>(stride_O) = 0; + get<1, 1>(stride_LSE) = 0; + } + + typename Operation::Arguments arguments = + get_arguments(problem_shape, hw_info, scale_softmax, q.data_ptr(), k.data_ptr(), + v.data_ptr(), o.data_ptr(), lse.data_ptr(), + cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); + + Operation op; + + // size_t workspace_size = 0; + // workspace_size = Operation::get_workspace_size(arguments); + + // todo: if use workspace, need check workspace size first. + // we don't use workspace in current version. + + CUTLASS_CHECK(op.can_implement(arguments)); + CUTLASS_CHECK(op.initialize(arguments, nullptr)); + CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); + } +}; + +template +void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, + at::Tensor lse, float scale_softmax, int max_seqlen_q, int max_seqlen_kv) { + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + auto get_options = [&]() { + if constexpr (kIsMla) { + MlaOptions options; + options.b = cumulative_seqlen_q.size(0) - 1; + options.h = q.size(1); + options.h_k = k.size(1); + options.q = q.size(0) / options.b; + options.k = k.size(0) / options.b; + options.dl = v.size(-1); + options.dr = q.size(-1) - v.size(-1); + return options; + } else { + FmhaOptions options; + options.b = cumulative_seqlen_q.size(0) - 1; + options.h = q.size(1); + options.h_k = k.size(1); + options.q = q.size(0) / options.b; + options.k = k.size(0) / options.b; + options.d = q.size(-1); + return options; + } + }; + + auto options = get_options(); + + if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 && + (!std::is_same_v)) { + FwdRunner runner; + runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, + cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); + } else { + FwdRunner runner; + runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, + cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); + } +} diff --git a/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp new file mode 100644 index 0000000..572e67f --- /dev/null +++ b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +// Swizzle Q tile and H tile to improve L2 cache hit rate, +// and launch the longest main loop first to keep most SMs busy. + +struct CausalIndividualTileScheduler { + + static constexpr int TileQ = 16; + static constexpr int TileH = 8; + static constexpr int TileSize = TileQ * TileH; + + struct Params { + dim3 grid; + int tile_max_q; + FastDivmod divmod_tile_col; + FastDivmod divmod_tile_size; + FastDivmod divmod_tile_head; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + CausalIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + + dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size)); + // gridDim.x must multiple of TileH + const int tile_col_count = grid.x / TileH; + const int tile_max_q = grid.y / TileQ * TileQ; + return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH}; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + const int block_idx = blockIdx.y * gridDim.x + blockIdx.x; + + int tile_idx, tile_tail; + params.divmod_tile_size(tile_idx, tile_tail, block_idx); + + int tile_row_idx, tile_col_idx; + params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx); + + int row_offset_in_tail, col_offset_in_tail; + params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail); + + const int row_idx = tile_row_idx * TileQ + row_offset_in_tail; + const int col_idx = tile_col_idx * TileH + col_offset_in_tail; + + // last q tile launch first + if(blockIdx.y >= params.tile_max_q) { + return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z))); + } + + return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z))); + } + + CUTLASS_DEVICE + CausalIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +// Launch order: H Q B +struct CausalPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_h; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_h(block_decode, bidh, block_decode); + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + CausalPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp b/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 0000000..32e007c --- /dev/null +++ b/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,153 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + ProblemShape problem_shape; + + const ElementAcc* ptr_src_dQ; + tuple> stride_src_dQ; + const ElementAcc* ptr_src_dK; + tuple> stride_src_dK; + const ElementAcc* ptr_src_dV; + tuple> stride_src_dV; + + Element* ptr_dest_dQ; + tuple> stride_dest_dQ; + Element* ptr_dest_dK; + tuple> stride_dest_dK; + Element* ptr_dest_dV; + tuple> stride_dest_dV; + + ElementAcc scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) { + auto ptr_src_bh = ptr_src + get<2,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<2,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; + + int seqlen = count; + if constexpr (is_variable_length_v) { + int offset = count.cumulative_length[blockIdx.y]; + ptr_dest_bh += offset * get<0>(stride_dest); + seqlen = count.cumulative_length[blockIdx.y + 1] - offset; + } + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= seqlen) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAcc value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = static_cast(params.scale * value_src[v]); + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 0000000..bdcf1cb --- /dev/null +++ b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + ProblemShape problem_shape; + + const Element* ptr_O; + cute::tuple> stride_O; + const Element* ptr_dO; + cute::tuple> stride_dO; + + ElementAcc* ptr_sum_OdO; + cute::tuple> stride_sum_OdO; + + const ElementAcc* ptr_lse = nullptr; + cute::tuple> stride_lse; + + ElementAcc* ptr_scaled_lse = nullptr; + cute::tuple> stride_scaled_lse; + + ElementAcc sum_odo_scale = 1.0; + ElementAcc lse_scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); + auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); + auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); + + auto problem_q = get<0>(params.problem_shape); + int seqlen_q = problem_q; + if constexpr (is_variable_length_v) { + int offset = problem_q.cumulative_length[blockIdx.z]; + ptr_O_bh += offset * get<0>(params.stride_O); + ptr_dO_bh += offset * get<0>(params.stride_dO); + ptr_lse_bh += offset * get<0>(params.stride_lse); + seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset; + } + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= seqlen_q) continue; + ElementAcc acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO); + auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); + auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = params.sum_odo_scale * acc; + if (params.ptr_scaled_lse) { + *ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq; + } + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_options.hpp b/csrc/sm100/kernel/fmha_options.hpp new file mode 100644 index 0000000..d4faa8d --- /dev/null +++ b/csrc/sm100/kernel/fmha_options.hpp @@ -0,0 +1,85 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +template +struct find_option; + +template +struct find_option { + using option_value = Default; +}; + +template +struct find_option : + std::conditional_t< + Option::tag == kTag, + Option, + find_option + > +{}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK +}; + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/fmha_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 0000000..119f069 --- /dev/null +++ b/csrc/sm100/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size)); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct PersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_h; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000..59b410b --- /dev/null +++ b/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1841 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeK = decltype(get<1>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<2>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{}); + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kS + TileShapeQ{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeKQ = typename CollectiveMmaKQ::TileShape; + using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma; + + // compute dP + using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeVDO = typename CollectiveMmaVDO::TileShape; + using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{})); + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaKQ::Params::TMA_A; + using TMA_V = typename CollectiveMmaVDO::Params::TMA_A; + using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B; + using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H <= 0 || B <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaKQ::to_underlying_arguments( + make_shape(K, Q, D, HB), + typename CollectiveMmaKQ::Arguments { + args.mainloop.ptr_k, args.mainloop.stride_k, + args.mainloop.ptr_q, args.mainloop.stride_q, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaVDO::to_underlying_arguments( + make_shape(K, Q, D_VO, HB), + typename CollectiveMmaVDO::Arguments { + args.mainloop.ptr_v, args.mainloop.stride_v, + args.mainloop.ptr_do, args.mainloop.stride_do, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_a, + params_vdo.tma_load_a, + params_kq.tma_load_b, + params_vdo.tma_load_b, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); + auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_A(gK); + auto tSTgQ = cta_mma_kq.partition_B(gQ); + auto tDPTgV = cta_mma_vdo.partition_A(gV); + auto tDPTgDO = cta_mma_vdo.partition_B(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading 128 values of 32b each + // so 4*32b=128b + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK); + Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ); + + Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV); + Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + tDVrP.data() = TmemAllocation::kP; + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaKQ tiled_mma_kq; + TiledMmaVDO tiled_mma_vdo; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + auto tCg = thr_copy.partition_D(gmem); + auto tCr = thr_copy.partition_S(quantize(regs)); + auto tCc = thr_copy.partition_D(coord); + + + constexpr int R = decltype(tCr.layout())::rank; + auto tCg_v = group_modes<1, R>(tCg); + auto tCr_v = group_modes<1, R>(tCr); + auto tCc_v = group_modes<1, R>(tCc); + auto tCp_v = make_tensor(shape<1>(tCc_v)); + + for (int i = 0; i < size(tCp_v); ++i) { + tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); + } + + copy_if(copy_op, tCp_v, tCr_v, tCg_v); + + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + auto store_op = []() { + if constexpr (sizeof(Element) == 1) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else { + return SM100_TMEM_STORE_32dp32b8x{}; + } + }(); + + Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + + } + }; + + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + + auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST); + tDVrP.data() = TmemAllocation::kP; + + auto tiled_r2t = make_tmem_copy(store_op, tDVrP); + auto thread_r2t = tiled_r2t.get_slice(dp_idx); + + auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); + auto tRT_cST_p = thread_r2t.partition_S(tDVcST); + auto tRT_cST = split_wg(tRT_cST_p); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); + }); + + // notify for P + cutlass::arch::fence_view_async_tmem_store(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_128{}, _128{}), + make_stride(_1{}, _0{}) + ); + + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + iter_count -= iter_start; + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000..5a58157 --- /dev/null +++ b/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1834 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + using TileShapeK = decltype(get<1>(TileShape{})); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<3>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + 65536 * 16; + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kDQ + TileShapeDQK{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + + static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp; + static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, "TileShapeQ must be divisible by NumThreadsPerWarp"); + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + + // compute dP + using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDOV = typename CollectiveMmaDOV::TileShape; + using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma; + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{})); + using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + union{ + alignas(2048) cute::array> smem_p; + alignas(2048) cute::array> smem_p_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaDOV::Params::TMA_B; + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaQK::to_underlying_arguments( + make_shape(Q, K, D, HB), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q, args.mainloop.stride_q, + args.mainloop.ptr_k, args.mainloop.stride_k, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaDOV::to_underlying_arguments( + make_shape(Q, K, D_VO, HB), + typename CollectiveMmaDOV::Arguments { + args.mainloop.ptr_do, args.mainloop.stride_do, + args.mainloop.ptr_v, args.mainloop.stride_v, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_b, + params_vdo.tma_load_b, + params_kq.tma_load_a, + params_vdo.tma_load_a, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step{}); + auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_B(gK); + auto tSTgQ = cta_mma_kq.partition_A(gQ); + auto tDPTgV = cta_mma_vdo.partition_B(gV); + auto tDPTgDO = cta_mma_vdo.partition_A(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading kLoadPerThread * 32 values of 32b each + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaQK::make_fragment_B(sK); + Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ); + + Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV); + Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP); + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaQK tiled_mma_qk; + TiledMmaDOV tiled_mma_dov; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + auto tCg = thr_copy.partition_D(gmem); + auto tCr = thr_copy.partition_S(quantize(regs)); + auto tCc = thr_copy.partition_D(coord); + + + constexpr int R = decltype(tCr.layout())::rank; + auto tCg_v = group_modes<1, R>(tCg); + auto tCr_v = group_modes<1, R>(tCr); + auto tCc_v = group_modes<1, R>(tCc); + auto tCp_v = make_tensor(shape<1>(tCc_v)); + + for (int i = 0; i < size(tCp_v); ++i) { + tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); + } + + copy_if(copy_op, tCp_v, tCr_v, tCg_v); + + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_16dp32b32x{}; + + Tensor tSTtST = partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{})); + Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + } + }; + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cPT_p = thread_t2r.partition_D(cPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}) + (_, _, _, pipeline_compute_mma_p_producer_state.index()); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + auto sP_pi = as_position_independent_swizzle_tensor(sP); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p))); + auto sP_pi_slice = split_wg(sP_pi_slice_p); + copy_aligned(tRT_rST, sP_pi_slice); + }); + + // notify for P + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape (tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_16dp32b16x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + iter_count -= iter_start; + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000..8fe503b --- /dev/null +++ b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,619 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" + +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/fmha_causal_tile_scheduler.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/fmha_common.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + + +struct Sm100MlaFwdCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 184; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + +template< + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule +> +struct Sm100FmhaFwdKernelTmaWarpspecialized { + + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = ProblemShapeIn; + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue); + static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad); + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + static constexpr bool IsMla = std::is_same_v; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + using UnionType = union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + using StructType = struct { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + static constexpr bool IsPersistent = std::is_same_v || std::is_same_v; + using MainloopEpilogueStorage = std::conditional_t, + StructType>, + UnionType>; + + MainloopEpilogueStorage mainloop_epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct Arguments { + ProblemShape problem_shape; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + return apply_variable_length(params.problem_shape, batch_idx); + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + auto get_epilogue_storage = [&]() { + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + return reinterpret_cast(shared_storage.mainloop_epilogue.mainloop.smem_o.data()); + } else { + return &shared_storage.mainloop_epilogue.epilogue; + } + }; + typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage(); + + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr( + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr( + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineE pipeline_corr_epi( + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01( + shared_storage.pipelines.order_s01, params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue{params.epilogue}; + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, + params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01 + ); + + } + } + else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + mainloop.correction_empty( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + continue; + } + + mainloop.correction( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, + pipeline_s1_corr, pipeline_s1_corr_consumer_state, + pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.mma( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, + pipeline_load_kv, pipeline_load_kv_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, + pipeline_mma_s1, pipeline_mma_s1_producer_state, + pipeline_mma_corr, pipeline_mma_corr_producer_state + ); + + } + } + else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.load( + blk_coord, logical_problem_shape, + params.mainloop, params.problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state + ); + + } + } + else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + epilogue.store( + blk_coord, logical_problem_shape, + params.epilogue, params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_consumer_state + ); + + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } + } + +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/sm100/pybind.cu b/csrc/sm100/pybind.cu new file mode 100644 index 0000000..7d4744d --- /dev/null +++ b/csrc/sm100/pybind.cu @@ -0,0 +1,17 @@ +#include + +void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor o, at::Tensor lse, + int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); + +void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, + at::Tensor v, at::Tensor o, at::Tensor lse, + at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, + at::Tensor dq, at::Tensor dk, at::Tensor dv, + int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &FMHACutlassSM100FwdRun); + m.def("bwd", &FMHACutlassSM100BwdRun); +} diff --git a/csrc/flash_api.cpp b/csrc/sm90/flash_api.cpp similarity index 100% rename from csrc/flash_api.cpp rename to csrc/sm90/flash_api.cpp diff --git a/csrc/kernels/config.h b/csrc/sm90/kernels/config.h similarity index 100% rename from csrc/kernels/config.h rename to csrc/sm90/kernels/config.h diff --git a/csrc/kernels/get_mla_metadata.cu b/csrc/sm90/kernels/get_mla_metadata.cu similarity index 100% rename from csrc/kernels/get_mla_metadata.cu rename to csrc/sm90/kernels/get_mla_metadata.cu diff --git a/csrc/kernels/get_mla_metadata.h b/csrc/sm90/kernels/get_mla_metadata.h similarity index 100% rename from csrc/kernels/get_mla_metadata.h rename to csrc/sm90/kernels/get_mla_metadata.h diff --git a/csrc/kernels/mla_combine.cu b/csrc/sm90/kernels/mla_combine.cu similarity index 100% rename from csrc/kernels/mla_combine.cu rename to csrc/sm90/kernels/mla_combine.cu diff --git a/csrc/kernels/mla_combine.h b/csrc/sm90/kernels/mla_combine.h similarity index 100% rename from csrc/kernels/mla_combine.h rename to csrc/sm90/kernels/mla_combine.h diff --git a/csrc/kernels/params.h b/csrc/sm90/kernels/params.h similarity index 100% rename from csrc/kernels/params.h rename to csrc/sm90/kernels/params.h diff --git a/csrc/kernels/splitkv_mla.cu b/csrc/sm90/kernels/splitkv_mla.cu similarity index 100% rename from csrc/kernels/splitkv_mla.cu rename to csrc/sm90/kernels/splitkv_mla.cu diff --git a/csrc/kernels/splitkv_mla.h b/csrc/sm90/kernels/splitkv_mla.h similarity index 100% rename from csrc/kernels/splitkv_mla.h rename to csrc/sm90/kernels/splitkv_mla.h diff --git a/csrc/kernels/traits.h b/csrc/sm90/kernels/traits.h similarity index 100% rename from csrc/kernels/traits.h rename to csrc/sm90/kernels/traits.h diff --git a/csrc/kernels/utils.h b/csrc/sm90/kernels/utils.h similarity index 100% rename from csrc/kernels/utils.h rename to csrc/sm90/kernels/utils.h diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index 51b8600..d0e6faf 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -3,4 +3,7 @@ from flash_mla.flash_mla_interface import ( get_mla_metadata, flash_mla_with_kvcache, + flash_attn_varlen_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, ) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 47637f8..9c669ba 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -2,7 +2,9 @@ import torch -import flash_mla_cuda +import flash_mla_sm90 +import flash_mla_sm100 + def get_mla_metadata( @@ -20,10 +22,10 @@ def get_mla_metadata( tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + return flash_mla_sm90.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) -def flash_mla_with_kvcache( +def flash_mla_with_kvcache_sm90( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, @@ -52,7 +54,7 @@ def flash_mla_with_kvcache( """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( + out, softmax_lse = flash_mla_sm90.fwd_kvcache_mla( q, k_cache, head_dim_v, @@ -64,3 +66,264 @@ def flash_mla_with_kvcache( num_splits, ) return out, softmax_lse + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_sm100.fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * qo_total_len * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_sm100.bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ): + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_mla_with_kvcache_sm100( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO + pass + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + capability = torch.cuda.get_device_capability(q.device.index) + if capability == (9, 0): + return flash_mla_with_kvcache_sm90( + q, k_cache, block_table, cache_seqlens, head_dim_v, + tile_scheduler_metadata, num_splits, + softmax_scale, causal, + ) + elif capability == (10, 0): + raise ValueError(f"Unsupported device capability: {capability}") + else: + raise ValueError(f"Unsupported device capability: {capability}") diff --git a/setup.py b/setup.py index 217f540..58cf7b2 100644 --- a/setup.py +++ b/setup.py @@ -27,9 +27,13 @@ def get_features_args(): subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) -cc_flag = [] -cc_flag.append("-gencode") -cc_flag.append("arch=compute_90a,code=sm_90a") +cc_flag_sm90 = [] +cc_flag_sm90.append("-gencode") +cc_flag_sm90.append("arch=compute_90a,code=sm_90a") + +cc_flag_sm100 = [] +cc_flag_sm100.append("-gencode") +cc_flag_sm100.append("arch=compute_100a,code=sm_100a") this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -41,12 +45,12 @@ def get_features_args(): ext_modules = [] ext_modules.append( CUDAExtension( - name="flash_mla_cuda", + name="flash_mla_sm90", sources=[ - "csrc/flash_api.cpp", - "csrc/kernels/get_mla_metadata.cu", - "csrc/kernels/mla_combine.cu", - "csrc/kernels/splitkv_mla.cu", + "csrc/sm90/flash_api.cpp", + "csrc/sm90/kernels/get_mla_metadata.cu", + "csrc/sm90/kernels/mla_combine.cu", + "csrc/sm90/kernels/splitkv_mla.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), @@ -66,12 +70,49 @@ def get_features_args(): "--use_fast_math", "--ptxas-options=-v,--register-usage-level=10" ] - + cc_flag + + cc_flag_sm90 ) + get_features_args(), }, include_dirs=[ - Path(this_dir) / "csrc", + Path(this_dir) / "csrc" / "sm90", + Path(this_dir) / "csrc" / "cutlass" / "include", + ], + ) +) + +ext_modules.append( + CUDAExtension( + name="flash_mla_sm100", + sources=[ + "csrc/sm100/pybind.cu", + "csrc/sm100/fmha_cutlass_fwd_sm100.cu", + "csrc/sm100/fmha_cutlass_bwd_sm100.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"], + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-DNDEBUG", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "-lineinfo", + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", + ] + + cc_flag_sm100 + ), + }, + include_dirs=[ + Path(this_dir) / "csrc" / "sm100", Path(this_dir) / "csrc" / "cutlass" / "include", + Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", ], ) ) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla_sm90.py similarity index 100% rename from tests/test_flash_mla.py rename to tests/test_flash_mla_sm90.py diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py new file mode 100644 index 0000000..832c9fb --- /dev/null +++ b/tests/test_fmha_sm100.py @@ -0,0 +1,199 @@ +import random + +import torch +from torch.utils.checkpoint import checkpoint +import triton + +from flash_mla import flash_attn_varlen_func + + +def get_window_size(causal, window): + if window > 0: + window_size = (window - 1, 0) if causal else (window - 1, window - 1) + else: + window_size = (-1, -1) + return window_size + + +def get_attn_bias(s_q, s_k, causal, window): + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32) + if causal: + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + if window > 0: + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q - window) + attn_bias.masked_fill_(temp_mask, float("-inf")) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q + window - 1) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + return attn_bias + + +def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + x, y = x.double(), y.double() + RMSE = ((x - y) * (x - y)).mean().sqrt().item() + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + amax_diff = (x - y).abs().max().item() + # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") + assert cos_diff < 1e-5, f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}" + + +def sdpa(query, key, value, attn_bias, softmax_scale=None): + key = key.repeat_interleave(h // h_k, dim=-3) + value = value.repeat_interleave(h // h_k, dim=-3) + if softmax_scale is None: + softmax_scale = query.shape[-1] ** (-0.5) + attn_weight = query @ key.transpose(-2, -1) * softmax_scale + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight.to(query.dtype) @ value, lse + + +def sdpa_checkpoint(*args, **kwargs): + return checkpoint(sdpa, *args, use_reentrant=False, **kwargs) + + +def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd): + print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}") + torch.manual_seed(0) + random.seed(0) + + seqlens_q = torch.full((b,), mean_sq, dtype=torch.int32) + seqlens_k = torch.full((b,), mean_sk, dtype=torch.int32) + + if varlen: + for i in range(b): + seqlens_q[i] = max(random.normalvariate(mean_sq, mean_sq / 2), 1) + for i in range(b): + seqlens_k[i] = max(random.normalvariate(mean_sk, mean_sk / 2), seqlens_q[i].item()) + cu_seqlens_q = torch.cumsum(torch.nn.functional.pad(seqlens_q, (1, 0)), 0, dtype=torch.int32) + cu_seqlens_k = torch.cumsum(torch.nn.functional.pad(seqlens_k, (1, 0)), 0, dtype=torch.int32) + total_q = seqlens_q.sum().item() + total_k = seqlens_k.sum().item() + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + total_attn_compute = sum([(get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), + causal, window) == 0).sum().item() for i in range(b)]) + # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") + + q = torch.randn(total_q, h, d) + k = torch.randn(total_k, h_k, d) + v = torch.randn(total_k, h_k, dv) + grad_out = torch.randn(total_q, h, dv) + softmax_scale = (d + 100) ** (-0.5) + + offst_q = total_q + offst_kv = total_k + + q1_with_buffer = torch.empty(total_q + total_q, h, d, device=device, dtype=dtype) + k1_with_buffer = torch.empty(offst_kv + total_k, h_k, d, device=device, dtype=dtype) + v1_with_buffer = torch.empty(offst_kv + total_k, h_k, dv, device=device, dtype=dtype) + q1_with_buffer[total_q:] = q + k1_with_buffer[offst_kv:] = k + v1_with_buffer[offst_kv:] = v + q1 = q1_with_buffer[offst_q:].requires_grad_() + k1 = k1_with_buffer[offst_kv:].requires_grad_() + v1 = v1_with_buffer[offst_kv:].requires_grad_() + + q2 = q.clone().requires_grad_() + k2 = k.clone().requires_grad_() + v2 = v.clone().requires_grad_() + + def flash_attn(): + q1.grad = k1.grad = v1.grad = None + kwargs = {} + if causal: + kwargs["causal"] = causal + if window != 0: + kwargs["window_size"] = get_window_size(causal, window) + return flash_attn_varlen_func(q1, k1, v1, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, + max_seqlen_k, softmax_scale=softmax_scale, is_varlen=varlen, **kwargs) + + def torch_attn(): + q2.grad = k2.grad = v2.grad = None + out = [] + lse = [] + for i in range(b): + OUT, LSE = sdpa_checkpoint( + q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()].float().transpose(-3, -2), + k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), + v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), + attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window), + softmax_scale=softmax_scale, + ) + out.append(OUT.transpose(-3, -2)) + lse.append(LSE.transpose(-2, -1)) + out = torch.cat(out) + lse = torch.cat(lse) + return out, lse + + out_flash, lse_flash = flash_attn() + out_torch, lse_torch = torch_attn() + assert_close(out_flash, out_torch, "out") + assert_close(lse_flash, lse_torch, "lse") + + if has_bwd: + out_flash.backward(grad_out, retain_graph=True) + out_torch.backward(grad_out, retain_graph=True) + assert_close(q1.grad, q2.grad, "dq") + assert_close(k1.grad, k2.grad, "dk") + assert_close(v1.grad, v2.grad, "dv") + dq1 = q1.grad.clone() + dk1 = k1.grad.clone() + dv1 = v1.grad.clone() + + def forward(): + return flash_attn() + + def backward(): + q1.grad = k1.grad = v1.grad = None + out_flash.backward(grad_out, retain_graph=True) + + for _ in range(5): + out, lse = forward() + assert torch.equal(out, out_flash), "out deterministic check failed!" + assert torch.equal(lse, lse_flash), "lse deterministic check failed!" + if has_bwd: + backward() + # assert torch.equal(q1.grad, dq1), "dq deterministic check failed!" + assert torch.equal(k1.grad, dk1), "dk deterministic check failed!" + assert torch.equal(v1.grad, dv1), "dv deterministic check failed!" + + # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + # forward() + # if has_bwd: + # backward() + # print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120)) + + def timer(func, name): + t = triton.testing.do_bench(func, warmup=2, rep=3) + FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == "fwd" else ((d * 3 + dv * 2))) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOP/s, name: {name}") + return t + + timer(forward, "fwd") + if has_bwd: + timer(backward, "bwd") + + +if __name__ == "__main__": + dtype = torch.bfloat16 + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + + b = 4 + window = 0 + has_bwd = False + + for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]: + for varlen in [False, True]: + for (h, h_k) in [(32, 32), (32, 4)]: + if h != h_k: + has_bwd = False + else: + has_bwd = True + for (d, dv) in [(128, 128), (192, 128)]: + for causal in [False, True]: + test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd) From ef5b1a69fc56fc8ea9405509cf67d2984acd205d Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Thu, 14 Aug 2025 09:34:17 +0800 Subject: [PATCH 05/24] Drop support for CUDA <12.8 --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 07e021a..75bbb16 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,7 @@ Currently released: ## Requirements - Hopper GPUs -- CUDA 12.3 and above - - **But we highly recommend 12.8 or above for the best performance** +- CUDA 12.8 and above - PyTorch 2.0 and above ## Quick start From c7590278ce1bfb86440e514c697bc4190aecd19c Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Thu, 14 Aug 2025 09:37:44 +0800 Subject: [PATCH 06/24] Fix accuracy issue in sum_OdO kernel --- csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp index bdcf1cb..db6a9b4 100644 --- a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp +++ b/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -140,7 +140,7 @@ struct FmhaKernelBwdSumOdO { *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); for (int v = 0; v < kElementsPerLoad; v++) { - acc += value_O[v] * value_dO[v]; + acc += ElementAcc(value_O[v]) * ElementAcc(value_dO[v]); } } From 2d291b0c31050ba259e87a4ae6fbf75a47824716 Mon Sep 17 00:00:00 2001 From: zhang Date: Mon, 25 Aug 2025 11:41:50 +0800 Subject: [PATCH 07/24] Remove tma padding for fwd inputs (#85) --- csrc/sm100/collective/fmha_fusion.hpp | 6 +-- .../sm100_fmha_load_tma_warpspecialized.hpp | 49 ++++++------------- ...m100_fmha_mla_load_tma_warpspecialized.hpp | 49 ++++++------------- csrc/sm100/fmha_cutlass_fwd_sm100.cu | 5 +- csrc/sm100/fmha_cutlass_fwd_sm100.cuh | 11 ++--- ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 29 ++++++++--- tests/test_fmha_sm100.py | 15 ++---- 7 files changed, 68 insertions(+), 96 deletions(-) diff --git a/csrc/sm100/collective/fmha_fusion.hpp b/csrc/sm100/collective/fmha_fusion.hpp index 1486767..8c09eaf 100644 --- a/csrc/sm100/collective/fmha_fusion.hpp +++ b/csrc/sm100/collective/fmha_fusion.hpp @@ -220,13 +220,13 @@ struct CausalMask : NoMask { BlkCoord const& blk_coord, TileShape const& tile_shape, ProblemSize const& problem_size) { - + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); if constexpr (IsQBegin) { return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } else { - const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); - return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); + const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ; + return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count); } } diff --git a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp index 1951056..3606dcc 100644 --- a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized { auto dQ = args.dQ; auto dK = args.dK; auto dV = args.dV; - auto problem_shape_qk = problem_shape; + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; - if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dQ) = get<0>(dQ); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_Q -= max_length_q * get<0>(dQ); - } - } - - if constexpr (is_variable_length_v>) { - auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; - if (cumulative_length_kv != nullptr) { - int max_length_kv = get<1>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dK) = get<0>(dK); - get<2,1>(dV) = get<0>(dV); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_K -= max_length_kv * get<0>(dK); - ptr_V -= max_length_kv * get<0>(dV); + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); } + } else { + problem_shape_qk = problem_shape; } auto params_qk = CollectiveMmaQK::to_underlying_arguments( @@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized { Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); int q_offs_0 = 0; - int q_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(params_problem_shape).max_length; - q_offs_0 = max_length_q - get<0>(problem_shape); - q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; get<2,1>(blk_coord_q) = 0; } } - Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); @@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized { Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); int kv_offs_0 = 0; - int kv_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length = get<1>(params_problem_shape).cumulative_length; if (cumulative_length != nullptr) { - int max_length = get<1>(params_problem_shape).max_length; - kv_offs_0 = max_length - get<1>(problem_shape); - kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; get<2,1>(blk_coord_kv) = 0; } } - Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); @@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized { ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); - Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); diff --git a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp index c2d3e2b..c8fc13b 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { auto dQ = args.dQ; auto dK = args.dK; auto dV = args.dV; - auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; - if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dQ) = get<0>(dQ); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_Q -= max_length_q * get<0>(dQ); - } - } - - if constexpr (is_variable_length_v>) { - auto cumulative_length_kv = get<1>(problem_shape).cumulative_length; - if (cumulative_length_kv != nullptr) { - int max_length_kv = get<1>(problem_shape).max_length; - // for variable sequence lenght, the batch is in units of row_stride - get<2,1>(dK) = get<0>(dK); - get<2,1>(dV) = get<0>(dV); - get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape))); - // offset ptr by the amount we add back in later - ptr_K -= max_length_kv * get<0>(dK); - ptr_V -= max_length_kv * get<0>(dV); + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); } + } else { + problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));; } auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); @@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); int q_offs_0 = 0; - int q_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { - int max_length_q = get<0>(params_problem_shape).max_length; - q_offs_0 = max_length_q - get<0>(problem_shape); - q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape); + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; get<2,1>(blk_coord_q) = 0; } } - Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); @@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); int kv_offs_0 = 0; - int kv_offs_2_1 = 0; if constexpr (is_variable_length_v>) { auto cumulative_length = get<1>(params_problem_shape).cumulative_length; if (cumulative_length != nullptr) { - int max_length = get<1>(params_problem_shape).max_length; - kv_offs_0 = max_length - get<1>(problem_shape); - kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape); + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; get<2,1>(blk_coord_kv) = 0; } } - Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); @@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); - Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cu b/csrc/sm100/fmha_cutlass_fwd_sm100.cu index e322709..997886e 100644 --- a/csrc/sm100/fmha_cutlass_fwd_sm100.cu +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cu @@ -18,8 +18,9 @@ void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_va static constexpr bool IsVarlen = std::is_same_v; static constexpr bool IsMla = std::is_same_v; static constexpr bool IsCausalMask = std::is_same_v>; - using Option = std::conditional_t, - Option>; + using Option = + std::conditional_t, + Option>; run_fmha_fwd( workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cuh b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh index 71831bb..987a5f7 100644 --- a/csrc/sm100/fmha_cutlass_fwd_sm100.cuh +++ b/csrc/sm100/fmha_cutlass_fwd_sm100.cuh @@ -143,8 +143,8 @@ struct FwdRunner { ProblemShapeType problem_size_for_launch; - get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q}; - get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv}; + get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q}; + get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv}; get<2>(problem_size_for_launch) = get<2>(problem_size); get<3>(problem_size_for_launch) = get<3>(problem_size); @@ -206,10 +206,6 @@ struct FwdRunner { void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr, void *cumulative_length_q, void *cumulative_length_kv) { auto problem_shape_ = problem_shape; - if constexpr (kIsVarlen) { - get<0>(problem_shape_).cumulative_length = static_cast(cumulative_length_q); - get<1>(problem_shape_).cumulative_length = static_cast(cumulative_length_kv); - } typename Operation::Arguments arguments{ problem_shape_, @@ -230,6 +226,7 @@ struct FwdRunner { int total_seqlen_q = q.size(0); int total_seqlen_kv = k.size(0); + ProblemShapeType problem_shape = initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv, cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); @@ -322,7 +319,7 @@ void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v auto options = get_options(); if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 && - (!std::is_same_v)) { + (std::is_same_v> || std::is_same_v>)) { FwdRunner runner; runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); diff --git a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index 8fe503b..43bb035 100644 --- a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -465,6 +465,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { else if (role == WarpRole::Correction) { cutlass::arch::warpgroup_reg_dealloc(); + bool has_valid = false; + CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); @@ -476,6 +478,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + has_valid = true; + if (get<1>(logical_problem_shape) == 0) { mainloop.correction_empty( blk_coord, @@ -505,16 +509,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { if constexpr (NumWarpsEpilogue == 0) { static_assert(NumWarpsCorrection == 1); - uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + if (has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } } else if (role == WarpRole::MMA) { warpgroup_reg_set(); - tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); - __syncwarp(); + bool allocated = false; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { @@ -527,6 +532,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + if (!allocated) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + allocated = true; + } + if (get<1>(logical_problem_shape) == 0) { continue; } @@ -580,6 +591,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { else if (role == WarpRole::Epilogue) { warpgroup_reg_set(); + bool has_valid = false; + CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); @@ -591,6 +604,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { continue; } + has_valid = true; + epilogue.store( blk_coord, logical_problem_shape, params.epilogue, params.problem_shape, @@ -602,8 +617,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { static_assert(NumWarpsEpilogue <= 1); if constexpr (NumWarpsEpilogue == 1) { - uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + if(has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } } } diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py index 832c9fb..7cb19a2 100644 --- a/tests/test_fmha_sm100.py +++ b/tests/test_fmha_sm100.py @@ -82,18 +82,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win grad_out = torch.randn(total_q, h, dv) softmax_scale = (d + 100) ** (-0.5) - offst_q = total_q - offst_kv = total_k - - q1_with_buffer = torch.empty(total_q + total_q, h, d, device=device, dtype=dtype) - k1_with_buffer = torch.empty(offst_kv + total_k, h_k, d, device=device, dtype=dtype) - v1_with_buffer = torch.empty(offst_kv + total_k, h_k, dv, device=device, dtype=dtype) - q1_with_buffer[total_q:] = q - k1_with_buffer[offst_kv:] = k - v1_with_buffer[offst_kv:] = v - q1 = q1_with_buffer[offst_q:].requires_grad_() - k1 = k1_with_buffer[offst_kv:].requires_grad_() - v1 = v1_with_buffer[offst_kv:].requires_grad_() + q1 = q.clone().requires_grad_() + k1 = k.clone().requires_grad_() + v1 = v.clone().requires_grad_() q2 = q.clone().requires_grad_() k2 = k.clone().requires_grad_() From eb7583357f0a2ca44a00d528639e0fb374c4254a Mon Sep 17 00:00:00 2001 From: Li Xiang Date: Mon, 25 Aug 2025 13:44:30 +0800 Subject: [PATCH 08/24] Remove cudaMalloc and cudaFree in backward (#87) * get rid of cudaMalloc and cudaFree * minor fix --------- Co-authored-by: Jiashi Li --- csrc/sm100/common/utils.hpp | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/csrc/sm100/common/utils.hpp b/csrc/sm100/common/utils.hpp index f43770d..6815839 100644 --- a/csrc/sm100/common/utils.hpp +++ b/csrc/sm100/common/utils.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include "cutlass/numeric_types.h" #include "helper.h" @@ -36,18 +37,21 @@ struct DeviceAllocation { T* ptr_ = nullptr; size_t offset_ = 0; size_t size_ = 0; + torch::Tensor tensor; DeviceAllocation(DeviceAllocation const&) = delete; DeviceAllocation& operator=(DeviceAllocation const&) = delete; DeviceAllocation() = default; DeviceAllocation(size_t size) { reset(size); } - ~DeviceAllocation() { reset(); } + ~DeviceAllocation() {} void reset(size_t size, size_t offset=0) { - reset(); - auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset)); - assert(ret == cudaSuccess); + size_t num_element = sizeof(T) * (size + offset); + auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + + tensor = torch::empty(num_element, options); + ptr_ = tensor.data_ptr(); size_ = size; offset_ = offset; } @@ -60,24 +64,7 @@ struct DeviceAllocation { return ptr_ + offset_; } - void reset() { - if (ptr_ != nullptr) { - auto ret = cudaFree(ptr_); - assert(ret == cudaSuccess); - } - } - size_t size() const { return size_; } size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } - - void copy_from_host(const T* ptr, size_t sz) { - auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); - assert(ret == cudaSuccess); - } - - void copy_from_device(const T* ptr, size_t sz) { - auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); - assert(ret == cudaSuccess); - } -}; \ No newline at end of file +}; From 261330bb6dfacdff8ff4b67e126417863b31aa72 Mon Sep 17 00:00:00 2001 From: Zeyu WANG Date: Wed, 27 Aug 2025 19:59:57 +0800 Subject: [PATCH 09/24] fix calc space bug (#91) * fix calc space bug * use python code to allocate the buffer for backward kernel --- csrc/sm100/common/utils.hpp | 39 +-------------------------- csrc/sm100/device/fmha_device_bwd.hpp | 12 ++++----- csrc/sm100/fmha_cutlass_bwd_sm100.cuh | 7 ++--- flash_mla/flash_mla_interface.py | 2 +- 4 files changed, 10 insertions(+), 50 deletions(-) diff --git a/csrc/sm100/common/utils.hpp b/csrc/sm100/common/utils.hpp index 6815839..fdaeff0 100644 --- a/csrc/sm100/common/utils.hpp +++ b/csrc/sm100/common/utils.hpp @@ -30,41 +30,4 @@ struct cutlass_dtype<__nv_fp8_e5m2> { }; template -using cutlass_dtype_t = typename cutlass_dtype::type; - -template -struct DeviceAllocation { - T* ptr_ = nullptr; - size_t offset_ = 0; - size_t size_ = 0; - torch::Tensor tensor; - - DeviceAllocation(DeviceAllocation const&) = delete; - DeviceAllocation& operator=(DeviceAllocation const&) = delete; - - DeviceAllocation() = default; - DeviceAllocation(size_t size) { reset(size); } - ~DeviceAllocation() {} - - void reset(size_t size, size_t offset=0) { - size_t num_element = sizeof(T) * (size + offset); - auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); - - tensor = torch::empty(num_element, options); - ptr_ = tensor.data_ptr(); - size_ = size; - offset_ = offset; - } - - T* get() { - return ptr_ + offset_; - } - - const T* get() const { - return ptr_ + offset_; - } - - size_t size() const { return size_; } - - size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } -}; +using cutlass_dtype_t = typename cutlass_dtype::type; \ No newline at end of file diff --git a/csrc/sm100/device/fmha_device_bwd.hpp b/csrc/sm100/device/fmha_device_bwd.hpp index d2463ac..76b7ed5 100644 --- a/csrc/sm100/device/fmha_device_bwd.hpp +++ b/csrc/sm100/device/fmha_device_bwd.hpp @@ -225,11 +225,11 @@ class Sm100FmhaBwd { int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment size_t workspace_bytes = 0; // OdO vector - workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + workspace_bytes += sizeof(ElementAccumulator) * B*H*Q; // scaled LSE vector - workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + workspace_bytes += sizeof(ElementAccumulator) * B*H*Q; // FP32 versions of outputs that are churned (start off with Q only) - workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + workspace_bytes += sizeof(ElementAccumulator) * B*H*Q*D; return workspace_bytes; } @@ -247,7 +247,7 @@ class Sm100FmhaBwd { ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); params_.dQ_acc = dQ_acc; - params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + params_.dQ_acc_size = sizeof(ElementAccumulator) * B*H*Q*D; auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); auto args_convert = to_convert_arguments(args, dQ_acc); params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); @@ -274,9 +274,9 @@ class Sm100FmhaBwd { int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment char* workspace_chr = reinterpret_cast(workspace); ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); - workspace_chr += B*H*Q * sizeof(ElementAccumulator); + workspace_chr += sizeof(ElementAccumulator) * B*H*Q; ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); - workspace_chr += B*H*Q * sizeof(ElementAccumulator); + workspace_chr += sizeof(ElementAccumulator) * B*H*Q; ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); } diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cuh b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh index 2b19be2..f4a1ce8 100644 --- a/csrc/sm100/fmha_cutlass_bwd_sm100.cuh +++ b/csrc/sm100/fmha_cutlass_bwd_sm100.cuh @@ -174,13 +174,10 @@ struct BwdRunner { Operation op; - size_t workspace_size = 0; - workspace_size = Operation::get_workspace_size(arguments); - DeviceAllocation workspace(workspace_size); - uint8_t* workspace_ptr = workspace.get(); + uint8_t* workspace_ptr = static_cast(workspace_buffer.data_ptr()); CUTLASS_CHECK(op.can_implement(arguments)); - CUTLASS_CHECK(op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(op.initialize(arguments, workspace_ptr)); CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); } diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 9c669ba..084117e 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -154,7 +154,7 @@ def _flash_attn_varlen_backward( max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 bs = cu_seqlens_qo.shape[0] - 1 workspace_bytes = 0 - workspace_bytes += 4 * qo_total_len * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse if num_qo_heads != num_kv_heads: workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc From ebf30641e27b777c22c38968b8c0aa38da1bac19 Mon Sep 17 00:00:00 2001 From: zhang Date: Mon, 22 Sep 2025 17:08:22 +0800 Subject: [PATCH 10/24] Refine handling for q/v sequence length equals zero. (#92) --- .../sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp | 5 ++++- .../sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp | 4 ---- .../sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp | 3 +++ .../sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp | 4 ---- .../collective/sm100_fmha_mla_load_tma_warpspecialized.hpp | 3 +++ csrc/sm100/device/fmha.hpp | 5 +++++ csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp | 2 +- csrc/sm100/kernel/fmha_tile_scheduler.hpp | 2 +- tests/test_fmha_sm100.py | 3 +++ 9 files changed, 20 insertions(+), 11 deletions(-) diff --git a/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp index 616357c..6f9bba3 100644 --- a/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -118,12 +118,15 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { auto cumulative_length_q = get<0>(problem_shape).cumulative_length; if (cumulative_length_q != nullptr) { int max_length_q = get<0>(problem_shape).max_length; + get<0>(problem_shape_O).max_length = max(1, max_length_q); // for variable sequence lenght, the batch is in units of row_stride get<2,1>(dO) = get<0>(dO); - get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O)); + get<2,1>(problem_shape_O) = max(1, max_length_q * (1 + get<2,1>(problem_shape_O))); // offset ptr by the amount we add back in later ptr_O -= max_length_q * get<0>(dO); } + } else { + get<0>(problem_shape_O) = max(1, get<0>(problem_shape_O)); } auto tma_store_o = make_tma_copy( diff --git a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index f39fd75..56f571a 100644 --- a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -1155,10 +1155,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { float lse = -INFINITY; int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); -#define DSHOW(x) print(#x ": "); print(x); print("\n") - if (threadIdx.x % 128 == 0 && block0()) { - DSHOW(sO); - } #if 1 using ElementOut = typename CollectiveEpilogue::ElementOut; diff --git a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp index 3606dcc..86e3149 100644 --- a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -112,6 +112,9 @@ struct Sm100FmhaLoadTmaWarpspecialized { problem_shape_qk = problem_shape; } + get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk)); + get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk)); + auto params_qk = CollectiveMmaQK::to_underlying_arguments( problem_shape_qk, typename CollectiveMmaQK::Arguments { diff --git a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp index bf41af9..994bd4e 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -1162,10 +1162,6 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { float lse = -INFINITY; int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); -#define DSHOW(x) print(#x ": "); print(x); print("\n") - if (threadIdx.x % 128 == 0 && block0()) { - DSHOW(sO); - } #if 1 using ElementOut = typename CollectiveEpilogue::ElementOut; diff --git a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp index c8fc13b..0b7d76f 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp +++ b/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -119,6 +119,9 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));; } + get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk)); + get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk)); + auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); auto params_qk = CollectiveMmaQK::to_underlying_arguments( diff --git a/csrc/sm100/device/fmha.hpp b/csrc/sm100/device/fmha.hpp index f8406d3..5dcb069 100644 --- a/csrc/sm100/device/fmha.hpp +++ b/csrc/sm100/device/fmha.hpp @@ -208,6 +208,11 @@ class FMHA { dim3 const block = Kernel::get_block_shape(); dim3 const grid = get_grid_shape(params); + // No need to launch the kernel + if(grid.x == 0 || grid.y == 0 || grid.z == 0) { + return Status::kSuccess; + } + // configure smem size and carveout int smem_size = Kernel::SharedStorageSize; diff --git a/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp index 572e67f..c879fe6 100644 --- a/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp +++ b/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp @@ -160,7 +160,7 @@ struct CausalPersistentTileScheduler { return Params { num_blocks, - { size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) }, + { size<3,0>(problem_size) }, { max(1, num_m_blocks) }, { size<3,1>(problem_size) }, hw_info }; } diff --git a/csrc/sm100/kernel/fmha_tile_scheduler.hpp b/csrc/sm100/kernel/fmha_tile_scheduler.hpp index 119f069..97e7962 100644 --- a/csrc/sm100/kernel/fmha_tile_scheduler.hpp +++ b/csrc/sm100/kernel/fmha_tile_scheduler.hpp @@ -123,7 +123,7 @@ struct PersistentTileScheduler { return Params { num_blocks, - { num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, + { max(1, num_m_blocks)}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, hw_info }; } diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py index 7cb19a2..2ba8b46 100644 --- a/tests/test_fmha_sm100.py +++ b/tests/test_fmha_sm100.py @@ -29,6 +29,9 @@ def get_attn_bias(s_q, s_k, causal, window): def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + close_tensor = torch.isclose(x.to(torch.float32), y.to(torch.float32), rtol=1e-5, atol=1e-5) + if close_tensor.all(): + return x, y = x.double(), y.double() RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) From c28eca99dbc664dd2716415ed03492afe5fefade Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Wed, 24 Sep 2025 14:22:05 +0800 Subject: [PATCH 11/24] Reorganize files and add sparse prefill/decoding kernels on hopper --- .gitignore | 1 + README.md | 169 ++++- csrc/{sm90/kernels => }/params.h | 34 +- csrc/pybind.cpp | 442 +++++++++++ .../dense}/collective/fmha_common.hpp | 0 .../dense}/collective/fmha_fusion.hpp | 0 ..._fmha_fwd_epilogue_tma_warpspecialized.hpp | 0 ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 6 +- .../sm100_fmha_load_tma_warpspecialized.hpp | 4 +- ...a_mla_fwd_mainloop_tma_warpspecialized.hpp | 8 +- ...m100_fmha_mla_load_tma_warpspecialized.hpp | 4 +- .../dense}/common/gather_tensor.hpp | 0 .../sm100/{ => prefill/dense}/common/helper.h | 0 .../sm100/{ => prefill/dense}/common/mask.cuh | 0 .../dense}/common/pipeline_mla.hpp | 0 .../{ => prefill/dense}/common/pow_2.hpp | 0 .../{ => prefill/dense}/common/utils.hpp | 0 .../sm100/{ => prefill/dense}/device/fmha.hpp | 0 .../dense}/device/fmha_device_bwd.hpp | 0 .../dense}/fmha_cutlass_bwd_sm100.cu | 4 +- .../dense}/fmha_cutlass_bwd_sm100.cuh | 0 .../dense}/fmha_cutlass_fwd_sm100.cu | 11 +- .../dense}/fmha_cutlass_fwd_sm100.cuh | 0 .../{pybind.cu => prefill/dense/interface.h} | 9 +- .../kernel/fmha_causal_tile_scheduler.hpp | 0 .../dense}/kernel/fmha_kernel_bwd_convert.hpp | 7 + .../dense}/kernel/fmha_kernel_bwd_sum_OdO.hpp | 7 + .../dense}/kernel/fmha_options.hpp | 0 .../dense}/kernel/fmha_tile_scheduler.hpp | 0 ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 9 +- ...mha_bwd_mla_kernel_tma_warpspecialized.hpp | 9 +- ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 17 +- csrc/sm90/{kernels => decode/dense}/config.h | 2 - .../{kernels => decode/dense}/splitkv_mla.cu | 43 +- csrc/sm90/decode/dense/splitkv_mla.h | 10 + csrc/sm90/{kernels => decode/dense}/traits.h | 0 .../decode/sparse_fp8/components/config.h | 121 +++ .../decode/sparse_fp8/components/dequant.h | 88 +++ .../decode/sparse_fp8/components/epilogue.h | 87 +++ .../decode/sparse_fp8/components/helpers.h | 86 +++ .../sparse_fp8/components/named_barriers.h | 10 + csrc/sm90/decode/sparse_fp8/splitkv_mla.cu | 614 +++++++++++++++ csrc/sm90/decode/sparse_fp8/splitkv_mla.h | 9 + csrc/sm90/flash_api.cpp | 216 ------ csrc/sm90/kernels/get_mla_metadata.h | 5 - csrc/sm90/kernels/mla_combine.h | 6 - csrc/sm90/kernels/splitkv_mla.h | 6 - csrc/sm90/prefill/sparse/fwd.cu | 709 ++++++++++++++++++ csrc/sm90/prefill/sparse/fwd.h | 9 + csrc/sm90/prefill/sparse/helpers.h | 177 +++++ .../kernels => smxx}/get_mla_metadata.cu | 28 +- csrc/smxx/get_mla_metadata.h | 5 + csrc/{sm90/kernels => smxx}/mla_combine.cu | 13 +- csrc/smxx/mla_combine.h | 6 + csrc/{sm90/kernels => }/utils.h | 34 + flash_mla/__init__.py | 1 + flash_mla/flash_mla_interface.py | 104 +-- setup.py | 115 +-- tests/lib.py | 73 ++ tests/quant.py | 68 ++ tests/test_flash_mla_decoding.py | 343 +++++++++ tests/test_flash_mla_prefill.py | 197 +++++ tests/test_flash_mla_sm90.py | 153 ---- tests/test_fmha_sm100.py | 73 +- 64 files changed, 3510 insertions(+), 642 deletions(-) rename csrc/{sm90/kernels => }/params.h (59%) create mode 100644 csrc/pybind.cpp rename csrc/sm100/{ => prefill/dense}/collective/fmha_common.hpp (100%) rename csrc/sm100/{ => prefill/dense}/collective/fmha_fusion.hpp (100%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp (100%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_load_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/common/gather_tensor.hpp (100%) rename csrc/sm100/{ => prefill/dense}/common/helper.h (100%) rename csrc/sm100/{ => prefill/dense}/common/mask.cuh (100%) rename csrc/sm100/{ => prefill/dense}/common/pipeline_mla.hpp (100%) rename csrc/sm100/{ => prefill/dense}/common/pow_2.hpp (100%) rename csrc/sm100/{ => prefill/dense}/common/utils.hpp (100%) rename csrc/sm100/{ => prefill/dense}/device/fmha.hpp (100%) rename csrc/sm100/{ => prefill/dense}/device/fmha_device_bwd.hpp (100%) rename csrc/sm100/{ => prefill/dense}/fmha_cutlass_bwd_sm100.cu (98%) rename csrc/sm100/{ => prefill/dense}/fmha_cutlass_bwd_sm100.cuh (100%) rename csrc/sm100/{ => prefill/dense}/fmha_cutlass_fwd_sm100.cu (98%) rename csrc/sm100/{ => prefill/dense}/fmha_cutlass_fwd_sm100.cuh (100%) rename csrc/sm100/{pybind.cu => prefill/dense/interface.h} (84%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_causal_tile_scheduler.hpp (100%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_kernel_bwd_convert.hpp (97%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_kernel_bwd_sum_OdO.hpp (97%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_options.hpp (100%) rename csrc/sm100/{ => prefill/dense}/kernel/fmha_tile_scheduler.hpp (100%) rename csrc/sm100/{ => prefill/dense}/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp (99%) rename csrc/sm100/{ => prefill/dense}/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp (98%) rename csrc/sm90/{kernels => decode/dense}/config.h (78%) rename csrc/sm90/{kernels => decode/dense}/splitkv_mla.cu (97%) create mode 100644 csrc/sm90/decode/dense/splitkv_mla.h rename csrc/sm90/{kernels => decode/dense}/traits.h (100%) create mode 100644 csrc/sm90/decode/sparse_fp8/components/config.h create mode 100644 csrc/sm90/decode/sparse_fp8/components/dequant.h create mode 100644 csrc/sm90/decode/sparse_fp8/components/epilogue.h create mode 100644 csrc/sm90/decode/sparse_fp8/components/helpers.h create mode 100644 csrc/sm90/decode/sparse_fp8/components/named_barriers.h create mode 100644 csrc/sm90/decode/sparse_fp8/splitkv_mla.cu create mode 100644 csrc/sm90/decode/sparse_fp8/splitkv_mla.h delete mode 100644 csrc/sm90/flash_api.cpp delete mode 100644 csrc/sm90/kernels/get_mla_metadata.h delete mode 100644 csrc/sm90/kernels/mla_combine.h delete mode 100644 csrc/sm90/kernels/splitkv_mla.h create mode 100644 csrc/sm90/prefill/sparse/fwd.cu create mode 100644 csrc/sm90/prefill/sparse/fwd.h create mode 100644 csrc/sm90/prefill/sparse/helpers.h rename csrc/{sm90/kernels => smxx}/get_mla_metadata.cu (64%) create mode 100644 csrc/smxx/get_mla_metadata.h rename csrc/{sm90/kernels => smxx}/mla_combine.cu (94%) create mode 100644 csrc/smxx/mla_combine.h rename csrc/{sm90/kernels => }/utils.h (71%) create mode 100644 tests/lib.py create mode 100644 tests/quant.py create mode 100644 tests/test_flash_mla_decoding.py create mode 100644 tests/test_flash_mla_prefill.py delete mode 100644 tests/test_flash_mla_sm90.py diff --git a/.gitignore b/.gitignore index 4535280..6b00da7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ dist/ /.vscode compile_commands.json .cache +/dev diff --git a/README.md b/README.md index 75bbb16..8cf01a3 100644 --- a/README.md +++ b/README.md @@ -1,69 +1,184 @@ # FlashMLA -## Performance Update (2025.04.22) +## Introduction -We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement on compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Just switch to the new version and enjoy the instant speedup! 🚀🚀🚀 +FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](TODO) models. This repository contains the following implementations: -Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). +**Sparse Attention Kernels** -The new kernel primarily targets compute-intensive settings (where the number of q heads $\times$ the number of q tokens per request (if MTP is disabled then it's 1) $\ge 64$). For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. +*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](TODO).* -## Introduction +- Token-level sparse attention for the prefill stage +- Token-level sparse attention for the decoding stage, with FP8 KV cache -FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. +**Dense Attention Kernels** -Currently released: -- BF16, FP16 -- Paged kvcache with block size of 64 +- Dense attention for the prefill stage +- Dense attention for the decoding stage -## Requirements +## News -- Hopper GPUs -- CUDA 12.8 and above -- PyTorch 2.0 and above +- **2025.09.26(TODO) Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](TODO), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. +- **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell! +- **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). +- **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀 -## Quick start +## Performance -### Install +#### Test & benchmark MLA decoding (Sparse & Dense): ```bash -pip install -v . +python tests/test_flash_mla_decoding.py ``` -### Benchmark +The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8. -#### Testing MLA Decoding +#### Test & benchmark MHA prefill (Dense): ```bash -python tests/test_flash_mla_sm90.py +python tests/test_fmha_sm100.py ``` -#### Testing MLA Forward/Backward +It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA. + +#### Test & benchmark MLA prefill (Sparse): ```bash -python tests/test_fmha_sm100.py +python tests/test_flash_mla_prefill.py ``` -It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. +It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8. + +## Requirements + +- Hopper / Blackwell GPUs (See the support matrix below) +- CUDA 12.8 and above (CUDA 12.9+ is required for Blackwell kernels) +- PyTorch 2.0 and above -Note. For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. +Support matrix: -### Usage +| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | +| :---: | :---: | :---: | :---: | +| Dense Decoding | Hopper | MQA | BF16 | +| Sparse Decoding | Hopper | MQA | FP8 [1] | +| Dense Prefill | Blackwell | MHA | | +| Sparse Prefill | Hopper | MQA | | + +[1]: For more details on using FP8 KV cache, see documents below. + +[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` = 576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](TODO). + +## Installation + +```bash +git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla +cd flash-mla +git submodule update --init --recursive +pip install -v . +``` + +## Usage + +### MLA Decoding + +To use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example: ```python from flash_mla import get_mla_metadata, flash_mla_with_kvcache -tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) +tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, + s_q * h_q // h_kv, + h_kv, + h_q, + is_fp8, + topk, +) for i in range(num_layers): ... o_i, lse_i = flash_mla_with_kvcache( q_i, kvcache_i, block_table, cache_seqlens, dv, - tile_scheduler_metadata, num_splits, causal=True, + tile_scheduler_metadata, num_splits, + is_causal, is_fp8_kvcache, indices, ) ... ``` +Where + +- `s_q` is the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1. +- `h_kv` is the number of key-value heads. +- `h_q` is the number of query heads. + +**FP8 KV Cache:** +If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16. + +In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: +- **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values. +- **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on. +- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy. + +See `tests/quant.py` for quantization and dequantization details. + +**Sparse Attention (`indices` tensor):** +The `indices` tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens. + +- **Shape:** `indices` should be a 3D tensor of shape `(batch_size, seq_len_q, topk)`. +- **Format:** `indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block)`, where `t` is the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into `indices_in_kvcache`, the kernel does not require the `block_table` parameter. +- **Invalid entries:** Set invalid indices to `-1`. + +**Return Values:** +The kernel returns `(out, lse)`, where: +- `out` is the attention result. +- `lse` is the log-sum-exp value of the attention scores for each query head. + +See `tests/test_flash_mla_decoding.py` for a complete example. + +### Sparse MLA Prefill + +For the sparse MLA prefill kernel, call `flash_mla_sparse_fwd` directly with the following parameters: +- `q`: Query tensor of shape `[s_q, h_q, d_qk]` +- `kv`: Key-Value tensor of shape `[s_kv, h_kv, d_qk]` +- `indices`: Indices tensor of shape `[s_q, h_kv, topk]` +- `sm_scale`: A scalar value + +**Note on batching:** This kernel does not support a batch dimension. For multi-batch inference, reshape the input tensors and adjust the `indices` parameter to simulate batch processing. + +**Invalid indices:** Set invalid entries in `indices` to `-1` or any number `>= s_kv`. + +**Return Values and Equivalent PyTorch Code:** +The kernel returns `(out, max_logits, lse)`. This is equivalent to the following PyTorch operations: + +```python +Q: [s_q, h_q, d_qk], bfloat16 +kv: [s_kv, h_kv, d_qk], bfloat16 +indices: [s_q, h_kv, topk], int32 + +kv = kv.squeeze(1) # [s_kv, d_qk], h_kv must be 1 +indices = indices.squeeze(1) # [s_q, topk] +focused_kv = kv[indices] # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk]. + +P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e) # [s_q, h_q, topk] +max_logits = P.max(dim=-1) # [s_q, h_q] +lse = log2sumexp2(P, dim=-1, base=2) # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2 +S = exp2(P - lse) # [s_q, h_q, topk] +out = S @ focused_kv # [s_q, h_q, d_qk] + +return (out, max_logits, lse) +``` + +See `tests/test_flash_mla_prefill.py` for a complete example. + +### Dense MHA Prefill + +This kernel implements the standard dense Multi-Head Attention (MHA) forward and backward operations. It can be called using: +- `flash_attn_varlen_func` +- `flash_attn_varlen_qkvpacked_func` +- `flash_attn_varlen_kvpacked_func` + +The usage is similar to the `flash_attn` package. See `tests/test_fmha_sm100.py` for a complete example. + ## Acknowledgement FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. @@ -109,7 +224,7 @@ The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.c ```bibtex @misc{flashmla2025, - title={FlashMLA: Efficient MLA decoding kernels}, + title={FlashMLA: Efficient Multi-head Latent Attention Kernels}, author={Jiashi Li, Shengyu Liu}, year={2025}, publisher = {GitHub}, diff --git a/csrc/sm90/kernels/params.h b/csrc/params.h similarity index 59% rename from csrc/sm90/kernels/params.h rename to csrc/params.h index 3b4e254..baa2f7f 100644 --- a/csrc/sm90/kernels/params.h +++ b/csrc/params.h @@ -1,8 +1,8 @@ #pragma once -//////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/bfloat16.h" -struct Flash_fwd_mla_params { +struct DecodingParams { using index_t = int64_t; int b; // batch size @@ -14,11 +14,13 @@ struct Flash_fwd_mla_params { int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k bool is_causal; float scale_softmax, scale_softmax_log2; + int topk; void *__restrict__ q_ptr; void *__restrict__ k_ptr; void *__restrict__ o_ptr; void *__restrict__ softmax_lse_ptr; + int *__restrict__ indices_ptr; index_t q_batch_stride; index_t k_batch_stride; @@ -29,6 +31,8 @@ struct Flash_fwd_mla_params { index_t q_head_stride; index_t k_head_stride; index_t o_head_stride; + index_t indices_batch_stride; + index_t indices_row_stride; int *__restrict__ block_table; index_t block_table_batch_stride; @@ -45,9 +49,9 @@ struct Flash_fwd_mla_params { }; static constexpr int TileSchedulerMetaDataSize = 8; -// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] +// [begin_idx (inclusive), begin_block_idx (inclusive), end_idx (inclusive), end_block_idx (exclusive), begin_n_split_idx, _, _, _] -struct Mla_metadata_params { +struct GetDecodingMetadataParams { int *__restrict__ seqlens_k_ptr; int *__restrict__ tile_scheduler_metadata_ptr; int *__restrict__ num_splits_ptr; @@ -55,4 +59,26 @@ struct Mla_metadata_params { int block_size_n; int fixed_overhead_num_blocks; int num_sm_parts; + int topk; +}; + +struct SparsePrefillParams { + int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk; + float sm_scale, sm_scale_div_log2; + + // Input tensors + cutlass::bfloat16_t* __restrict__ q; // [s_q, h_q, d_qk] + cutlass::bfloat16_t* __restrict__ kv; // [s_kv, h_kv, d_qk] + int* __restrict__ indices; // [s_q, h_kv, topk] + + int stride_q_s_q; int stride_q_h_q; + int stride_kv_s_kv; int stride_kv_h_kv; + int stride_indices_s_q; int stride_indices_h_kv; + + // Output tensors + cutlass::bfloat16_t* __restrict__ out; // [s_q, h_q, d_v] + float* __restrict__ max_logits; // [s_q, h_q] + float* __restrict__ lse; // [s_q, h_q] + + cudaStream_t stream; }; diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp new file mode 100644 index 0000000..b360c24 --- /dev/null +++ b/csrc/pybind.cpp @@ -0,0 +1,442 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include + +#include "params.h" +#include "smxx/get_mla_metadata.h" +#include "smxx/mla_combine.h" +#include "sm90/decode/dense/splitkv_mla.h" +#include "sm90/decode/sparse_fp8/splitkv_mla.h" +#include "sm90/prefill/sparse/fwd.h" +#include "sm100/prefill/dense/interface.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +struct Arch { + int major; + int minor; + + bool is_sm90() const { + return major == 9 && minor == 0; + } + + bool is_sm100() const { + return major == 10 && minor == 0; + } + + void assert_is_supported() const { + TORCH_CHECK(is_sm90() || is_sm100(), "Only SM90 and SM100 are supported"); + } +}; + +// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. Hopper Dense BF16, Hopper Sparse FP8, etc.) +struct DecodingAttnImplMeta { + int num_sm_parts; + int fixed_overhead_num_blocks; + int k_block_size; +}; + +DecodingAttnImplMeta get_attn_impl_meta( + Arch arch, + int sm_count, + int num_q_tokens_per_head_k, + int h_k, + std::optional h_q_, + bool is_fp8_kvcache, + bool is_sparse_attn +) { + if (arch.is_sm90()) { + if (is_sparse_attn) { + if (is_fp8_kvcache) { + TORCH_CHECK(h_q_.has_value()); + int h_q = h_q_.value(); + TORCH_CHECK(h_q % h_k == 0); + int s_q = num_q_tokens_per_head_k * h_k / h_q; + // FP8 + Sparse MLA + return { + std::max((sm_count/2) / h_k / (cutlass::ceil_div(h_q/h_k, 2*64) * s_q), 1), + 5, + 64 + }; + } else { + // Sparse BF16 MLA + TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM90"); + } + } else { + if (is_fp8_kvcache) { + // Dense FP8 MLA + TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); + } else { + // Dense BF16 MLA + return { + std::max(sm_count / h_k / cutlass::ceil_div(num_q_tokens_per_head_k, 64), 1), + 5, + 64 + }; + } + } + } else if (arch.is_sm100()) { + TORCH_CHECK(false, "Unsupported GPU architecture"); + } else { + TORCH_CHECK(false, "Unsupported GPU architecture"); + } +} + + +std::vector +get_mla_decoding_metadata( + at::Tensor &seqlens_k, + const int num_q_tokens_per_head_k, + const int h_k, + const std::optional h_q, + const bool is_fp8_kvcache, + const std::optional topk +) { + bool is_sparse_attn = topk.has_value(); + CHECK_DEVICE(seqlens_k); + TORCH_CHECK(seqlens_k.is_contiguous()); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); + if (is_sparse_attn) + TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided"); + + int batch_size = seqlens_k.size(0); + int *seqlens_k_ptr = seqlens_k.data_ptr(); + auto options = seqlens_k.options(); + + auto dprops = at::cuda::getCurrentDeviceProperties(); + int sm_count = dprops->multiProcessorCount; + Arch arch = {dprops->major, dprops->minor}; + arch.assert_is_supported(); + DecodingAttnImplMeta attn_impl_meta = get_attn_impl_meta(arch, sm_count, num_q_tokens_per_head_k, h_k, h_q, is_fp8_kvcache, is_sparse_attn); + + auto tile_scheduler_metadata = torch::empty({attn_impl_meta.num_sm_parts, TileSchedulerMetaDataSize}, options); + auto num_splits = torch::empty({batch_size + 1}, options); + int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + int *num_splits_ptr = num_splits.data_ptr(); + + at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + GetDecodingMetadataParams params = {}; + params.seqlens_k_ptr = seqlens_k_ptr; + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; + params.num_splits_ptr = num_splits_ptr; + params.batch_size = batch_size; + params.block_size_n = attn_impl_meta.k_block_size; + params.fixed_overhead_num_blocks = attn_impl_meta.fixed_overhead_num_blocks; + params.num_sm_parts = attn_impl_meta.num_sm_parts; + params.topk = is_sparse_attn ? topk.value() : -1; + run_get_mla_metadata_kernel(params, stream); + + return {tile_scheduler_metadata, num_splits}; +} + +std::vector +fwd_kvcache_mla( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) + const int head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const float softmax_scale, + bool is_causal, + const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor &num_splits, // batch_size + 1 + const bool &is_fp8, + const std::optional &indices // None, or batch_size x seqlen_q x topk +) { + bool is_sparse_attn = indices.has_value(); + int topk = is_sparse_attn ? indices->size(-1) : -1; + + // Check the architecture + auto dprops = at::cuda::getCurrentDeviceProperties(); + Arch arch = {dprops->major, dprops->minor}; + arch.assert_is_supported(); + + // Check data types + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); + + if (!is_fp8) { + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + } else { + TORCH_CHECK(kcache.dtype() == torch::kFloat8_e4m3fn || kcache.dtype() == torch::kInt8 || kcache.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn or int8 or uint8"); + } + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); + TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); + TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32"); + + // Check device + CHECK_DEVICE(q); + CHECK_DEVICE(kcache); + CHECK_DEVICE(seqlens_k); + CHECK_DEVICE(block_table); + CHECK_DEVICE(tile_scheduler_metadata); + CHECK_DEVICE(num_splits); + if (is_sparse_attn) CHECK_DEVICE(indices.value()); + + // Check layout + TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension"); + CHECK_CONTIGUOUS(seqlens_k); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + CHECK_CONTIGUOUS(tile_scheduler_metadata); + CHECK_CONTIGUOUS(num_splits); + TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension"); + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q_ori = sizes[1]; + const int num_heads_q = sizes[2]; + const int head_size_k = sizes[3]; + TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); + TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); + + const int max_num_blocks_per_seq = block_table.size(1); + const int num_blocks = kcache.size(0); + const int page_block_size = kcache.size(1); + const int num_heads_k = kcache.size(2); + TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q_ori == 1) { is_causal = false; } + + const int num_q_heads_per_hk = num_heads_q / num_heads_k; + const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; + const int num_heads = num_heads_k; + q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) + .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); + + CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); + if (!is_fp8) { + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); + } else { + int bytes_per_token = 512 + 64*2 + (512/128)*4; + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, bytes_per_token); + TORCH_CHECK(num_heads_k == 1, "Currently the number of k heads must be 1 when is_fp8_kvcache is True"); + TORCH_CHECK(kcache.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True"); + } + CHECK_SHAPE(seqlens_k, batch_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); + CHECK_SHAPE(num_splits, batch_size+1); + if (is_sparse_attn) CHECK_SHAPE(indices.value(), batch_size, seqlen_q_ori, topk); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); + at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse); + + DecodingParams params = {}; + // Set the sizes. + params.b = batch_size; + params.s_q = seqlen_q_ori; + params.q_seq_per_hk = q_seq_per_hk; + params.seqlens_k_ptr = seqlens_k.data_ptr(); + params.h_q = num_heads_q; + params.h_k = num_heads_k; + params.num_blocks = num_blocks; + params.q_head_per_hk = num_q_heads_per_hk; + params.is_causal = is_causal; + params.d = head_size_k; + params.d_v = head_size_v; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); + params.topk = topk; + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = kcache.data_ptr(); + params.o_ptr = out.data_ptr(); + params.indices_ptr = is_sparse_attn ? indices->data_ptr() : nullptr; + params.softmax_lse_ptr = softmax_lse.data_ptr(); + // All stride are in elements, not bytes. + params.q_batch_stride = q.stride(0); + params.k_batch_stride = kcache.stride(0); + params.o_batch_stride = out.stride(0); + params.q_row_stride = q.stride(-3); + params.k_row_stride = kcache.stride(1); + params.o_row_stride = out.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = kcache.stride(2); + params.o_head_stride = out.stride(-2); + params.indices_batch_stride = is_sparse_attn ? indices->stride(0) : 0; + params.indices_row_stride = is_sparse_attn ? indices->stride(1) : 0; + + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.page_block_size = page_block_size; + + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + params.num_sm_parts = tile_scheduler_metadata.size(0); + params.num_splits_ptr = num_splits.data_ptr(); + + const int total_num_splits = batch_size + params.num_sm_parts; + at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse_accum); + CHECK_CONTIGUOUS(out_accum); + params.total_num_splits = total_num_splits; + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK(head_size_k == 576); + + if (q_dtype == torch::kHalf) { +#ifdef FLASH_MLA_DISABLE_FP16 + TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); +#endif + } + + if (arch.is_sm90()) { + if (is_sparse_attn) { + if (is_fp8) { + TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90"); + sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream); + } else { + TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); + } + } else { + if (is_fp8) { + TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); + } else { + if (q_dtype == torch::kBFloat16) { + sm90::run_flash_splitkv_mla_kernel(params, stream); + } else if (q_dtype == torch::kHalf) { +#ifndef FLASH_MLA_DISABLE_FP16 + sm90::run_flash_splitkv_mla_kernel(params, stream); +#endif + } else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } + } + } + } else if (arch.is_sm100()) { + TORCH_CHECK(false, "Unsupported GPU architecture"); + } else { + TORCH_CHECK(false, "Unsupported GPU architecture"); + } + + if (q_dtype == torch::kBFloat16) { + run_flash_mla_combine_kernel(params, stream); + } else if (q_dtype == torch::kHalf) { +#ifndef FLASH_MLA_DISABLE_FP16 + run_flash_mla_combine_kernel(params, stream); +#endif + } else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } + + out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); + softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) + .reshape({batch_size, num_heads_q, seqlen_q_ori}); + + return {out, softmax_lse}; +} + + +inline int int64_stride_to_int(int64_t orig_stride) { + if (orig_stride > std::numeric_limits::max()) { + TORCH_CHECK(false, "[Sparse TopK Attention] Stride exceeds int32 limit: ", orig_stride); + } + return static_cast(orig_stride); +} + +std::vector sparse_prefill_fwd( + const at::Tensor &q, + const at::Tensor &kv, + const at::Tensor &indices, + float sm_scale, + int d_v +) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9; + TORCH_CHECK(is_sm90, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 architectures"); + + CHECK_DEVICE(q); + CHECK_DEVICE(kv); + CHECK_DEVICE(indices); + + TORCH_CHECK(q.dtype() == torch::kBFloat16); + TORCH_CHECK(kv.dtype() == torch::kBFloat16); + TORCH_CHECK(indices.dtype() == torch::kInt32); + + int s_q = q.size(0); + int s_kv = kv.size(0); + int h_q = q.size(1); + int h_kv = kv.size(1); + int d_qk = q.size(2); + int topk = indices.size(2); + + CHECK_SHAPE(q, s_q, h_q, d_qk); + CHECK_SHAPE(kv, s_kv, h_kv, d_qk); + CHECK_SHAPE(indices, s_q, h_kv, topk); + + TORCH_CHECK(q.stride(-1) == 1); + TORCH_CHECK(kv.stride(-1) == 1); + TORCH_CHECK(indices.stride(-1) == 1); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto opts = q.options(); + at::Tensor out = torch::empty({s_q, h_q, d_v}, opts); + CHECK_CONTIGUOUS(out); + + at::Tensor buf_attn_score, max_logits, lse, p_sum; + max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); + lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); + CHECK_CONTIGUOUS(max_logits); + CHECK_CONTIGUOUS(lse); + + SparsePrefillParams params = { + s_q, s_kv, h_q, h_kv, d_qk, d_v, topk, + sm_scale, sm_scale * 1.44269504f, + + (cutlass::bfloat16_t*)q.data_ptr(), + (cutlass::bfloat16_t*)kv.data_ptr(), + (int*)indices.data_ptr(), + + int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), + int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)), + int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)), + + (cutlass::bfloat16_t*)out.data_ptr(), + (float*)max_logits.data_ptr(), + (float*)lse.data_ptr(), + + at::cuda::getCurrentCUDAStream().stream() + }; + + if (is_sm90) { + sm90::run_fwd_kernel(params); + } else { + TORCH_CHECK(false, "Unknown architecture"); + } + + return {out, max_logits, lse}; +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashMLA"; + m.def("get_mla_decoding_metadata", &get_mla_decoding_metadata); + m.def("fwd_kvcache_mla", &fwd_kvcache_mla); + m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun); + m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun); + m.def("sparse_prefill_fwd", &sparse_prefill_fwd); +} diff --git a/csrc/sm100/collective/fmha_common.hpp b/csrc/sm100/prefill/dense/collective/fmha_common.hpp similarity index 100% rename from csrc/sm100/collective/fmha_common.hpp rename to csrc/sm100/prefill/dense/collective/fmha_common.hpp diff --git a/csrc/sm100/collective/fmha_fusion.hpp b/csrc/sm100/prefill/dense/collective/fmha_fusion.hpp similarity index 100% rename from csrc/sm100/collective/fmha_fusion.hpp rename to csrc/sm100/prefill/dense/collective/fmha_fusion.hpp diff --git a/csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp similarity index 100% rename from csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp diff --git a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index f39fd75..4783a13 100644 --- a/csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -37,9 +37,9 @@ #include "cute/tensor.hpp" #include "cute/layout.hpp" -#include "collective/fmha_common.hpp" -#include "collective/fmha_fusion.hpp" -#include "collective/sm100_fmha_load_tma_warpspecialized.hpp" +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_fusion.hpp" +#include "../collective/sm100_fmha_load_tma_warpspecialized.hpp" namespace cutlass::fmha::collective { diff --git a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp index 3606dcc..987ac22 100644 --- a/csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -36,8 +36,8 @@ #include "cute/tensor.hpp" #include "cute/layout.hpp" -#include "collective/fmha_common.hpp" -#include "collective/fmha_fusion.hpp" +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_fusion.hpp" namespace cutlass::fmha::collective { diff --git a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp index bf41af9..1e66d1a 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -37,10 +37,10 @@ #include "cute/tensor.hpp" #include "cute/layout.hpp" -#include "collective/fmha_common.hpp" -#include "collective/fmha_fusion.hpp" -#include "collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" -#include "common/pipeline_mla.hpp" +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_fusion.hpp" +#include "../collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" +#include "../common/pipeline_mla.hpp" namespace cutlass::fmha::collective { diff --git a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp index c8fc13b..d161a99 100644 --- a/csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -36,8 +36,8 @@ #include "cute/tensor.hpp" #include "cute/layout.hpp" -#include "collective/fmha_common.hpp" -#include "collective/fmha_fusion.hpp" +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_fusion.hpp" namespace cutlass::fmha::collective { diff --git a/csrc/sm100/common/gather_tensor.hpp b/csrc/sm100/prefill/dense/common/gather_tensor.hpp similarity index 100% rename from csrc/sm100/common/gather_tensor.hpp rename to csrc/sm100/prefill/dense/common/gather_tensor.hpp diff --git a/csrc/sm100/common/helper.h b/csrc/sm100/prefill/dense/common/helper.h similarity index 100% rename from csrc/sm100/common/helper.h rename to csrc/sm100/prefill/dense/common/helper.h diff --git a/csrc/sm100/common/mask.cuh b/csrc/sm100/prefill/dense/common/mask.cuh similarity index 100% rename from csrc/sm100/common/mask.cuh rename to csrc/sm100/prefill/dense/common/mask.cuh diff --git a/csrc/sm100/common/pipeline_mla.hpp b/csrc/sm100/prefill/dense/common/pipeline_mla.hpp similarity index 100% rename from csrc/sm100/common/pipeline_mla.hpp rename to csrc/sm100/prefill/dense/common/pipeline_mla.hpp diff --git a/csrc/sm100/common/pow_2.hpp b/csrc/sm100/prefill/dense/common/pow_2.hpp similarity index 100% rename from csrc/sm100/common/pow_2.hpp rename to csrc/sm100/prefill/dense/common/pow_2.hpp diff --git a/csrc/sm100/common/utils.hpp b/csrc/sm100/prefill/dense/common/utils.hpp similarity index 100% rename from csrc/sm100/common/utils.hpp rename to csrc/sm100/prefill/dense/common/utils.hpp diff --git a/csrc/sm100/device/fmha.hpp b/csrc/sm100/prefill/dense/device/fmha.hpp similarity index 100% rename from csrc/sm100/device/fmha.hpp rename to csrc/sm100/prefill/dense/device/fmha.hpp diff --git a/csrc/sm100/device/fmha_device_bwd.hpp b/csrc/sm100/prefill/dense/device/fmha_device_bwd.hpp similarity index 100% rename from csrc/sm100/device/fmha_device_bwd.hpp rename to csrc/sm100/prefill/dense/device/fmha_device_bwd.hpp diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cu b/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu similarity index 98% rename from csrc/sm100/fmha_cutlass_bwd_sm100.cu rename to csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu index 4ff745d..54d85db 100644 --- a/csrc/sm100/fmha_cutlass_bwd_sm100.cu +++ b/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu @@ -1,7 +1,7 @@ -#include +#include "interface.h" + #include #include -#include #include #include "common/mask.cuh" #include "common/utils.hpp" diff --git a/csrc/sm100/fmha_cutlass_bwd_sm100.cuh b/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh similarity index 100% rename from csrc/sm100/fmha_cutlass_bwd_sm100.cuh rename to csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cu b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu similarity index 98% rename from csrc/sm100/fmha_cutlass_fwd_sm100.cu rename to csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu index 997886e..ab66f0f 100644 --- a/csrc/sm100/fmha_cutlass_fwd_sm100.cu +++ b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu @@ -1,12 +1,13 @@ -#include "common/mask.cuh" -#include "common/utils.hpp" -#include "fmha_cutlass_fwd_sm100.cuh" +#include "interface.h" -#include #include #include #include -#include + +#include "common/mask.cuh" +#include "common/utils.hpp" + +#include "fmha_cutlass_fwd_sm100.cuh" template void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_varlen, diff --git a/csrc/sm100/fmha_cutlass_fwd_sm100.cuh b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh similarity index 100% rename from csrc/sm100/fmha_cutlass_fwd_sm100.cuh rename to csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh diff --git a/csrc/sm100/pybind.cu b/csrc/sm100/prefill/dense/interface.h similarity index 84% rename from csrc/sm100/pybind.cu rename to csrc/sm100/prefill/dense/interface.h index 7d4744d..80ef2bc 100644 --- a/csrc/sm100/pybind.cu +++ b/csrc/sm100/prefill/dense/interface.h @@ -1,4 +1,6 @@ -#include +#pragma once + +#include void FMHACutlassSM100FwdRun(at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, @@ -10,8 +12,3 @@ void FMHACutlassSM100BwdRun(at::Tensor workspace_buffer, at::Tensor d_o, at::Ten at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, int mask_mode_code, float softmax_scale, int max_seqlen_q, int max_seqlen_kv, bool is_varlen); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fwd", &FMHACutlassSM100FwdRun); - m.def("bwd", &FMHACutlassSM100BwdRun); -} diff --git a/csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp b/csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp similarity index 100% rename from csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp similarity index 97% rename from csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp index 32e007c..9a25ff3 100644 --- a/csrc/sm100/kernel/fmha_kernel_bwd_convert.hpp +++ b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp @@ -34,6 +34,7 @@ #include "cutlass/cutlass.h" #include "cute/layout.hpp" +#include "utils.h" // for IS_SM100 namespace cutlass::fmha::kernel { @@ -138,6 +139,7 @@ struct FmhaKernelBwdConvert { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if IS_SM100 if (params.ptr_src_dQ != nullptr) { copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); } @@ -147,6 +149,11 @@ struct FmhaKernelBwdConvert { if (params.ptr_src_dV != nullptr) { copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape)); } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } }; diff --git a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp similarity index 97% rename from csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp index db6a9b4..07ae4f2 100644 --- a/csrc/sm100/kernel/fmha_kernel_bwd_sum_OdO.hpp +++ b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -34,6 +34,7 @@ #include "cutlass/cutlass.h" #include "cute/layout.hpp" +#include "utils.h" // for IS_SM100 namespace cutlass::fmha::kernel { @@ -104,6 +105,7 @@ struct FmhaKernelBwdSumOdO { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if IS_SM100 auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); @@ -155,6 +157,11 @@ struct FmhaKernelBwdSumOdO { } } } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } }; diff --git a/csrc/sm100/kernel/fmha_options.hpp b/csrc/sm100/prefill/dense/kernel/fmha_options.hpp similarity index 100% rename from csrc/sm100/kernel/fmha_options.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_options.hpp diff --git a/csrc/sm100/kernel/fmha_tile_scheduler.hpp b/csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp similarity index 100% rename from csrc/sm100/kernel/fmha_tile_scheduler.hpp rename to csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 59b410b..057b45e 100644 --- a/csrc/sm100/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -41,7 +41,8 @@ #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "collective/fmha_common.hpp" +#include "utils.h" // for IS_SM100 +#include "../collective/fmha_common.hpp" #include @@ -1499,6 +1500,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if IS_SM100 int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); @@ -1823,6 +1825,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { /* no-op */ } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } static dim3 get_block_shape() { diff --git a/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp similarity index 99% rename from csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp index 5a58157..0d4af85 100644 --- a/csrc/sm100/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -41,7 +41,8 @@ #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "collective/fmha_common.hpp" +#include "utils.h" // for IS_SM100 +#include "../collective/fmha_common.hpp" #include @@ -1492,6 +1493,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if IS_SM100 int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); @@ -1816,6 +1818,11 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { /* no-op */ } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } static dim3 get_block_shape() { diff --git a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp similarity index 98% rename from csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp rename to csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index 43bb035..ef75280 100644 --- a/csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -37,11 +37,12 @@ #include "cutlass/pipeline/pipeline.hpp" #include "cute/arch/tmem_allocator_sm100.hpp" -#include "kernel/fmha_options.hpp" -#include "kernel/fmha_tile_scheduler.hpp" -#include "kernel/fmha_causal_tile_scheduler.hpp" -#include "collective/fmha_fusion.hpp" -#include "collective/fmha_common.hpp" +#include "utils.h" // for IS_SM100 +#include "../kernel/fmha_options.hpp" +#include "../kernel/fmha_tile_scheduler.hpp" +#include "../kernel/fmha_causal_tile_scheduler.hpp" +#include "../collective/fmha_fusion.hpp" +#include "../collective/fmha_common.hpp" namespace cutlass::fmha::kernel { @@ -251,6 +252,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if IS_SM100 TileScheduler tile_scheduler{params.tile_scheduler}; @@ -629,6 +631,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { /* no-op, donate regs and exit */ } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100\n"); + } +#endif } }; diff --git a/csrc/sm90/kernels/config.h b/csrc/sm90/decode/dense/config.h similarity index 78% rename from csrc/sm90/kernels/config.h rename to csrc/sm90/decode/dense/config.h index c9ce159..e97e0bc 100644 --- a/csrc/sm90/kernels/config.h +++ b/csrc/sm90/decode/dense/config.h @@ -8,6 +8,4 @@ static constexpr int PAGE_BLOCK_SIZE = 64; static constexpr int HEAD_DIM_K = 576; static constexpr int HEAD_DIM_V = 512; -static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5; - } diff --git a/csrc/sm90/kernels/splitkv_mla.cu b/csrc/sm90/decode/dense/splitkv_mla.cu similarity index 97% rename from csrc/sm90/kernels/splitkv_mla.cu rename to csrc/sm90/decode/dense/splitkv_mla.cu index 5e1fded..cb2e476 100644 --- a/csrc/sm90/kernels/splitkv_mla.cu +++ b/csrc/sm90/decode/dense/splitkv_mla.cu @@ -1,20 +1,22 @@ #include -#include "params.h" #include "utils.h" + +#include "params.h" #include "config.h" #include "traits.h" using namespace cute; using cutlass::arch::NamedBarrier; +namespace sm90 { + // Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking // The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2) // so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM static constexpr float MAX_INIT_VAL_SM = -1e30f; static constexpr float MAX_INIT_VAL = -1e33f; - __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a @@ -756,7 +758,7 @@ __forceinline__ __device__ void wg0_subroutine( TMABarrier barriers_K1[9], bool &cur_phase_K0, const TMAParams &tma_params, - const Flash_fwd_mla_params ¶ms, + const DecodingParams ¶ms, int* block_table_ptr, int seqlen_k, int block_idx, @@ -868,7 +870,7 @@ __forceinline__ __device__ void wg1_subroutine( TMABarrier barriers_K1[9], bool &cur_phase_K1, const TMAParams &tma_params, - const Flash_fwd_mla_params ¶ms, + const DecodingParams ¶ms, int* block_table_ptr, int seqlen_k, int block_idx, @@ -943,7 +945,7 @@ __forceinline__ __device__ void wg1_subroutine( } // A helper function for determining the length of the causal mask for one q token -__forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params ¶ms, int m_block_idx, int local_seq_q_idx) { +__forceinline__ __device__ int get_mask_len(const DecodingParams ¶ms, int m_block_idx, int local_seq_q_idx) { int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx; if (global_seq_q_idx < params.q_seq_per_hk) { int s_q_idx = global_seq_q_idx / params.q_head_per_hk; @@ -956,7 +958,7 @@ __forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params ¶ms, template __global__ void __launch_bounds__(T::NUM_THREADS, 1, 1) -flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params, __grid_constant__ const TmaParams tma_params) { +flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) { // grid shape: [ // num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))), // num_kv_heads, @@ -966,6 +968,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params // If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx]) // For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file). +#if IS_SM90 const int m_block_idx = blockIdx.x; const int k_head_idx = blockIdx.y; const int partition_idx = blockIdx.z; @@ -1018,11 +1021,11 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; // We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race. - int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); + int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); int begin_idx = tile_scheduler_metadata.x; - int begin_seqlen = tile_scheduler_metadata.y; + int sched_begin_block_idx = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; - int end_seqlen = tile_scheduler_metadata.w; + int sched_end_block_idx = tile_scheduler_metadata.w; if (begin_idx >= params.b) return; int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); @@ -1034,9 +1037,9 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params constexpr int kBlockN = T::PAGE_BLOCK_SIZE; const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); - const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0; - int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); - const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN); + const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0; + int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(seqlen_k, kBlockN); + const bool is_no_split = __ldg(params.num_splits_ptr + batch_idx + 1) - __ldg(params.num_splits_ptr + batch_idx) == 1; int rRightBorderForQSeq[2]; if (params.is_causal) { @@ -1057,7 +1060,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params // Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks // NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1); - end_block_idx = batch_idx == end_idx ? cute::ceil_div(min(end_seqlen, seqlen_k-common_mask_len), kBlockN) : cute::ceil_div(seqlen_k-common_mask_len, kBlockN); + int last_block_in_seq = cute::ceil_div(seqlen_k-common_mask_len, kBlockN); + end_block_idx = batch_idx == end_idx ? min(sched_end_block_idx, last_block_in_seq) : last_block_in_seq; CUTLASS_PRAGMA_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { @@ -1267,11 +1271,16 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (batch_idx != end_idx) __syncthreads(); } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); + } +#endif } template -void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { +void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream) { using T = Traits; auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); auto tma_Q = cute::make_tma_copy( @@ -1347,8 +1356,10 @@ void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t str CHECK_CUDA_KERNEL_LAUNCH(); } -template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream); #ifndef FLASH_MLA_DISABLE_FP16 -template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream); #endif + +} diff --git a/csrc/sm90/decode/dense/splitkv_mla.h b/csrc/sm90/decode/dense/splitkv_mla.h new file mode 100644 index 0000000..6d45cfa --- /dev/null +++ b/csrc/sm90/decode/dense/splitkv_mla.h @@ -0,0 +1,10 @@ +#pragma once + +#include "params.h" + +namespace sm90 { + +template +void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream); + +} diff --git a/csrc/sm90/kernels/traits.h b/csrc/sm90/decode/dense/traits.h similarity index 100% rename from csrc/sm90/kernels/traits.h rename to csrc/sm90/decode/dense/traits.h diff --git a/csrc/sm90/decode/sparse_fp8/components/config.h b/csrc/sm90/decode/sparse_fp8/components/config.h new file mode 100644 index 0000000..bdba0b8 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/config.h @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include + +using bf16 = cutlass::bfloat16_t; +using fp8 = cutlass::float_e4m3_t; +using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; + +using namespace cute; + +static constexpr int NUM_THREADS = 128*3; +static constexpr int BLOCK_M = 64; +static constexpr int TOPK_BLOCK_SIZE = 64; +static constexpr int PAGE_BLOCK_SIZE = 64; +static constexpr int QUANT_TILE_SIZE = 128; + +static constexpr int HEAD_DIM_K = 576; +static constexpr int HEAD_DIM_V = 512; +static constexpr int HEAD_DIM_NOPE = HEAD_DIM_V; +static constexpr int HEAD_DIM_ROPE = HEAD_DIM_K - HEAD_DIM_V; +static constexpr int NUM_SCALES = HEAD_DIM_NOPE / QUANT_TILE_SIZE; +static constexpr int NUM_BYTES_PER_TOKEN = HEAD_DIM_NOPE + NUM_SCALES*sizeof(float) + HEAD_DIM_ROPE*sizeof(bf16); + +static constexpr int NUM_K_BUFS = 2; + +using SmemLayoutQTile = decltype(tile_to_shape( + GMMA::Layout_SW128_Atom{}, + Shape, Int<64>>{} +)); + +template +using SmemLayoutQTiles = decltype(tile_to_shape( + SmemLayoutQTile{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +)); + +using SmemLayoutQ = SmemLayoutQTiles<9>; + +using SmemLayoutKTile = decltype(tile_to_shape( + GMMA::Layout_INTER_Atom{}, + Shape, _64>{}, + Step<_1, _2>{} +)); + +template +using SmemLayoutKTiles = decltype(tile_to_shape( + SmemLayoutKTile{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +)); + +template +using SmemLayoutKTilesTransposed = decltype(composition( + SmemLayoutKTiles{}, + Layout, Int>, Stride, _1>>{} +)); + +using SmemLayoutOBuf = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} +)); + +using SmemLayoutOAccumBuf = Layout< + Shape, Int>, + Stride, _1> // We use stride = 520 here to avoid bank conflict +>; + +using SmemLayoutK = SmemLayoutKTiles<9>; +using SmemLayoutV = SmemLayoutKTilesTransposed<8>; +using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>; + +using SmemLayoutS = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} +)); + +struct SharedMemoryPlan { + array_aligned> q; + union { + array_aligned> k[NUM_K_BUFS]; + array_aligned> oBuf; + array_aligned> oAccumBuf; + } u; + array_aligned> s; + bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE]; + + float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M]; + transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS]; +}; + +template< + typename Shape_Q, typename TMA_Q, + typename Shape_O, typename TMA_O +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + Shape_O shape_O; TMA_O tma_O; +}; + +using TiledMMA_QK = decltype(make_tiled_mma( + GMMA::MMA_64x64x16_F32BF16BF16_SS{}, + Layout>{} +)); + +using TiledMMA_QK_rQ = decltype(make_tiled_mma( + GMMA::MMA_64x64x16_F32BF16BF16_RS{}, + Layout>{} +)); + +using TiledMMA_PV_LocalP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_RS{}, + Layout>{} +)); + +using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_SS{}, + Layout>{} +)); diff --git a/csrc/sm90/decode/sparse_fp8/components/dequant.h b/csrc/sm90/decode/sparse_fp8/components/dequant.h new file mode 100644 index 0000000..c3efc05 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/dequant.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include + +struct fp8x8 { + __nv_fp8x4_e4m3 lo; + __nv_fp8x4_e4m3 hi; +}; + +struct fp8x16 { + fp8x8 lo; + fp8x8 hi; +}; + +struct bf16x8 { + __nv_bfloat162 a, b, c, d; +}; + +__device__ __forceinline__ +bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) { + __nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale); + + #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \ + { \ + float4 fp32x4 = (float4)(FP8x4); \ + OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \ + OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \ + } + + bf16x8 result; + DEQUANT_FP8x4(result.a, result.b, inputs.lo); + DEQUANT_FP8x4(result.c, result.d, inputs.hi); + + return result; +} + +enum class L1CacheHint { + NO_ALLOCATE, + EVICT_FIRST, + EVICT_NORMAL, + EVICT_LAST +}; + +enum class L2PrefetchHint { + B64, + B128, + B256 +}; + +template< + typename T, + L1CacheHint l1_cache_hint, + L2PrefetchHint l2_prefetch_hint +> +__device__ __forceinline__ +T load_128b_from_gmem(const void* addr) { + static_assert(sizeof(T) == 128/8); + int4 ret; + + #define EXEC(L1_HINT_STR, L2_HINT_STR) { \ + asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v4.s32 {%0, %1, %2, %3}, [%4];" \ + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) \ + : "l"(addr)); \ + } + + #define DISPATCH_L2(L1_HINT_STR) { \ + if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \ + EXEC(L1_HINT_STR, "64B") \ + else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \ + EXEC(L1_HINT_STR, "128B") \ + else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \ + EXEC(L1_HINT_STR, "256B") \ + } + + if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE) + DISPATCH_L2("no_allocate") + else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST) + DISPATCH_L2("evict_first") + else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL) + DISPATCH_L2("evict_normal") + else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST) + DISPATCH_L2("evict_last") + + #undef EXEC + #undef DISPATCH_L2 + return *reinterpret_cast(&ret); +} diff --git a/csrc/sm90/decode/sparse_fp8/components/epilogue.h b/csrc/sm90/decode/sparse_fp8/components/epilogue.h new file mode 100644 index 0000000..038cbfd --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/epilogue.h @@ -0,0 +1,87 @@ +#pragma once + +#include "named_barriers.h" + +// Store O / OAccum +template< + bool IS_NO_SPLIT, + typename TMAParams, + typename Tensor0, + typename Tensor1, + typename Tensor2, + typename Tensor3 +> +__forceinline__ __device__ void store_o( + Tensor0 &rO, // ((2, 2, 32), 1, 1) + Tensor1 &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) + Tensor2 &sOutputBuf, + Tensor3 &sOutputAccumBuf, + float rL[2], + TMAParams &tma_params, + int batch_idx, + int s_q_idx, + int head_block_idx, + int num_valid_seq_q, + int warpgroup_idx, + int idx_in_warpgroup +) { + using cutlass::arch::NamedBarrier; + if constexpr (IS_NO_SPLIT) { + // Should convert the output to bfloat16 / float16, and save it to O + Tensor rOb = make_tensor_like(rO); + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); ++idx) { + rOb(idx) = (bf16)(rO(idx) / rL[idx%4 >= 2]); + } + + Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); + TiledCopy r2s_tiled_copy = make_tiled_copy_C( + Copy_Atom{}, + TiledMMA_PV_LocalP{} + ); + ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup); + Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb); + Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf); + cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf); + cutlass::arch::fence_view_async_shared(); + + NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); + + if (threadIdx.x == 0) { + Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx); + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, head_block_idx, _0{}); + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sOutputBuf), + thr_tma.partition_D(my_tma_gO) + ); + cute::tma_store_arrive(); + } + } else { + // Should save the result to OAccum + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); idx += 2) { + int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); + int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; + *(float2*)(&(sOutputAccumBuf(row, col))) = float2 { + rO(idx) / rL[idx%4 >= 2], + rO(idx+1) / rL[idx%4 >= 2], + }; + } + cutlass::arch::fence_view_async_shared(); + + NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); + + if (elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) { + int row = local_row * (256/32) + (threadIdx.x / 32); + if (row < num_valid_seq_q) { + SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float)); + } + } + cute::tma_store_arrive(); + } + } +} diff --git a/csrc/sm90/decode/sparse_fp8/components/helpers.h b/csrc/sm90/decode/sparse_fp8/components/helpers.h new file mode 100644 index 0000000..8a336ea --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/helpers.h @@ -0,0 +1,86 @@ +#pragma once + +// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx +// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a +__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { + int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); + return row_idx; +} + + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +template< + typename TMA, + typename Tensor0, + typename Tensor1 +> +CUTE_DEVICE +void launch_tma_copy( + const TMA &tma_copy, + const Tensor0 &src, + Tensor1 &dst, + transac_bar_t &bar, + const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL, + const uint16_t &multicast_mask = 0 +) { + auto thr_tma = tma_copy.get_slice(_0{}); + cute::copy( + tma_copy.with(reinterpret_cast(bar), multicast_mask, cache_hint), + thr_tma.partition_S(src), + thr_tma.partition_D(dst) + ); +} + +template +CUTE_DEVICE +static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { + long2 data_long2 = *reinterpret_cast(&data); + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + asm volatile ( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" + : + : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) + ); +} + +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +template +CUTE_DEVICE +T* get_peer_addr(const T* p) { + return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); +} diff --git a/csrc/sm90/decode/sparse_fp8/components/named_barriers.h b/csrc/sm90/decode/sparse_fp8/components/named_barriers.h new file mode 100644 index 0000000..b91cb22 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/components/named_barriers.h @@ -0,0 +1,10 @@ +#pragma once + +enum NamedBarriers : uint32_t { + sScale_and_sS_ready = 0, + sScale_and_sS_free = 1, + oBuf_free_and_sL_ready = 2, + epilogue_r2s_ready = 3, + batch_loop_sync = 4, + warpgroup0_sync = 5 +}; diff --git a/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu new file mode 100644 index 0000000..3283413 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu @@ -0,0 +1,614 @@ +#include "splitkv_mla.h" + +#include +#include +#include +#include + +#include "utils.h" +#include "components/config.h" +#include "components/epilogue.h" +#include "components/helpers.h" +#include "components/named_barriers.h" +#include "components/dequant.h" +using namespace cute; + +namespace sm90 { + +static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::NamedBarrier; + +// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction +template< + typename Tensor0, + typename Tensor1 +> +__forceinline__ __device__ void save_rPb_to_sP( + Tensor0 const &rPb, + Tensor1 const &sP, + int idx_in_warpgroup +) { + auto r2s_copy = make_tiled_copy_C( + Copy_Atom{}, + TiledMMA_QK{} + ); + ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_rPb = thr_copy.retile_S(rPb); + Tensor thr_copy_sP = thr_copy.partition_D(sP); + cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); +} + + +// Retrieve rPb (64x64, bfloat16) from sP using the ldmatrix instruction +template< + typename Tensor0, + typename Tensor1 +> +__forceinline__ __device__ void retrieve_rP_from_sP( + Tensor0 &rPb, + Tensor1 const &sP, + int idx_in_warpgroup +) { + TiledCopy s2r_copy = make_tiled_copy_A( + Copy_Atom{}, + TiledMMA_PV_LocalP{} + ); + ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_sP = thr_copy.partition_S(sP); + Tensor thr_copy_rPb = thr_copy.retile_D(rPb); + cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb); +} + + +template< + typename Tensor0, + typename Tensor1, + typename Tensor2 +> +__forceinline__ __device__ void scale_softmax( + Tensor0 &rP, + Tensor1 &rS, + Tensor2 &rO, + float scale_softmax_log2, + float sScale[], + float rM[2], + float rL[2], + bool is_kv_valid[], + int block_idx, + int idx_in_warpgroup +) { + float scale_for_olds[2]; + CUTE_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + Tensor cur_rP = flatten(rP(make_coord(_, local_row_idx, _), _, _)); + Tensor cur_rS = flatten(rS(make_coord(_, local_row_idx, _), _, _)); + Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); + + float cur_max = -INFINITY; + CUTE_UNROLL + for (int i = 0; i < size(cur_rP); ++i) { + if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2]) + cur_rP(i) = -INFINITY; + cur_max = max(cur_max, cur_rP(i)); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + + cur_max *= scale_softmax_log2; + float old_max = rM[local_row_idx]; + rM[local_row_idx] = max(cur_max, old_max); + float scale_for_old = exp2f(old_max - rM[local_row_idx]); + scale_for_olds[local_row_idx] = scale_for_old; + + CUTE_UNROLL + for (int i = 0; i < size(cur_rO); ++i) { + cur_rO(i) *= scale_for_old; + } + + float cur_sum = 0; + CUTE_UNROLL + for (int i = 0; i < size(cur_rP); ++i) { + cur_rP(i) = exp2f(cur_rP(i)*scale_softmax_log2 - rM[local_row_idx]); + cur_rS(i) = (bf16)cur_rP(i); + cur_sum += cur_rP(i); + } + rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; + } + if (idx_in_warpgroup%4 == 0) + *(float2*)(sScale + 2*(idx_in_warpgroup/4)) = *(float2*)(scale_for_olds); +} + +template +__global__ void __launch_bounds__(NUM_THREADS, 1, 2) +flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) { +#if IS_SM90 + const int head_block_idx = blockIdx.x; + const int s_q_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int idx_in_cluster = head_block_idx % 2; + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int idx_in_warpgroup = threadIdx.x % 128; + const int warp_idx = cutlass::canonical_warp_idx_sync(); + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{}); + Tensor sOBuf = make_tensor(make_smem_ptr(plan.u.oBuf.data()), SmemLayoutOBuf{}); + Tensor sOAccumBuf = make_tensor(make_smem_ptr(plan.u.oAccumBuf.data()), SmemLayoutOAccumBuf{}); + Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); + float* sM = plan.sM; + float* sL = plan.sL; + float* sScale = plan.sScale; + + // Prefetch TMA descriptors + if (warp_idx == 0 && elect_one_sync()) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + } + + // Initialize TMA barriers + if (warp_idx == 0 && elect_one_sync()) { + plan.bar_q.init(1); + CUTE_UNROLL + for (int i = 0; i < NUM_K_BUFS; ++i) { + plan.bar_k_local_ready[i].init(128); + plan.bar_k_remote_ready[i].init(1); + plan.bar_k_avail[i].init(4); + } + fence_view_async_shared(); + } + cute::cluster_arrive(); + + bool bar_phase_q = 0; + int bar_phase_k = 0; // Don't use array here to prevent using local memory + + // Programmatic Dependent Launch: Wait for the previous kernel to finish + // Don't use PDL because of compiler bugs! + // cudaGridDependencySynchronize(); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int sched_begin_block_idx = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int sched_end_block_idx = tile_scheduler_metadata.w; + if (begin_idx >= params.b) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + + if (warp_idx == 0 && elect_one_sync()) { + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, begin_idx), + Tile, Int>{} + )(_, _, head_block_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); + } + + cute::cluster_wait(); // Wait for barriers from the other CTA to be ready + + auto get_cur_req_info = [&](int batch_idx) -> std::tuple { + constexpr int kBlockN = TOPK_BLOCK_SIZE; + const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0; + // NOTE TopK attention has nothing to do with causal mask and sliding window + int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(params.topk, kBlockN); + const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(params.topk, kBlockN); + return {start_block_idx, end_block_idx, is_no_split}; + }; + + if (warpgroup_idx == 0) { + cutlass::arch::warpgroup_reg_alloc<192>(); + + TiledMMA tiled_mma_QK = TiledMMA_QK{}; + ThrMMA thr_mma_QK = tiled_mma_QK.get_slice(idx_in_warpgroup); + TiledMMA tiled_mma_PV = TiledMMA_PV_LocalP{}; + ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); + + float rL[2], rM[2]; + Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); + Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); + Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + + rL[0] = rL[1] = 0.0f; + rM[0] = rM[1] = MAX_INIT_VAL; + cute::fill(rO, 0.); + + // Wait for Q + plan.bar_q.wait(bar_phase_q); + bar_phase_q ^= 1; + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS; + Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutHalfV{}); + + // Wait, issue WGMMA + plan.bar_k_local_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + + gemm( + tiled_mma_QK, + thr_mma_QK.partition_fragment_A(sQ), + thr_mma_QK.partition_fragment_B(sK), + rP + ); + + bar_phase_k ^= 1<(); + + // Calculate S = softmax(mask(scale(P))) + if (block_idx != start_block_idx) + NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_free); // Make sure that sScale and sS is free + + // Since in our case TOPK_BLOCK_SIZE == BLOCK_M, so we only need to do OOB checking for the last 2 blocks + scale_softmax(rP, rS, rO, params.scale_softmax_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup); + + // Store S into shared, inform warpgroup 1 + save_rPb_to_sP(rS, sS, idx_in_warpgroup); + fence_view_async_shared(); + + // Issue O += S @ V + gemm( + tiled_mma_PV, + rS, + thr_mma_PV.partition_fragment_B(sV), + rO + ); + + NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_ready); + + cute::warpgroup_wait<0>(); + + plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); + plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); + } + + // Copy the next q + if (warp_idx == 0 && elect_one_sync()) { + if (batch_idx != end_idx) { + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx+1), + Tile, Int>{} + )(_, _, head_block_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); + } else { + cudaTriggerProgrammaticLaunchCompletion(); + } + } + + // Synchronize L and M across warpgroups + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + if (idx_in_warpgroup%4 == 0) { + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + int row = get_AorC_row_idx(i, idx_in_warpgroup); + sL[row] = rL[i]; + sM[row] = rM[i]; + } + } + + // This is a synchronization point for warpgroup 0/1. + // Warpgroup 0 should wait wg 1 for oBuf/oAccumBuf (overlapped with k) to be free + // Warpgroup 1 should wait wg 0 for sL to be ready + NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); + + CUTE_UNROLL + for (int i = 0; i < 2; ++i) + rL[i] = rL[i] == 0.0f ? 1.0f : rL[i]; + + int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M); + int start_seq_idx = s_q_idx*params.q_head_per_hk + head_block_idx*BLOCK_M; + if (is_no_split) { + bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1) + Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( + Shape, Int>{}, + make_stride(params.o_row_stride, _1{}) + )); + float* gSoftmaxLse = (float*)params.softmax_lse_ptr + batch_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1) + + store_o(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL[i]; + gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / (float)M_LOG2E; + } + + cute::tma_store_wait<0>(); + } else { + int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; + int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) + float* gSoftmaxLseAccum = (float*)params.softmax_lseaccum_ptr + split_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1) + Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< + Shape, Int>, + Stride, _1> + >{}); + store_o(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL[i]; + gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i]; + } + + cute::tma_store_wait<0>(); + } + + cute::cluster_sync(); // Must use arrive_and_wait here to prevent overwritting sL while WG1 is writing back its result + } + } else if (warpgroup_idx == 1) { + cutlass::arch::warpgroup_reg_dealloc<160>(); + + TiledMMA tiled_mma_PV = TiledMMA_PV_RemoteP{}; + ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); + Tensor rO = partition_fragment_C(tiled_mma_PV, Shape, Int>{}); + float rL[2]; + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + cute::fill(rO, 0.); + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS; + Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data() + (SmemLayoutV{})(_256{}, _0{})), SmemLayoutHalfV{}); + + // Wait for S and sScale + NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready); + + // Scale O + float cur_scales[2]; + *(float2*)cur_scales = *(float2*)(sScale + (idx_in_warpgroup/4)*2); + CUTE_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); + CUTE_UNROLL + for (int i = 0; i < size(cur_rO); ++i) { + cur_rO(i) *= cur_scales[local_row_idx]; + } + } + + // Issue O += S @ V, and wait + gemm( + tiled_mma_PV, + thr_mma_PV.partition_fragment_A(sS), + thr_mma_PV.partition_fragment_B(sV), + rO + ); + cute::warpgroup_wait<0>(); + + plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); + plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); + + if (block_idx != end_block_idx-1) + NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_free); // Tell WG0 that sScale and sS are available + } + + NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + int row = get_AorC_row_idx(i, idx_in_warpgroup); + rL[i] = sL[row]; + } + + CUTE_UNROLL + for (int i = 0; i < 2; ++i) + rL[i] = rL[i] == 0.0f ? 1.0f : rL[i]; + + int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M); + int start_seq_idx = s_q_idx*params.q_head_per_hk+head_block_idx*BLOCK_M; + if (is_no_split) { + bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1) + Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( + Shape, Int>{}, + make_stride(params.o_row_stride, _1{}) + )); + + store_o(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + cute::tma_store_wait<0>(); + } else { + int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; + int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) + Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< + Shape, Int>, + Stride, _1> + >{}); + store_o(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + cute::tma_store_wait<0>(); + } + + cute::cluster_sync(); // We must use arrive_and_wait instead of arrive here to create an order between "forall warp in WG1, warp has done written back O" and "warp 2 signals `bar_k_avail`" + } + } else { + // Producer warpgroup + cutlass::arch::warpgroup_reg_dealloc<152>(); + + int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0); // NOTE TPBNO + int lane_idx = idx_in_warpgroup % 32; + int my_token_idx = warp_idx*8 + lane_idx%8; + + CUTE_NO_UNROLL + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1) + + #define GET_TOKEN_INDEX(block_idx) __ldg(gIndices + (block_idx)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx) + int nxt_token_index = GET_TOKEN_INDEX(start_block_idx); + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS; + + // Define shared and global tensors + bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE; + bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base); + + transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx])); + int token_index = nxt_token_index; + if (block_idx+1 != end_block_idx) + nxt_token_index = GET_TOKEN_INDEX(block_idx+1); + int block_index = token_index/PAGE_BLOCK_SIZE; + int rel_idx_in_block = (token_index+PAGE_BLOCK_SIZE) % PAGE_BLOCK_SIZE; // NOTE When token_index is -1, -1/PAGE_BLOCK_SIZE = 0 and (-1+PAGE_BLOCK_SIZE)%PAGE_BLOCK_SIZE = 63, so there will be no illegal-memory-access error + fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride; + float4 scales = load_128b_from_gmem((float*)(gK_base+HEAD_DIM_NOPE)); + + // Wait for the nope buffer to be available + plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1); + bar_phase_k ^= 1 << buf_idx; + + // Copy block #block_index + if (idx_in_warpgroup == 0) { + plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16)); + } + + // Collectively copy from global memory and dequant + // For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py + + fp8* gK_nope = gK_base + (lane_idx/8)*16; + if (token_index == -1) { + scales = {0.0f, 0.0f, 0.0f, 0.0f}; + } + CUTE_UNROLL + for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) { + fp8x16 cur_fp8x16 = load_128b_from_gmem(gK_nope + dim_idx*64); // We use EVICT_LAST here since gK_base may not be aligned to 32B + float scale = dim_idx < 4 ? (dim_idx < 2 ? scales.x : scales.y) : (dim_idx < 6 ? scales.z : scales.w); + auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) { + int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE; + bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, scale); + *(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; + st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); + }; + if (token_index == -1) + *(uint128_t*)(&cur_fp8x16) = uint128_t(); + dequant_and_save_bf16x8(cur_fp8x16.lo, 0); + dequant_and_save_bf16x8(cur_fp8x16.hi, 8); + } + + bf16* gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8; + bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE; + bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base); + + CUTE_UNROLL + for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) { + bf16x8 cur_bf16x8 = load_128b_from_gmem(gK_rope + dim_idx*32); + if (token_index == -1) + *(uint128_t*)(&cur_bf16x8) = uint128_t(); + int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE; + *(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; + st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); + } + + fence_view_async_shared(); + + if (idx_in_warpgroup < 32) { + // We put this after fence_view_async_shared() since this won't be read by async proxy + int2 indices = __ldg((int2*)(gIndices + block_idx*TOPK_BLOCK_SIZE + lane_idx*2)); + *(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = {indices.x != -1, indices.y != -1}; + } + + // Signal the barrier + plan.bar_k_local_ready[buf_idx].arrive(); + } + + cute::cluster_sync(); + } + } + + if (begin_idx > end_idx) { + cute::cluster_sync(); // Don't need a cluster_sync() when begin_idx <= end_idx, since the loop will execute at least once and the final statement is cluster_sync() + } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); + } +#endif + +} + + +void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.h_k == 1); + FLASH_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0); + + auto shape_Q = make_shape(params.q_head_per_hk, params.d, params.s_q, params.b); + auto tma_Q = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((bf16*)params.q_ptr), + make_layout( + shape_Q, + make_stride(params.q_row_stride, _1{}, params.q_head_per_hk*params.q_row_stride, params.q_batch_stride) + ) + ), + SmemLayoutQ{} + ); + + auto shape_O = make_shape(params.q_head_per_hk, params.d_v, params.s_q, params.b); + auto tma_O = cute::make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + make_gmem_ptr((bf16*)params.o_ptr), + make_layout( + shape_O, + make_stride(params.o_row_stride, _1{}, params.q_head_per_hk*params.o_row_stride, params.o_batch_stride) + ) + ), + SmemLayoutOBuf{} + ); + + TmaParams< + decltype(shape_Q), decltype(tma_Q), + decltype(shape_O), decltype(tma_O) + > tma_params = { + shape_Q, tma_Q, + shape_O, tma_O + }; + auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + const int num_m_block = cute::ceil_div(params.q_head_per_hk, 2*BLOCK_M) * 2; + // NOTE Don't use PDL because of potential compiler bugs! + // cudaLaunchAttribute mla_kernel_attributes[1]; + // mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + // mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; + // cudaLaunchConfig_t mla_kernel_config = { + // dim3(num_m_block, params.h_k, params.num_sm_parts), + // dim3(NUM_THREADS, 1, 1), + // smem_size, + // stream, + // mla_kernel_attributes, + // 1 + // }; + // cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); + cutlass::ClusterLaunchParams launch_params = { + dim3(num_m_block, params.s_q, params.num_sm_parts), + dim3(NUM_THREADS, 1, 1), + dim3(2, 1, 1), + smem_size, + stream + }; + cutlass::launch_kernel_on_cluster( + launch_params, (void*)mla_kernel, params, tma_params + ); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +} diff --git a/csrc/sm90/decode/sparse_fp8/splitkv_mla.h b/csrc/sm90/decode/sparse_fp8/splitkv_mla.h new file mode 100644 index 0000000..daa21a3 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/splitkv_mla.h @@ -0,0 +1,9 @@ +#pragma once + +#include "params.h" + +namespace sm90 { + +void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams ¶ms, cudaStream_t stream); + +} diff --git a/csrc/sm90/flash_api.cpp b/csrc/sm90/flash_api.cpp deleted file mode 100644 index a87e1ab..0000000 --- a/csrc/sm90/flash_api.cpp +++ /dev/null @@ -1,216 +0,0 @@ -// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#include -#include -#include -#include - -#include - -#include "kernels/config.h" -#include "kernels/get_mla_metadata.h" -#include "kernels/mla_combine.h" -#include "kernels/params.h" -#include "kernels/splitkv_mla.h" - -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -std::vector -get_mla_metadata( - at::Tensor &seqlens_k, - const int num_heads_per_head_k, - const int num_heads_k -) { - CHECK_DEVICE(seqlens_k); - TORCH_CHECK(seqlens_k.is_contiguous()); - TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); - - int batch_size = seqlens_k.size(0); - int *seqlens_k_ptr = seqlens_k.data_ptr(); - auto options = seqlens_k.options(); - - auto dprops = at::cuda::getCurrentDeviceProperties(); - int sm_count = dprops->multiProcessorCount; - int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M); - - auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); - auto num_splits = torch::empty({batch_size + 1}, options); - int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); - int *num_splits_ptr = num_splits.data_ptr(); - - at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - Mla_metadata_params params = {}; - params.seqlens_k_ptr = seqlens_k_ptr; - params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; - params.num_splits_ptr = num_splits_ptr; - params.batch_size = batch_size; - params.block_size_n = Config::PAGE_BLOCK_SIZE; - params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS; - params.num_sm_parts = num_sm_parts; - run_get_mla_metadata_kernel(params, stream); - - return {tile_scheduler_metadata, num_splits}; -} - -std::vector -mha_fwd_kvcache_mla( - at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size - const int head_size_v, - const at::Tensor &seqlens_k, // batch_size - const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq - const float softmax_scale, - bool is_causal, - const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize - const at::Tensor &num_splits // batch_size + 1 -) { - // Check the architecture - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - TORCH_CHECK(is_sm90); - - // Check data types - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); - TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); - TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); - TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); - TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); - - // Check device - CHECK_DEVICE(q); - CHECK_DEVICE(kcache); - CHECK_DEVICE(seqlens_k); - CHECK_DEVICE(block_table); - CHECK_DEVICE(tile_scheduler_metadata); - CHECK_DEVICE(num_splits); - - // Check layout - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - CHECK_CONTIGUOUS(seqlens_k); - TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); - CHECK_CONTIGUOUS(tile_scheduler_metadata); - CHECK_CONTIGUOUS(num_splits); - - const auto sizes = q.sizes(); - const int batch_size = sizes[0]; - const int seqlen_q_ori = sizes[1]; - const int num_heads_q = sizes[2]; - const int head_size_k = sizes[3]; - TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); - TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); - - const int max_num_blocks_per_seq = block_table.size(1); - const int num_blocks = kcache.size(0); - const int page_block_size = kcache.size(1); - const int num_heads_k = kcache.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be postive"); - TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - if (seqlen_q_ori == 1) { is_causal = false; } - - const int num_q_heads_per_hk = num_heads_q / num_heads_k; - const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; - const int num_heads = num_heads_k; - q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) - .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); - - CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); - CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); - CHECK_SHAPE(seqlens_k, batch_size); - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); - TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); - CHECK_SHAPE(num_splits, batch_size+1); - - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); - at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); - CHECK_CONTIGUOUS(softmax_lse); - - Flash_fwd_mla_params params = {}; - // Set the sizes. - params.b = batch_size; - params.s_q = seqlen_q_ori; - params.q_seq_per_hk = q_seq_per_hk; - params.seqlens_k_ptr = seqlens_k.data_ptr(); - params.h_q = num_heads_q; - params.h_k = num_heads_k; - params.num_blocks = num_blocks; - params.q_head_per_hk = num_q_heads_per_hk; - params.is_causal = is_causal; - params.d = head_size_k; - params.d_v = head_size_v; - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = kcache.data_ptr(); - params.o_ptr = out.data_ptr(); - params.softmax_lse_ptr = softmax_lse.data_ptr(); - // All stride are in elements, not bytes. - params.q_batch_stride = q.stride(0); - params.k_batch_stride = kcache.stride(0); - params.o_batch_stride = out.stride(0); - params.q_row_stride = q.stride(-3); - params.k_row_stride = kcache.stride(-3); - params.o_row_stride = out.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = kcache.stride(-2); - params.o_head_stride = out.stride(-2); - - params.block_table = block_table.data_ptr(); - params.block_table_batch_stride = block_table.stride(0); - params.page_block_size = page_block_size; - - params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); - params.num_sm_parts = tile_scheduler_metadata.size(0); - params.num_splits_ptr = num_splits.data_ptr(); - - const int total_num_splits = batch_size + params.num_sm_parts; - at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); - at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); - CHECK_CONTIGUOUS(softmax_lse_accum); - CHECK_CONTIGUOUS(out_accum); - params.total_num_splits = total_num_splits; - params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); - params.oaccum_ptr = out_accum.data_ptr(); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(head_size_k == 576); - if (q_dtype == torch::kBFloat16) { - run_flash_splitkv_mla_kernel(params, stream); - run_flash_mla_combine_kernel(params, stream); - } else if (q_dtype == torch::kHalf) { -#ifdef FLASH_MLA_DISABLE_FP16 - TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); -#else - run_flash_splitkv_mla_kernel(params, stream); - run_flash_mla_combine_kernel(params, stream); -#endif - } else { - TORCH_CHECK(false, "Unsupported tensor dtype for query"); - } - - out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) - .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); - softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) - .reshape({batch_size, num_heads_q, seqlen_q_ori}); - - return {out, softmax_lse}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashMLA"; - m.def("get_mla_metadata", &get_mla_metadata); - m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); -} diff --git a/csrc/sm90/kernels/get_mla_metadata.h b/csrc/sm90/kernels/get_mla_metadata.h deleted file mode 100644 index 5130581..0000000 --- a/csrc/sm90/kernels/get_mla_metadata.h +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "params.h" - -void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/kernels/mla_combine.h b/csrc/sm90/kernels/mla_combine.h deleted file mode 100644 index 69035e9..0000000 --- a/csrc/sm90/kernels/mla_combine.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#include "params.h" - -template -void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/kernels/splitkv_mla.h b/csrc/sm90/kernels/splitkv_mla.h deleted file mode 100644 index 479fb50..0000000 --- a/csrc/sm90/kernels/splitkv_mla.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#include "params.h" - -template -void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/prefill/sparse/fwd.cu b/csrc/sm90/prefill/sparse/fwd.cu new file mode 100644 index 0000000..084e0e2 --- /dev/null +++ b/csrc/sm90/prefill/sparse/fwd.cu @@ -0,0 +1,709 @@ +#include "fwd.h" + +#include +#include +#include +#include +#include +#include + +#include "utils.h" +#include "helpers.h" + +namespace sm90 { + +using namespace cute; + +constexpr int D_Q = 576; +constexpr int D_K = 576; +constexpr int D_V = 512; + +constexpr int B_H = 64; +constexpr int B_TOPK = 64; // TopK block size +constexpr int NUM_THREADS = 128*3; +static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) + +template +using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( + GMMA::Layout_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTilesTransposed = decltype(composition( + SmemLayoutKTiles{}, + Layout, Int>, Stride, _1>>{} +)); + +using SmemLayoutQ = SmemLayoutQTiles<9>; +using SmemLayoutO = SmemLayoutOTiles<8>; +using SmemLayoutK = SmemLayoutKTiles<9>; +using SmemLayoutV = SmemLayoutKTilesTransposed<8>; +using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>; + +using SmemLayoutS = decltype(coalesce(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} +), Shape<_1, _1>{})); + +struct SharedMemoryPlan { + union { + array_aligned> q; + array_aligned> o; + } q_o; + array_aligned> k[2]; + array_aligned> s; + + bool is_kv_valid[2][B_TOPK]; + float2 sM[32]; + float2 sL[64]; // For reduction across WG0/1 in epilogue + float final_max_logits[64], final_lse[64]; + transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready; +}; + +using TiledMMA_QK = decltype(make_tiled_mma( + GMMA::MMA_64x64x16_F32BF16BF16_SS{}, + Layout>{} +)); + +using TiledMMA_PV_LocalP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_RS{}, + Layout>{} +)); + +using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_SS{}, + Layout>{} +)); + +template< + typename Shape_Q, typename TMA_Q +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + CUtensorMap tensor_map_O; +}; + +enum NamedBarriers : uint32_t { + wg0_bunch_0_ready = 0, + wg1_bunch_0_ready = 1, + wg0_s0_ready = 2, + wg1_s1_ready = 3, + sL_ready = 4, + warpgroup0_sync = 5, + warpgroup1_sync = 6 +}; + +// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction +template< + typename Tensor0, + typename Tensor1 +> +__forceinline__ __device__ void save_rS_to_sS( + Tensor0 const &rPb, + Tensor1 const &sP, + int idx_in_warpgroup +) { + auto r2s_copy = make_tiled_copy_C( + Copy_Atom{}, + TiledMMA_QK{} + ); + ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_rPb = thr_copy.retile_S(rPb); + Tensor thr_copy_sP = thr_copy.partition_D(sP); + cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); +} + + +template +__global__ void __launch_bounds__(NUM_THREADS, 1, 1) +sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) { + // NOTE This kernel uses a similar schedule to Flash MLA - 0422. For a detailed explanation, please refer to https://github.com/deepseek-ai/FlashMLA/blob/main/docs/20250422-new-kernel-deep-dive.md +#if IS_SM90 + const int q_h_idx = blockIdx.x % (params.h_q/B_H); + const int s_q_idx = blockIdx.x / (params.h_q/B_H); + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int warp_idx = cutlass::canonical_warp_idx_sync(); + const int idx_in_warpgroup = threadIdx.x % 128; + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{}); + Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{}); + Tensor sS0 = make_tensor(make_smem_ptr(plan.k[0].data()+64*512), SmemLayoutS{}); // Overlap with sK0's RoPE part + Tensor sS1 = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); + + if (warp_idx == 0 && elect_one_sync()) { + // Prefetch TMA descriptors + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(&tma_params.tensor_map_O); + + // Initialize barriers + plan.bar_q.init(1); + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + plan.bar_k0_free[i].init(128); + plan.bar_k0_ready[i].init(128); + plan.bar_k1_free[i].init(128); + plan.bar_k1_ready[i].init(128); + } + plan.bar_is_kv_valid_ready.init(16); + fence_barrier_init(); + } + + __syncthreads(); + + const int num_topk_blocks = params.topk / B_TOPK; + if (warpgroup_idx == 0 || warpgroup_idx == 1) { + cutlass::arch::warpgroup_reg_alloc<216>(); + + if (warp_idx == 0 && elect_one_sync()) { + // Load Q + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx), + Tile, Int>{} + )(_, _, q_h_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16)); + } + + float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation + float rL[2] = {0.0f, 0.0f}; + Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); + Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); + Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); + cute::fill(rO, 0.0f); + + // Wait for Q + plan.bar_q.wait(0); + + bool cur_bar_wait_phase = 0; + + struct Warpgroup0 {}; + struct Warpgroup1 {}; + + auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) { + constexpr bool IS_WG1 = std::is_same_v; + TiledMMA tiled_mma_QK = TiledMMA_QK{}; + Tensor sQ_tile = flat_divide(sQ, Tile, Int<64>>{})(_, _, _0{}, tile_idx); + Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{}); + gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup); + }; + + auto mask_rP = [&](auto warpgroup_idx) { + constexpr bool IS_WG1 = std::is_same_v; + plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); + CUTE_UNROLL + for (int row_idx = 0; row_idx < 2; ++row_idx) { + CUTE_UNROLL + for (int i = row_idx*2; i < size(rP); i += 4) { + int col = 8*(i/4) + (idx_in_warpgroup%4)*2; + if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY; + if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY; + } + } + }; + + auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) { + plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); + constexpr bool IS_WG1 = std::is_same_v; + const float scale = params.sm_scale_div_log2; + float r_sM[2]; + if constexpr (IS_WG1) { + *(float2*)r_sM = plan.sM[idx_in_warpgroup/4]; + } + float new_maxs[2]; + CUTE_UNROLL + for (int row_idx = 0; row_idx < 2; ++row_idx) { + // Get rowwise max + float cur_max = -INFINITY; + CUTE_UNROLL + for (int i = row_idx*2; i < size(rP); i += 4) { + cur_max = max(cur_max, max(rP(i), rP(i+1))); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + cur_max *= scale; + + // Get new max and scale + // For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round) + new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max); + + // Scale O + float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]); + CUTE_UNROLL + for (int i = row_idx*2; i < size(rO); i += 4) { + rO(i) *= scale_for_o; + rO(i+1) *= scale_for_o; + } + + // Get rS + float cur_sum = 0; + CUTE_UNROLL + for (int i = row_idx*2; i < size(rP); i += 4) { + rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]); + rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]); + rS(i) = (bf16)rP(i); + rS(i+1) = (bf16)rP(i+1); + cur_sum += rP(i) + rP(i+1); + } + rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum; + } + __syncwarp(); + if (idx_in_warpgroup%4 == 0) { + plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs; + } + rM[0] = new_maxs[0]; + rM[1] = new_maxs[1]; + }; + + auto reduce_L = [&]() { + // Reduce L + // For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131 + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + if (idx_in_warpgroup%4 == 0) + plan.sL[threadIdx.x/4] = *(float2*)(rL); + NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready); + float2 peer_L = plan.sL[(threadIdx.x/4)^32]; + rL[0] += peer_L.x; + rL[1] += peer_L.y; + }; + + auto store_O = [&]() { + float scale_factors[2]; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) + scale_factors[i] = rL[i] == 0.0f ? 1.0f : 1.0f / rL[i]; + + Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{}); + bf16* stsm_addrs[4]; + int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16); + CUTE_UNROLL + for (int i = 0; i < 64/16; ++i) { + stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i); + } + bool s2g_pred = warp_idx%4 == 0 && elect_one_sync(); + + warpgroup_wait<0>(); + CUTE_UNROLL + for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) { + // Convert + constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128; // 64: tile size, 128: warpgroup size + bf16 cur_rOb[NUM_ELEMS_EACH_TILE]; + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) { + cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]); + } + // R -> S + CUTE_UNROLL + for (int i = 0; i < 64/16; ++i) { + SM90_U32x4_STSM_N::copy( + *reinterpret_cast(cur_rOb + i*8 + 0), + *reinterpret_cast(cur_rOb + i*8 + 2), + *reinterpret_cast(cur_rOb + i*8 + 4), + *reinterpret_cast(cur_rOb + i*8 + 6), + *reinterpret_cast(stsm_addrs[i] + tile_idx*(B_H*64)) + ); + } + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync); + // S -> G + if (s2g_pred) { + int g_tile_idx = warpgroup_idx*4 + tile_idx; + SM90_TMA_STORE_3D::copy( + &tma_params.tensor_map_O, + plan.q_o.o.data() + g_tile_idx*(B_H*64), + g_tile_idx*64, + q_h_idx*B_H, + s_q_idx + ); + } + } + cute::tma_store_arrive(); + }; + + + if (warpgroup_idx == 0) { + // Warpgroup 0 + + auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) { + plan.bar_k0_ready[0].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup0{}, 0, true); + qkt_gemm_one_tile(Warpgroup0{}, 1, false); + qkt_gemm_one_tile(Warpgroup0{}, 2, false); + qkt_gemm_one_tile(Warpgroup0{}, 3, false); + warpgroup_commit_batch(); + }; + + auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) { + plan.bar_k0_ready[1].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup0{}, 4, false); + qkt_gemm_one_tile(Warpgroup0{}, 5, false); + qkt_gemm_one_tile(Warpgroup0{}, 6, false); + qkt_gemm_one_tile(Warpgroup0{}, 7, false); + qkt_gemm_one_tile(Warpgroup0{}, 8, false); + warpgroup_commit_batch(); + }; + + auto scale_rS = [&](float scales[2]) { + CUTE_UNROLL + for (int row = 0; row < 2; ++row) { + CUTE_UNROLL + for (int i = row*2; i < size(rP); i += 4) { + rS(i) = (bf16)(rP(i) * scales[row]); + rS(i+1) = (bf16)(rP(i+1) * scales[row]); + } + } + }; + + auto rescale_rO = [&](float scales[2]) { + CUTE_UNROLL + for (int row = 0; row < 2; ++row) { + CUTE_UNROLL + for (int i = row*2; i < size(rO); i += 4) { + rO(i) *= scales[row]; + rO(i+1) *= scales[row]; + } + rL[row] *= scales[row]; + } + }; + + CUTE_NO_UNROLL + for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{}); + Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{}); + + if (block_idx == 0) { + // NOTE We put these code here to avoid register spilling + pipelined_wait_and_qkt_gemm_l(); + pipelined_wait_and_qkt_gemm_r(); + warpgroup_wait<0>(); + } + + // Online softmax, inform WG1 + mask_rP(Warpgroup0{}); + + online_softmax_and_rescale_o(Warpgroup0{}); + NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready); + + // Issue rO0 += rS0 @ sV0l + gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + // Mark V0L as free + warpgroup_wait<0>(); + plan.bar_k0_free[0].arrive(); + + // Wait for new sM, scale rS, save, inform WG1 + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready); + float new_rM[2], scale_factors[2]; + *(float2*)new_rM = plan.sM[idx_in_warpgroup/4]; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + scale_factors[i] = exp2f(rM[i] - new_rM[i]); + rM[i] = new_rM[i]; + } + scale_rS(scale_factors); + save_rS_to_sS(rS, sS0, idx_in_warpgroup); + fence_view_async_shared(); + NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready); + + // Wait for sS1 + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready); + + // Rescale rO0, Issue rO0 += sS1 @ sV1L + rescale_rO(scale_factors); + gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + cur_bar_wait_phase ^= 1; + + if (block_idx+2 < num_topk_blocks) { + // Launch the next QK^T GEMM + pipelined_wait_and_qkt_gemm_l(); + + // Mark V1L as free + warpgroup_wait<1>(); + plan.bar_k1_free[0].arrive(); + pipelined_wait_and_qkt_gemm_r(); + + // Wait for rP0 = sQ @ sK0 + warpgroup_wait<0>(); + } else { + // Mark V1L as free + warpgroup_wait<0>(); + plan.bar_k1_free[0].arrive(); + } + } + + reduce_L(); + store_O(); + } else { + // Warpgroup 1 + + auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) { + plan.bar_k1_ready[1].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup1{}, 4, true); + qkt_gemm_one_tile(Warpgroup1{}, 5, false); + qkt_gemm_one_tile(Warpgroup1{}, 6, false); + qkt_gemm_one_tile(Warpgroup1{}, 7, false); + qkt_gemm_one_tile(Warpgroup1{}, 8, false); + plan.bar_k1_ready[0].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup1{}, 0, false); + qkt_gemm_one_tile(Warpgroup1{}, 1, false); + qkt_gemm_one_tile(Warpgroup1{}, 2, false); + qkt_gemm_one_tile(Warpgroup1{}, 3, false); + warpgroup_commit_batch(); + }; + + CUTE_NO_UNROLL + for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{}); + Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{}); + + // Issue rP1 = sQ @ sK1, and wait + pipelined_wait_and_qkt_gemm(); + warpgroup_wait<0>(); + + mask_rP(Warpgroup1{}); + + // Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready) + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready); + online_softmax_and_rescale_o(Warpgroup1{}); + NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready); + + + // Issue rO1 += rS1 @ sV1R + gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + // Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R + save_rS_to_sS(rS, sS1, idx_in_warpgroup); // Put it here is faster + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready); + gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + // Save rS1, inform WG0 + fence_view_async_shared(); + NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready); + + // Wait for GEMM, and inform that sV1R is free + warpgroup_wait<1>(); + plan.bar_k1_free[1].arrive(); + + // Wait for GEMM, and inform that sV0R is free + warpgroup_wait<0>(); + plan.bar_k0_free[1].arrive(); + + cur_bar_wait_phase ^= 1; + } + + reduce_L(); + store_O(); + + // Save lse + if (idx_in_warpgroup%4 == 0) { + for (int row = 0; row < 2; ++row) { + int real_row = get_AorC_row_idx(row, idx_in_warpgroup); + bool is_no_valid_tokens = rL[row] == 0.0f; + plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row]; + plan.final_lse[real_row] = is_no_valid_tokens ? -INFINITY : log2f(rL[row]) + rM[row]; + } + fence_view_async_shared(); + } + + NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync); + if (idx_in_warpgroup == 0) { + int g_offset = s_q_idx*params.h_q + q_h_idx*B_H; + SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float)); + SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float)); + cute::tma_store_arrive(); + } + } + } else { + // Producer warpgroup + cutlass::arch::warpgroup_reg_dealloc<72>(); + + constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE; + constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS; + int idx_in_group = idx_in_warpgroup % GROUP_SIZE; + int group_idx = idx_in_warpgroup / GROUP_SIZE; + int* gIndices = params.indices + s_q_idx*params.topk; // [topk] + + bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8)); + bf16* my_gKV_base = params.kv + idx_in_group*8; + + int64_t token_indices[2][NUM_ROWS_PER_GROUP]; + bool is_token_valid[2][NUM_ROWS_PER_GROUP]; + auto load_token_indices = [&](int block_idx) { + CUTE_UNROLL + for (int buf_idx = 0; buf_idx < 2; ++buf_idx) { + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { + int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx; + int t = __ldg(gIndices + offs); + token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster + is_token_valid[buf_idx][local_row] = t >= 0 && t < params.s_kv; + } + } + }; + + int64_t cache_policy = createpolicy_evict_last(); + auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) { + // Copy some K/V tiles from global memory to shared memory + // A tile has a shape of 64 (B_TOPK) x 64 + // `buf_idx` is the index of the shared memory buffer, 0 or 1 + // `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8 + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { + int64_t token_index = token_indices[buf_idx][local_row]; + CUTE_UNROLL + for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) { + cp_async_cacheglobal_l2_prefetch_256B( + my_gKV_base + token_index + tile_idx*64, + my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64), + is_token_valid[buf_idx][local_row], + cache_policy + ); + } + } + }; + + auto commit_to_mbar = [&](transac_bar_t &bar) { + cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar)); + }; + + int cur_bar_wait_phase = 1; + + CUTE_NO_UNROLL + for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + load_token_indices(block_idx); + + // V0L + plan.bar_k0_free[0].wait(cur_bar_wait_phase); + copy_tiles(block_idx+0, 0, 0, 4); + commit_to_mbar(plan.bar_k0_ready[0]); + + // V1R + plan.bar_k1_free[1].wait(cur_bar_wait_phase); + copy_tiles(block_idx+1, 1, 4, 9); + commit_to_mbar(plan.bar_k1_ready[1]); + + // V0R + plan.bar_k0_free[1].wait(cur_bar_wait_phase); + copy_tiles(block_idx+0, 0, 4, 9); + commit_to_mbar(plan.bar_k0_ready[1]); + + // V1L + plan.bar_k1_free[0].wait(cur_bar_wait_phase); + copy_tiles(block_idx+1, 1, 0, 4); + commit_to_mbar(plan.bar_k1_ready[0]); + + // Valid mask + // NOTE V1R's finish implies maskings of the last round have finished + if (idx_in_group == 0) { + CUTE_UNROLL + for (int buf_idx = 0; buf_idx < 2; ++buf_idx) + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) + plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row]; + plan.bar_is_kv_valid_ready.arrive(); + } + + cur_bar_wait_phase ^= 1; + } + } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); + } +#endif +} + + +void run_fwd_kernel(const SparsePrefillParams& params) { + FLASH_ASSERT(params.h_kv == 1); + FLASH_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings + FLASH_ASSERT(params.topk > 0); + FLASH_ASSERT(params.h_q % B_H == 0); + + auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); + auto tma_Q = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((bf16*)params.q), + make_layout( + shape_Q, + make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) + ) + ), + SmemLayoutQ{} + ); + + CUtensorMap tensor_map_O; + { + uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q}; + uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)}; + uint32_t box_size[3] = {64, B_H, 1}; + uint32_t elem_stride[3] = {1, 1, 1}; + CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tensor_map_O, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + 3, + params.out, + size, + stride, + box_size, + elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + FLASH_ASSERT(res == CUresult::CUDA_SUCCESS); + } + + TmaParams< + decltype(shape_Q), decltype(tma_Q) + > tma_params = { + shape_Q, tma_Q, + tensor_map_O + }; + auto kernel = &sparse_attn_fwd_kernel; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + cutlass::ClusterLaunchParams launch_params = { + dim3((params.h_q/B_H)*params.s_q, 1, 1), // NOTE We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z) + dim3(NUM_THREADS, 1, 1), + dim3(1, 1, 1), + smem_size, + params.stream + }; + cutlass::launch_kernel_on_cluster( + launch_params, (void*)kernel, params, tma_params + ); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +} diff --git a/csrc/sm90/prefill/sparse/fwd.h b/csrc/sm90/prefill/sparse/fwd.h new file mode 100644 index 0000000..60cb624 --- /dev/null +++ b/csrc/sm90/prefill/sparse/fwd.h @@ -0,0 +1,9 @@ +#pragma once + +#include "params.h" + +namespace sm90 { + +void run_fwd_kernel(const SparsePrefillParams& params); + +} diff --git a/csrc/sm90/prefill/sparse/helpers.h b/csrc/sm90/prefill/sparse/helpers.h new file mode 100644 index 0000000..fd68c36 --- /dev/null +++ b/csrc/sm90/prefill/sparse/helpers.h @@ -0,0 +1,177 @@ +#pragma once + +#include +#include +#include + +namespace sm90 { + +using bf16 = cutlass::bfloat16_t; +using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::fence_barrier_init; +using cutlass::arch::NamedBarrier; + +__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" + :: "r"(dst_addr), + "l"(src), + "n"(16)); +} + +__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst, bool pred, int64_t cache_policy) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); + asm volatile("cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2, %3;\n" + :: "r"(dst_addr), + "l"(src), + "r"(pred?16:0), + "l"(cache_policy)); +} + +__forceinline__ __device__ int64_t createpolicy_evict_last() { + int64_t res; + asm volatile( + "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" + : "=l"(res) + : + ); + return res; +} + +__forceinline__ __device__ int64_t createpolicy_evict_first() { + int64_t res; + asm volatile( + "createpolicy.fractional.L2::evict_first.b64 %0, 1.0; \n\t" + : "=l"(res) + : + ); + return res; +} + + +__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { + // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx + // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a + int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); + return row_idx; +} + +__forceinline__ __device__ int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) { + int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1); + return col_idx; +} + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h +// * Copyright (c) 2024, Tri Dao. +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + using namespace cute; + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +// A simpiler version of gemm +template +__forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { + using namespace cute; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor sA_frag = thr_mma.partition_fragment_A(sA); + Tensor sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(sA_frag) == size<2>(sB_frag)); + + warpgroup_fence_operand(rC_frag); + warpgroup_arrive(); + tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(sA_frag); ++k) { + cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_fence_operand(rC_frag); +} + +template +__forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { + using namespace cute; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(rA_frag) == size<2>(sB_frag)); + + warpgroup_fence_operand(const_cast(rA_frag)); + warpgroup_fence_operand(rC_frag); + warpgroup_arrive(); + tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(rA_frag); ++k) { + cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_fence_operand(rC_frag); + warpgroup_fence_operand(const_cast(rA_frag)); +} + + +__forceinline__ __device__ uint32_t get_sm_id() { + uint32_t ret; + asm("mov.u32 %0, %smid;" : "=r"(ret)); + return ret; +} + +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +template +CUTE_DEVICE +T* get_peer_addr(const T* p) { + return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); +} + +template< + typename TMA, + typename Tensor0, + typename Tensor1 +> +CUTE_DEVICE +void launch_tma_copy( + const TMA &tma_copy, + Tensor0 src, + Tensor1 dst, + transac_bar_t &bar, + const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL +) { + auto thr_tma = tma_copy.get_slice(cute::_0{}); + cute::copy( + tma_copy.with(reinterpret_cast(bar), 0, cache_hint), + thr_tma.partition_S(src), + thr_tma.partition_D(dst) + ); +} + +} diff --git a/csrc/sm90/kernels/get_mla_metadata.cu b/csrc/smxx/get_mla_metadata.cu similarity index 64% rename from csrc/sm90/kernels/get_mla_metadata.cu rename to csrc/smxx/get_mla_metadata.cu index 6b78f9b..9b5be62 100644 --- a/csrc/sm90/kernels/get_mla_metadata.cu +++ b/csrc/smxx/get_mla_metadata.cu @@ -6,7 +6,7 @@ #include "utils.h" __global__ void __launch_bounds__(32, 1, 1) -get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { +get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params) { int *seqlens_k_ptr = params.seqlens_k_ptr; int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; int *num_splits_ptr = params.num_splits_ptr; @@ -18,12 +18,26 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { extern __shared__ int shared_mem[]; int* num_blocks_shared = shared_mem; // [batch_size] int* num_splits_shared = shared_mem + batch_size; // [batch_size+1] + int* seqlens_k_shared = shared_mem + batch_size*2+1; // [batch_size] + int* first_block_idx_shared = shared_mem + batch_size*3+1; // [batch_size] + int* last_block_idx_shared = shared_mem + batch_size*4+1; // [batch_size] int total_num_blocks = 0; for (int i = threadIdx.x; i < batch_size; i += 32) { - int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + int cur_s_k = params.topk == -1 ? __ldg(seqlens_k_ptr + i) : params.topk; + seqlens_k_shared[i] = cur_s_k; + int first_token_idx = 0; + int last_token_idx = max(cur_s_k-1, 0); + int cur_first_block_idx = first_token_idx / block_size_n; + int cur_last_block_idx = last_token_idx / block_size_n; + // NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx] + // NOTE Before clamping, first_token_idx <= last_token_idx always holds, so after clamping, first_token_idx <= last_token_idx still holds. + // NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel. + int num_blocks = cur_last_block_idx - cur_first_block_idx + 1; total_num_blocks += num_blocks + fixed_overhead_num_blocks; num_blocks_shared[i] = num_blocks; + first_block_idx_shared[i] = cur_first_block_idx; + last_block_idx_shared[i] = cur_last_block_idx; } for (int offset = 16; offset >= 1; offset /= 2) { total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); @@ -31,14 +45,14 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { __syncwarp(); if (threadIdx.x == 0) { - int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks); + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; num_splits_shared[0] = 0; for (int i = 0; i < num_sm_parts; ++i) { int tile_scheduler_metadata0[4], tile_scheduler_metadata1; tile_scheduler_metadata0[0] = now_idx; - tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata0[1] = now_block + first_block_idx_shared[now_idx]; tile_scheduler_metadata1 = now_n_split_idx; int remain_payload = payload; while (now_idx < batch_size) { @@ -61,7 +75,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { } } tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; - tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block + first_block_idx_shared[now_idx] : (seqlens_k_shared[now_idx-1] == 0 ? 0 : last_block_idx_shared[now_idx-1] + 1); *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; } @@ -74,8 +88,8 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { } } -void run_get_mla_metadata_kernel(Mla_metadata_params ¶ms, cudaStream_t stream) { - int smem_size = sizeof(int) * (params.batch_size*2+1); +void run_get_mla_metadata_kernel(GetDecodingMetadataParams ¶ms, cudaStream_t stream) { + int smem_size = sizeof(int) * (params.batch_size*5+1); CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params); CHECK_CUDA_KERNEL_LAUNCH(); diff --git a/csrc/smxx/get_mla_metadata.h b/csrc/smxx/get_mla_metadata.h new file mode 100644 index 0000000..7a1d1c4 --- /dev/null +++ b/csrc/smxx/get_mla_metadata.h @@ -0,0 +1,5 @@ +#pragma once + +#include "params.h" + +void run_get_mla_metadata_kernel(GetDecodingMetadataParams ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/kernels/mla_combine.cu b/csrc/smxx/mla_combine.cu similarity index 94% rename from csrc/sm90/kernels/mla_combine.cu rename to csrc/smxx/mla_combine.cu index b6ba8f8..ff609bf 100644 --- a/csrc/sm90/kernels/mla_combine.cu +++ b/csrc/smxx/mla_combine.cu @@ -7,13 +7,12 @@ #include "params.h" #include "utils.h" -#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V using namespace cute; template __global__ void __launch_bounds__(NUM_THREADS) -flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { +flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { // grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M] // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m @@ -176,12 +175,14 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params template -void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { +void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream) { + static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA + FLASH_ASSERT(params.d_v == HEAD_DIM_V); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { constexpr int BLOCK_SIZE_M = 8; constexpr int NUM_THREADS = BLOCK_SIZE_M*32; constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float); - auto combine_kernel = &flash_fwd_mla_combine_kernel; + auto combine_kernel = &flash_fwd_mla_combine_kernel; CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) cudaLaunchAttribute attribute[1]; @@ -200,8 +201,8 @@ void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t str CHECK_CUDA_KERNEL_LAUNCH(); } -template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream); #ifndef FLASH_MLA_DISABLE_FP16 -template void run_flash_mla_combine_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream); #endif \ No newline at end of file diff --git a/csrc/smxx/mla_combine.h b/csrc/smxx/mla_combine.h new file mode 100644 index 0000000..eca7501 --- /dev/null +++ b/csrc/smxx/mla_combine.h @@ -0,0 +1,6 @@ +#pragma once + +#include "params.h" + +template +void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/kernels/utils.h b/csrc/utils.h similarity index 71% rename from csrc/sm90/kernels/utils.h rename to csrc/utils.h index ae9d0fc..571412f 100644 --- a/csrc/sm90/kernels/utils.h +++ b/csrc/utils.h @@ -30,3 +30,37 @@ } while(0) #define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); } + +template +__inline__ __host__ __device__ T ceil_div(const T &a, const T &b) { + return (a + b - 1) / b; +} + +#ifndef TRAP_ONLY_DEVICE_ASSERT +#define TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +// For development, we define both IS_SM100 and IS_SM90 when using CLion or VSCode IDEs so code highlighting will be correct. +#if defined(__CLION_IDE__) || defined(__VSCODE_IDE__) +#define IS_SM100 1 +#define IS_SM90 1 +#else + +// We define the following macros to detect the CUDA architecture, so that we can enable/disable certains kernels that depends on specific architectures. +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000) +#define IS_SM100 1 +#else +#define IS_SM100 0 +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900) +#define IS_SM90 1 +#else +#define IS_SM90 0 +#endif + +#endif // defined(__CLION_IDE__) || defined(__VSCODE_IDE__) \ No newline at end of file diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index d0e6faf..66f1986 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -6,4 +6,5 @@ flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func, + flash_mla_sparse_fwd ) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 084117e..4d27621 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -2,30 +2,33 @@ import torch -import flash_mla_sm90 -import flash_mla_sm100 - - +import flash_mla.cuda as flash_mla_cuda def get_mla_metadata( cache_seqlens: torch.Tensor, - num_heads_per_head_k: int, + num_q_tokens_per_head_k: int, num_heads_k: int, + num_heads_q: Optional[int] = None, + is_fp8_kvcache: bool = False, + topk: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. - num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. - num_heads_k: num_heads_k. + num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. + num_heads_k: The number of k heads. + num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled + is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. + topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. Returns: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return flash_mla_sm90.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + return flash_mla_cuda.get_mla_decoding_metadata(cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk) -def flash_mla_with_kvcache_sm90( +def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, @@ -35,6 +38,8 @@ def flash_mla_with_kvcache_sm90( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -47,6 +52,8 @@ def flash_mla_with_kvcache_sm90( num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. + is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md + indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). @@ -54,7 +61,9 @@ def flash_mla_with_kvcache_sm90( """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse = flash_mla_sm90.fwd_kvcache_mla( + if indices is not None: + assert causal == False, "causal must be `false` if sparse attention is enabled." + out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( q, k_cache, head_dim_v, @@ -64,10 +73,42 @@ def flash_mla_with_kvcache_sm90( causal, tile_scheduler_metadata, num_splits, + is_fp8_kvcache, + indices ) return out, softmax_lse +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + + Returns: + (output, max_logits, lse) + About the definition of output, max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v + ) + return results + + def _flash_attn_varlen_forward( q: torch.Tensor, k: torch.Tensor, @@ -96,7 +137,7 @@ def _flash_attn_varlen_forward( lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) - flash_mla_sm100.fwd( + flash_mla_cuda.dense_prefill_fwd( workspace_buffer, q, k, @@ -159,7 +200,7 @@ def _flash_attn_varlen_backward( if num_qo_heads != num_kv_heads: workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) - flash_mla_sm100.bwd( + flash_mla_cuda.dense_prefill_bwd( workspace_buffer, do, q, @@ -195,7 +236,7 @@ def forward( causal: bool = False, softmax_scale: Optional[float] = None, is_varlen: bool = True, - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: out, lse = _flash_attn_varlen_forward( q, k, v, cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, @@ -290,40 +331,3 @@ def flash_attn_varlen_kvpacked_func( cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, causal, softmax_scale, is_varlen, ) - - -def flash_mla_with_kvcache_sm100( - q: torch.Tensor, - k_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - head_dim_v: int, - softmax_scale: Optional[float] = None, - causal: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - # TODO - pass - - -def flash_mla_with_kvcache( - q: torch.Tensor, - k_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - head_dim_v: int, - tile_scheduler_metadata: Optional[torch.Tensor] = None, - num_splits: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - capability = torch.cuda.get_device_capability(q.device.index) - if capability == (9, 0): - return flash_mla_with_kvcache_sm90( - q, k_cache, block_table, cache_seqlens, head_dim_v, - tile_scheduler_metadata, num_splits, - softmax_scale, causal, - ) - elif capability == (10, 0): - raise ValueError(f"Unsupported device capability: {capability}") - else: - raise ValueError(f"Unsupported device capability: {capability}") diff --git a/setup.py b/setup.py index 58cf7b2..338117f 100644 --- a/setup.py +++ b/setup.py @@ -12,29 +12,31 @@ ) -def append_nvcc_threads(nvcc_extra_args): - nvcc_threads = os.getenv("NVCC_THREADS") or "32" - return nvcc_extra_args + ["--threads", nvcc_threads] - +def is_flag_set(flag: str) -> bool: + return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"] def get_features_args(): features_args = [] - DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"] - if DISABLE_FP16: + if is_flag_set("FLASH_MLA_DISABLE_FP16"): features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args +def get_arch_flags(): + DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") + DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") + arch_flags = [] + if not DISABLE_SM100: + arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) + if not DISABLE_SM90: + arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) + return arch_flags + +def get_nvcc_thread_args(): + nvcc_threads = os.getenv("NVCC_THREADS") or "32" + return ["--threads", nvcc_threads] subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) -cc_flag_sm90 = [] -cc_flag_sm90.append("-gencode") -cc_flag_sm90.append("arch=compute_90a,code=sm_90a") - -cc_flag_sm100 = [] -cc_flag_sm100.append("-gencode") -cc_flag_sm100.append("arch=compute_100a,code=sm_100a") - this_dir = os.path.dirname(os.path.abspath(__file__)) if IS_WINDOWS: @@ -45,79 +47,44 @@ def get_features_args(): ext_modules = [] ext_modules.append( CUDAExtension( - name="flash_mla_sm90", + name="flash_mla.cuda", sources=[ - "csrc/sm90/flash_api.cpp", - "csrc/sm90/kernels/get_mla_metadata.cu", - "csrc/sm90/kernels/mla_combine.cu", - "csrc/sm90/kernels/splitkv_mla.cu", + "csrc/pybind.cpp", + "csrc/smxx/get_mla_metadata.cu", + "csrc/smxx/mla_combine.cu", + "csrc/sm90/decode/dense/splitkv_mla.cu", + "csrc/sm90/decode/sparse_fp8/splitkv_mla.cu", + "csrc/sm90/prefill/sparse/fwd.cu", + "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", + "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-DNDEBUG", - "-D_USE_MATH_DEFINES", - "-Wno-deprecated-declarations", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v,--register-usage-level=10" - ] - + cc_flag_sm90 - ) + get_features_args(), + "nvcc": [ + "-O3", + "-std=c++17", + "-DNDEBUG", + "-D_USE_MATH_DEFINES", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v,--register-usage-level=10" + ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(), }, include_dirs=[ + Path(this_dir) / "csrc", Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "cutlass" / "include", - ], - ) -) - -ext_modules.append( - CUDAExtension( - name="flash_mla_sm100", - sources=[ - "csrc/sm100/pybind.cu", - "csrc/sm100/fmha_cutlass_fwd_sm100.cu", - "csrc/sm100/fmha_cutlass_bwd_sm100.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"], - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-DNDEBUG", - "-Wno-deprecated-declarations", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "-lineinfo", - "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", - ] - + cc_flag_sm100 - ), - }, - include_dirs=[ - Path(this_dir) / "csrc" / "sm100", - Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", ], ) ) - try: cmd = ['git', 'rev-parse', '--short', 'HEAD'] rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() diff --git a/tests/lib.py b/tests/lib.py new file mode 100644 index 0000000..f884721 --- /dev/null +++ b/tests/lib.py @@ -0,0 +1,73 @@ +from typing import List + +import torch + +def cdiv(x: int, y: int): + return (x+y-1) // y + +def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7): + """ + Check if two tensors are close enough + """ + def get_cos_diff(x: torch.Tensor, y: torch.Tensor) -> float: + """ + Calculate the cosine diff between two tensors + """ + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum().item() + if denominator == 0: + return 0 + sim = 2 * (x * y).sum().item() / denominator + return 1 - sim + assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}" + + ans = ans.clone().to(torch.float) + ref = ref.clone().to(torch.float) + + # Deal with anomalies + def deal_with_anomalies(val: float): + ref_mask = (ref == val) if (val == val) else (ref != ref) + ans_mask = (ans == val) if (val == val) else (ans != ans) + ref[ref_mask] = 0.0 + ans[ans_mask] = 0.0 + if not torch.equal(ref_mask, ans_mask): + print(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref") + return False + return True + + anomalies_check_passed = True + anomalies_check_passed &= deal_with_anomalies(float("inf")) + anomalies_check_passed &= deal_with_anomalies(float("-inf")) + anomalies_check_passed &= deal_with_anomalies(float("nan")) + + if not anomalies_check_passed: + return False + + cos_diff = get_cos_diff(ans, ref) + raw_abs_err = torch.abs(ans-ref) + raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6)) + rel_err = raw_rel_err.masked_fill(raw_abs_err List[int]: + result = [] + for size in t.shape[::-1]: + result.append(pos % size) + pos = pos // size + assert pos == 0 + return result[::-1] + print(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}") + print(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}") + print(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)") + print(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})") + return False + else: + if abs(cos_diff) > cos_diff_tol: + print(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})") + return False + return True \ No newline at end of file diff --git a/tests/quant.py b/tests/quant.py new file mode 100644 index 0000000..afee4b2 --- /dev/null +++ b/tests/quant.py @@ -0,0 +1,68 @@ +import enum + +import torch + +def quantize_k_cache( + input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d) + dv: int, + tile_size: int = 128, +) -> torch.Tensor: + """ + Quantize the k-cache + Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size() + For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md + """ + assert dv % tile_size == 0 + num_tiles = dv // tile_size + num_blocks, block_size, h_k, d = input_k_cache.shape + assert h_k == 1 + input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] + input_elem_size = input_k_cache.element_size() + + result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) + result_k_nope_part = result[..., :dv] + result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32) + result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype) + result_k_rope_part[:] = input_k_cache[..., dv:] + + for tile_idx in range(0, num_tiles): + cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] + result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv + + cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] + cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) + result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope + + result = result.view(num_blocks, block_size, 1, -1) + return result + + +def dequantize_k_cache( + quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token) + dv: int = 512, + tile_size: int = 128, + d: int = 576 +) -> torch.Tensor: + """ + De-quantize the k-cache + """ + assert dv % tile_size == 0 + num_tiles = dv // tile_size + num_blocks, block_size, h_k, _ = quant_k_cache.shape + assert h_k == 1 + result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device) + + quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) + + input_nope = quant_k_cache[..., :dv] + input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32) + input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16) + result[..., dv:] = input_rope + + for tile_idx in range(0, num_tiles): + cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32) + cur_scales = input_scale[..., tile_idx].unsqueeze(-1) + result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales + + result = result.view(num_blocks, block_size, 1, d) + return result diff --git a/tests/test_flash_mla_decoding.py b/tests/test_flash_mla_decoding.py new file mode 100644 index 0000000..64ddf72 --- /dev/null +++ b/tests/test_flash_mla_decoding.py @@ -0,0 +1,343 @@ +import argparse +import math +import random +import dataclasses +from typing import Optional, Tuple, List + +import torch +import triton + +import quant +import flash_mla +from lib import cdiv, check_is_allclose + +@dataclasses.dataclass +class TestParam: + b: int # Batch size + s_q: int # Number of queries for one request + s_k: int # Seq len, or mean seq len if varlen == True + is_varlen: bool + is_causal: bool + is_fp8: bool + topk: Optional[int] = None + test_performance: bool = True + is_all_indices_invalid: bool = False + have_zero_seqlen_k: bool = False + block_size: int = 64 + h_q: int = 128 # Number of q heads + h_kv: int = 1 # Number of kv heads + d: int = 576 # Q/K head dim (= dv + RoPE dim) + dv: int = 512 # V head dim + seed: int = 0 + + +def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Generate test data from a given configuration + Return: [cache_seqlens, q, block_table, blocked_k] + Pay attention: This function changes the random seed + """ + random.seed(t.seed) + torch.manual_seed(t.seed) + torch.cuda.manual_seed(t.seed) + torch.backends.cudnn.deterministic = True + + assert t.h_q % t.h_kv == 0 + + cache_seqlens_cpu = torch.full((t.b,), t.s_k, dtype=torch.int32, device='cpu') + if t.is_varlen: + for i in range(t.b): + cache_seqlens_cpu[i] = max(random.normalvariate(t.s_k, t.s_k / 2), t.s_q) + + if t.have_zero_seqlen_k: + zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0 + cache_seqlens_cpu[zeros_mask] = 0 + + max_seqlen = cache_seqlens_cpu.max().item() + max_seqlen_pad = cdiv(max_seqlen, 256) * 256 + cache_seqlens = cache_seqlens_cpu.cuda() + + q = torch.randn(t.b, t.s_q, t.h_q, t.d) + q.clamp_(min=-1.0, max=1.0) + + block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size) + block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(t.b, -1) + blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10 + blocked_k.clamp_(min=-1.0, max=1.0) + + if t.topk is None: + for i in range(t.b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, t.block_size) + blocked_k[block_table[i][cur_num_blocks:]] = float("nan") + if cur_len % t.block_size != 0: + blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan") + block_table[i][cur_num_blocks:] = 2147480000 + return cache_seqlens, q, block_table, blocked_k, None, None + else: + block_table_cpu = block_table.cpu() + abs_indices = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu") + indices_in_kvcache = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu") + for i in range(t.b): + # Generate indices + for j in range(t.s_q): + cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk] + cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size) + if len(cur_abs_indices) < t.topk: + pad_len = t.topk - len(cur_abs_indices) + cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')]) + cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')]) + + # Mask KV + perm = torch.randperm(t.topk, device='cpu') + cur_abs_indices = cur_abs_indices[perm] + cur_blocked_indices = cur_blocked_indices[perm] + + # Fill it with invalid indices if needed + if t.is_all_indices_invalid: + cur_abs_indices.fill_(-1) + cur_blocked_indices.fill_(-1) + + abs_indices[i, j, :] = cur_abs_indices + indices_in_kvcache[i, j, :] = cur_blocked_indices + + # Mask nonused KV as NaN + all_indices = indices_in_kvcache.flatten().tolist() + all_indices = list(set(all_indices)) + if -1 in all_indices: + all_indices.remove(-1) + all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu') + + blocked_k = blocked_k.view(-1, t.h_kv, t.d) + nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu') + nonused_indices_mask[all_indices] = False + blocked_k[nonused_indices_mask, :, :] = float("nan") + blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d) + + abs_indices = abs_indices.to(q.device) + indices_in_kvcache = indices_in_kvcache.to(q.device) + + return cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache + + +def reference_torch( + cache_seqlens: torch.Tensor, # [batch_size] + block_table: torch.Tensor, # [batch_size, ?] + q: torch.Tensor, # [batch_size, s_q, h_q, d] + blocked_k: torch.Tensor, # [?, block_size, h_kv, d] + dv: int, + is_causal: bool, + indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation in PyTorch + """ + def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): + mask = torch.zeros(s_q, s_k, dtype=torch.bool) + for i in range(s_q): + cur_indices = indices[i] + valid_indices = cur_indices[cur_indices != -1] + mask[i, valid_indices] = True + return mask + + def scaled_dot_product_attention( + batch_idx: int, + query: torch.Tensor, # [h_q, s_q, d] + kv: torch.Tensor, # [h_kv, s_k, d] + dv: int, + is_causal, + indices: Optional[torch.Tensor], # [s_q, topk] + ) -> Tuple[torch.Tensor, torch.Tensor]: + h_q = query.size(0) + h_kv = kv.size(0) + s_q = query.shape[-2] + s_k = kv.shape[-2] + query = query.float() + kv = kv.float() + if h_kv != 1: + kv = kv.repeat_interleave(h_q // h_kv, dim=0) + kv[kv != kv] = 0.0 + attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] + if (is_causal and query.size(1) > 1) or indices is not None: + mask = torch.ones(s_q, s_k, dtype=torch.bool) + if is_causal: + assert indices is None + mask = mask.tril(diagonal=s_k - s_q) + if indices is not None: + mask &= get_topk_attn_mask(s_q, s_k, indices) + attn_bias = torch.zeros(s_q, s_k, dtype=torch.float) + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + attn_weight += attn_bias.to(q.dtype) + attn_weight /= math.sqrt(query.size(-1)) + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] + # Correct for q tokens which has no attendable k + lonely_q_mask = (lse == float("-inf")) + output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output, lse + + b, s_q, h_q, d = q.size() + block_size = blocked_k.size(1) + h_kv = blocked_k.size(2) + cache_seqlens_cpu = cache_seqlens.cpu() + out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + cur_len = cache_seqlens_cpu[i].item() + cur_num_blocks = cdiv(cur_len, block_size) + cur_block_indices = block_table[i][0: cur_num_blocks] + cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] + cur_out, cur_lse = scaled_dot_product_attention( + i, + q[i].transpose(0, 1), + cur_kv.transpose(0, 1), + dv, + is_causal, + indices[i] if indices is not None else None + ) + out_ref[i] = cur_out.transpose(0, 1) + lse_ref[i] = cur_lse + out_ref = out_ref.to(torch.bfloat16) + return out_ref, lse_ref + + +@torch.inference_mode() +def test_flash_mla(t: TestParam): + print('-------------------------------') + print(f"Running on {t}...") + + # Generating test data + torch.cuda.synchronize() + cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache = generate_test_data(t) + + if t.is_fp8: + # The quantization error may be too large to be distinguished from wrong kernels + # So we quantize and de-quantize kv-cache here to mitigate quantization error + blocked_k_quantized = quant.quantize_k_cache(blocked_k, t.dv, 128) + blocked_k_dequantized = quant.dequantize_k_cache(blocked_k_quantized) + blocked_k = blocked_k_dequantized + + # Get schedule metadata + torch.cuda.synchronize() + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( + cache_seqlens, + t.s_q * t.h_q // t.h_kv, + t.h_kv, + t.h_q, + t.is_fp8, + t.topk + ) + torch.cuda.synchronize() + + def run_flash_mla(): + return flash_mla.flash_mla_with_kvcache( + q, + blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore + block_table, + cache_seqlens, + t.dv, + tile_scheduler_metadata, + num_splits, + causal=t.is_causal, + is_fp8_kvcache=t.is_fp8, + indices=indices_in_kvcache + ) + + out_ans, lse_ans = run_flash_mla() + out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) + assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6) + assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) + + if t.test_performance: + time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore + mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk + compute_volume_flop = t.b*t.h_q*t.s_q*sum([ + 2*t.d*mean_attended_seqlens, # Q * K^T + 2*mean_attended_seqlens*t.dv, # attention * V + ]) + q_elem_size = torch.bfloat16.itemsize + kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize + memory_volume_B = t.b*sum([ + t.s_q*t.h_q*(t.d*q_elem_size), # Q + (t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V + t.s_q*t.h_q*(t.dv*q_elem_size), # Output + ]) + achieved_tflops = compute_volume_flop / time_usage / 1e12 + achieved_gBps = memory_volume_B / time_usage / 1e9 + + print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") + + +def main(torch_dtype): + device = torch.device("cuda:0") + torch.set_default_dtype(torch_dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + + correctness_cases = [ + TestParam(b, s_q, s_k, is_varlen, is_causal, is_fp8, topk, test_performance=False) + for b in [1, 2, 6, 64] + for s_q in [1, 2, 4] + for s_k in [20, 140, 4096] + for is_varlen in [False, True] + for is_causal in [False, True] + for (is_fp8, topk) in [ + (False, None), + (True, 128), + (True, 2048) + ] + if not (is_causal and topk is not None) + ] + + corner_cases = [ + # Cases where all topk indices are invalid + TestParam(128, 2, 4096, is_varlen=True, is_causal=False, is_fp8=True, topk=topk, test_performance=False, is_all_indices_invalid=True) + for topk in [128, 2048, 4096] + ] + [ + # Cases where some kv cache have zero length + TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=False, have_zero_seqlen_k=True) + for (is_causal, is_fp8, topk) in [ + (False, False, None), + (True, False, None), + (False, True, 128), + (False, True, 2048), + ] + ] + + performance_cases = [ + TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=True) + for (is_causal, is_fp8, topk) in [ + (False, False, None), + (True, False, None), + (False, True, 2048), + ] + for s_q in [1, 2] + for s_k in [4096, 8192, 16384, 32768] + ] + + testcases = correctness_cases + corner_cases + performance_cases + + for testcase in testcases: + test_flash_mla(testcase) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype", + type=str, + choices=["bf16", "fp16"], + default="bf16", + help="Data type to use for testing (bf16 or fp16)", + ) + + args = parser.parse_args() + + torch_dtype = torch.bfloat16 + if args.dtype == "fp16": + torch_dtype = torch.float16 + + main(torch_dtype) diff --git a/tests/test_flash_mla_prefill.py b/tests/test_flash_mla_prefill.py new file mode 100644 index 0000000..f85f2d6 --- /dev/null +++ b/tests/test_flash_mla_prefill.py @@ -0,0 +1,197 @@ +import math +import time +from typing import Tuple +import random +import dataclasses + +import torch +import triton + +from flash_mla import flash_mla_sparse_fwd +from lib import check_is_allclose + +@dataclasses.dataclass +class TestParam: + b: int + s_q: int + s_kv: int + topk: int + h_q: int = 128 + h_kv: int = 1 + d_qk: int = 576 + d_v: int = 512 + seed: int = 0 + check_correctness: bool = True + benchmark: bool = True + +@dataclasses.dataclass +class Testcase: + t: TestParam + q: torch.Tensor + kv: torch.Tensor + indices: torch.Tensor + +def generate_testcase(t: TestParam) -> Testcase: + torch.manual_seed(t.seed) + torch.cuda.manual_seed(t.seed) + random.seed(t.seed) + q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 + kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((t.b, t.s_q, t.h_kv, t.topk), t.s_kv, dtype=torch.int32) + for b in range(t.b): + for s in range(t.s_q): + for h in range(t.h_kv): + # TODO Comment + near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 + cur_indices = torch.randperm(t.s_kv)[:t.topk] + cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),)) + if len(cur_indices) < t.topk: + cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)]) + cur_indices = cur_indices[torch.randperm(t.topk)] + indices[b, s, h] = cur_indices + indices = indices.to(q.device) + + return Testcase( + t=t, + q=q, + kv=kv, + indices=indices + ) + +def get_flop(p: TestParam) -> float: + flop = 2 * sum([ + p.h_q * p.d_qk * p.topk, + p.h_q * p.d_v * p.topk + ]) * p.b * p.s_q + return flop + +def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: + return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) + + assert p.b == 1 + indices = t.indices[0, :, 0, :] # [s_q, topk] + invalid_indices_mask = (indices < 0) | (indices >= p.s_kv) + qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk] + kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk] + + kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(p.s_q, p.topk, p.d_qk) # [s_q, topk, d_qk] + attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk] + attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf')) + attn_score *= sm_scale * math.log2(math.e) + max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q] + lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] + attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk] + result = attn_score @ kvs[:, :, :p.d_v] + return (max_logits, lse, result) + +@torch.inference_mode() +def run_test(p: TestParam) -> bool: + print("================") + print(f"Running on {p}") + torch.cuda.empty_cache() + assert p.b == 1 + + t = generate_testcase(p) + sm_scale = 1 / math.sqrt(p.d_qk) + torch.cuda.synchronize() + + def run_ans(): + return flash_mla_sparse_fwd( + t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale + ) + + ans_out, ans_max_logits, ans_lse = run_ans() + torch.cuda.synchronize() + + if p.benchmark: + flop = get_flop(p) + prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore + prefill_flops = flop/prefill_ans_time/1e12 + print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops") + + if p.check_correctness: + torch.cuda.synchronize() + ref_max_logits, ref_lse, ref_out = reference_torch(p, t, sm_scale) + torch.cuda.synchronize() + + is_correct = True + is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6) + is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) + is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) + + return is_correct + else: + return True + + +if __name__ == '__main__': + device = torch.device("cuda:0") + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.set_float32_matmul_precision('high') + + correctness_cases = [ + # Regular shapes + TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) + for s_kv, topk in [ + # Regular shapes + (128, 128), + (256, 256), + (512, 512), + + # Irregular shapes + (592, 128), + (1840, 256), + (1592, 384), + (1521, 512), + + # Irregular shapes with OOB TopK + (95, 128), + (153, 256), + (114, 384), + ] + for s_q in [ + 1, 62 + ] + ] + + corner_cases = [ + # In these cases, some blocks may not have any valid topk indices + TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) + for s_kv, topk in [ + (32, 2048), + (64, 8192) + ] + for s_q in [1, 1024] + ] + + performance_cases = [ + TestParam(1, s_q, s_kv, topk, h_q=128) + for s_q in [4096] + for s_kv in [4096, 8192, 16384, 32768, 49152, 65536, 81920, 98304, 114688, 131072] + for topk in [2048] + ] + + testcases = correctness_cases + corner_cases + performance_cases + + failed_cases = [] + for test in testcases: + if test.benchmark: + time.sleep(0.2) + is_correct = run_test(test) + if not is_correct: + failed_cases.append(test) + + if len(failed_cases) > 0: + print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") + for case in failed_cases: + print(f" {case}") + else: + print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") + diff --git a/tests/test_flash_mla_sm90.py b/tests/test_flash_mla_sm90.py deleted file mode 100644 index 67c9d93..0000000 --- a/tests/test_flash_mla_sm90.py +++ /dev/null @@ -1,153 +0,0 @@ -import argparse -import math -import random - -import torch -import triton - -from flash_mla import flash_mla_with_kvcache, get_mla_metadata - - -def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): - query = query.float() - key = key.float() - value = value.float() - key = key.repeat_interleave(h_q // h_kv, dim=0) - value = value.repeat_interleave(h_q // h_kv, dim=0) - attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) - if is_causal: - s_q = query.shape[-2] - s_k = key.shape[-2] - attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - attn_weight += attn_bias - lse = attn_weight.logsumexp(dim=-1) - attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) - return attn_weight @ value, lse - - -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: - x, y = x.double(), y.double() - RMSE = ((x - y) * (x - y)).mean().sqrt().item() - cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) - amax_diff = (x - y).abs().max().item() - # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5 - - -@torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): - print( - f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" - ) - - cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) - if varlen: - for i in range(b): - cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) - total_seqlens = cache_seqlens.sum().item() - mean_seqlens = cache_seqlens.float().mean().int().item() - max_seqlen = cache_seqlens.max().item() - max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 - # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") - - q = torch.randn(b, s_q, h_q, d) - block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32 - ).view(b, max_seqlen_pad // block_size) - blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( - float("nan") - ) - blocked_v = blocked_k[..., :dv] - - tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens, s_q * h_q // h_kv, h_kv - ) - - def flash_mla(): - return flash_mla_with_kvcache( - q, - blocked_k, - block_table, - cache_seqlens, - dv, - tile_scheduler_metadata, - num_splits, - causal=causal, - ) - - def ref_mla(): - out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) - lse = torch.empty(b, h_q, s_q, dtype=torch.float32) - for i in range(b): - begin = i * max_seqlen_pad - end = begin + cache_seqlens[i] - O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), - h_q=h_q, - h_kv=h_kv, - is_causal=causal, - ) - out[i] = O.transpose(0, 1) - lse[i] = LSE - return out, lse - - out_flash, lse_flash = flash_mla() - out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") - cal_diff(lse_flash, lse_torch, "lse") - - t = triton.testing.do_bench(flash_mla) - FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(q.dtype).bits // 8 - ) - print( - f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" - ) - - -def main(torch_dtype): - device = torch.device("cuda:0") - torch.set_default_dtype(torch_dtype) - torch.set_default_device(device) - torch.cuda.set_device(device) - torch.manual_seed(0) - random.seed(0) - - h_kv = 1 - d, dv = 576, 512 - causal = True - - for b in [128]: - for s in [4096, 8192, 16384]: - for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 - for s_q in [1, 2]: # MTP = 1, 2 - for varlen in [False, True]: - test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--dtype", - type=str, - choices=["bf16", "fp16"], - default="bf16", - help="Data type to use for testing (bf16 or fp16)", - ) - - args = parser.parse_args() - - torch_dtype = torch.bfloat16 - if args.dtype == "fp16": - torch_dtype = torch.float16 - - main(torch_dtype) diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py index 7cb19a2..6b2ba45 100644 --- a/tests/test_fmha_sm100.py +++ b/tests/test_fmha_sm100.py @@ -6,6 +6,7 @@ from flash_mla import flash_attn_varlen_func +from lib import check_is_allclose def get_window_size(causal, window): if window > 0: @@ -28,21 +29,15 @@ def get_attn_bias(s_q, s_k, causal, window): return attn_bias -def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None: - x, y = x.double(), y.double() - RMSE = ((x - y) * (x - y)).mean().sqrt().item() - cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) - amax_diff = (x - y).abs().max().item() - # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5, f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}" - - def sdpa(query, key, value, attn_bias, softmax_scale=None): + query = query.float().transpose(-3, -2) + key = key.float().transpose(-3, -2) + value = value.float().transpose(-3, -2) key = key.repeat_interleave(h // h_k, dim=-3) value = value.repeat_interleave(h // h_k, dim=-3) if softmax_scale is None: softmax_scale = query.shape[-1] ** (-0.5) - attn_weight = query @ key.transpose(-2, -1) * softmax_scale + attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale attn_weight += attn_bias lse = attn_weight.logsumexp(dim=-1) attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) @@ -53,8 +48,8 @@ def sdpa_checkpoint(*args, **kwargs): return checkpoint(sdpa, *args, use_reentrant=False, **kwargs) -def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd): - print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}") +def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, check_correctness: bool = True): + print(f"{b=}, {mean_sq=}, {mean_sk=}, {varlen=}, {h=}, {h_k=}, {d=}, {dv=}, {causal=}, {has_bwd=}, {check_correctness=}") torch.manual_seed(0) random.seed(0) @@ -76,19 +71,20 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win causal, window) == 0).sum().item() for i in range(b)]) # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") - q = torch.randn(total_q, h, d) - k = torch.randn(total_k, h_k, d) - v = torch.randn(total_k, h_k, dv) - grad_out = torch.randn(total_q, h, dv) + q = torch.randn(total_q, h, d)/10 + k = torch.randn(total_k, h_k, d)/10 + v = torch.randn(total_k, h_k, dv)/10 + grad_out = torch.randn(total_q, h, dv)/10 softmax_scale = (d + 100) ** (-0.5) q1 = q.clone().requires_grad_() k1 = k.clone().requires_grad_() v1 = v.clone().requires_grad_() - q2 = q.clone().requires_grad_() - k2 = k.clone().requires_grad_() - v2 = v.clone().requires_grad_() + if check_correctness: + q2 = q.clone().requires_grad_() + k2 = k.clone().requires_grad_() + v2 = v.clone().requires_grad_() def flash_attn(): q1.grad = k1.grad = v1.grad = None @@ -106,9 +102,9 @@ def torch_attn(): lse = [] for i in range(b): OUT, LSE = sdpa_checkpoint( - q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()].float().transpose(-3, -2), - k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), - v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()].float().transpose(-3, -2), + q2[cu_seqlens_q[i].item(): cu_seqlens_q[i + 1].item()], + k2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()], + v2[cu_seqlens_k[i].item(): cu_seqlens_k[i + 1].item()], attn_bias=get_attn_bias(seqlens_q[i].item(), seqlens_k[i].item(), causal, window), softmax_scale=softmax_scale, ) @@ -119,20 +115,23 @@ def torch_attn(): return out, lse out_flash, lse_flash = flash_attn() - out_torch, lse_torch = torch_attn() - assert_close(out_flash, out_torch, "out") - assert_close(lse_flash, lse_torch, "lse") - if has_bwd: out_flash.backward(grad_out, retain_graph=True) - out_torch.backward(grad_out, retain_graph=True) - assert_close(q1.grad, q2.grad, "dq") - assert_close(k1.grad, k2.grad, "dk") - assert_close(v1.grad, v2.grad, "dv") dq1 = q1.grad.clone() dk1 = k1.grad.clone() dv1 = v1.grad.clone() + if check_correctness: + out_torch, lse_torch = torch_attn() + assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01/65536) + + if has_bwd: + out_torch.backward(grad_out, retain_graph=True) + assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + def forward(): return flash_attn() @@ -150,12 +149,6 @@ def backward(): assert torch.equal(k1.grad, dk1), "dk deterministic check failed!" assert torch.equal(v1.grad, dv1), "dv deterministic check failed!" - # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: - # forward() - # if has_bwd: - # backward() - # print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120)) - def timer(func, name): t = triton.testing.do_bench(func, warmup=2, rep=3) FLOPS = total_attn_compute * h * 2 * ((d + dv) if name == "fwd" else ((d * 3 + dv * 2))) @@ -173,18 +166,20 @@ def timer(func, name): device = torch.device("cuda:0") torch.set_default_device(device) torch.cuda.set_device(device) + torch.set_float32_matmul_precision("high") - b = 4 + b = 2 window = 0 has_bwd = False for (mean_sq, mean_sk) in [(4096, 4096), (8192, 8192)]: for varlen in [False, True]: - for (h, h_k) in [(32, 32), (32, 4)]: + for (h, h_k) in [(128, 128), (32, 4)]: if h != h_k: has_bwd = False else: has_bwd = True for (d, dv) in [(128, 128), (192, 128)]: for causal in [False, True]: - test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd) + skip_correctness_check = mean_sq == 8192 and mean_sk == 8192 and h == 128 + test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, window, has_bwd, not skip_correctness_check) From 87709cf4cce80392c67befe132fd338dd3049bc2 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Wed, 24 Sep 2025 14:13:06 +0800 Subject: [PATCH 12/24] Add a comment --- tests/test_flash_mla_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flash_mla_prefill.py b/tests/test_flash_mla_prefill.py index f85f2d6..19a6dbe 100644 --- a/tests/test_flash_mla_prefill.py +++ b/tests/test_flash_mla_prefill.py @@ -45,7 +45,7 @@ def generate_testcase(t: TestParam) -> Testcase: for b in range(t.b): for s in range(t.s_q): for h in range(t.h_kv): - # TODO Comment + # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 cur_indices = torch.randperm(t.s_kv)[:t.topk] cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),)) From 7232d69d5e269db902d69d7718a5a55efaab4be8 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Mon, 29 Sep 2025 15:11:24 +0800 Subject: [PATCH 13/24] Fill in link to DSv3.2 paper --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8cf01a3..2f6b8db 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ ## Introduction -FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](TODO) models. This repository contains the following implementations: +FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations: **Sparse Attention Kernels** -*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](TODO).* +*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp).* - Token-level sparse attention for the prefill stage - Token-level sparse attention for the decoding stage, with FP8 KV cache @@ -18,7 +18,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee ## News -- **2025.09.26(TODO) Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](TODO), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. +- **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. - **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell! - **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). - **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀 @@ -66,7 +66,7 @@ Support matrix: [1]: For more details on using FP8 KV cache, see documents below. -[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` = 576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](TODO). +[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` = 576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp). ## Installation From fd249aacce56327affecd16f89e035b12691974f Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 29 Sep 2025 02:21:37 -0700 Subject: [PATCH 14/24] Add Sparse Decoding Kernel and Sparse Prefill Kernel for Blackwell Signed-off-by: simon-mo --- README.md | 8 +- csrc/pybind.cpp | 38 +- csrc/sm100/decode/sparse_fp8/dequant.h | 61 ++ csrc/sm100/decode/sparse_fp8/splitkv_mla.cu | 592 +++++++++++++++ csrc/sm100/decode/sparse_fp8/splitkv_mla.h | 10 + csrc/sm100/defines.h | 30 + csrc/sm100/helpers.h | 97 +++ csrc/sm100/intrinsics.h | 461 ++++++++++++ csrc/sm100/prefill/sparse/fwd.cu | 785 ++++++++++++++++++++ csrc/sm100/prefill/sparse/fwd.h | 9 + csrc/sm100/prefill/sparse/helpers.h | 104 +++ csrc/sm100/prefill/sparse/intrinsics.h | 638 ++++++++++++++++ csrc/sm100/prefill/sparse/ws_gemm.h | 328 ++++++++ csrc/sm100/tma_cta_group2_nosplit.h | 281 +++++++ csrc/sm100/ws_gemm.h | 426 +++++++++++ setup.py | 16 + tests/test_flash_mla_decoding.py | 5 + 17 files changed, 3882 insertions(+), 7 deletions(-) create mode 100644 csrc/sm100/decode/sparse_fp8/dequant.h create mode 100644 csrc/sm100/decode/sparse_fp8/splitkv_mla.cu create mode 100644 csrc/sm100/decode/sparse_fp8/splitkv_mla.h create mode 100644 csrc/sm100/defines.h create mode 100644 csrc/sm100/helpers.h create mode 100644 csrc/sm100/intrinsics.h create mode 100644 csrc/sm100/prefill/sparse/fwd.cu create mode 100644 csrc/sm100/prefill/sparse/fwd.h create mode 100644 csrc/sm100/prefill/sparse/helpers.h create mode 100644 csrc/sm100/prefill/sparse/intrinsics.h create mode 100644 csrc/sm100/prefill/sparse/ws_gemm.h create mode 100644 csrc/sm100/tma_cta_group2_nosplit.h create mode 100644 csrc/sm100/ws_gemm.h diff --git a/README.md b/README.md index 2f6b8db..354cdde 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ python tests/test_flash_mla_decoding.py The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8. +For Blackwell GPUs, the token-level sparse MLA decoding kernel can achieve up to 350 TFlops (on B200) which is not really optimized yet. + #### Test & benchmark MHA prefill (Dense): ```bash @@ -47,7 +49,7 @@ It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation python tests/test_flash_mla_prefill.py ``` -It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8. +It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9. ## Requirements @@ -60,9 +62,9 @@ Support matrix: | Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | | :---: | :---: | :---: | :---: | | Dense Decoding | Hopper | MQA | BF16 | -| Sparse Decoding | Hopper | MQA | FP8 [1] | +| Sparse Decoding | Hopper & Blackwell | MQA | FP8 [1] | | Dense Prefill | Blackwell | MHA | | -| Sparse Prefill | Hopper | MQA | | +| Sparse Prefill | Hopper & Blackwell | MQA | | [1]: For more details on using FP8 KV cache, see documents below. diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index b360c24..6ec3f21 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -16,7 +16,9 @@ #include "sm90/decode/dense/splitkv_mla.h" #include "sm90/decode/sparse_fp8/splitkv_mla.h" #include "sm90/prefill/sparse/fwd.h" +#include "sm100/decode/sparse_fp8/splitkv_mla.h" #include "sm100/prefill/dense/interface.h" +#include "sm100/prefill/sparse/fwd.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -31,7 +33,7 @@ struct Arch { } bool is_sm100() const { - return major == 10 && minor == 0; + return major == 10; } void assert_is_supported() const { @@ -86,7 +88,31 @@ DecodingAttnImplMeta get_attn_impl_meta( } } } else if (arch.is_sm100()) { - TORCH_CHECK(false, "Unsupported GPU architecture"); + if (is_sparse_attn) { + if (is_fp8_kvcache) { + TORCH_CHECK(h_q_.has_value()); + int h_q = h_q_.value(); + TORCH_CHECK(h_q % h_k == 0); + int s_q = num_q_tokens_per_head_k * h_k / h_q; + // FP8 + Sparse MLA + return { + std::max(sm_count / h_k / (cutlass::ceil_div(h_q/h_k, 64) * s_q), 1), + 5, + 64 + }; + } else { + // Sparse BF16 MLA + TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM100"); + } + } else { + if (is_fp8_kvcache) { + // FP8 MLA + TORCH_CHECK(false, "FP8 Dence MLA is not supported on SM100"); + } else { + // Normal BF16 MLA + TORCH_CHECK(false, "BF16 Dence MLA is not supported on SM100"); + } + } } else { TORCH_CHECK(false, "Unsupported GPU architecture"); } @@ -326,7 +352,8 @@ fwd_kvcache_mla( } } } else if (arch.is_sm100()) { - TORCH_CHECK(false, "Unsupported GPU architecture"); + TORCH_CHECK(is_fp8 && is_sparse_attn, "Only FP8 + Sparse attention is supported on SM100"); + sm100::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream); } else { TORCH_CHECK(false, "Unsupported GPU architecture"); } @@ -366,7 +393,8 @@ std::vector sparse_prefill_fwd( ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9; - TORCH_CHECK(is_sm90, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 architectures"); + bool is_sm100 = dprops->major == 10; + TORCH_CHECK(is_sm90 || is_sm100, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 or SM100 architectures"); CHECK_DEVICE(q); CHECK_DEVICE(kv); @@ -423,6 +451,8 @@ std::vector sparse_prefill_fwd( if (is_sm90) { sm90::run_fwd_kernel(params); + } else if (is_sm100) { + sm100::run_fwd_kernel(params); } else { TORCH_CHECK(false, "Unknown architecture"); } diff --git a/csrc/sm100/decode/sparse_fp8/dequant.h b/csrc/sm100/decode/sparse_fp8/dequant.h new file mode 100644 index 0000000..3ed46e1 --- /dev/null +++ b/csrc/sm100/decode/sparse_fp8/dequant.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +#include "sm100/defines.h" + +namespace sm100 { + +struct fp8x8 { + __nv_fp8x4_e4m3 lo; + __nv_fp8x4_e4m3 hi; +}; + +struct fp8x32 { + fp8x8 a0, a1, a2, a3; +}; + +struct fp8x16 { + fp8x8 a0, a1; +}; + +__device__ __forceinline__ +bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) { + __nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale); + + #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \ + { \ + float4 fp32x4 = (float4)(FP8x4); \ + OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \ + OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \ + } + + bf16x8 result; + DEQUANT_FP8x4(result.a01, result.a23, inputs.lo); + DEQUANT_FP8x4(result.a45, result.a67, inputs.hi); + + return result; +} + +__device__ __forceinline__ +fp8x32 ldg_256_fp8x32(void* src_ptr) { + int32x8_t val; + asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" + : "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3), + "=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7) + : "l"(src_ptr) + ); + return *reinterpret_cast(&val); +} + +__device__ __forceinline__ +fp8x16 ldg_128_fp8x16(void* src_ptr) { + int4 ret; + asm volatile("ld.global.nc.L1::evict_first.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) + : "l"(src_ptr)); + return *reinterpret_cast(&ret); +} + +} diff --git a/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu b/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu new file mode 100644 index 0000000..068e9fd --- /dev/null +++ b/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu @@ -0,0 +1,592 @@ +#include "splitkv_mla.h" + +#include +#include +#include +#include +#include + +#include "utils.h" +#include "dequant.h" +#include "sm100/defines.h" +#include "sm100/helpers.h" +#include "sm100/intrinsics.h" +#include "sm100/ws_gemm.h" + +namespace sm100 { + +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::NamedBarrier; +using namespace cute; + +constexpr int B_H = 64; +constexpr int B_TOPK = 64; +constexpr int D_K = 576; +constexpr int D_V = 512; +constexpr int NUM_BUFS = 2; +constexpr int NUM_THREADS = 128*3; +constexpr int NUM_WORKING_THREADS = 128 + 128 + 32; +constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN + +template< + typename Shape_Q, typename TMA_Q, + typename Shape_O, typename TMA_O +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + Shape_O shape_O; TMA_O tma_O; +}; + +namespace tmem_addr { + constexpr int o = 0; // o: [0, 256] + constexpr int p = 256; // p: [256, 288] +}; + +using SmemLayoutQ = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutOBuf = decltype(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, // TODO This may lead to TMA double traffic + Shape, Int>{} +)); + +using SmemLayoutOAccumBuf = Layout< + Shape, Int>, + Stride, _1> // We use stride = 520 here to avoid bank conflict +>; + +using SmemLayoutS = decltype(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +)); + +template +using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTilesTransposed = decltype(composition( + SmemLayoutKTiles{}, + Layout< + Shape, Int>, + Stride, _1> + >{} +)); + +using SmemLayoutK = SmemLayoutKTiles<9>; +using SmemLayoutV = SmemLayoutKTilesTransposed<8>; + +struct SharedMemoryPlan { + array_aligned> q; + union { + array_aligned> o_buf; + array_aligned> o_accum_buf; + array_aligned> k[NUM_BUFS]; + } u; + array_aligned> s; + transac_bar_t bar_q; + transac_bar_t bar_k_ready[NUM_BUFS], bar_k_free[NUM_BUFS]; + transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS]; + float rowwise_max_buf[128], rowwise_li_buf[128]; + bool is_token_valid[NUM_BUFS][B_TOPK]; + array_aligned tmem_start_addr; +}; + +using TiledMMA_QK = decltype(make_tiled_mma( + SM100_MMA_F16BF16_WS_SS_NOELECT{}, + Layout>{} +)); // TODO Use TS? + +using TiledMMA_SV = decltype(make_tiled_mma( + SM100_MMA_F16BF16_WS_SS_NOELECT{}, + Layout>{}, + Tile, Int>{} +)); + +template +CUTE_DEVICE +void store_128b(void* smem_ptr, const T &data) { + static_assert(sizeof(T) == 16); + *(__int128*)smem_ptr = *(__int128*)&data; +} + +template +__global__ void __launch_bounds__(NUM_THREADS, 1, 1) +flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) { +#if IS_SM100 + const int head_block_idx = blockIdx.x; + const int s_q_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int idx_in_warpgroup = threadIdx.x % 128; + const int warp_idx = cutlass::canonical_warp_idx_sync(); + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{}); + + if (warp_idx == 0 && elect_one_sync()) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + } + + if (warp_idx == 0) { + if (elect_one_sync()) { + plan.bar_q.init(1); + for (int i = 0; i < NUM_BUFS; ++i) { + plan.bar_k_ready[i].init(128); + plan.bar_k_free[i].init(1); + plan.bar_qk_done[i].init(1); + plan.bar_so_ready[i].init(128); + } + cutlass::arch::fence_barrier_init(); + } + cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data()); + TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); + cute::TMEM::Allocator1Sm().release_allocation_lock(); + } + __syncthreads(); + + int bar_phase_k = 0; + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int sched_begin_block_idx = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int sched_end_block_idx = tile_scheduler_metadata.w; + if (begin_idx >= params.b) { + if (warp_idx == 0) { + cute::TMEM::Allocator1Sm().free(0, 512); + } + return; + } + + auto get_cur_req_info = [&](int batch_idx) -> std::tuple { + int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0; + int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : params.topk / B_TOPK; + bool is_no_split = start_block_idx == 0 && end_block_idx == params.topk / B_TOPK; + return {start_block_idx, end_block_idx, is_no_split}; + }; + + if (warpgroup_idx == 0) { + // Producer warpgroup + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1) + + constexpr int GROUP_SIZE = 4, NUM_GROUPS = 128 / GROUP_SIZE; + constexpr int ROWS_PER_GROUP = B_TOPK / NUM_GROUPS; + int group_idx = idx_in_warpgroup / GROUP_SIZE; + int idx_in_group = idx_in_warpgroup % GROUP_SIZE; + + NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1); + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = block_idx % NUM_BUFS; + + // Wait for buffer to be available + plan.bar_k_free[buf_idx].wait(bar_phase_k>>buf_idx&1^1); + + // Load + Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); + + CUTE_UNROLL + for (int local_row = 0; local_row < ROWS_PER_GROUP; ++local_row) { + int smem_row = group_idx + local_row*NUM_GROUPS; + int token_index = __ldg(gIndices + block_idx*B_TOPK + smem_row); + bool is_token_invalid = token_index == -1; + if (idx_in_group == 0) + plan.is_token_valid[buf_idx][smem_row] = !is_token_invalid; + if (is_token_invalid) { + uint128_t zeros = uint128_t{}; + CUTE_UNROLL + for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) { + int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16; + store_128b(&sK(smem_row, col_base ), zeros); + store_128b(&sK(smem_row, col_base+8), zeros); + } + CUTE_UNROLL + for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) { + int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8; + store_128b(&sK(smem_row, D_V+col_base), zeros); + } + } else { + int block_index = token_index/B_TOPK; + int rel_idx_in_block = (token_index+B_TOPK) % B_TOPK; // NOTE When token_index is -1, -1/B_TOPK = 0 and (-1+B_TOPK)%B_TOPK = 63, so there will be no illegal-memory-access error. However, masking is necessary to prevent NaN (TODO Skip some rows instead?) TODO Masking + fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride; + float4 scales = __ldg((float4*)(gK_base + D_V)); + + CUTE_UNROLL + for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) { + int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16; + fp8x16 cur_fp8s = ldg_128_fp8x16(gK_base + col_base); + float cur_scale = local_col < (256/(GROUP_SIZE*16)) ? + (local_col < (128/(GROUP_SIZE*16)) ? scales.x : scales.y) : + (local_col < (384/(GROUP_SIZE*16)) ? scales.z : scales.w); + store_128b(&sK(smem_row, col_base ), cvt_fp8x8_bf16x8(cur_fp8s.a0, cur_scale)); + store_128b(&sK(smem_row, col_base+8), cvt_fp8x8_bf16x8(cur_fp8s.a1, cur_scale)); + } + + CUTE_UNROLL + for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) { + int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8; + fp8x16 cur_k_rope_fp8s = ldg_128_fp8x16(gK_base + D_V + 4*sizeof(float) + col_base*sizeof(bf16)); + bf16x8 cur_k_rope = *reinterpret_cast(&cur_k_rope_fp8s); + store_128b(&sK(smem_row, D_V+col_base), cur_k_rope); + } + } + } + + fence_view_async_shared(); + + // Signal + plan.bar_k_ready[buf_idx].arrive(); + + bar_phase_k ^= 1<(); + + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + + NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1); + + float li = 0.0f; + float mi = MAX_INIT_VAL; + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = block_idx % NUM_BUFS; + + // Wait for P + plan.bar_qk_done[buf_idx].wait(bar_phase_k>>buf_idx&1); + tcgen05_after_thread_sync(); + + // Load P from TMEM + float p[B_TOPK/2]; + float2* p_float2 = reinterpret_cast(p); + tmem_ld_32dp32bNx(tmem_addr::p, p); + cutlass::arch::fence_view_async_tmem_load(); + + // Get rowwise max + float cur_max = -INFINITY; + CUTE_UNROLL + for (int i = 0; i < B_TOPK/2; ++i) { + if (!plan.is_token_valid[buf_idx][(idx_in_warpgroup/64)*(B_TOPK/2)+i]) p[i] = -INFINITY; + cur_max = max(cur_max, p[i]); + } + cur_max *= params.scale_softmax_log2; + + NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers + plan.rowwise_max_buf[idx_in_warpgroup] = cur_max; + NamedBarrier::arrive_and_wait(128, 0); + cur_max = max(cur_max, plan.rowwise_max_buf[idx_in_warpgroup ^ 64]); + + float new_max = max(mi, cur_max); + float scale_for_old = exp2f(mi - new_max); + float2 scale_for_old_float2 = {scale_for_old, scale_for_old}; + + // Get S + float2 scale_softmax_log2_float2 = {params.scale_softmax_log2, params.scale_softmax_log2}; + float2 neg_new_max_float2 = {-new_max, -new_max}; + bf16 s[B_TOPK/2]; + float2 cur_sum = {0.0f, 0.0f}; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/2; ++i) { + float2 t = float2_fma(p_float2[i], scale_softmax_log2_float2, neg_new_max_float2); + t.x = exp2(t.x); + t.y = exp2(t.y); + *(__nv_bfloat162*)&s[i*2] = __float22bfloat162_rn(t); + cur_sum = float2_add(cur_sum, t); + } + + // Save S + // NOTE We don't need a barrier here, since the current QK^T has finished implies that the previous SV has finished + bf16* sS_base = plan.s.data() + (idx_in_warpgroup/64)*(B_H*B_TOPK/2) + (idx_in_warpgroup%64) * 8; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/8; i += 1) { + store_128b(sS_base + i*8*B_H, *((bf16x8*)s + i)); + } + fence_view_async_shared(); + + // Rescale O + if (block_idx != start_block_idx) { + constexpr int B_SCALE_O = 64; + float2 o[B_SCALE_O/2]; + CUTE_UNROLL + for (int b = 0; b < (D_V/2)/B_SCALE_O; ++b) { + tmem_ld_32dp32bNx(tmem_addr::o + b*B_SCALE_O, o); + cutlass::arch::fence_view_async_tmem_load(); + CUTE_UNROLL + for (int i = 0; i < B_SCALE_O/2; ++i) + o[i] = float2_mul(o[i], scale_for_old_float2); + tmem_st_32dp32bNx(tmem_addr::o + b*B_SCALE_O, o); + cutlass::arch::fence_view_async_tmem_store(); + } + } + plan.bar_so_ready[buf_idx].arrive(); + + // Update mi and li + mi = new_max; + li = li * scale_for_old + cur_sum.x + cur_sum.y; + + bar_phase_k ^= 1<>((end_block_idx-1)%NUM_BUFS)&1^1); + tcgen05_after_thread_sync(); + + // Save O + float o_scale = li == 0.0f ? 0.0f : 1.0f / li; + float2 o_scale_float2 = {o_scale, o_scale}; + if (is_no_split) { + constexpr int B_EPI = 32; + float2 o[B_EPI/2]; + __nv_bfloat162 o_bf16[B_EPI/2]; + Tensor sO = make_tensor(make_smem_ptr(plan.u.o_buf.data()), SmemLayoutOBuf{}); + bf16* sO_base = plan.u.o_buf.data() + ((idx_in_warpgroup/64)*128)*B_H + (idx_in_warpgroup%64)*8; + CUTE_UNROLL + for (int i = 0; i < (D_V/2) / B_EPI; ++i) { + // Load + tmem_ld_32dp32bNx(tmem_addr::o + i*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + // Scale & Convert + CUTE_UNROLL + for (int j = 0; j < B_EPI/2; ++j) { + o[j] = float2_mul(o[j], o_scale_float2); + o_bf16[j] = __float22bfloat162_rn(o[j]); + } + // Store + int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); + CUTE_UNROLL + for (int j = 0; j < B_EPI / 8; ++j) + store_128b(sO_base + (col_base+j*8)*B_H, *reinterpret_cast(&o_bf16[j*4])); + } + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, 0); + if (warp_idx == 4 && elect_one_sync()) { + Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx); + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, head_block_idx, _0{}); + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sO), + thr_tma.partition_D(my_tma_gO) + ); + cute::tma_store_arrive(); + } + } else { + constexpr int B_EPI = 64; + float2 o[B_EPI/2]; + Tensor sO = make_tensor(make_smem_ptr(plan.u.o_accum_buf.data()), SmemLayoutOAccumBuf{}); + CUTE_UNROLL + for (int i = 0; i < (D_V/2) / B_EPI; ++i) { + // Load + tmem_ld_32dp32bNx(tmem_addr::o + i*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + // Scale & Convert + CUTE_UNROLL + for (int j = 0; j < B_EPI/2; ++j) + o[j] = float2_mul(o[j], o_scale_float2); + // Store + int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); + CUTE_UNROLL + for (int j = 0; j < B_EPI / 4; ++j) + store_128b(&sO(idx_in_warpgroup%64, col_base + j*4), *reinterpret_cast(&o[j*2])); + } + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, 0); + if (elect_one_sync()) { + CUTE_UNROLL + for (int local_row = 0; local_row < B_H/4; ++local_row) { + int smem_row = local_row*4 + (warp_idx-4); + if (smem_row < num_valid_heads) { + SM90_BULK_COPY_S2G::copy( + &sO(smem_row, _0{}), + (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx + smem_row)*D_V, + D_V*sizeof(float) + ); + } + } + cute::tma_store_arrive(); + } + } + + cute::tma_store_wait<0>(); + } + + if (warp_idx == 4) { + cute::TMEM::Allocator1Sm().free(0, 512); + } + } else { + cutlass::arch::warpgroup_reg_dealloc<96>(); + if (warp_idx == 8) { + // UTCMMA warp + + bool bar_phase_q = 0; + TiledMMA tiled_mma_qk = TiledMMA_QK{}; + TiledMMA tiled_mma_sv = TiledMMA_SV{}; + Tensor tP = partition_fragment_C(tiled_mma_qk, Shape, Int>{}); + Tensor tO = partition_fragment_C(tiled_mma_sv, Shape, Int>{}); + tO.data().get() = tmem_addr::o; + tP.data().get() = tmem_addr::p; + Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); + + if (elect_one_sync()) { + // Copy Q + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx), + Tile, Int>{} + )(_, _, head_block_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16)); + } + + NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1); + + if (elect_one_sync()) { + // Wait for Q + plan.bar_q.wait(bar_phase_q); + bar_phase_q ^= 1; + tcgen05_after_thread_sync(); + + CUTE_NO_UNROLL + for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { + int buf_idx = block_idx % NUM_BUFS; + + // Wait for K + plan.bar_k_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + tcgen05_after_thread_sync(); + Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); + + // Issue P = Q @ K^T + utcmma_ss(tiled_mma_qk, sQ, sK, tP, true); + umma_arrive_noelect(plan.bar_qk_done[buf_idx]); + + // Wait for S + plan.bar_so_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + tcgen05_after_thread_sync(); + Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutV{}); + + // Issue O += S @ V + utcmma_ss(tiled_mma_sv, sS, sV, tO, block_idx == start_block_idx); + umma_arrive_noelect(plan.bar_k_free[buf_idx]); + + bar_phase_k ^= 1< tma_params = { + shape_Q, tma_Q, + shape_O, tma_O + }; + auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + const int num_m_blocks = cute::ceil_div(params.q_head_per_hk, B_H); + // NOTE Don't use PDL because of potential compiler bugs! + mla_kernel<<>>(params, tma_params); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +} \ No newline at end of file diff --git a/csrc/sm100/decode/sparse_fp8/splitkv_mla.h b/csrc/sm100/decode/sparse_fp8/splitkv_mla.h new file mode 100644 index 0000000..cc8c6da --- /dev/null +++ b/csrc/sm100/decode/sparse_fp8/splitkv_mla.h @@ -0,0 +1,10 @@ +#pragma once + +#include "params.h" + +namespace sm100 { + +void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams ¶ms, cudaStream_t stream); + +} + diff --git a/csrc/sm100/defines.h b/csrc/sm100/defines.h new file mode 100644 index 0000000..0e779a3 --- /dev/null +++ b/csrc/sm100/defines.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +namespace sm100 { + +using bf16 = cutlass::bfloat16_t; +using fp8 = cutlass::float_e4m3_t; +using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::fence_barrier_init; +using cutlass::arch::NamedBarrier; + +struct int32x8_t { + int a0, a1, a2, a3, a4, a5, a6, a7; +}; + +struct float8 { + float2 a01, a23, a45, a67; +}; + +struct bf16x8 { + __nv_bfloat162 a01; + __nv_bfloat162 a23; + __nv_bfloat162 a45; + __nv_bfloat162 a67; +}; + +} diff --git a/csrc/sm100/helpers.h b/csrc/sm100/helpers.h new file mode 100644 index 0000000..9195b33 --- /dev/null +++ b/csrc/sm100/helpers.h @@ -0,0 +1,97 @@ +#pragma once + +#include +#include "defines.h" + +namespace sm100 { + +using namespace cute; + +using _72 = Int<72>; +using _576 = Int<576>; + +template< + typename TMA, + typename Tensor0, + typename Tensor1 +> +CUTE_DEVICE +void launch_tma_copy( + const TMA &tma_copy, + Tensor0 src, + Tensor1 dst, + transac_bar_t &bar, + const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL +) { + auto thr_tma = tma_copy.get_slice(_0{}); + cute::copy( + tma_copy.with(reinterpret_cast(bar), 0, cache_hint), + thr_tma.partition_S(src), + thr_tma.partition_D(dst) + ); +} + +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma_ss( + TiledMMA &tiled_mma, + TensorA sA, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sA_frag = thr_mma.partition_fragment_A(sA); + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(sA_frag) == size<2>(sB_frag)); + static_assert(size<1>(sA_frag) == size<1>(tC_frag)); + static_assert(size<1>(sB_frag) == size<2>(tC_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(sA_frag); ++k) { + cute::gemm( + tiled_mma, + sA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma_ts( + TiledMMA &tiled_mma, + TensorA tA_frag, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(tA_frag) == size<2>(sB_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(tA_frag); ++k) { + cute::gemm( + tiled_mma, + tA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +} diff --git a/csrc/sm100/intrinsics.h b/csrc/sm100/intrinsics.h new file mode 100644 index 0000000..c2402ee --- /dev/null +++ b/csrc/sm100/intrinsics.h @@ -0,0 +1,461 @@ +#pragma once + +#include +#include + +#include "defines.h" + +namespace sm100 { + +using namespace cute; + +__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" + :: "r"(dst_addr), + "l"(src), + "n"(16)); +} + +CUTE_DEVICE +int64_t createpolicy_evict_last() { + int64_t res; + asm volatile( + "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" + : "=l"(res) + : + ); + return res; +} + +template +CUTE_DEVICE +static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { + static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b."); + long2 data_long2 = *reinterpret_cast(&data); + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + asm volatile ( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" + : + : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) + ); +} + + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +CUTE_DEVICE +void umma_arrive_multicast_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); +} + +CUTE_DEVICE +void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); +} + +CUTE_DEVICE +void umma_arrive_noelect(transac_bar_t &smem_ptr) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + : + :"r"(bar_intptr)); +} + +CUTE_DEVICE +void umma_arrive_2x1SM_noelect(transac_bar_t &smem_ptr) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); + asm volatile("tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];" + : + :"r"(bar_intptr)); +} + +CUTE_DEVICE +float2 float2_add(const float2 &a, const float2 &b) { + float2 res; + cute::add(res, a, b); + return res; +} + +CUTE_DEVICE +float2 float2_mul(const float2 &a, const float2 &b) { + float2 res; + cute::mul(res, a, b); + return res; +} + +CUTE_DEVICE +float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) { + // return a*b+c + float2 res; + cute::fma(res, a, b, c); + return res; +} + +CUTE_DEVICE +float2 float2_neg(const float2 &a) { + float2 t = {-1.0f, -1.0f}; + return float2_mul(a, t); +} + +template +CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + if constexpr (USE_CTA0_MBAR) { + mbar_addr &= Sm100MmaPeerBitMask; + } + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbar_addr), "l"(uint64_t(cache_hint)) + : "memory" + ); +} + +// 32 data path lanes, 32-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); + uint32_t* dst_ptr = reinterpret_cast(dst_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile ("trap"); + } +} + +// 32 data path lanes, 32-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); + uint32_t* src_ptr = reinterpret_cast(src_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32" + "[%1], {%0};\n" + : + : "r"(src_ptr[0]), + "r"(dst_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32" + "[%2], {%0, %1};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), + "r"(dst_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32" + "[%4], {%0, %1, %2, %3};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), + "r"(dst_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32" + "[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), + "r"(dst_addr)); + } else if constexpr (N == 16) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32" + "[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), + "r"(dst_addr)); + } else if constexpr (N == 32) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32" + "[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), + "r"(dst_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.st.sync.aligned.32x32b.x64.b32" + "[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), + "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), + "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), + "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), + "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), + "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), + "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), + "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), + "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), + "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), + "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), + "r"(src_ptr[63]), + "r"(dst_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.st.sync.aligned.32x32b.x128.b32" + "[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), + "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), + "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), + "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), + "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), + "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), + "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), + "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), + "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), + "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), + "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), + "r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]), + "r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]), + "r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]), + "r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]), + "r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]), + "r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]), + "r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]), + "r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]), + "r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]), + "r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]), + "r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]), + "r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]), + "r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]), + "r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]), + "r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]), + "r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]), + "r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]), + "r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]), + "r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]), + "r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]), + "r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]), + "r"(src_ptr[126]), "r"(src_ptr[127]), + "r"(dst_addr)); + } else { + asm volatile ("trap"); + } +} + +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +template +CUTE_DEVICE +T* get_peer_addr(const T* p) { + return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); +} + + +} diff --git a/csrc/sm100/prefill/sparse/fwd.cu b/csrc/sm100/prefill/sparse/fwd.cu new file mode 100644 index 0000000..963ac78 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd.cu @@ -0,0 +1,785 @@ +#include "fwd.h" + +#include +#include +#include +#include +#include +#include + +#include "params.h" +#include "utils.h" +#include "sm100/ws_gemm.h" +#include "sm100/helpers.h" +#include "sm100/intrinsics.h" +#include "sm100/tma_cta_group2_nosplit.h" + +namespace sm100 { + +using namespace cute; + +CUTE_DEVICE int32x8_t ldg_256_indices(void* src_ptr) { + int32x8_t val; + asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" + : "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3), + "=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7) + : "l"(src_ptr) + ); + return val; +} + +template< + typename Shape_Q, typename TMA_Q, + typename Shape_O, typename TMA_O +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + Shape_O shape_O; TMA_O tma_O; + CUtensorMap tensor_map_kv; +}; + +struct float2x2 { + float2 lo, hi; +}; + +constexpr int D_Q = 576; +constexpr int D_K = 576; +constexpr int D_V = 512; +constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan + +constexpr int B_H = 128; // For 2 CTAs +constexpr int B_TOPK = 128; // For 2 CTAs +constexpr int NUM_BUFS = 2; +constexpr int NUM_THREADS = 256 + 128 + 128; // 128 TMA threads, 128 scale & exp threads, 32 UTCMMA threads + +constexpr int D_sQ = 256, NUM_sQ_TILES = D_sQ / 64; +constexpr int D_tQ = D_Q - D_sQ, NUM_tQ_TILES = D_tQ / 64; +static_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q); + +// Tensor memory columns +namespace tmem_cols { + // 0 ~ 256: output + // 256 ~ 320: P + // 320 ~ 512: Q[192:576] + constexpr int o = 0; + constexpr int p = 256; + constexpr int q = 512 - D_tQ/2; + static_assert(p+64 <= q); +} + +template +using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutO = SmemLayoutOTiles<8>; + +template +using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutV = decltype(coalesce(tile_to_shape( + UMMA::Layout_MN_SW128_Atom{}, + Shape, Int>{}, + Step<_2, _1>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutSTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +struct SharedMemoryPlan { + union { + array_aligned>> q_full; + struct { + array_aligned>> sq; + array_aligned> v; + // NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q + array_aligned>> k; + } s; + array_aligned> o; + } u; + array_aligned>> s; + char is_k_valid[NUM_BUFS][B_TOPK/8]; + transac_bar_t bar_prologue_q, bar_prologue_utccp; + transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS]; // Pi = QKi^T done (i.e. Ki free) + transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. Vi free) + transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS]; + transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS]; // Vi is ready + transac_bar_t bar_p_free[NUM_BUFS]; + transac_bar_t bar_so_ready[NUM_BUFS]; // S and O are ready + transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS]; + array_aligned tmem_start_addr; + float rowwise_max_buf[128], rowwise_li_buf[128]; +}; + +using TiledMMA_P_tQ = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_TS_NOELECT{} +)); + +using TiledMMA_P_sQ = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_SS_NOELECT{} +)); + +using TiledMMA_O = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_SS_NOELECT{}, + Layout>{}, + Tile, Layout, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512] +)); + +/* +Pipeline Overview: + +| Copy | MMA | Scale & Exp | + +K0 +V0 + P0 = QK0^T +K1 S0 = exp(P0) + scale(O) w.r.t P0 + P1 = QK1^T +K2 S1 = exp(P1) + O += S0V0 +V1 scale(O) w.r.t P1 + P2 = QK2^T +K3 S2 = exp(P2) + O += S1V1 +V2 scale(O) w.r.t P2 + P3 = QK3^T +K4 S3 = exp(P3) + O += S2V2 +V3 scale(O) w.r.t P3 + +... + + O += S(n-3)V(n-3) +V(n-2) scale(O) w.r.t P(n-2) + P(n-1) = QK(n-1)^T + S(n-1) = exp(P(n-1)) + O += S(n-2)V(n-2) +V(n-1) scale(O) w.r.t P(n-1) + O += S(n-1)V(n-1) +*/ + +template +__global__ void __launch_bounds__(NUM_THREADS, 1, 2) +sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) { +#if IS_SM100 + const int cta_idx = blockIdx.x % 2; + const int s_q_idx = blockIdx.x / 2; + const int warp_idx = cutlass::canonical_warp_idx_sync(); + const int lane_idx = threadIdx.x % 32; + const int num_k_blocks = params.topk / B_TOPK; + const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const int idx_in_warpgroup = threadIdx.x % 128; + + // Prefetch TMA descriptors + if (threadIdx.x == 0) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv)); + } + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles<9>{}); + + int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk] + + // Allocate tmem tensors + TiledMMA tiled_mma_P_tQ = TiledMMA_P_tQ{}; + TiledMMA tiled_mma_P_sQ = TiledMMA_P_sQ{}; + TiledMMA tiled_mma_O = TiledMMA_O{}; + Tensor tP = partition_fragment_C(tiled_mma_P_tQ, Shape, Int>{}); + Tensor tQr = tiled_mma_P_tQ.get_slice(_0{}).make_fragment_A( + partition_shape_A(tiled_mma_P_tQ, Shape, Int>{}) + ); + Tensor tO = partition_fragment_C(tiled_mma_O, Shape, Int>{}); + tP.data().get() = tmem_cols::p; + tQr.data().get() = tmem_cols::q; + tO.data().get() = tmem_cols::o; + + if (warp_idx == 0) { + if (elect_one_sync()) { + // Initialize barriers + plan.bar_prologue_q.init(1); + plan.bar_prologue_utccp.init(1); + CUTE_UNROLL + for (int i = 0; i < NUM_BUFS; ++i) { + plan.bar_qk_part_done[i].init(1); + plan.bar_qk_done[i].init(1); + plan.bar_sv_part_done[i].init(1); + plan.bar_sv_done[i].init(1); + plan.bar_k_part0_ready[i].init(1); + plan.bar_k_part1_ready[i].init(1); + plan.bar_v_part0_ready[i].init(1); + plan.bar_v_part1_ready[i].init(1); + plan.bar_p_free[i].init(128*2); + plan.bar_so_ready[i].init(128*2); + plan.bar_k_valid_ready[i].init(16); + plan.bar_k_valid_free[i].init(128); + } + fence_barrier_init(); + } + } + + cute::cluster_sync(); // We must add a cluster_sync() here, or TMA from CTA1 may launch before barrier initialization in CTA0 + + if (warp_idx == 0) { + if (elect_one_sync()) { + // Copy Q + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx), + Tile>{} + )(_, cta_idx, _); + launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST); + } + + // Initialize TMEM + // We put this before cluster_arrive to make sure that the TMEM allocation is done before UTCCP + cute::TMEM::Allocator2Sm().allocate(512, plan.tmem_start_addr.data()); + TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); + cute::TMEM::Allocator2Sm().release_allocation_lock(); + __syncwarp(); + } + + if (warpgroup_idx == 0) { + cutlass::arch::warpgroup_reg_alloc<144>(); + // Scale & Exp warps + + // The following three numbers are + // - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V) + // - li: sumexp, i.e. li := sum(exp(Pi*scale - mi)) + // - real_mi: real max logits, i.e. real_mi := max(Pi*scale) + // where Pi is the i-th row of P, P := QK^T + // mi and real_mi are always consistent within the two threads that + // controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update + float mi = MAX_INIT_VAL; + float li = 0.0f; + float real_mi = -CUDART_INF_F; + + const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2}; + uint128_t* sS_base = (uint128_t*)plan.s.data() + idx_in_warpgroup%64 + 64*((idx_in_warpgroup/64)*8); + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + // Wait for P + plan.bar_qk_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + + // Load P + float2 p[(B_TOPK/2)/2]; + tmem_ld_32dp32bNx(tmem_cols::p, p); + cutlass::arch::fence_view_async_tmem_load(); + tcgen05_before_thread_sync(); + plan.bar_p_free[k%NUM_BUFS].arrive(0u); + + // Mask + plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1); + // The following code enables NVCC to use R2P instruction + // Although we perform 2x LDS.32 instructions here, don't worry, NVCC will + // convert them to one LDS.64 instruction. However, if we write LDS.64 + // here, NVCC won't use R2P. + uint32_t is_k_valid_lo = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0)); + uint32_t is_k_valid_hi = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0) + 4); + float* p_float = (float*)p; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/2; i += 1) { + if (!(is_k_valid_lo >> i & 1)) + p_float[i] = -CUDART_INF_F; + } + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/2; i += 1) { + if (!(is_k_valid_hi >> i & 1)) + p_float[i+(B_TOPK/2)/2] = -CUDART_INF_F; + } + + // Get rowwise max of Pi + float cur_pi_max = -CUDART_INF_F; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2); i += 1) { + cur_pi_max = max(cur_pi_max, p_float[i]); + } + cur_pi_max *= params.sm_scale_div_log2; + + plan.bar_k_valid_free[k%NUM_BUFS].arrive(); + + NamedBarrier::arrive_and_wait(128, 0); // Wait for rowwise_max_buf and sP to be ready + plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max; + NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers + cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]); + real_mi = max(real_mi, cur_pi_max); + bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f); + // By this point: + // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...) + // - should_scale_o is identical among threads 0~31+64~95; and is identical among threads 32~63+96~127 + + // Calc scale factor, and scale li + float new_max, scale_for_old; + if (!should_scale_o) { + // Don't scale O + scale_for_old = 1.0f; + new_max = mi; + } else { + new_max = max(cur_pi_max, mi); + scale_for_old = exp2f(mi - new_max); + } + mi = new_max; // mi is still identical within each row + li *= scale_for_old; + + // Calculate S + __nv_bfloat162 s[(B_TOPK/2)/2]; + float2 neg_new_max = float2 {-new_max, -new_max}; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/2; i += 1) { + float2 d = float2_fma(p[i], scale, neg_new_max); + d.x = exp2f(d.x); + d.y = exp2f(d.y); + li += d.x + d.y; // NOTE Theorically we can have use FFMA2 here but actually this is faster... + s[i] = __float22bfloat162_rn(d); + } + + // Wait for last SV gemm, write S + if (k > 0) { + plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + CUTE_UNROLL + for (int i = 0; i < B_TOPK/2/8; i += 1) { + sS_base[64*i] = *(uint128_t*)(s + i*4); + } + + // Scale O + if (k > 0 && should_scale_o) { + float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old}; + // plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE We have waited for last SV gemm before + tcgen05_after_thread_sync(); + + static constexpr int CHUNK_SIZE = 32; + float2 o[CHUNK_SIZE/2]; + CUTE_UNROLL + for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) { + // Load O + tmem_ld_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); + cutlass::arch::fence_view_async_tmem_load(); + + // Mult + for (int i = 0; i < CHUNK_SIZE/2; ++i) { + o[i] = float2_mul(o[i], scale_for_old_float2); + } + + // Store O + tmem_st_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); + cutlass::arch::fence_view_async_tmem_store(); + } + tcgen05_before_thread_sync(); + } + + fence_view_async_shared(); + plan.bar_so_ready[k%NUM_BUFS].arrive(0u); + } + + // Epilogue + + if (real_mi == -CUDART_INF_F) { + // real_mi == -CUDART_INF_F <=> No valid TopK indices + // We set li to 0 to fit the definition that li := exp(x[i] - mi) + li = 0.0f; + mi = -CUDART_INF_F; + } + + // Exchange li + plan.rowwise_li_buf[idx_in_warpgroup] = li; + NamedBarrier::arrive_and_wait(128, 0); + li += plan.rowwise_li_buf[idx_in_warpgroup^64]; + + // Store mi and li + if (idx_in_warpgroup < 64) { + int global_index = s_q_idx*params.h_q + cta_idx*(B_H/2) + idx_in_warpgroup; + float cur_lse = log2f(li) + mi; + params.max_logits[global_index] = real_mi; + params.lse[global_index] = cur_lse; + } + + // Wait for the last GEMM + plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + + // Store O + float output_scale = __fdividef(1.0f, li); + Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{}); + constexpr int B_EPI = 64; + Tensor tma_gO = flat_divide( + tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx), + Shape, Int>{} + )(_, _, cta_idx, _); + Tensor sO_divided = flat_divide( + sO, + Shape, Int>{} + )(_, _, _0{}, _); + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + + float2 o[B_EPI/2]; + bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during tmem_ld + if (!have_valid_indices) { + // If there are no valid indices, we set o[i] to 0 and don't load from TMEM + CUTE_UNROLL + for (int i = 0; i < B_EPI/2; ++i) + o[i].x = o[i].y = 0.0f; + output_scale = 1.0f; + } + + float2 output_scale_float2 = make_float2(output_scale, output_scale); + + CUTE_UNROLL + for (int k = 0; k < (D_V/2)/B_EPI; ++k) { + // Load O from tO + if (have_valid_indices) { + tmem_ld_32dp32bNx(tmem_cols::o + k*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + } + + // Convert and store + CUTE_UNROLL + for (int i = 0; i < B_EPI/8; ++i) { + __nv_bfloat162 o_bf16[4]; + CUTE_UNROLL + for (int j = 0; j < 4; ++j) { + float2 d = float2_mul(o[i*4+j], output_scale_float2); + o_bf16[j] = __float22bfloat162_rn(d); + } + int smem_row = idx_in_warpgroup % 64; + int smem_col = (idx_in_warpgroup/64)*(D_V/2) + k*B_EPI + i*8; + *(uint128_t*)(&sO(smem_row, smem_col)) = *(uint128_t*)(o_bf16); + } + + // Sync + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, 0); + + if (warp_idx == 0 && elect_one_sync()) { + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sO_divided(_, _, k)), + thr_tma.partition_D(tma_gO(_, _, k)) + ); + } + if (warp_idx == 1 && elect_one_sync()) { + int k2 = k + (D_V/B_EPI/2); + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sO_divided(_, _, k2)), + thr_tma.partition_D(tma_gO(_, _, k2)) + ); + } + } + + if (warp_idx == 0) { + cute::TMEM::Allocator2Sm().free(0, 512); + } + } else if (warpgroup_idx == 1) { + // Producer warp for K + cutlass::arch::warpgroup_reg_dealloc<96>(); + int warp_idx = cutlass::canonical_warp_idx_sync() - 4; + constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/2)/4/NUM_WARPS; + if (elect_one_sync()) { + bf16* sK_base = plan.u.s.k.data() + warp_idx*4*64; + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + int4 indices[NUM_LOCAL_ROWS_PER_WARP]; + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) + indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK + cta_idx*(B_TOPK/2)) + local_row*NUM_WARPS + warp_idx); + + auto load_part_ki = [&](transac_bar_t* bar, int local_col_start, int local_col_end) { + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { + CUTE_UNROLL + for (int local_col = local_col_start; local_col < local_col_end; ++local_col) + tma_gather4( + &(tma_params.tensor_map_kv), + bar, + sK_base + local_row*(4*NUM_WARPS)*64 + local_col*((B_TOPK/2)*64), + local_col*64, + indices[local_row], + TMA::CacheHintSm90::EVICT_LAST + ); + } + }; + + int cur_buf = k%NUM_BUFS; + if (k > 0) { + plan.bar_qk_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + load_part_ki(plan.bar_k_part0_ready+cur_buf, 0, D_sQ/64); + + if (k > 0) { + plan.bar_qk_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + load_part_ki(plan.bar_k_part1_ready+cur_buf, D_sQ/64, D_K/64); + } + } + } else if (warpgroup_idx == 2) { + // Producer warps for V + cutlass::arch::warpgroup_reg_dealloc<96>(); + int warp_idx = cutlass::canonical_warp_idx_sync() - 8; + constexpr int NUM_WARPS = 4; + + if (elect_one_sync()) { + // Wait for UTCCP + plan.bar_prologue_utccp.wait(0); + + bf16* sV_base = plan.u.s.v.data() + warp_idx*4*64; + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + auto load_part_vi = [&](transac_bar_t* bar, int local_row_start, int local_row_end) { + CUTE_UNROLL + for (int local_row = local_row_start; local_row < local_row_end; ++local_row) { + int4 token_idxs = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx); + CUTE_UNROLL + for (int local_col = 0; local_col < (D_V/2)/64; ++local_col) + tma_gather4( + &(tma_params.tensor_map_kv), + bar, + sV_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64), + local_col*64 + (cta_idx?256:0), + token_idxs, + TMA::CacheHintSm90::EVICT_LAST + ); + } + }; + + int cur_buf = k%NUM_BUFS; + if (k > 0) { + plan.bar_sv_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + load_part_vi(plan.bar_v_part0_ready+cur_buf, 0, (B_TOPK/2)/4/NUM_WARPS); + + if (k > 0) { + plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + load_part_vi(plan.bar_v_part1_ready+cur_buf, (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS); + } + } + } else { + cutlass::arch::warpgroup_reg_alloc<168>(); + + // MMA warp + if (cta_idx == 0 && warp_idx == 12 && elect_one_sync()) { + // S -> T copy for Q + UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc( + make_tensor( + make_smem_ptr(plan.u.q_full.data() + (B_H/2)*D_sQ), + tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64>>{} + ) + ) + ); + plan.bar_prologue_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16)); + plan.bar_prologue_q.wait(0); + tcgen05_after_thread_sync(); + CUTE_UNROLL + for (int tile_idx = 0; tile_idx < NUM_tQ_TILES; ++tile_idx) { + // A tile is 64 rows * 64 cols (128B) + CUTE_UNROLL + for (int subtile_idx = 0; subtile_idx < 8; ++subtile_idx) { + // A subtile is 64 rows * 8 cols (128b) + SM100_UTCCP_2x64dp128bitlw0213_2cta::copy( + sQ_desc + tile_idx*((B_H/2)*128/16) + subtile_idx*(16/16), // Remember that 4 LSBs are not included + tmem_cols::q + tile_idx*32 + subtile_idx*4 + ); + } + } + umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2); + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks+1; ++k) { + if (k < num_k_blocks) { + // Pi = QKi^T + int cur_buf = k%NUM_BUFS; + Tensor sQl = make_tensor(make_smem_ptr(plan.u.s.sq.data()), SmemLayoutQTiles{}); + Tensor sKl = make_tensor(make_smem_ptr(plan.u.s.k.data()), SmemLayoutKTiles{}); + Tensor sKr = make_tensor(make_smem_ptr(plan.u.s.k.data()+64*D_sQ), SmemLayoutKTiles{}); + + // Wait for K (part0) + plan.bar_k_part0_ready[cur_buf].arrive_and_expect_tx(B_TOPK*D_sQ*sizeof(bf16)); + plan.bar_k_part0_ready[cur_buf].wait((k/NUM_BUFS)&1); + if (k > 0) { + plan.bar_p_free[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + tcgen05_after_thread_sync(); + + utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true); + umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2); + + // Wait for K (part1) + plan.bar_k_part1_ready[cur_buf].arrive_and_expect_tx(B_TOPK*(D_K-D_sQ)*sizeof(bf16)); + plan.bar_k_part1_ready[cur_buf].wait((k/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + + utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false); + umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2); + } + if (k > 0) { + // O += S(i-1)V(i-1) + int cur_buf = (k-1)%NUM_BUFS; + + Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutSTiles<2>{}); + Tensor sV = make_tensor(make_smem_ptr(plan.u.s.v.data()), SmemLayoutV{}); + Tensor sS_divided = flat_divide(sS, Tile, _64>{})(_, _, _0{}, _); // (B_H/2, 64, 2) + Tensor sV_divided = flat_divide(sV, Tile, _64>{})(_, _, _0{}, _); // (D_V/2, 64, 2) + + // Wait for S(i-1) and O to be scaled + plan.bar_so_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); + + // Wait for V (part0), and issue O += sS @ sV + plan.bar_v_part0_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16)); + plan.bar_v_part0_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + + utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1); + umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2); + + // Wait for V (part1), and issue O += sS @ sV + plan.bar_v_part1_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16)); + plan.bar_v_part1_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); + tcgen05_after_thread_sync(); + utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false); + umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2); + } + } + } else if (warp_idx == 13) { + // KV valid loading warp + static_assert(B_TOPK == 128); + if (lane_idx < 16) { + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + int cur_buf = k%NUM_BUFS; + int32x8_t indices = ldg_256_indices(gIndices + k*B_TOPK + lane_idx*8); + auto is_valid = [&](int index) -> char { + return index >= 0 && index < params.s_kv; + }; + char is_ks_valid_mask = \ + is_valid(indices.a7) << 7 | + is_valid(indices.a6) << 6 | + is_valid(indices.a5) << 5 | + is_valid(indices.a4) << 4 | + is_valid(indices.a3) << 3 | + is_valid(indices.a2) << 2 | + is_valid(indices.a1) << 1 | + is_valid(indices.a0) << 0; + + plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1); + plan.is_k_valid[cur_buf][lane_idx] = is_ks_valid_mask; + plan.bar_k_valid_ready[cur_buf].arrive(); + } + } + } + } + +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100"); + } +#endif +} + +void run_fwd_kernel(const SparsePrefillParams& params) { + FLASH_ASSERT(params.h_kv == 1); + FLASH_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings + FLASH_ASSERT(params.h_q == B_H); // To save some calculation + + auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); + auto tma_Q = cute::make_tma_copy( + SM100_TMA_2SM_LOAD_NOSPLIT{}, + make_tensor( + make_gmem_ptr((bf16*)params.q), + make_layout( + shape_Q, + make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) + ) + ), + SmemLayoutQTiles<9>{} + ); + + auto shape_O = make_shape(params.h_q, params.d_v, params.s_q); + auto tma_O = cute::make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + make_gmem_ptr((bf16*)params.out), + make_layout( + shape_O, + make_stride(params.d_v, _1{}, params.h_q*params.d_v) + ) + ), + SmemLayoutOTiles<1>{} + ); + + CUtensorMap tensor_map_kv; + { + uint64_t size[2] = {D_K, (unsigned long)params.s_kv}; + uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)}; + uint32_t box_size[2] = {64, 1}; + uint32_t elem_stride[2] = {1, 1}; + CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tensor_map_kv, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + 2, + params.kv, + size, + stride, + box_size, + elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + FLASH_ASSERT(res == CUresult::CUDA_SUCCESS); + } + + TmaParams< + decltype(shape_Q), decltype(tma_Q), + decltype(shape_O), decltype(tma_O) + > tma_params = { + shape_Q, tma_Q, + shape_O, tma_O, + tensor_map_kv + }; + auto kernel = &sparse_attn_fwd_kernel; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + cutlass::ClusterLaunchParams launch_params = { + dim3(2*params.s_q, 1, 1), + dim3(NUM_THREADS, 1, 1), + dim3(2, 1, 1), + smem_size, + params.stream + }; + cutlass::launch_kernel_on_cluster( + launch_params, (void*)kernel, params, tma_params + ); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +} diff --git a/csrc/sm100/prefill/sparse/fwd.h b/csrc/sm100/prefill/sparse/fwd.h new file mode 100644 index 0000000..6558e80 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd.h @@ -0,0 +1,9 @@ +#pragma once + +#include "params.h" + +namespace sm100 { + +void run_fwd_kernel(const SparsePrefillParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/helpers.h b/csrc/sm100/prefill/sparse/helpers.h new file mode 100644 index 0000000..991b40d --- /dev/null +++ b/csrc/sm100/prefill/sparse/helpers.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include "sm100/defines.h" + +namespace sm100 { + +using namespace cute; + +using _72 = Int<72>; +using _576 = Int<576>; + +template< + typename TMA, + typename Tensor0, + typename Tensor1 +> +CUTE_DEVICE +void launch_tma_copy( + const TMA &tma_copy, + Tensor0 src, + Tensor1 dst, + transac_bar_t &bar, + const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL +) { + auto thr_tma = tma_copy.get_slice(_0{}); + cute::copy( + tma_copy.with(reinterpret_cast(bar), 0, cache_hint), + thr_tma.partition_S(src), + thr_tma.partition_D(dst) + ); +} + +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma( + TiledMMA &tiled_mma, + TensorA sA, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sA_frag = thr_mma.partition_fragment_A(sA); + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(sA_frag) == size<2>(sB_frag)); + static_assert(size<1>(sA_frag) == size<1>(tC_frag)); + static_assert(size<1>(sB_frag) == size<2>(tC_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(sA_frag); ++k) { + cute::gemm( + tiled_mma, + sA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma_ts( + TiledMMA &tiled_mma, + TensorA tA_frag, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(tA_frag) == size<2>(sB_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(tA_frag); ++k) { + cute::gemm( + tiled_mma, + tA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +struct bf16x8 { + __nv_bfloat162 a01; + __nv_bfloat162 a23; + __nv_bfloat162 a45; + __nv_bfloat162 a67; +}; + +} diff --git a/csrc/sm100/prefill/sparse/intrinsics.h b/csrc/sm100/prefill/sparse/intrinsics.h new file mode 100644 index 0000000..85a8203 --- /dev/null +++ b/csrc/sm100/prefill/sparse/intrinsics.h @@ -0,0 +1,638 @@ +#pragma once + +#include +#include "defines.h" + +namespace sm100 { + +using namespace cute; + +struct int32x8_t { + int a0, a1, a2, a3, a4, a5, a6, a7; +}; + +struct float8 { + float2 a01, a23, a45, a67; +}; + +__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" + :: "r"(dst_addr), + "l"(src), + "n"(16)); +} + +template +CUTE_DEVICE +static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { + static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b."); + long2 data_long2 = *reinterpret_cast(&data); + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + asm volatile ( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" + : + : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) + ); +} + +CUTE_DEVICE +void umma_arrive_multicast_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); +} + +CUTE_DEVICE +void umma_arrive_multicast_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) { + umma_arrive_multicast_noelect((uint64_t*)smem_ptr, cta_mask); +} + +CUTE_DEVICE +void umma_arrive_multicast_2x1SM_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" + "}" + : + :"r"(bar_intptr), "h"(cta_mask)); +} + +CUTE_DEVICE +void umma_arrive_multicast_2x1SM_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) { + umma_arrive_multicast_2x1SM_noelect((uint64_t*)smem_ptr, cta_mask); +} + +CUTE_DEVICE +int64_t createpolicy_evict_last() { + int64_t res; + asm volatile( + "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" + : "=l"(res) + : + ); + return res; +} + +CUTE_DEVICE +void atomicadd_f32x4_with_policy(void* global_addr, const float4 &data, int64_t cache_policy) { + asm volatile( + "red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t" + : + : "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w), + "l"((int64_t)global_addr), "l"(cache_policy) + ); +} + +CUTE_DEVICE +void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) { + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(src_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" + : + : "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +} + +CUTE_DEVICE +float2 float2_add(const float2 &a, const float2 &b) { + float2 res; + cute::add(res, a, b); + return res; +} + +CUTE_DEVICE +float2 float2_mul(const float2 &a, const float2 &b) { + float2 res; + cute::mul(res, a, b); + return res; +} + +CUTE_DEVICE +float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) { + // return a*b+c + float2 res; + cute::fma(res, a, b, c); + return res; +} + +CUTE_DEVICE +float2 float2_neg(const float2 &a) { + float2 t = {-1.0f, -1.0f}; + return float2_mul(a, t); +} + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +template +CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + if constexpr (USE_CTA0_MBAR) { + mbar_addr &= Sm100MmaPeerBitMask; + } + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbar_addr), "l"(uint64_t(cache_hint)) + : "memory" + ); +} + +// 32 data path lanes, 32-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); + uint32_t* dst_ptr = reinterpret_cast(dst_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile ("trap"); + } +} + +// 16 data path lanes, 256-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_ld_16dp256bNx(uint32_t const &src_addr, T* dst_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + uint32_t* dst_ptr = reinterpret_cast(dst_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } +} + + +// 32 data path lanes, 32-bit pattern, repeated N times +template +CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); + uint32_t* src_ptr = reinterpret_cast(src_ptr_); + + if constexpr (N == 1) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32" + "[%1], {%0};\n" + : + : "r"(src_ptr[0]), + "r"(dst_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32" + "[%2], {%0, %1};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), + "r"(dst_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32" + "[%4], {%0, %1, %2, %3};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), + "r"(dst_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32" + "[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), + "r"(dst_addr)); + } else if constexpr (N == 16) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32" + "[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), + "r"(dst_addr)); + } else if constexpr (N == 32) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32" + "[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), + "r"(dst_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.st.sync.aligned.32x32b.x64.b32" + "[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), + "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), + "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), + "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), + "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), + "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), + "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), + "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), + "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), + "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), + "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), + "r"(src_ptr[63]), + "r"(dst_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.st.sync.aligned.32x32b.x128.b32" + "[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127};\n" + : + : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), + "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), + "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), + "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), + "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), + "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), + "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), + "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), + "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), + "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), + "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), + "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), + "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), + "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), + "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), + "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), + "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), + "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), + "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), + "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), + "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), + "r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]), + "r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]), + "r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]), + "r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]), + "r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]), + "r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]), + "r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]), + "r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]), + "r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]), + "r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]), + "r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]), + "r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]), + "r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]), + "r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]), + "r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]), + "r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]), + "r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]), + "r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]), + "r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]), + "r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]), + "r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]), + "r"(src_ptr[126]), "r"(src_ptr[127]), + "r"(dst_addr)); + } else { + asm volatile ("trap"); + } +} + + +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +template +CUTE_DEVICE +T* get_peer_addr(const T* p) { + return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); +} + +} diff --git a/csrc/sm100/prefill/sparse/ws_gemm.h b/csrc/sm100/prefill/sparse/ws_gemm.h new file mode 100644 index 0000000..78c9005 --- /dev/null +++ b/csrc/sm100/prefill/sparse/ws_gemm.h @@ -0,0 +1,328 @@ +#pragma once + +#include + +namespace cute { + +// Extensions to CuTe +// CuTe don't support UTCMMA with .ws, so we add it here + +template +struct SM100_MMA_F16BF16_WS_SS_NOELECT +{ + static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 64 || N == 128 || N == 256, + "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_WS_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + +template +struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + + +// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync() +template +struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +// template +// struct MMA_Traits> : MMA_Traits> {}; +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +} \ No newline at end of file diff --git a/csrc/sm100/tma_cta_group2_nosplit.h b/csrc/sm100/tma_cta_group2_nosplit.h new file mode 100644 index 0000000..12e65b5 --- /dev/null +++ b/csrc/sm100/tma_cta_group2_nosplit.h @@ -0,0 +1,281 @@ +#pragma once + +#include + +namespace cute { + +// Extensions to CuTe +// CuTe 自带的 SM100_TMA_2SM_LOAD_1D 系列要求参与的 thread 数量为 2(using ThrID = Layout<_2>;),还会对数据进行切分,用起来太恶心了,所以我们自己改一版。另外,为了和其他使用 SM90 TMA 的部分统一,这里我们让它接受 TMA::CacheHintSm90 而不是 TMA::CacheHintSm100 + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_1D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_2D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_3D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_4D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_5D_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_NOSPLIT +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM100_TMA_2SM_LOAD_1D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM100_TMA_2SM_LOAD_2D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM100_TMA_2SM_LOAD_3D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM100_TMA_2SM_LOAD_4D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; + + + +struct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {}; + +// The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_NOSPLIT arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; + } + + // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_NOSPLIT arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) + : opargs_(desc, mbar, cache) {} +}; + + +} diff --git a/csrc/sm100/ws_gemm.h b/csrc/sm100/ws_gemm.h new file mode 100644 index 0000000..54edd3d --- /dev/null +++ b/csrc/sm100/ws_gemm.h @@ -0,0 +1,426 @@ +#pragma once + +#include + +namespace cute { + +// Extensions to CuTe +// CuTe don't support UTCMMA with .ws, so we add it here + +template +struct SM100_MMA_F16BF16_WS_SS_NOELECT +{ + static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 64 || N == 128 || N == 256, + "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_WS_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +using namespace cute; +template +struct SM100_MMA_F16BF16_WS_TS_NOELECT +{ + static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 64 || N == 128 || N == 256, + "SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], [%1], %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_TS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_WS_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + +template +struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + + +// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync() +template +struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +// template +// struct MMA_Traits> : MMA_Traits> {}; +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +} // namespace cute diff --git a/setup.py b/setup.py index 338117f..15fa671 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ BuildExtension, CUDAExtension, IS_WINDOWS, + CUDA_HOME ) @@ -22,8 +23,21 @@ def get_features_args(): return features_args def get_arch_flags(): + # Check NVCC Version + # NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py` + assert CUDA_HOME is not None, "PyTorch must be compiled with CUDA support" + nvcc_version = subprocess.check_output( + [os.path.join(CUDA_HOME, "bin", "nvcc"), '--version'], stderr=subprocess.STDOUT + ).decode('utf-8') + nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip() + major, minor = map(int, nvcc_version_number.split('.')) + print(f'Compiling using NVCC {major}.{minor}') + DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") + if major < 12 or (major == 12 and minor <= 8): + assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." + arch_flags = [] if not DISABLE_SM100: arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) @@ -55,8 +69,10 @@ def get_nvcc_thread_args(): "csrc/sm90/decode/dense/splitkv_mla.cu", "csrc/sm90/decode/sparse_fp8/splitkv_mla.cu", "csrc/sm90/prefill/sparse/fwd.cu", + "csrc/sm100/decode/sparse_fp8/splitkv_mla.cu", "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", + "csrc/sm100/prefill/sparse/fwd.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), diff --git a/tests/test_flash_mla_decoding.py b/tests/test_flash_mla_decoding.py index 64ddf72..d6c1f81 100644 --- a/tests/test_flash_mla_decoding.py +++ b/tests/test_flash_mla_decoding.py @@ -319,6 +319,11 @@ def main(torch_dtype): ] testcases = correctness_cases + corner_cases + performance_cases + + # Prune out unsupported cases + cc_major, cc_minor = torch.cuda.get_device_capability() + if cc_major == 10: + testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)] for testcase in testcases: test_flash_mla(testcase) From 472477e875b746c731eb63669bd81b7def9679db Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Mon, 29 Sep 2025 18:23:20 +0800 Subject: [PATCH 15/24] Add Deep-Dive Blog for the New Sparse Decoding Kernel on Hopper (#100) --- csrc/sm100/tma_cta_group2_nosplit.h | 2 +- docs/250929-hopper-fp8-sparse-deep-dive.md | 52 ++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 docs/250929-hopper-fp8-sparse-deep-dive.md diff --git a/csrc/sm100/tma_cta_group2_nosplit.h b/csrc/sm100/tma_cta_group2_nosplit.h index 12e65b5..045456d 100644 --- a/csrc/sm100/tma_cta_group2_nosplit.h +++ b/csrc/sm100/tma_cta_group2_nosplit.h @@ -5,7 +5,7 @@ namespace cute { // Extensions to CuTe -// CuTe 自带的 SM100_TMA_2SM_LOAD_1D 系列要求参与的 thread 数量为 2(using ThrID = Layout<_2>;),还会对数据进行切分,用起来太恶心了,所以我们自己改一版。另外,为了和其他使用 SM90 TMA 的部分统一,这里我们让它接受 TMA::CacheHintSm90 而不是 TMA::CacheHintSm100 +// CuTe's SM100_TMA_2SM_LOAD_1D requires two threads to perform this operation cooperatively (using ThrID = Layout<_2>;), which doesn't fit our use case. //////////////////////////////////////////////////////////////////////////////////////////////////// /// TMA_LOAD : Initiates a TMA copy from global memory to shared memory diff --git a/docs/250929-hopper-fp8-sparse-deep-dive.md b/docs/250929-hopper-fp8-sparse-deep-dive.md new file mode 100644 index 0000000..cd71346 --- /dev/null +++ b/docs/250929-hopper-fp8-sparse-deep-dive.md @@ -0,0 +1,52 @@ +# A Deep Dive Into The Flash MLA FP8 Decoding Kernel on Hopper + +With the release of DeepSeek-V3.2, we have doubled the context length of our models from 64K tokens to 128K tokens. This puts significant pressure on GPU memory (a single request with 128K tokens requires a KVCache of size $576 \times 2 \times 62 \times 128 \times 1024 = 8.72\ \mathrm{GiB}$), which can lead to out-of-memory (OOM) errors or under-utilized GPUs due to small batch sizes. To address this, we introduced FP8 KVCache for DeepSeek-V3.2. + +However, writing a high-performance decoding kernel is challenging due to the need for dequantization and its sparse memory access patterns. In this blog, we share the story behind our new FP8 sparse decoding kernel for Hopper GPUs. We will first explain our FP8 KVCache format, then provide a theoretical analysis of clock cycles, and finally detail the techniques used in our new kernel. + +## The FP8 KVCache Format + +Recall that the decoding phase of the Multi-head Latent Attention (MLA) algorithm operates similarly to Multi-Query Attention (MQA), with 128 query heads and 1 key head, where `head_dim_k = 576` and `head_dim_v = 512` respectively. To reduce the size of the KVCache while maintaining accuracy, we use a fine-grained quantization method. Specifically, we apply tile-level quantization (with a tile size of $1 \times 128$) to the first 512 elements in each token's KV Cache. This results in 512 `float8_e4m3` values and 4 `float32` scale factors. For the remaining 64 elements (the RoPE part), we do not apply quantization as they are sensitive to precision loss. Therefore, in GPU memory, each token's KVCache occupies 656 bytes, consisting of 512 `float8_e4m3`s, 4 `float32`s, and 64 `bfloat16`s. + +Inside the kernel, we first dequantize the 512 `float8_e4m3` values into 512 `bfloat16`s. We then concatenate them with the 64 original `bfloat16` values from the RoPE part. Finally, we perform the MQA calculation using matrix multiplication-add (MMA) operations in `bfloat16` precision (i.e., the inputs to the MMAs are in `bfloat16` and the outputs are in `float32`. This applies to both the QK gemm and the attention-score-V gemm). + +## Theoretical Analysis of Clock Cycles + +The main challenge is that Tensor Cores (which handle MMA calculations) are extremely fast, while the dequantization process, performed on CUDA Cores, struggles to keep up. + +The basic unit on an NVIDIA GPU is the Stream Multiprocessor (SM). You can think of each SM as an independent core on the GPU. For simplicity, let's focus on a single SM. Each SM can process 4096 MMA Flops per clock cycle (calculated as `989 TFlops / 1830 MHz / 132 SMs` on H800). If we assign each CTA (CUDA Thread Block) to process 64 query heads, it only requires $64 \times (576+512) \times 2 / 4096 \approx 34$ cycles for MMA operations per K/V token. + +However, because the H800 cannot directly cast `float8_e4m3` to `bfloat16`, dequantizing the KVCache for one token requires the following steps: +1. Convert `float8_e4m3` to `half` +2. Convert `half` to `float32` +3. Convert `float32` to `bfloat16` +4. Multiply the converted `bfloat16` value by the `float32` scale factor + +According to [NVIDIA's documentation](https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#throughput-of-native-arithmetic-instructions), we need at least $(\frac{1}{64} + \frac{1}{64} + \frac{1}{16} + \frac{1}{256}) \times 512 \approx 50$ cycles for dequantizing each token! This is significantly more than the 34 cycles required for the MMA operations, meaning the kernel is **dequantization-bound**. If left unaddressed, dequantization would become the performance bottleneck, leaving the powerful Tensor Cores underutilized. + +## Crossover + +Before we continue, it's important to note a key fact: every query head within the same query token attends to the same key heads, because this is Multi-Query Attention (MQA). + +Recall that each CTA processes 64 query heads, while DeepSeek-V3.2 has a total of 128 query heads. If we can find a way to "share" the dequantized K/V values between two CTAs that are processing different sets of query heads, then each CTA would only need to dequantize **half** of the KV cache – which is fantastic! We call this method "crossover", since the idea was actually inspired by [Chromosomal crossover](https://en.wikipedia.org/wiki/Chromosomal_crossover) during [Meiosis](https://en.wikipedia.org/wiki/Meiosis). + +The next question is, how do we implement this in CUDA? Before NVIDIA's Hopper architecture, the only options for data exchange between CTAs were global memory or the L2 cache, which are slow. However, the powerful Distributed Shared Memory gave us a new solution. + +## Distributed Shared Memory to the Rescue + +Distributed Shared Memory (DSM) is a new feature introduced with the Hopper architecture, alongside the CTA Cluster (thread block cluster). CTAs within the same cluster can directly access each other's shared memory. For more details, you can refer to [NVIDIA Hopper Architecture In-Depth](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/). + +Here is how we use it: We launch CTAs in clusters of size 2. Each CTA within a cluster is responsible for 64 query heads from the same query token. Each CTA performs the following steps: +1. Loads *half* of the quantized K/V from global memory. We use a wide `__ldg` load with a width of 128 bits to improve performance. +2. Dequantizes its assigned half on the CUDA Cores. +3. Stores the dequantized K/V into its own shared memory. +4. Simultaneously uses `st.async` to write the dequantized K/V into the shared memory of the other CTA in the cluster. + +For synchronization between these operations, we rely on the [cluster transaction barrier](https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/), another powerful programming primitive available in CTA Clusters. After the data exchange is complete, each CTA has the *full* set of dequantized K and V values available in its own shared memory, which it can then use to perform the MMA operations. + +## Performance +Using these techniques, we achieved 410 TFLOPS in a compute-bound configuration (batch_size=128, num_heads=128, s_q=2, topk=2048) on H800 SXM5 GPUs. This is a significant improvement over the 250 TFLOPS achieved by our previous FP8 sparse decoding kernel without the crossover technique. + +Although this number is still below the 640 TFLOPS peak of our previous bfloat16 dense decoding kernel, one reason is that it's a **sparse** kernel, and its topk is only 2048. With a smaller topk, the relative overhead of the kernel's prologue and epilogue becomes larger compared with dense decoding with long context length. If we set topk to a larger value, such as 32768, this kernel can achieve up to 460 TFLOPS. + +From another perspective, the execution time of this kernel in the configuration mentioned above is comparable to that of the dense decoding kernel when the sequence length is around 3000. When the sequence length exceeds 3000, the performance advantage of our new kernel becomes even more significant. This also highlights the effectiveness of our DeepSeek Sparse Attention algorithm. From 42f3c5789db65b5ff1eadea0fe4ce3805483a8e8 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Mon, 29 Sep 2025 18:29:18 +0800 Subject: [PATCH 16/24] Rename deep dive blog --- ...parse-deep-dive.md => 20250929-hopper-fp8-sparse-deep-dive.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/{250929-hopper-fp8-sparse-deep-dive.md => 20250929-hopper-fp8-sparse-deep-dive.md} (100%) diff --git a/docs/250929-hopper-fp8-sparse-deep-dive.md b/docs/20250929-hopper-fp8-sparse-deep-dive.md similarity index 100% rename from docs/250929-hopper-fp8-sparse-deep-dive.md rename to docs/20250929-hopper-fp8-sparse-deep-dive.md From e9b67321b17e53b3743cbe0e180973a943c4b217 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Tue, 30 Sep 2025 18:21:54 +0800 Subject: [PATCH 17/24] Update blog and README --- README.md | 2 +- docs/20250929-hopper-fp8-sparse-deep-dive.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 354cdde..df021de 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee ## News -- **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. +- **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out [here](docs/20250929-hopper-fp8-sparse-deep-dive.md). - **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell! - **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). - **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀 diff --git a/docs/20250929-hopper-fp8-sparse-deep-dive.md b/docs/20250929-hopper-fp8-sparse-deep-dive.md index cd71346..cf3c166 100644 --- a/docs/20250929-hopper-fp8-sparse-deep-dive.md +++ b/docs/20250929-hopper-fp8-sparse-deep-dive.md @@ -14,7 +14,7 @@ Inside the kernel, we first dequantize the 512 `float8_e4m3` values into 512 `bf The main challenge is that Tensor Cores (which handle MMA calculations) are extremely fast, while the dequantization process, performed on CUDA Cores, struggles to keep up. -The basic unit on an NVIDIA GPU is the Stream Multiprocessor (SM). You can think of each SM as an independent core on the GPU. For simplicity, let's focus on a single SM. Each SM can process 4096 MMA Flops per clock cycle (calculated as `989 TFlops / 1830 MHz / 132 SMs` on H800). If we assign each CTA (CUDA Thread Block) to process 64 query heads, it only requires $64 \times (576+512) \times 2 / 4096 \approx 34$ cycles for MMA operations per K/V token. +The basic unit on an NVIDIA GPU is the Stream Multiprocessor (SM). You can think of each SM as an independent core on the GPU. For simplicity, let's focus on a single SM. Each SM can process 4096 MMA Flops per clock cycle (calculated as `989 TFlops / 1830 MHz / 132 SMs` on H800). In our kernel, each CTA runs on one SM, and each SM is only mapped to one CTA. If we assign each CTA (CUDA Thread Block) to process 64 query heads, it only requires $64 \times (576+512) \times 2 / 4096 \approx 34$ cycles for MMA operations per K/V token. However, because the H800 cannot directly cast `float8_e4m3` to `bfloat16`, dequantizing the KVCache for one token requires the following steps: 1. Convert `float8_e4m3` to `half` From 7f55c7151acfeaacfd610c022aaa26f836c9fac1 Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Tue, 30 Sep 2025 23:29:18 +0800 Subject: [PATCH 18/24] Fix error message --- csrc/pybind.cpp | 6 +++--- .../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp | 2 +- .../sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 6ec3f21..13541d4 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -41,7 +41,7 @@ struct Arch { } }; -// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. Hopper Dense BF16, Hopper Sparse FP8, etc.) +// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. SM90 Dense BF16, SM90 Sparse FP8, etc.) struct DecodingAttnImplMeta { int num_sm_parts; int fixed_overhead_num_blocks; @@ -334,7 +334,7 @@ fwd_kvcache_mla( TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90"); sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream); } else { - TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); + TORCH_CHECK(false, "Only FP8 kvcahe is supported for sparse MLA on SM90"); } } else { if (is_fp8) { @@ -347,7 +347,7 @@ fwd_kvcache_mla( sm90::run_flash_splitkv_mla_kernel(params, stream); #endif } else { - TORCH_CHECK(false, "Unsupported tensor dtype for query"); + TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90"); } } } diff --git a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index 057b45e..c34713b 100644 --- a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -949,7 +949,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { TensorC const& coord, TensorShape const& tensor_shape) { - //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + // TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version. // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); auto copy_op = make_cotiled_copy( diff --git a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp index 0d4af85..c25d638 100644 --- a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -953,7 +953,8 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { TensorR const& regs, TensorC const& coord, TensorShape const& tensor_shape) { - //TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. + + // TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version. // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); auto copy_op = make_cotiled_copy( From 1858932afd3bd4cf2d3f91bfdaa9f8d96f2afe14 Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Tue, 30 Sep 2025 23:33:43 +0800 Subject: [PATCH 19/24] Code format --- tests/quant.py | 26 ++++++------- tests/test_flash_mla_decoding.py | 64 ++++++++++++++++---------------- tests/test_flash_mla_prefill.py | 27 +++++++------- tests/test_fmha_sm100.py | 19 +++++----- 4 files changed, 66 insertions(+), 70 deletions(-) diff --git a/tests/quant.py b/tests/quant.py index afee4b2..0624759 100644 --- a/tests/quant.py +++ b/tests/quant.py @@ -1,5 +1,3 @@ -import enum - import torch def quantize_k_cache( @@ -19,20 +17,20 @@ def quantize_k_cache( input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] input_elem_size = input_k_cache.element_size() - result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) + result = torch.empty((num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) result_k_nope_part = result[..., :dv] - result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32) - result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype) + result_k_scale_factor = result[..., dv:dv + num_tiles * 4].view(torch.float32) + result_k_rope_part = result[..., dv + num_tiles * 4:].view(input_k_cache.dtype) result_k_rope_part[:] = input_k_cache[..., dv:] for tile_idx in range(0, num_tiles): - cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] + cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] - cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) - result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope - + cur_quantized_nope = (input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) + result_k_nope_part[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_quantized_nope + result = result.view(num_blocks, block_size, 1, -1) return result @@ -55,14 +53,14 @@ def dequantize_k_cache( quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) input_nope = quant_k_cache[..., :dv] - input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32) - input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16) + input_scale = quant_k_cache[..., dv:dv + num_tiles * 4].view(torch.float32) + input_rope = quant_k_cache[..., dv + num_tiles * 4:].view(torch.bfloat16) result[..., dv:] = input_rope for tile_idx in range(0, num_tiles): - cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32) + cur_nope = input_nope[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].to(torch.float32) cur_scales = input_scale[..., tile_idx].unsqueeze(-1) - result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales - + result[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_nope * cur_scales + result = result.view(num_blocks, block_size, 1, d) return result diff --git a/tests/test_flash_mla_decoding.py b/tests/test_flash_mla_decoding.py index d6c1f81..dc140d7 100644 --- a/tests/test_flash_mla_decoding.py +++ b/tests/test_flash_mla_decoding.py @@ -2,20 +2,20 @@ import math import random import dataclasses -from typing import Optional, Tuple, List +from typing import Optional, Tuple import torch import triton -import quant import flash_mla +import quant from lib import cdiv, check_is_allclose @dataclasses.dataclass class TestParam: - b: int # Batch size - s_q: int # Number of queries for one request - s_k: int # Seq len, or mean seq len if varlen == True + b: int # Batch size + s_q: int # Number of queries for one request + s_k: int # Seq len, or mean seq len if varlen == True is_varlen: bool is_causal: bool is_fp8: bool @@ -24,8 +24,8 @@ class TestParam: is_all_indices_invalid: bool = False have_zero_seqlen_k: bool = False block_size: int = 64 - h_q: int = 128 # Number of q heads - h_kv: int = 1 # Number of kv heads + h_q: int = 128 # Number of q heads + h_kv: int = 1 # Number of kv heads d: int = 576 # Q/K head dim (= dv + RoPE dim) dv: int = 512 # V head dim seed: int = 0 @@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. cur_num_blocks = cdiv(cur_len, t.block_size) blocked_k[block_table[i][cur_num_blocks:]] = float("nan") if cur_len % t.block_size != 0: - blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan") + blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan") block_table[i][cur_num_blocks:] = 2147480000 return cache_seqlens, q, block_table, blocked_k, None, None else: @@ -82,12 +82,12 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. # Generate indices for j in range(t.s_q): cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk] - cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size) + cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size) if len(cur_abs_indices) < t.topk: pad_len = t.topk - len(cur_abs_indices) cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')]) cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')]) - + # Mask KV perm = torch.randperm(t.topk, device='cpu') cur_abs_indices = cur_abs_indices[perm] @@ -100,7 +100,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. abs_indices[i, j, :] = cur_abs_indices indices_in_kvcache[i, j, :] = cur_blocked_indices - + # Mask nonused KV as NaN all_indices = indices_in_kvcache.flatten().tolist() all_indices = list(set(all_indices)) @@ -109,11 +109,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu') blocked_k = blocked_k.view(-1, t.h_kv, t.d) - nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu') + nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu') nonused_indices_mask[all_indices] = False blocked_k[nonused_indices_mask, :, :] = float("nan") blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d) - + abs_indices = abs_indices.to(q.device) indices_in_kvcache = indices_in_kvcache.to(q.device) @@ -139,7 +139,7 @@ def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): valid_indices = cur_indices[cur_indices != -1] mask[i, valid_indices] = True return mask - + def scaled_dot_product_attention( batch_idx: int, query: torch.Tensor, # [h_q, s_q, d] @@ -157,7 +157,7 @@ def scaled_dot_product_attention( if h_kv != 1: kv = kv.repeat_interleave(h_q // h_kv, dim=0) kv[kv != kv] = 0.0 - attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] + attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] if (is_causal and query.size(1) > 1) or indices is not None: mask = torch.ones(s_q, s_k, dtype=torch.bool) if is_causal: @@ -169,14 +169,14 @@ def scaled_dot_product_attention( attn_bias.masked_fill_(mask.logical_not(), float("-inf")) attn_weight += attn_bias.to(q.dtype) attn_weight /= math.sqrt(query.size(-1)) - lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] + lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q] attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv] # Correct for q tokens which has no attendable k lonely_q_mask = (lse == float("-inf")) output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0 lse[lonely_q_mask] = float("+inf") - + return output, lse b, s_q, h_q, d = q.size() @@ -202,7 +202,7 @@ def scaled_dot_product_attention( lse_ref[i] = cur_lse out_ref = out_ref.to(torch.bfloat16) return out_ref, lse_ref - + @torch.inference_mode() def test_flash_mla(t: TestParam): @@ -235,7 +235,7 @@ def test_flash_mla(t: TestParam): def run_flash_mla(): return flash_mla.flash_mla_with_kvcache( q, - blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore + blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore block_table, cache_seqlens, t.dv, @@ -248,27 +248,27 @@ def run_flash_mla(): out_ans, lse_ans = run_flash_mla() out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) - assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6) - assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) + assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) + assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536) if t.test_performance: - time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore + time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk - compute_volume_flop = t.b*t.h_q*t.s_q*sum([ - 2*t.d*mean_attended_seqlens, # Q * K^T - 2*mean_attended_seqlens*t.dv, # attention * V + compute_volume_flop = t.b * t.h_q * t.s_q * sum([ + 2 * t.d * mean_attended_seqlens, # Q * K^T + 2 * mean_attended_seqlens * t.dv, # attention * V ]) q_elem_size = torch.bfloat16.itemsize - kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize - memory_volume_B = t.b*sum([ - t.s_q*t.h_q*(t.d*q_elem_size), # Q - (t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V - t.s_q*t.h_q*(t.dv*q_elem_size), # Output + kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize + memory_volume_B = t.b * sum([ + t.s_q * t.h_q * (t.d * q_elem_size), # Q + (t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V + t.s_q * t.h_q * (t.dv * q_elem_size), # Output ]) achieved_tflops = compute_volume_flop / time_usage / 1e12 achieved_gBps = memory_volume_B / time_usage / 1e9 - print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") + print(f"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s") def main(torch_dtype): @@ -324,7 +324,7 @@ def main(torch_dtype): cc_major, cc_minor = torch.cuda.get_device_capability() if cc_major == 10: testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)] - + for testcase in testcases: test_flash_mla(testcase) diff --git a/tests/test_flash_mla_prefill.py b/tests/test_flash_mla_prefill.py index 19a6dbe..d2f5b7e 100644 --- a/tests/test_flash_mla_prefill.py +++ b/tests/test_flash_mla_prefill.py @@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase: torch.manual_seed(t.seed) torch.cuda.manual_seed(t.seed) random.seed(t.seed) - q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 - kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10 + kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10 q.clamp_(-10, 10) kv.clamp_(-10, 10) @@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase: # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 cur_indices = torch.randperm(t.s_kv)[:t.topk] - cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),)) + cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),)) if len(cur_indices) < t.topk: cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)]) cur_indices = cur_indices[torch.randperm(t.topk)] @@ -72,9 +72,9 @@ def get_flop(p: TestParam) -> float: def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) - + assert p.b == 1 - indices = t.indices[0, :, 0, :] # [s_q, topk] + indices = t.indices[0, :, 0, :] # [s_q, topk] invalid_indices_mask = (indices < 0) | (indices >= p.s_kv) qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk] kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk] @@ -104,15 +104,15 @@ def run_ans(): return flash_mla_sparse_fwd( t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale ) - + ans_out, ans_max_logits, ans_lse = run_ans() torch.cuda.synchronize() if p.benchmark: flop = get_flop(p) - prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore - prefill_flops = flop/prefill_ans_time/1e12 - print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops") + prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore + prefill_flops = flop / prefill_ans_time / 1e12 + print(f"Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops") if p.check_correctness: torch.cuda.synchronize() @@ -120,9 +120,9 @@ def run_ans(): torch.cuda.synchronize() is_correct = True - is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6) - is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) - is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) + is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6) + is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536) + is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536) return is_correct else: @@ -187,11 +187,10 @@ def run_ans(): is_correct = run_test(test) if not is_correct: failed_cases.append(test) - + if len(failed_cases) > 0: print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") for case in failed_cases: print(f" {case}") else: print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") - diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py index 6b2ba45..62e3344 100644 --- a/tests/test_fmha_sm100.py +++ b/tests/test_fmha_sm100.py @@ -5,7 +5,6 @@ import triton from flash_mla import flash_attn_varlen_func - from lib import check_is_allclose def get_window_size(causal, window): @@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win causal, window) == 0).sum().item() for i in range(b)]) # print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}") - q = torch.randn(total_q, h, d)/10 - k = torch.randn(total_k, h_k, d)/10 - v = torch.randn(total_k, h_k, dv)/10 - grad_out = torch.randn(total_q, h, dv)/10 + q = torch.randn(total_q, h, d) / 10 + k = torch.randn(total_k, h_k, d) / 10 + v = torch.randn(total_k, h_k, dv) / 10 + grad_out = torch.randn(total_q, h, dv) / 10 softmax_scale = (d + 100) ** (-0.5) q1 = q.clone().requires_grad_() @@ -123,14 +122,14 @@ def torch_attn(): if check_correctness: out_torch, lse_torch = torch_attn() - assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) - assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01/65536) + assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) + assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536) if has_bwd: out_torch.backward(grad_out, retain_graph=True) - assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) - assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) - assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01/128, cos_diff_tol=7e-6) + assert check_is_allclose("dq", q1.grad, q2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) + assert check_is_allclose("dk", k1.grad, k2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) + assert check_is_allclose("dv", v1.grad, v2.grad, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) def forward(): return flash_attn() From 1408756a88e52a25196b759eaf8db89d2b51b5a1 Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Wed, 1 Oct 2025 00:04:36 +0800 Subject: [PATCH 20/24] Update README --- README.md | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index df021de..f08d888 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## Introduction -FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations: +FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2-Exp](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations: **Sparse Attention Kernels** @@ -19,7 +19,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee ## News - **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out [here](docs/20250929-hopper-fp8-sparse-deep-dive.md). -- **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell! +- **2025.08.01 Kernels for MHA on SM100**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on SM100! - **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). - **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀 @@ -31,9 +31,7 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee python tests/test_flash_mla_decoding.py ``` -The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8. - -For Blackwell GPUs, the token-level sparse MLA decoding kernel can achieve up to 350 TFlops (on B200) which is not really optimized yet. +The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet). #### Test & benchmark MHA prefill (Dense): @@ -49,22 +47,22 @@ It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation python tests/test_flash_mla_prefill.py ``` -It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9. +It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9. ## Requirements -- Hopper / Blackwell GPUs (See the support matrix below) -- CUDA 12.8 and above (CUDA 12.9+ is required for Blackwell kernels) +- SM90 / SM100 (See the support matrix below) +- CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels) - PyTorch 2.0 and above Support matrix: | Kernel | GPU Architecture | MLA Mode [2] | KVCache Format | | :---: | :---: | :---: | :---: | -| Dense Decoding | Hopper | MQA | BF16 | -| Sparse Decoding | Hopper & Blackwell | MQA | FP8 [1] | -| Dense Prefill | Blackwell | MHA | | -| Sparse Prefill | Hopper & Blackwell | MQA | | +| Dense Decoding | SM90 | MQA | BF16 | +| Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] | +| Dense Prefill | SM100 | MHA | | +| Sparse Prefill | SM90 & SM100 | MQA | | [1]: For more details on using FP8 KV cache, see documents below. From 082094b793fcc7452977d0a71a00e266a2e3061e Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Fri, 16 Jan 2026 17:04:20 +0800 Subject: [PATCH 21/24] Multiple updates and refactorings (#150) * Multiple updates and refactorings * Remove dead code --- .gitignore | 1 + README.md | 5 +- csrc/api/api.cpp | 15 + csrc/api/common.h | 229 ++++ csrc/api/dense_decode.h | 225 ++++ csrc/api/dense_fwd.h | 5 + csrc/api/sparse_decode.h | 495 ++++++++ csrc/api/sparse_fwd.h | 243 ++++ csrc/cutlass | 2 +- csrc/{sm100 => }/defines.h | 4 - .../kerutils/include/kerutils/common/common.h | 8 + .../kerutils/include/kerutils/device/common.h | 70 ++ .../include/kerutils/device/device.cuh | 13 + .../include/kerutils/device/sm100/gemm.cuh} | 308 ++++- .../include/kerutils/device/sm100/helpers.cuh | 137 ++ .../kerutils/device/sm100/intrinsics.cuh | 382 ++++++ .../device/sm100/tma_cta_group2_nosplit.cuh} | 27 +- .../include/kerutils/device/sm80/helpers.cuh | 55 + .../kerutils/device/sm80/intrinsics.cuh | 146 +++ .../include/kerutils/device/sm90/helpers.cuh | 110 ++ .../kerutils/device/sm90/intrinsics.cuh | 107 ++ csrc/kerutils/include/kerutils/host/host.h | 155 +++ csrc/kerutils/include/kerutils/kerutils.cuh | 4 + .../kerutils/supplemental/torch_tensors.h | 71 ++ csrc/params.h | 132 +- csrc/pybind.cpp | 472 ------- csrc/sm100/decode/head128/README.md | 1 + csrc/sm100/decode/head64/config.h | 212 ++++ .../decode/head64/instantiations/model1.cu | 8 + .../sm100/decode/head64/instantiations/v32.cu | 8 + csrc/sm100/decode/head64/kernel.cuh | 968 ++++++++++++++ csrc/sm100/decode/head64/kernel.h | 11 + csrc/sm100/decode/sparse_fp8/dequant.h | 61 - csrc/sm100/decode/sparse_fp8/splitkv_mla.cu | 592 --------- csrc/sm100/decode/sparse_fp8/splitkv_mla.h | 10 - csrc/sm100/helpers.h | 94 +- csrc/sm100/intrinsics.h | 461 ------- .../dense/kernel/fmha_kernel_bwd_convert.hpp | 4 +- .../dense/kernel/fmha_kernel_bwd_sum_OdO.hpp | 4 +- ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 27 +- ...mha_bwd_mla_kernel_tma_warpspecialized.hpp | 29 +- ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 4 +- csrc/sm100/prefill/sparse/common_subroutine.h | 208 ++++ csrc/sm100/prefill/sparse/fwd.h | 9 - .../sm100/prefill/sparse/fwd/head128/config.h | 140 +++ .../fwd/head128/instantiations/phase1_k512.cu | 8 + .../fwd/head128/instantiations/phase1_k576.cu | 8 + .../sparse/{fwd.cu => fwd/head128/phase1.cuh} | 311 ++--- .../sm100/prefill/sparse/fwd/head128/phase1.h | 10 + csrc/sm100/prefill/sparse/fwd/head64/config.h | 157 +++ .../fwd/head64/instantiations/phase1_k512.cu | 8 + .../fwd/head64/instantiations/phase1_k576.cu | 8 + .../prefill/sparse/fwd/head64/phase1.cuh | 673 ++++++++++ csrc/sm100/prefill/sparse/fwd/head64/phase1.h | 10 + .../fwd_for_small_topk/head128/config.h | 140 +++ .../instantiations/phase1_decode_k512.cu | 8 + .../instantiations/phase1_prefill_k512.cu | 8 + .../fwd_for_small_topk/head128/phase1.cuh | 1107 +++++++++++++++++ .../fwd_for_small_topk/head128/phase1.h | 10 + csrc/sm100/prefill/sparse/helpers.h | 104 -- csrc/sm100/prefill/sparse/intrinsics.h | 638 ---------- csrc/sm100/prefill/sparse/ws_gemm.h | 328 ----- csrc/sm90/decode/dense/instantiations/bf16.cu | 8 + csrc/sm90/decode/dense/instantiations/fp16.cu | 10 + .../dense/{splitkv_mla.cu => splitkv_mla.cuh} | 58 +- csrc/sm90/decode/dense/splitkv_mla.h | 2 +- .../decode/sparse_fp8/components/config.h | 110 +- .../decode/sparse_fp8/components/dequant.h | 57 +- .../decode/sparse_fp8/components/epilogue.h | 87 -- .../decode/sparse_fp8/components/helpers.h | 27 +- .../sparse_fp8/components/named_barriers.h | 10 - csrc/sm90/decode/sparse_fp8/config.h | 279 +++++ .../instantiations/model1_persistent_h128.cu | 7 + .../instantiations/model1_persistent_h64.cu | 8 + .../instantiations/v32_persistent_h128.cu | 7 + .../instantiations/v32_persistent_h64.cu | 7 + csrc/sm90/decode/sparse_fp8/splitkv_mla.cu | 614 --------- csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh | 787 ++++++++++++ csrc/sm90/decode/sparse_fp8/splitkv_mla.h | 6 +- csrc/sm90/{prefill/sparse => }/helpers.h | 21 +- csrc/sm90/prefill/sparse/config.h | 147 +++ csrc/sm90/prefill/sparse/fwd.cu | 711 +---------- csrc/sm90/prefill/sparse/fwd.h | 2 +- .../sparse/instantiations/phase1_k512.cu | 10 + .../instantiations/phase1_k512_topklen.cu | 10 + .../sparse/instantiations/phase1_k576.cu | 8 + .../instantiations/phase1_k576_topklen.cu | 8 + csrc/sm90/prefill/sparse/phase1.cuh | 646 ++++++++++ csrc/sm90/prefill/sparse/phase1.h | 10 + .../combine/combine.cu} | 154 +-- csrc/smxx/decode/combine/combine.h | 10 + .../get_decoding_sched_meta.cu} | 66 +- .../get_decoding_sched_meta.h | 9 + csrc/smxx/get_mla_metadata.h | 5 - csrc/smxx/mla_combine.h | 6 - csrc/utils.h | 50 +- flash_mla/__init__.py | 9 + flash_mla/flash_mla_interface.py | 211 +++- setup.py | 58 +- tests/kernelkit/.gitignore | 9 + tests/kernelkit/__init__.py | 11 + tests/kernelkit/bench.py | 205 +++ tests/kernelkit/compare.py | 95 ++ tests/kernelkit/generate.py | 25 + tests/kernelkit/precision.py | 30 + tests/kernelkit/utils.py | 50 + tests/lib.py | 456 ++++++- tests/quant.py | 160 ++- tests/ref.py | 103 ++ ...ng.py => test_flash_mla_dense_decoding.py} | 190 +-- tests/test_flash_mla_prefill.py | 196 --- tests/test_flash_mla_sparse_decoding.py | 319 +++++ tests/test_flash_mla_sparse_prefill.py | 180 +++ tests/test_fmha_sm100.py | 8 +- 114 files changed, 10785 insertions(+), 5295 deletions(-) create mode 100644 csrc/api/api.cpp create mode 100644 csrc/api/common.h create mode 100644 csrc/api/dense_decode.h create mode 100644 csrc/api/dense_fwd.h create mode 100644 csrc/api/sparse_decode.h create mode 100644 csrc/api/sparse_fwd.h rename csrc/{sm100 => }/defines.h (96%) create mode 100644 csrc/kerutils/include/kerutils/common/common.h create mode 100644 csrc/kerutils/include/kerutils/device/common.h create mode 100644 csrc/kerutils/include/kerutils/device/device.cuh rename csrc/{sm100/ws_gemm.h => kerutils/include/kerutils/device/sm100/gemm.cuh} (65%) create mode 100644 csrc/kerutils/include/kerutils/device/sm100/helpers.cuh create mode 100644 csrc/kerutils/include/kerutils/device/sm100/intrinsics.cuh rename csrc/{sm100/tma_cta_group2_nosplit.h => kerutils/include/kerutils/device/sm100/tma_cta_group2_nosplit.cuh} (95%) create mode 100644 csrc/kerutils/include/kerutils/device/sm80/helpers.cuh create mode 100644 csrc/kerutils/include/kerutils/device/sm80/intrinsics.cuh create mode 100644 csrc/kerutils/include/kerutils/device/sm90/helpers.cuh create mode 100644 csrc/kerutils/include/kerutils/device/sm90/intrinsics.cuh create mode 100644 csrc/kerutils/include/kerutils/host/host.h create mode 100644 csrc/kerutils/include/kerutils/kerutils.cuh create mode 100644 csrc/kerutils/include/kerutils/supplemental/torch_tensors.h delete mode 100644 csrc/pybind.cpp create mode 100644 csrc/sm100/decode/head128/README.md create mode 100644 csrc/sm100/decode/head64/config.h create mode 100644 csrc/sm100/decode/head64/instantiations/model1.cu create mode 100644 csrc/sm100/decode/head64/instantiations/v32.cu create mode 100644 csrc/sm100/decode/head64/kernel.cuh create mode 100644 csrc/sm100/decode/head64/kernel.h delete mode 100644 csrc/sm100/decode/sparse_fp8/dequant.h delete mode 100644 csrc/sm100/decode/sparse_fp8/splitkv_mla.cu delete mode 100644 csrc/sm100/decode/sparse_fp8/splitkv_mla.h delete mode 100644 csrc/sm100/intrinsics.h create mode 100644 csrc/sm100/prefill/sparse/common_subroutine.h delete mode 100644 csrc/sm100/prefill/sparse/fwd.h create mode 100644 csrc/sm100/prefill/sparse/fwd/head128/config.h create mode 100644 csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu create mode 100644 csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu rename csrc/sm100/prefill/sparse/{fwd.cu => fwd/head128/phase1.cuh} (72%) create mode 100644 csrc/sm100/prefill/sparse/fwd/head128/phase1.h create mode 100644 csrc/sm100/prefill/sparse/fwd/head64/config.h create mode 100644 csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu create mode 100644 csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu create mode 100644 csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh create mode 100644 csrc/sm100/prefill/sparse/fwd/head64/phase1.h create mode 100644 csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/config.h create mode 100644 csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu create mode 100644 csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu create mode 100644 csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh create mode 100644 csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h delete mode 100644 csrc/sm100/prefill/sparse/helpers.h delete mode 100644 csrc/sm100/prefill/sparse/intrinsics.h delete mode 100644 csrc/sm100/prefill/sparse/ws_gemm.h create mode 100644 csrc/sm90/decode/dense/instantiations/bf16.cu create mode 100644 csrc/sm90/decode/dense/instantiations/fp16.cu rename csrc/sm90/decode/dense/{splitkv_mla.cu => splitkv_mla.cuh} (95%) delete mode 100644 csrc/sm90/decode/sparse_fp8/components/epilogue.h delete mode 100644 csrc/sm90/decode/sparse_fp8/components/named_barriers.h create mode 100644 csrc/sm90/decode/sparse_fp8/config.h create mode 100644 csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu create mode 100644 csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu create mode 100644 csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu create mode 100644 csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu delete mode 100644 csrc/sm90/decode/sparse_fp8/splitkv_mla.cu create mode 100644 csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh rename csrc/sm90/{prefill/sparse => }/helpers.h (90%) create mode 100644 csrc/sm90/prefill/sparse/config.h create mode 100644 csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu create mode 100644 csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu create mode 100644 csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu create mode 100644 csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu create mode 100644 csrc/sm90/prefill/sparse/phase1.cuh create mode 100644 csrc/sm90/prefill/sparse/phase1.h rename csrc/smxx/{mla_combine.cu => decode/combine/combine.cu} (55%) create mode 100644 csrc/smxx/decode/combine/combine.h rename csrc/smxx/{get_mla_metadata.cu => decode/get_decoding_sched_meta/get_decoding_sched_meta.cu} (55%) create mode 100644 csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h delete mode 100644 csrc/smxx/get_mla_metadata.h delete mode 100644 csrc/smxx/mla_combine.h create mode 100644 tests/kernelkit/.gitignore create mode 100644 tests/kernelkit/__init__.py create mode 100644 tests/kernelkit/bench.py create mode 100644 tests/kernelkit/compare.py create mode 100644 tests/kernelkit/generate.py create mode 100644 tests/kernelkit/precision.py create mode 100644 tests/kernelkit/utils.py create mode 100644 tests/ref.py rename tests/{test_flash_mla_decoding.py => test_flash_mla_dense_decoding.py} (50%) delete mode 100644 tests/test_flash_mla_prefill.py create mode 100644 tests/test_flash_mla_sparse_decoding.py create mode 100644 tests/test_flash_mla_sparse_prefill.py diff --git a/.gitignore b/.gitignore index 6b00da7..036d436 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ dist/ compile_commands.json .cache /dev +/.clangd diff --git a/README.md b/README.md index f08d888..1945725 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,8 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee #### Test & benchmark MLA decoding (Sparse & Dense): ```bash -python tests/test_flash_mla_decoding.py +python tests/test_flash_mla_dense_decoding.py +python tests/test_flash_mla_sparse_decoding.py ``` The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet). @@ -44,7 +45,7 @@ It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation #### Test & benchmark MLA prefill (Sparse): ```bash -python tests/test_flash_mla_prefill.py +python tests/test_flash_mla_sparse_prefill.py ``` It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9. diff --git a/csrc/api/api.cpp b/csrc/api/api.cpp new file mode 100644 index 0000000..f43f2a0 --- /dev/null +++ b/csrc/api/api.cpp @@ -0,0 +1,15 @@ +#include + +#include "sparse_fwd.h" +#include "sparse_decode.h" +#include "dense_decode.h" +#include "dense_fwd.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashMLA"; + m.def("sparse_decode_fwd", &sparse_attn_decode_interface); + m.def("dense_decode_fwd", &dense_attn_decode_interface); + m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface); + m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun); + m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun); +} diff --git a/csrc/api/common.h b/csrc/api/common.h new file mode 100644 index 0000000..6beeab4 --- /dev/null +++ b/csrc/api/common.h @@ -0,0 +1,229 @@ +#pragma once + +#include +#include +#include +#include + +#include + +static constexpr float LOG_2_E = 1.44269504f; + +// Instantiation for tensor.data_ptr() +template<> +inline cutlass::bfloat16_t* at::TensorBase::data_ptr() const { + return reinterpret_cast(this->data_ptr()); +} + +// A struct that holds the architecture information of the current GPU. +struct Arch { + int major; + int minor; + int num_sms; + cudaDeviceProp* device_prop; + + Arch() { + device_prop = at::cuda::getCurrentDeviceProperties(); + major = device_prop->major; + minor = device_prop->minor; + num_sms = device_prop->multiProcessorCount; + } + + bool is_sm90a() const { + return major == 9 && minor == 0; + } + + bool is_sm100f() const { + return major == 10; + } +}; + +// Convert int64_t stride to int32_t, with overflow check. +inline int int64_stride_to_int(int64_t orig_stride) { + if (orig_stride > std::numeric_limits::max()) { + TORCH_CHECK(false, "[FlashMLA] Stride exceeds int32 limit: ", orig_stride); + } + return static_cast(orig_stride); +} + +#define DISPATCH_NUM_HEADS(NUM_HEADS, CONSTEXPR_NAME, ...) \ + [&] () { \ + if (NUM_HEADS == 128) { \ + static constexpr int CONSTEXPR_NAME = 128; \ + return __VA_ARGS__(); \ + } else if (NUM_HEADS == 64) { \ + static constexpr int CONSTEXPR_NAME = 64; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \ + } \ + } (); + +#define DISPATCH_HEAD_DIM(HEAD_DIM, CONSTEXPR_NAME, ...) \ +[&] () { \ + if (HEAD_DIM == 576) { \ + static constexpr int CONSTEXPR_NAME = 576; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM == 512) { \ + static constexpr int CONSTEXPR_NAME = 512; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported head_dim_qk: ", HEAD_DIM); \ + } \ +} (); + +#define DISPATCH_BOOLEAN_FLAG(FLAG, CONSTEXPR_NAME, ...) \ + [&] () { \ + if (FLAG) { \ + static constexpr bool CONSTEXPR_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONSTEXPR_NAME = false; \ + return __VA_ARGS__(); \ + } \ + } (); + +#define DISPATCH_MODEL_TYPE(MODEL_TYPE, CONSTEXPR_NAME, ...) \ +[&] () { \ + if (MODEL_TYPE == ModelType::V32) { \ + static constexpr ModelType CONSTEXPR_NAME = ModelType::V32; \ + return __VA_ARGS__(); \ + } else if (MODEL_TYPE == ModelType::MODEL1) { \ + static constexpr ModelType CONSTEXPR_NAME = ModelType::MODEL1; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported model type: ", (int)MODEL_TYPE); \ + } \ +} (); + +// The following code is adapted from https://ykiko.me/en/articles/680412313/, which converts enum values to string names. +template +constexpr auto get_static_enum_name(){ + std::string_view name; +#if __GNUC__ || __clang__ + name = __PRETTY_FUNCTION__; + std::size_t start = name.find('=') + 2; + std::size_t end = name.size() - 1; + name = std::string_view{ name.data() + start, end - start }; + start = name.find("::"); +#elif _MSC_VER + name = __FUNCSIG__; + std::size_t start = name.find('<') + 1; + std::size_t end = name.rfind(">("); + name = std::string_view{ name.data() + start, end - start }; + start = name.rfind("::"); +#endif + return start == std::string_view::npos ? name : std::string_view { + name.data() + start + 2, name.size() - start - 2 + }; +} + +template +static constexpr std::size_t get_enum_max(){ + constexpr T value = static_cast(N); + if constexpr (get_static_enum_name().find(")") == std::string_view::npos) + return get_enum_max(); + else + return N; +} + +template requires std::is_enum_v +static constexpr std::string get_dynamic_enum_name(T value){ + constexpr std::size_t num = get_enum_max(); + constexpr auto names = [](std::index_sequence){ + return std::array{ + get_static_enum_name(Is)>()... + }; + }(std::make_index_sequence{}); + return (std::string)names[static_cast(value)]; +} + +// A shortcut macro to declare supported features in an implementation class. +#define DECLARE_SUPPORTED_FEATURES(...) \ +protected: \ + static constexpr FeatureT features[] = { __VA_ARGS__ }; \ + constexpr inline std::span get_supported_features() const override { \ + return features; \ + } + +/* +ImplBase - The base class for every implementation. + +Every implementation should inherit from this class and implement the pure virtual functions, including: +- `run_`: The function that runs the implementation. +- `get_supported_features`: The function that returns the supported features of the implementation. You may use `DECLARE_SUPPORTED_FEATURES` to declare the supported features in a concise way. + +The dispatcher will invoke `ImplBase::run()`, which checks if all required features are supported by the implementation, and then calls `run_`. +*/ +template< + typename RunArgT_, + typename FeatureT_ +> +class ImplBase { +protected: + using RunArgT = RunArgT_; + using FeatureT = FeatureT_; + + virtual inline void run_(const RunArgT ¶ms, const std::vector &required_features) = 0; + + constexpr virtual inline std::span get_supported_features() const = 0; + + virtual ~ImplBase() = default; + +public: + inline bool check_if_all_features_are_supported(const std::vector &required_features) { + for (const auto &required_feature : required_features) { + bool is_supported = false; + for (const auto &supported_feature : get_supported_features()) { + if (required_feature == supported_feature) { + is_supported = true; + break; + } + } + if (!is_supported) { + return false; + } + } + return true; + } + + inline void check_if_all_features_are_supported_and_abort(const std::vector &required_features) { + if (!check_if_all_features_are_supported(required_features)) { + fprintf(stderr, "[FlashMLA] Error: The chosen implementation does not support all required features.\n"); + fprintf(stderr, "Required features:\n"); + for (const auto &f : required_features) { + fprintf(stderr, " - %3d: %s\n", static_cast(f), get_dynamic_enum_name(f).c_str()); + } + fprintf(stderr, "\n"); + fprintf(stderr, "Supported features:\n"); + for (const auto &supported_feature : get_supported_features()) { + fprintf(stderr, " - %3d: %s\n", static_cast(supported_feature), get_dynamic_enum_name(supported_feature).c_str()); + } + fprintf(stderr, "\n"); + fprintf(stderr, "Features that are required but not supported:\n"); + for (const auto &required_feature : required_features) { + bool is_supported = false; + for (const auto &supported_feature : get_supported_features()) { + if (required_feature == supported_feature) { + is_supported = true; + break; + } + } + if (!is_supported) { + fprintf(stderr, " - %3d: %s\n", static_cast(required_feature), get_dynamic_enum_name(required_feature).c_str()); + } + } + fprintf(stderr, "\n"); + Arch cur_gpu_arch = Arch(); + fprintf(stderr, "Current GPU: %s, SM %d.%d with %d SMs\n", cur_gpu_arch.device_prop->name, cur_gpu_arch.major, cur_gpu_arch.minor, cur_gpu_arch.num_sms); + fprintf(stderr, "This means that the dispatcher has chosen an implementation that does not support all required features. Maybe there is a bug in the dispatcher, or you have requested an invalid combination of features.\n"); + TORCH_CHECK(false, "The chosen implementation does not support all required features. See message above for details."); + } + } + + inline void run(const RunArgT ¶ms, const std::vector &required_features) { + check_if_all_features_are_supported_and_abort(required_features); + run_(params, required_features); + } +}; + diff --git a/csrc/api/dense_decode.h b/csrc/api/dense_decode.h new file mode 100644 index 0000000..7df178a --- /dev/null +++ b/csrc/api/dense_decode.h @@ -0,0 +1,225 @@ +#pragma once + +#include +#include + +#include "common.h" +#include "params.h" + +#include "sm90/decode/dense/splitkv_mla.h" +#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" +#include "smxx/decode/combine/combine.h" + +static std::tuple, std::optional> +dense_attn_decode_interface( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) + const int head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const float softmax_scale, + bool is_causal, + std::optional &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4) + std::optional &num_splits // batch_size + 1 +) { + // Check arch + Arch arch = Arch(); + if (!arch.is_sm90a()) { + TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture"); + } + + // Check data types + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); + + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + + // Check device + KU_CHECK_DEVICE(q); + KU_CHECK_DEVICE(kcache); + KU_CHECK_DEVICE(seqlens_k); + KU_CHECK_DEVICE(block_table); + KU_CHECK_DEVICE(tile_scheduler_metadata); + KU_CHECK_DEVICE(num_splits); + + // Check layout + TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension"); + KU_CHECK_CONTIGUOUS(seqlens_k); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); + KU_CHECK_CONTIGUOUS(num_splits); + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q_ori = sizes[1]; + const int num_heads_q = sizes[2]; + const int head_size_k = sizes[3]; + TORCH_CHECK(head_size_k == 576 || head_size_k == 512, "Only head_size_k == 576 or 512 is supported"); + TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); + + const int max_num_blocks_per_seq = block_table.size(1); + const int num_blocks = kcache.size(0); + const int page_block_size = kcache.size(1); + const int num_heads_k = kcache.size(2); + TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q_ori == 1) { is_causal = false; } + + const int num_q_heads_per_hk = num_heads_q / num_heads_k; + const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; + const int num_heads = num_heads_k; + q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) + .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); + int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 64), 1); + + KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); + KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); + KU_CHECK_SHAPE(seqlens_k, batch_size); + KU_CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, DecodingSchedMetaSize/sizeof(int)); + KU_CHECK_SHAPE(num_splits, batch_size+1); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts); + at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + KU_CHECK_CONTIGUOUS(out); + KU_CHECK_CONTIGUOUS(lse); + + if (!tile_scheduler_metadata.has_value()) { + tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32)); + num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32)); + KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); + KU_CHECK_CONTIGUOUS(num_splits); + + GetDecodeSchedMetaParams get_sched_meta_params = { + batch_size, seqlen_q_ori, + 64, + 5, + -1, -1, + nullptr, nullptr, + seqlens_k.data_ptr(), + (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(), + num_splits->data_ptr(), + num_sm_parts, + at::cuda::getCurrentCUDAStream().stream() + }; + smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params); + } else { + KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32); + KU_CHECK_DTYPE(num_splits, torch::kInt32); + KU_CHECK_DEVICE(tile_scheduler_metadata); + KU_CHECK_DEVICE(num_splits); + KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); + KU_CHECK_CONTIGUOUS(num_splits); + KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int)); + KU_CHECK_SHAPE(num_splits, batch_size+1); + } + + // Set the sizes + DenseAttnDecodeParams params; + params.b = batch_size; + params.s_q = seqlen_q_ori; + params.q_seq_per_hk = q_seq_per_hk; + params.seqlens_k_ptr = seqlens_k.data_ptr(); + params.h_q = num_heads_q; + params.h_k = num_heads_k; + params.num_blocks = num_blocks; + params.q_head_per_hk = num_q_heads_per_hk; + params.is_causal = is_causal; + params.d = head_size_k; + params.d_v = head_size_v; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = kcache.data_ptr(); + params.o_ptr = out.data_ptr(); + params.softmax_lse_ptr = lse.data_ptr(); + // All stride are in elements, not bytes. + params.q_batch_stride = q.stride(0); + params.k_batch_stride = kcache.stride(0); + params.o_batch_stride = out.stride(0); + params.q_row_stride = q.stride(1); + params.k_row_stride = kcache.stride(1); + params.o_row_stride = out.stride(2); + params.q_head_stride = q.stride(2); + params.k_head_stride = kcache.stride(2); + params.o_head_stride = out.stride(1); + + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.page_block_size = page_block_size; + + params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(); + params.num_sm_parts = num_sm_parts; + params.num_splits_ptr = num_splits->data_ptr(); + + const int total_num_splits = batch_size + params.num_sm_parts; + at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); + KU_CHECK_CONTIGUOUS(lse_accum); + KU_CHECK_CONTIGUOUS(out_accum); + params.total_num_splits = total_num_splits; + params.softmax_lseaccum_ptr = lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + + params.stream = at::cuda::getCurrentCUDAStream().stream(); + + if (q_dtype == torch::kBFloat16) { + sm90::run_flash_splitkv_mla_kernel(params); + } else if (q_dtype == torch::kHalf) { +#ifdef FLASH_MLA_DISABLE_FP16 + TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); +#else + sm90::run_flash_splitkv_mla_kernel(params); +#endif + } else { + TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90"); + } + + CombineParams combine_params = { + batch_size, seqlen_q_ori, + num_heads_q, head_size_v, + + params.softmax_lse_ptr, + params.o_ptr, + num_heads*q_seq_per_hk, num_heads_q, + num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v, + + params.softmax_lseaccum_ptr, + params.oaccum_ptr, + num_heads*q_seq_per_hk, num_heads_q, + num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v, + + params.tile_scheduler_metadata_ptr, + params.num_splits_ptr, + params.num_sm_parts, + + nullptr, + at::cuda::getCurrentCUDAStream().stream() + }; + + if (q_dtype == torch::kBFloat16) { + smxx::decode::run_flash_mla_combine_kernel(combine_params); + } else if (q_dtype == torch::kHalf) { +#ifndef FLASH_MLA_DISABLE_FP16 + smxx::decode::run_flash_mla_combine_kernel(combine_params); +#endif + } else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } + + out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2) + .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); + lse = lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) + .reshape({batch_size, num_heads_q, seqlen_q_ori}); + + return {out, lse, tile_scheduler_metadata, num_splits}; +} diff --git a/csrc/api/dense_fwd.h b/csrc/api/dense_fwd.h new file mode 100644 index 0000000..c5a4acf --- /dev/null +++ b/csrc/api/dense_fwd.h @@ -0,0 +1,5 @@ +#pragma once + +#include "common.h" + +#include "sm100/prefill/dense/interface.h" diff --git a/csrc/api/sparse_decode.h b/csrc/api/sparse_decode.h new file mode 100644 index 0000000..6df5c84 --- /dev/null +++ b/csrc/api/sparse_decode.h @@ -0,0 +1,495 @@ +#pragma once + +#include "common.h" + +#include "params.h" + +#include "sm90/decode/sparse_fp8/splitkv_mla.h" +#include "sm100/decode/head64/kernel.h" +#include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h" +#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" +#include "smxx/decode/combine/combine.h" + +// Feature set of sparse decoding kernels +enum class DecodeFeatures : int { + HEAD_64, + HEAD_128, + + HEAD_DIM_576, + HEAD_DIM_512, + + V32_KVCACHE_FORMAT, + MODEL1_KVCACHE_FORMAT, + + ATTN_SINK, + TOPK_LENGTH, + EXTRA_KVCACHE, + EXTRA_TOPK_LENGTH +}; + +struct DecodeImplMeta { + int num_sm_parts; + int fixed_overhead_num_blocks; + int block_size_topk; +}; + +class DecodeImplBase : public ImplBase< + SparseAttnDecodeParams, + DecodeFeatures +> { +public: + virtual DecodeImplMeta get_meta(int h_q, int s_q) = 0; +}; + +class Decode_Sm90_Impl : public DecodeImplBase { + DECLARE_SUPPORTED_FEATURES( + DecodeFeatures::HEAD_64, + DecodeFeatures::HEAD_128, + DecodeFeatures::HEAD_DIM_512, + DecodeFeatures::HEAD_DIM_576, + DecodeFeatures::V32_KVCACHE_FORMAT, + DecodeFeatures::MODEL1_KVCACHE_FORMAT, + DecodeFeatures::ATTN_SINK, + DecodeFeatures::TOPK_LENGTH, + DecodeFeatures::EXTRA_KVCACHE, + DecodeFeatures::EXTRA_TOPK_LENGTH + ) + +public: + DecodeImplMeta get_meta(int h_q, int s_q) override { + Arch arch = Arch(); + return { + std::max(arch.num_sms / s_q / (h_q/64), 1), + 5, + 64 + }; + } + +protected: + void run_(const SparseAttnDecodeParams ¶ms, const std::vector &required_features) override { + DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() { + DISPATCH_NUM_HEADS(params.h_q, NUM_HEADS, [&]() { + sm90::decode::sparse_fp8::run_flash_splitkv_mla_fp8_sparse_kernel(params); + }); + }); + } +}; + +class Decode_Sm100_Head64_Impl : public DecodeImplBase { + DECLARE_SUPPORTED_FEATURES( + DecodeFeatures::HEAD_64, + DecodeFeatures::HEAD_DIM_512, + DecodeFeatures::HEAD_DIM_576, + DecodeFeatures::V32_KVCACHE_FORMAT, + DecodeFeatures::MODEL1_KVCACHE_FORMAT, + DecodeFeatures::ATTN_SINK, + DecodeFeatures::TOPK_LENGTH, + DecodeFeatures::EXTRA_KVCACHE, + DecodeFeatures::EXTRA_TOPK_LENGTH + ) + +public: + DecodeImplMeta get_meta(int h_q, int s_q) override { + Arch arch = Arch(); + return { + std::max(arch.num_sms / s_q, 1), + 5, + 64 + }; + } + +protected: + void run_(const SparseAttnDecodeParams ¶ms, const std::vector &required_features) override { + DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() { + sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel(params); + }); + } +}; + + +// An implementation that calls the head64 kernel twice to process head128 +// Necessary for running V3.2 shape (i.e. h = 128, d_qk = 576) on SM100f +class Decode_Sm100_Head64x2_Impl : public DecodeImplBase { + DECLARE_SUPPORTED_FEATURES( + DecodeFeatures::HEAD_128, + DecodeFeatures::HEAD_DIM_512, + DecodeFeatures::HEAD_DIM_576, + DecodeFeatures::V32_KVCACHE_FORMAT, + DecodeFeatures::MODEL1_KVCACHE_FORMAT, + DecodeFeatures::ATTN_SINK, + DecodeFeatures::TOPK_LENGTH, + DecodeFeatures::EXTRA_KVCACHE, + DecodeFeatures::EXTRA_TOPK_LENGTH + ) + +public: + DecodeImplMeta get_meta(int h_q, int s_q) override { + Arch arch = Arch(); + return { + std::max(arch.num_sms / s_q, 1), + 5, + 64 + }; + } + +protected: + void run_(const SparseAttnDecodeParams ¶ms, const std::vector &required_features) override { + DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() { + for (int start_head_idx = 0; start_head_idx < 128; start_head_idx += 64) { + SparseAttnDecodeParams cur_params = params; + cur_params.q += start_head_idx * params.stride_q_h_q; + if (cur_params.attn_sink) { + cur_params.attn_sink += start_head_idx; + } + cur_params.lse += start_head_idx; + cur_params.out += start_head_idx * params.stride_o_h_q; + cur_params.lse_accum += start_head_idx; + cur_params.o_accum += start_head_idx * params.stride_o_accum_h_q; + cur_params.h_q = 64; + sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel(cur_params); + } + }); + } +}; + + +class Decode_Sm100_Head128_Impl : public DecodeImplBase { + DECLARE_SUPPORTED_FEATURES( + DecodeFeatures::HEAD_128, + DecodeFeatures::HEAD_DIM_512, + DecodeFeatures::MODEL1_KVCACHE_FORMAT, + DecodeFeatures::ATTN_SINK, + DecodeFeatures::TOPK_LENGTH, + DecodeFeatures::EXTRA_KVCACHE, + DecodeFeatures::EXTRA_TOPK_LENGTH + ) + +public: + DecodeImplMeta get_meta(int h_q, int s_q) override { + Arch arch = Arch(); + return { + std::max(arch.num_sms / s_q / 2, 1), + 3, + 64 + }; + } + +protected: + void run_(const SparseAttnDecodeParams ¶ms, const std::vector &required_features) override { + sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel(params); + } +}; + +static std::tuple, std::optional> +sparse_attn_decode_interface( + const at::Tensor &q, // [b, s_q, h_q, d_qk] + const at::Tensor &kv, // [num_blocks, page_block_size, h_k, d_qk] + const at::Tensor &indices, // [b, s_q, topk] + const std::optional &topk_length, // [b, s_q] + const std::optional &attn_sink, // [h_q] + std::optional &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4) + std::optional &num_splits, // batch_size + 1 + const std::optional &extra_kv, + const std::optional &extra_indices, + const std::optional &extra_topk_length, + int d_v, + float sm_scale +) { + using bf16 = cutlass::bfloat16_t; + + // Check the architecture + Arch arch = Arch(); + + KU_CHECK_NDIM(q, 4); + KU_CHECK_NDIM(kv, 4); + KU_CHECK_NDIM(indices, 3); + + int b = q.size(0); + int s_q = q.size(1); + int h_q = q.size(2); + int d_qk = q.size(3); + int num_blocks = kv.size(0); + int page_block_size = kv.size(1); + int h_kv = kv.size(2); + int topk = indices.size(2); + + bool have_topk_length = topk_length.has_value(); + bool have_extra_kcache = extra_kv.has_value(); + bool have_extra_topk_length = extra_topk_length.has_value(); + bool have_attn_sink = attn_sink.has_value(); + + int extra_num_blocks = 0, extra_page_block_size = 0, extra_topk = 0; + if (have_extra_kcache) { + extra_num_blocks = extra_kv->size(0); + extra_page_block_size = extra_kv->size(1); + } + if (extra_indices.has_value()) { + extra_topk = extra_indices->size(-1); + } + + // metadata sanity check + TORCH_CHECK(b > 0); + TORCH_CHECK(s_q > 0); + TORCH_CHECK(h_q > 0); + TORCH_CHECK(h_kv == 1, "Currently only MQA (i.e. h_kv == 1) is supported for sparse decoding"); + TORCH_CHECK(d_qk == 576 || d_qk == 512, "Only head_size_k == 576 or 512 is supported for sparse decoding"); + TORCH_CHECK(d_v == 512, "Only head_size_v == 512 is supported for sparse decoding"); + TORCH_CHECK(topk > 0); + + if (have_extra_kcache) { + TORCH_CHECK(extra_indices.has_value(), "extra_indices_in_kvcache must be provided when extra_kcache is provided for sparse attention"); + } else { + TORCH_CHECK(!extra_indices.has_value(), "extra_indices_in_kvcache must not be provided when extra_k_cache is not provided"); + TORCH_CHECK(!extra_topk_length.has_value(), "extra_topk_length must not be provided when extra_k_cache is not provided"); + } + + // Check device + KU_CHECK_DEVICE(q); + KU_CHECK_DEVICE(kv); + KU_CHECK_DEVICE(indices); + KU_CHECK_DEVICE(topk_length); + KU_CHECK_DEVICE(attn_sink); + KU_CHECK_DEVICE(tile_scheduler_metadata); + KU_CHECK_DEVICE(num_splits); + KU_CHECK_DEVICE(extra_kv); + KU_CHECK_DEVICE(extra_indices); + KU_CHECK_DEVICE(extra_topk_length); + + // Check data type + KU_CHECK_DTYPE(q, torch::kBFloat16); + TORCH_CHECK(kv.dtype() == torch::kFloat8_e4m3fn || kv.dtype() == torch::kInt8 || kv.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn, int8 or uint8"); + if (extra_kv.has_value()) { + TORCH_CHECK(extra_kv->dtype() == torch::kFloat8_e4m3fn || extra_kv->dtype() == torch::kInt8 || extra_kv->dtype() == torch::kUInt8, "extra k cache must have dtype fp8_e4m3fn, int8 or uint8"); + } + KU_CHECK_DTYPE(indices, torch::kInt32); + KU_CHECK_DTYPE(topk_length, torch::kInt32); + KU_CHECK_DTYPE(attn_sink, torch::kFloat32); + KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32); + KU_CHECK_DTYPE(num_splits, torch::kInt32); + KU_CHECK_DTYPE(extra_indices, torch::kInt32); + KU_CHECK_DTYPE(extra_topk_length, torch::kInt32); + + // Check layout + KU_CHECK_LAST_DIM_CONTIGUOUS(q); + KU_CHECK_LAST_DIM_CONTIGUOUS(kv); + KU_CHECK_LAST_DIM_CONTIGUOUS(indices); + KU_CHECK_CONTIGUOUS(topk_length); + KU_CHECK_CONTIGUOUS(attn_sink); + + KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); + KU_CHECK_CONTIGUOUS(num_splits); + + KU_CHECK_LAST_DIM_CONTIGUOUS(extra_kv); + KU_CHECK_LAST_DIM_CONTIGUOUS(extra_indices); + KU_CHECK_CONTIGUOUS(extra_topk_length); + + // Check shape + KU_CHECK_SHAPE(q, b, s_q, h_q, d_qk); + { + int bytes_per_token; + if (d_qk == 576 && d_v == 512) { + // V3.2 style + bytes_per_token = 512 + 64*2 + (512/128)*4; + } else if (d_qk == 512 && d_v == 512) { + // MODEL1 style + bytes_per_token = 448 + 64*2 + (448/64)*1 + 1; + } else { + TORCH_CHECK(false, "Unsupported head sizes for is_fp8_kvcache == True"); + } + KU_CHECK_SHAPE(kv, num_blocks, page_block_size, h_kv, bytes_per_token); + KU_CHECK_SHAPE(extra_kv, extra_num_blocks, extra_page_block_size, h_kv, bytes_per_token); + TORCH_CHECK(kv.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True for kv cache"); + if (extra_kv.has_value()) { + TORCH_CHECK(extra_kv->stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True for extra kv cache"); + } + } + KU_CHECK_SHAPE(indices, b, s_q, topk); + KU_CHECK_SHAPE(topk_length, b); + KU_CHECK_SHAPE(attn_sink, h_q); + KU_CHECK_SHAPE(extra_indices, b, s_q, extra_topk); + KU_CHECK_SHAPE(extra_topk_length, b); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto opts = q.options(); + + at::Tensor out = torch::empty({b, s_q, h_q, d_v}, opts); + at::Tensor lse = torch::empty({b, s_q, h_q}, opts.dtype(at::kFloat)); + + ModelType model_type; + if (d_qk == 576) { + model_type = ModelType::V32; + } else if (d_qk == 512) { + model_type = ModelType::MODEL1; + } else { + TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); + } + + std::vector features; + if (h_q == 64) { + features.push_back(DecodeFeatures::HEAD_64); + } else if (h_q == 128) { + features.push_back(DecodeFeatures::HEAD_128); + } else { + TORCH_CHECK(false, "Unsupported h_q: ", h_q); + } + if (d_qk == 576) { + features.push_back(DecodeFeatures::HEAD_DIM_576); + } else if (d_qk == 512) { + features.push_back(DecodeFeatures::HEAD_DIM_512); + } else { + TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); + } + if (model_type == ModelType::V32) { + features.push_back(DecodeFeatures::V32_KVCACHE_FORMAT); + } else if (model_type == ModelType::MODEL1) { + features.push_back(DecodeFeatures::MODEL1_KVCACHE_FORMAT); + } else { + TORCH_CHECK(false, "Unsupported model type: ", (int)model_type); + } + if (have_attn_sink) { + features.push_back(DecodeFeatures::ATTN_SINK); + } + if (have_topk_length) { + features.push_back(DecodeFeatures::TOPK_LENGTH); + } + if (have_extra_kcache) { + features.push_back(DecodeFeatures::EXTRA_KVCACHE); + } + if (have_extra_topk_length) { + features.push_back(DecodeFeatures::EXTRA_TOPK_LENGTH); + } + + DecodeImplBase* impl; + if (arch.is_sm100f()) { + if (h_q == 64) { + impl = new Decode_Sm100_Head64_Impl(); + } else if (h_q == 128) { + if (d_qk == 576) { + impl = new Decode_Sm100_Head64x2_Impl(); + } else if (d_qk == 512) { + impl = new Decode_Sm100_Head128_Impl(); + } else { + TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); + } + } else { + TORCH_CHECK(false, "Unsupported h_q: ", h_q); + } + } else if (arch.is_sm90a()) { + impl = new Decode_Sm90_Impl(); + } else { + TORCH_CHECK(false, "Unsupported architecture for sparse decode fwd"); + } + + DecodeImplMeta impl_meta = impl->get_meta(h_q, s_q); + + SparseAttnDecodeParams params = { + b, s_q, h_q, h_kv, d_qk, d_v, + sm_scale, sm_scale * LOG_2_E, + num_blocks, page_block_size, topk, + model_type, + + (bf16*)q.data_ptr(), + (bf16*)kv.data_ptr(), + (int*)indices.data_ptr(), + ku::get_optional_tensor_ptr(topk_length), + ku::get_optional_tensor_ptr(attn_sink), + (float*)lse.data_ptr(), + (bf16*)out.data_ptr(), + + extra_num_blocks, extra_page_block_size, extra_topk, + ku::get_optional_tensor_ptr(extra_kv), + ku::get_optional_tensor_ptr(extra_indices), + ku::get_optional_tensor_ptr(extra_topk_length), + + int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), int64_stride_to_int(q.stride(2)), + int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)), + int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)), + int64_stride_to_int(lse.stride(0)), int64_stride_to_int(lse.stride(1)), + int64_stride_to_int(out.stride(0)), int64_stride_to_int(out.stride(1)), int64_stride_to_int(out.stride(2)), + + have_extra_kcache ? int64_stride_to_int(extra_kv->stride(0)) : 0, + have_extra_kcache ? int64_stride_to_int(extra_kv->stride(1)) : 0, + have_extra_kcache ? int64_stride_to_int(extra_indices->stride(0)) : 0, + have_extra_kcache ? int64_stride_to_int(extra_indices->stride(1)) : 0, + at::cuda::getCurrentCUDAStream().stream() + }; + + // Get MLA metadata if necessary + at::Tensor o_accum, lse_accum; + if (!tile_scheduler_metadata.has_value()) { + tile_scheduler_metadata = torch::empty({impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32)); + num_splits = torch::empty({b+1}, opts.dtype(torch::kInt32)); + KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); + KU_CHECK_CONTIGUOUS(num_splits); + + GetDecodeSchedMetaParams get_sched_meta_params = { + b, s_q, + impl_meta.block_size_topk, + impl_meta.fixed_overhead_num_blocks, + topk, + extra_topk, + ku::get_optional_tensor_ptr(topk_length), + ku::get_optional_tensor_ptr(extra_topk_length), + nullptr, + (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(), + num_splits->data_ptr(), + impl_meta.num_sm_parts, + at::cuda::getCurrentCUDAStream().stream() + }; + smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params); + } + // Stick the metadata pointers to `params` + KU_CHECK_DEVICE(tile_scheduler_metadata); + KU_CHECK_DEVICE(num_splits); + KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32); + KU_CHECK_DTYPE(num_splits, torch::kInt32); + KU_CHECK_CONTIGUOUS(tile_scheduler_metadata); + KU_CHECK_CONTIGUOUS(num_splits); + KU_CHECK_SHAPE(tile_scheduler_metadata, impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int)); + KU_CHECK_SHAPE(num_splits, b+1); + params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(); + params.num_splits_ptr = num_splits->data_ptr(); + params.num_sm_parts = impl_meta.num_sm_parts; + + // Allocate intermediate buffers for split-KV + const int total_num_splits = b + impl_meta.num_sm_parts; + lse_accum = torch::empty({total_num_splits, s_q, h_q}, opts.dtype(at::kFloat)); + o_accum = torch::empty({total_num_splits, s_q, h_q, d_v}, opts.dtype(at::kFloat)); + KU_CHECK_CONTIGUOUS(lse_accum); + KU_CHECK_CONTIGUOUS(o_accum); + params.lse_accum = lse_accum.data_ptr(); + params.o_accum = o_accum.data_ptr(); + params.stride_lse_accum_split = int64_stride_to_int(lse_accum.stride(0)); + params.stride_lse_accum_s_q = int64_stride_to_int(lse_accum.stride(1)); + params.stride_o_accum_split = int64_stride_to_int(o_accum.stride(0)); + params.stride_o_accum_s_q = int64_stride_to_int(o_accum.stride(1)); + params.stride_o_accum_h_q = int64_stride_to_int(o_accum.stride(2)); + + impl->run(params, features); + + CombineParams combine_params = { + b, s_q, h_q, d_v, + + params.lse, + params.out, + params.stride_lse_b, params.stride_lse_s_q, + params.stride_o_b, params.stride_o_s_q, params.stride_o_h_q, + + params.lse_accum, + params.o_accum, + params.stride_lse_accum_split, params.stride_lse_accum_s_q, + params.stride_o_accum_split, params.stride_o_accum_s_q, params.stride_o_accum_h_q, + + params.tile_scheduler_metadata_ptr, + params.num_splits_ptr, + params.num_sm_parts, + + ku::get_optional_tensor_ptr(attn_sink), + at::cuda::getCurrentCUDAStream().stream() + }; + smxx::decode::run_flash_mla_combine_kernel(combine_params); + + delete impl; + + return {out, lse.transpose(1, 2), tile_scheduler_metadata, num_splits}; +} diff --git a/csrc/api/sparse_fwd.h b/csrc/api/sparse_fwd.h new file mode 100644 index 0000000..66d7111 --- /dev/null +++ b/csrc/api/sparse_fwd.h @@ -0,0 +1,243 @@ +#pragma once + +#include "common.h" + +#include "params.h" + +#include "sm90/prefill/sparse/phase1.h" +#include "sm100/prefill/sparse/fwd/head128/phase1.h" +#include "sm100/prefill/sparse/fwd/head64/phase1.h" +#include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h" + +enum class FwdFeatures : int { + HEAD_64, + HEAD_128, + + HEAD_DIM_576, + HEAD_DIM_512, + + ATTN_SINK, + SINK_LSE, + TOPK_LENGTH +}; + +class FwdImplBase : public ImplBase< + SparseAttnFwdParams, + FwdFeatures +> {}; + +class Fwd_Sm90_Impl : public FwdImplBase { + DECLARE_SUPPORTED_FEATURES( + FwdFeatures::HEAD_64, + FwdFeatures::HEAD_128, + FwdFeatures::HEAD_DIM_512, + FwdFeatures::HEAD_DIM_576, + FwdFeatures::ATTN_SINK, + FwdFeatures::SINK_LSE, + FwdFeatures::TOPK_LENGTH + ) + +protected: + void run_(const SparseAttnFwdParams ¶ms, const std::vector &required_features) override { + DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { + DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() { + sm90::fwd::run_fwd_phase1_kernel(params); + }); + }); + } +}; + +class Fwd_Sm100_Head64_Impl : public FwdImplBase { + DECLARE_SUPPORTED_FEATURES( + FwdFeatures::HEAD_64, + FwdFeatures::HEAD_DIM_512, + FwdFeatures::HEAD_DIM_576, + FwdFeatures::ATTN_SINK, + FwdFeatures::SINK_LSE, + FwdFeatures::TOPK_LENGTH + ) + +protected: + void run_(const SparseAttnFwdParams ¶ms, const std::vector &required_features) override { + DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { + sm100::fwd::head64::run_fwd_phase1_kernel(params); + }); + } +}; + +class Fwd_Sm100_Head128_Impl : public FwdImplBase { + DECLARE_SUPPORTED_FEATURES( + FwdFeatures::HEAD_128, + FwdFeatures::HEAD_DIM_512, + FwdFeatures::HEAD_DIM_576, + FwdFeatures::ATTN_SINK, + FwdFeatures::SINK_LSE, + FwdFeatures::TOPK_LENGTH + ) + +protected: + void run_(const SparseAttnFwdParams ¶ms, const std::vector &required_features) override { + DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { + sm100::fwd::head128::run_fwd_phase1_kernel(params); + }); + } +}; + +class Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase { + DECLARE_SUPPORTED_FEATURES( + FwdFeatures::HEAD_128, + FwdFeatures::HEAD_DIM_512, + FwdFeatures::ATTN_SINK, + FwdFeatures::SINK_LSE, + FwdFeatures::TOPK_LENGTH + ) + +protected: + void run_(const SparseAttnFwdParams ¶ms, const std::vector &required_features) override { + sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel(params); + } +}; + +static std::vector sparse_attn_prefill_interface( + const at::Tensor &q, + const at::Tensor &kv, + const at::Tensor &indices, + float sm_scale, + int d_v, + const std::optional &attn_sink, + const std::optional &topk_length +) { + using bf16 = cutlass::bfloat16_t; + + Arch arch = Arch(); + bool is_sm90a = arch.is_sm90a(); + bool is_sm100f = arch.is_sm100f(); + TORCH_CHECK(is_sm90a || is_sm100f, "Sparse Attention Forward Kernel is only supported on SM90a and SM100f architectures."); + + KU_CHECK_NDIM(q, 3); + KU_CHECK_NDIM(kv, 3); + KU_CHECK_NDIM(indices, 3); + KU_CHECK_NDIM(attn_sink, 1); + KU_CHECK_NDIM(topk_length, 1); + + int s_q = q.size(0); + int s_kv = kv.size(0); + int h_q = q.size(1); + int h_kv = kv.size(1); + int d_qk = q.size(2); + int topk = indices.size(2); + bool have_topk_length = topk_length.has_value(); + + TORCH_CHECK(d_qk == 576 || d_qk == 512, "Invalid d_qk: ", d_qk); + TORCH_CHECK(d_v == 512, "Invalid d_v", d_v); + + KU_CHECK_DEVICE(q); + KU_CHECK_DEVICE(kv); + KU_CHECK_DEVICE(indices); + KU_CHECK_DEVICE(attn_sink); + KU_CHECK_DEVICE(topk_length); + + KU_CHECK_DTYPE(q, torch::kBFloat16); + KU_CHECK_DTYPE(kv, torch::kBFloat16); + KU_CHECK_DTYPE(indices, torch::kInt32); + KU_CHECK_DTYPE(attn_sink, torch::kFloat32); + KU_CHECK_DTYPE(topk_length, torch::kInt32); + + KU_CHECK_SHAPE(q, s_q, h_q, d_qk); + KU_CHECK_SHAPE(kv, s_kv, h_kv, d_qk); + KU_CHECK_SHAPE(indices, s_q, h_kv, topk); + KU_CHECK_SHAPE(attn_sink, h_q); + KU_CHECK_SHAPE(topk_length, s_q); + + KU_CHECK_LAST_DIM_CONTIGUOUS(q); + KU_CHECK_LAST_DIM_CONTIGUOUS(kv); + KU_CHECK_LAST_DIM_CONTIGUOUS(indices); + KU_CHECK_LAST_DIM_CONTIGUOUS(attn_sink); + KU_CHECK_LAST_DIM_CONTIGUOUS(topk_length); + + // Allocate results and buffers + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + auto opts = q.options(); + + at::Tensor out = torch::empty({s_q, h_q, d_v}, opts); + at::Tensor lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); + at::Tensor max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); + KU_CHECK_CONTIGUOUS(out); + KU_CHECK_CONTIGUOUS(lse); + KU_CHECK_CONTIGUOUS(max_logits); + + SparseAttnFwdParams params = { + s_q, s_kv, h_q, h_kv, d_qk, d_v, topk, + sm_scale, sm_scale * LOG_2_E, + + (bf16*)q.data_ptr(), + (bf16*)kv.data_ptr(), + (int*)indices.data_ptr(), + ku::get_optional_tensor_ptr(attn_sink), + ku::get_optional_tensor_ptr(topk_length), + + int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), + int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)), + int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)), + + (bf16*)out.data_ptr(), + (float*)max_logits.data_ptr(), + (float*)lse.data_ptr(), + + arch.num_sms, + at::cuda::getCurrentCUDAStream().stream() + }; + + std::vector required_features; + if (h_q == 64) { + required_features.push_back(FwdFeatures::HEAD_64); + } else if (h_q == 128) { + required_features.push_back(FwdFeatures::HEAD_128); + } else { + TORCH_CHECK(false, "Unsupported h_q: ", h_q); + } + if (d_qk == 576) { + required_features.push_back(FwdFeatures::HEAD_DIM_576); + } else if (d_qk == 512) { + required_features.push_back(FwdFeatures::HEAD_DIM_512); + } else { + TORCH_CHECK(false, "Unsupported d_qk: ", d_qk); + } + if (attn_sink.has_value()) { + required_features.push_back(FwdFeatures::ATTN_SINK); + } + if (have_topk_length) { + required_features.push_back(FwdFeatures::TOPK_LENGTH); + } + + if (is_sm90a) { + Fwd_Sm90_Impl fwd_impl; + fwd_impl.run(params, required_features); + } else if (is_sm100f) { + if (h_q == 64) { + Fwd_Sm100_Head64_Impl fwd_impl; + fwd_impl.run(params, required_features); + } else if (h_q == 128) { + Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl; + Fwd_Sm100_Head128_Impl regular_impl; + bool use_small_topk_impl = false; + if ( + (topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) || + !regular_impl.check_if_all_features_are_supported(required_features) + ) { + use_small_topk_impl = true; + } + if (use_small_topk_impl) { + small_topk_impl.run(params, required_features); + } else { + regular_impl.run(params, required_features); + } + } else { + TORCH_CHECK(false, "Unsupported h_q: ", h_q); + } + } else { + TORCH_CHECK(false, "Unsupported architecture"); + } + + return {out, max_logits, lse}; +} diff --git a/csrc/cutlass b/csrc/cutlass index e94e888..147f567 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit e94e888df3551224738bfa505787b515eae8352f +Subproject commit 147f5673d0c1c3dcf66f78d677fd647e4a020219 diff --git a/csrc/sm100/defines.h b/csrc/defines.h similarity index 96% rename from csrc/sm100/defines.h rename to csrc/defines.h index 0e779a3..91b96f0 100644 --- a/csrc/sm100/defines.h +++ b/csrc/defines.h @@ -3,8 +3,6 @@ #include #include -namespace sm100 { - using bf16 = cutlass::bfloat16_t; using fp8 = cutlass::float_e4m3_t; using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; @@ -26,5 +24,3 @@ struct bf16x8 { __nv_bfloat162 a45; __nv_bfloat162 a67; }; - -} diff --git a/csrc/kerutils/include/kerutils/common/common.h b/csrc/kerutils/include/kerutils/common/common.h new file mode 100644 index 0000000..92459b7 --- /dev/null +++ b/csrc/kerutils/include/kerutils/common/common.h @@ -0,0 +1,8 @@ +#pragma once + +namespace kerutils {} + +#define KU_PRINTLN(fmt, ...) { cute::print(fmt, ##__VA_ARGS__); print("\n"); } + +namespace ku = kerutils; + diff --git a/csrc/kerutils/include/kerutils/device/common.h b/csrc/kerutils/include/kerutils/device/common.h new file mode 100644 index 0000000..d13e7a9 --- /dev/null +++ b/csrc/kerutils/include/kerutils/device/common.h @@ -0,0 +1,70 @@ +/* +Common data types and macros that are used across the kerutils library. +*/ +#pragma once + +#include +#include + +#include +#include +#include // For CUTE_DEVICE + +namespace kerutils { + +// Cache hints +enum class CacheHint { + EVICT_FIRST, + EVICT_NORMAL, + EVICT_LAST, + EVICT_UNCHANGED, + NO_ALLOCATE +}; + +// Prefetch size +enum class PrefetchSize { + B64, + B128, + B256 +}; + +using nvbf16 = __nv_bfloat16; +using nvbf16x2 = __nv_bfloat162; +using nve4m3 = __nv_fp8_e4m3; +using nve4m3x2 = __nv_fp8x2_e4m3; +using nve4m3x4 = __nv_fp8x4_e4m3; + +using bf16 = cutlass::bfloat16_t; +using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; + +} + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +#define KERUTILS_ENABLE_SM80 +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) +static_assert(false, "kerutils doesn't support SM architectures below SM80"); +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#define KERUTILS_ENABLE_SM90 +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000)) +#define KERUTILS_ENABLE_SM90A +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) +#define KERUTILS_ENABLE_SM100 +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) +#define KERUTILS_ENABLE_SM100A +#endif + +#if (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) +#define KERUTILS_ENABLE_SM80 +#define KERUTILS_ENABLE_SM90 +#define KERUTILS_ENABLE_SM90A +#define KERUTILS_ENABLE_SM100 +#define KERUTILS_ENABLE_SM100A +#endif diff --git a/csrc/kerutils/include/kerutils/device/device.cuh b/csrc/kerutils/include/kerutils/device/device.cuh new file mode 100644 index 0000000..1673621 --- /dev/null +++ b/csrc/kerutils/include/kerutils/device/device.cuh @@ -0,0 +1,13 @@ +#pragma once + +#include "kerutils/common/common.h" + +#include "common.h" +#include "sm80/intrinsics.cuh" +#include "sm80/helpers.cuh" +#include "sm90/intrinsics.cuh" +#include "sm90/helpers.cuh" +#include "sm100/intrinsics.cuh" +#include "sm100/helpers.cuh" +#include "sm100/gemm.cuh" +#include "sm100/tma_cta_group2_nosplit.cuh" diff --git a/csrc/sm100/ws_gemm.h b/csrc/kerutils/include/kerutils/device/sm100/gemm.cuh similarity index 65% rename from csrc/sm100/ws_gemm.h rename to csrc/kerutils/include/kerutils/device/sm100/gemm.cuh index 54edd3d..8af4edc 100644 --- a/csrc/sm100/ws_gemm.h +++ b/csrc/kerutils/include/kerutils/device/sm100/gemm.cuh @@ -2,19 +2,22 @@ #include +#include + namespace cute { // Extensions to CuTe // CuTe don't support UTCMMA with .ws, so we add it here +// Besides, CuTe's UTCMMA has an `elect_one_sync()` inside which is really disgusting, so we have our own variant without `elect_one_sync()` here template -struct SM100_MMA_F16BF16_WS_SS_NOELECT +struct SM100_MMA_F16BF16_WS_TS_NOELECT { - static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); + static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); static_assert(N == 64 || N == 128 || N == 256, - "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128"); + "SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128"); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -22,7 +25,7 @@ struct SM100_MMA_F16BF16_WS_SS_NOELECT using CRegisters = uint32_t[1]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& tmem_a, uint64_t const& desc_b, uint32_t const& tmem_c, uint32_t const& scaleC, @@ -32,32 +35,32 @@ struct SM100_MMA_F16BF16_WS_SS_NOELECT "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], [%1], %2, %3, p, 0; \n\t" "}\n" : - : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); } }; template -struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_TS_NOELECT supports 16bit types"); - static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types"); - - using FrgTypeA = UMMA::smem_desc; + using FrgTypeA = UMMA::tmem_frg_1sm; // Actually this should be "duplicated", however, our great CuTe doesn't allow us to set it to "duplicated", so we just set it to NonInterleaved for a correct address calculation using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_ws_1sm; - // Logical shape-K is always 256bits, transform to units of elements + // Logical shape-K is always 256 bits; transform to units of elements static constexpr int K = 256 / cute::sizeof_bits::value; using Shape_MNK = Shape,Int,Int>; @@ -69,12 +72,12 @@ struct MMA_Traits,Int>>, Stride<_0,Stride< _1,Int>>>; - UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< - a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); - // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + template const& C) { static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - uint64_t desc_a = A[0]; + uint32_t tmem_a = raw_pointer_cast(A.data()); uint64_t desc_b = B[0]; uint32_t tmem_c = raw_pointer_cast(D.data()); uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); - SM100_MMA_F16BF16_WS_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + SM100_MMA_F16BF16_WS_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; -using namespace cute; template -struct SM100_MMA_F16BF16_WS_TS_NOELECT +struct SM100_MMA_F16BF16_WS_SS_NOELECT { - static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); + static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); static_assert(N == 64 || N == 128 || N == 256, - "SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128"); + "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128"); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -119,7 +122,7 @@ struct SM100_MMA_F16BF16_WS_TS_NOELECT using CRegisters = uint32_t[1]; CUTE_HOST_DEVICE static void - fma(uint32_t const& tmem_a, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t const& tmem_c, uint32_t const& scaleC, @@ -129,32 +132,32 @@ struct SM100_MMA_F16BF16_WS_TS_NOELECT "{\n\t" ".reg .pred p;\n\t" "setp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], [%1], %2, %3, p, 0; \n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" "}\n" : - : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); } }; template -struct MMA_Traits> { using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; - static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_TS_NOELECT supports 16bit types"); - using FrgTypeA = UMMA::tmem_frg_1sm; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; - using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; - // Logical shape-K is always 256 bits; transform to units of elements + // Logical shape-K is always 256bits, transform to units of elements static constexpr int K = 256 / cute::sizeof_bits::value; using Shape_MNK = Shape,Int,Int>; @@ -166,12 +169,12 @@ struct MMA_Traits,Int>>, Stride<_0,Stride< _1,Int>>>; - // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] - UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; - UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + template const& C) { static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_a = A[0]; uint64_t desc_b = B[0]; uint32_t tmem_c = raw_pointer_cast(D.data()); uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); - SM100_MMA_F16BF16_WS_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + SM100_MMA_F16BF16_WS_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); } }; @@ -209,7 +211,7 @@ template -// struct MMA_Traits> : MMA_Traits> {}; template +struct SM100_MMA_F16BF16_TS_NOELECT +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_TS_NOELECT M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16_TS_NOELECT N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_TS_NOELECT A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_TS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + +template +struct SM100_MMA_F16BF16_SS_NOELECT +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_NOELECT M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16_SS_NOELECT N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS_NOELECT supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +} diff --git a/csrc/kerutils/include/kerutils/device/sm100/helpers.cuh b/csrc/kerutils/include/kerutils/device/sm100/helpers.cuh new file mode 100644 index 0000000..b6719de --- /dev/null +++ b/csrc/kerutils/include/kerutils/device/sm100/helpers.cuh @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include "kerutils/device/common.h" + +namespace kerutils { + +// Perform SS UTCMMA +// sA and sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tC_frag should be tmem fragment +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma_ss( + TiledMMA &tiled_mma, + TensorA sA, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + using namespace cute; + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sA_frag = thr_mma.partition_fragment_A(sA); + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(sA_frag) == size<2>(sB_frag)); + static_assert(size<1>(sA_frag) == size<1>(tC_frag)); + static_assert(size<1>(sB_frag) == size<2>(tC_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(sA_frag); ++k) { + cute::gemm( + tiled_mma, + sA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +// Perform TS UTCMMA +// sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tA_frag and tC_frag should be tmem fragment +template< + typename TiledMMA, + typename TensorA, + typename TensorB, + typename TensorFragC +> +CUTE_DEVICE +void utcmma_ts( + TiledMMA &tiled_mma, + TensorA tA_frag, + TensorB sB, + TensorFragC tC_frag, + bool clear_accum +) { + using namespace cute; + tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter + auto sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(tA_frag) == size<2>(sB_frag)); + CUTE_UNROLL + for (int k = 0; k < size<2>(tA_frag); ++k) { + cute::gemm( + tiled_mma, + tA_frag(_, _, k), + sB_frag(_, _, k), + tC_frag + ); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } +} + +template +static constexpr auto make_umma_canonical_k_major_layout() { + using namespace cute; + using base_atom_type = \ + std::conditional_t, + std::conditional_t, + std::conditional_t, + std::conditional_t, + void + > + > + > + >; + static_assert(!std::is_same_v, "Invalid SWIZZLE value"); + return coalesce(tile_to_shape( + base_atom_type{}, + Shape, Int>{}, + Step<_1, _2>{} + ), Shape<_1, _1>{}); +} + +template +static constexpr auto make_umma_canonical_mn_major_layout() { + using namespace cute; + using base_atom_type = \ + std::conditional_t, + std::conditional_t, + std::conditional_t, + std::conditional_t, + void + > + > + > + >; + static_assert(!std::is_same_v, "Invalid SWIZZLE value"); + return coalesce(tile_to_shape( + base_atom_type{}, + Shape, Int>{}, + Step<_2, _1>{} + ), Shape<_1, _1>{}); +} + +template +auto make_umma_canonical_layout() { + if constexpr (MAJOR == cute::UMMA::Major::K) { + return make_umma_canonical_k_major_layout(); + } else { + return make_umma_canonical_mn_major_layout(); + } +} + +} diff --git a/csrc/kerutils/include/kerutils/device/sm100/intrinsics.cuh b/csrc/kerutils/include/kerutils/device/sm100/intrinsics.cuh new file mode 100644 index 0000000..d6a8b52 --- /dev/null +++ b/csrc/kerutils/include/kerutils/device/sm100/intrinsics.cuh @@ -0,0 +1,382 @@ +#pragma once + +#include "kerutils/device/common.h" + +namespace kerutils { + +// tma gather4 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor) +// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios +CUTE_DEVICE +void tma_gather4(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbar_addr), "l"(cache_hint) + : "memory" + ); +} + +// tma gather4 prefetch (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor) +// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios +CUTE_DEVICE +void tma_gather4_prefetch(const void* desc_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) { + asm volatile( + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint [%0, {%1, %2, %3, %4, %5}], %6;\n" + : + : "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "l"(cache_hint) + ); +} + +// tma gather4 with cta_group::2, allowing for synchronization across CTAs within a pair of CTAs (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor) +template +CUTE_DEVICE void tma_gather4_cta_group_2(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr); + if constexpr (USE_CTA0_MBAR) { + mbar_addr &= cute::Sm100MmaPeerBitMask; + } + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbar_addr), "l"(cache_hint) + : "memory" + ); +} + +// Vectorized addition for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add) +CUTE_DEVICE +float2 float2_add(const float2 &a, const float2 &b) { + float2 c; + asm volatile( + "add.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)) + ); + return c; +} + +// Vectorized multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-mul) +CUTE_DEVICE +float2 float2_mul(const float2 &a, const float2 &b) { + float2 c; + asm volatile( + "mul.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); + return c; +} + +// Vectorized fused addition-multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-fma) +CUTE_DEVICE +float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) { + // return a*b+c + float2 d; + asm volatile( + "fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); + return d; +} + +// Vectorized negation for foat32 +CUTE_DEVICE +float2 float2_neg(const float2 &a) { + float2 t = {-1.0f, -1.0f}; + return float2_mul(a, t); +} + +// st.bulk (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk) +CUTE_DEVICE +void st_bulk(void* dst_ptr, int64_t size) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + asm volatile ( + "st.bulk.weak.shared::cta [%0], %1, 0;\n" + : + : "r"(dst_addr), "l"(size) + : "memory" + ); +} + +struct CUTE_ALIGNAS(16) CLCResponseObj { + // An opaque 16B value + char opaque[16]; +}; + +struct CLCResult { + int is_valid; + int x, y, z; +}; + +// Issue a CLC try_cancel query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel) +CUTE_DEVICE +void issue_clc_query(transac_bar_t &bar, CLCResponseObj &response_obj) { + uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque); + uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar); + asm volatile( + "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [%0], [%1];\n" + : + : "r"(response_addr), "r"(mbarrier_addr) + ); +} + +// Issue a CLC try_cancel query with .multicast::cluster::all (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel) +CUTE_DEVICE +void issue_clc_query_multicast_cluster_all(transac_bar_t &bar, CLCResponseObj &response_obj) { + uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque); + uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar); + asm volatile( + "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n" + : + : "r"(response_addr), "r"(mbarrier_addr) + ); +} + +// Get the result of a CLC query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel) +// In this function, we separate get_first_ctaid::x/y/z and hope PTXAS's dead code elimination can remove unnecessary instructions +template +CUTE_DEVICE +CLCResult get_clc_query_response(CLCResponseObj &response_obj) { + uint32_t response_addr = cute::cast_smem_ptr_to_uint(&response_obj); + CLCResult result; + #define EMIT_ASM(LD_MODIFIER) \ + asm volatile( \ + "{\n" \ + ".reg .pred p1;\n\t" \ + ".reg .b128 clc_result;\n\t" \ + "ld" LD_MODIFIER ".shared.b128 clc_result, [%4];\n\t" \ + "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" \ + "selp.u32 %3, 1, 0, p1;\n\t" \ + "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 %0, clc_result;\n\t" \ + "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::y.b32.b128 %1, clc_result;\n\t" \ + "@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::z.b32.b128 %2, clc_result;\n\t" \ + "}\n" \ + : "=r"(result.x), "=r"(result.y), "=r"(result.z), "=r"(result.is_valid) \ + : "r"(response_addr) \ + : "memory" \ + ); + if constexpr (USE_LD_ACQUIRE) { + EMIT_ASM(".acquire.cta"); + } else { + EMIT_ASM(""); + } + return result; +} + +// LDG.256 or LDG.256 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld) +// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function +// NC_STR should be either "" or ".nc" +// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate" +// L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last" +// L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B" +#define KU_LDG_256(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \ + { \ + static_assert(std::is_pointer_v || std::is_array_v, "`global_addr` must be a pointer"); \ + static_assert(std::is_pointer_v || std::is_array_v, "`result` must be a pointer"); \ + uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \ + asm volatile( \ + "ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v4.u64 {%0, %1, %2, %3}, [%4];\n" \ + : "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]), \ + "=l"(result_as_uint64_ptr[2]), "=l"(result_as_uint64_ptr[3]) \ + : "l"(global_addr) \ + ); \ + } + +// STG.256 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st) +// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate" +// L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last" +#define KU_STG_256(global_addr, src, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR) \ + { \ + static_assert(std::is_pointer_v || std::is_array_v, "`global_addr` must be a pointer"); \ + static_assert(std::is_pointer_v || std::is_array_v, "`src` must be a pointer"); \ + uint64_t const* src_as_uint64_ptr = (uint64_t const*)(src); \ + asm volatile( \ + "st.global.L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".v4.u64 [%0], {%1, %2, %3, %4};\n" \ + : \ + : "l"(global_addr), "l"(src_as_uint64_ptr[0]), "l"(src_as_uint64_ptr[1]), \ + "l"(src_as_uint64_ptr[2]), "l"(src_as_uint64_ptr[3]) \ + ); \ + } + +} + +namespace kerutils { + +// tcgen05.commit.cta_group::1 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit) +CUTE_DEVICE +void umma_arrive_noelect(transac_bar_t &bar) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar); + asm volatile( + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];\n" + : + :"r"(bar_intptr) + ); +} + +// tcgen05.commit.cta_group::1, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit) +CUTE_DEVICE +void umma_arrive_multicast_noelect(transac_bar_t &bar, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar); + asm volatile( + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n" + : + :"r"(bar_intptr), "h"(cta_mask) + ); +} + +// tcgen05.commit.cta_group::2 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit) +CUTE_DEVICE +void umma_arrive_2x1SM_noelect(transac_bar_t &bar) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar); + asm volatile( + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];\n" + : + :"r"(bar_intptr) + ); +} + +// tcgen05.commit.cta_group::2, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit) +CUTE_DEVICE +void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &bar, uint16_t cta_mask) { + uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar); + asm volatile( + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n" + : + :"r"(bar_intptr), "h"(cta_mask) + ); +} + +// tcgen05.fence::before_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence) +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +// tcgen05.fence::after_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence) +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + + +// Load from tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld) +template +__device__ __forceinline__ +void tmem_ld_32dp32bNx(uint32_t tmem_start, void* data_) { + uint32_t* data = (uint32_t*)data_; + static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements"); + // NOTE The following code crashes VSCode intellisense engine, so we disable it +#ifndef __VSCODE_IDE__ + [&](cute::index_sequence) { + if constexpr (kNumElements == 1) { + cute::SM100_TMEM_LOAD_32dp32b1x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumElements == 2) { + cute::SM100_TMEM_LOAD_32dp32b2x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumElements == 4) { + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumElements == 8) { + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumElements == 16) { + cute::SM100_TMEM_LOAD_32dp32b16x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumElements == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumElements == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumElements == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, data[Is]...); + } + }(cute::make_index_sequence{}); +#endif +} + +// Load from tensor memory, 16 data path lanes, 128-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld) +template +__device__ __forceinline__ +void tmem_ld_16dp128bNx(uint32_t tmem_start, void* data_) { + uint32_t* data = (uint32_t*)data_; + static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32 || kNumReplications == 64, "Invalid kNumReplications"); +#ifndef __VSCODE_IDE__ + [&](cute::index_sequence) { + if constexpr (kNumReplications == 1) { + cute::SM100_TMEM_LOAD_16dp128b1x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 2) { + cute::SM100_TMEM_LOAD_16dp128b2x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 4) { + cute::SM100_TMEM_LOAD_16dp128b4x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 8) { + cute::SM100_TMEM_LOAD_16dp128b8x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 16) { + cute::SM100_TMEM_LOAD_16dp128b16x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 32) { + cute::SM100_TMEM_LOAD_16dp128b32x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 64) { + cute::SM100_TMEM_LOAD_16dp128b64x::copy(tmem_start, data[Is]...); + } + }(cute::make_index_sequence{}); +#endif +} + +// Load from tensor memory, 16 data path lanes, 256-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld) +template +__device__ __forceinline__ +void tmem_ld_16dp256bNx(uint32_t tmem_start, void* data_) { + uint32_t* data = (uint32_t*)data_; + static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32, "Invalid kNumReplications"); +#ifndef __VSCODE_IDE__ + [&](cute::index_sequence) { + if constexpr (kNumReplications == 1) { + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 2) { + cute::SM100_TMEM_LOAD_16dp256b2x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 4) { + cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 8) { + cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 16) { + cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start, data[Is]...); + } else if constexpr (kNumReplications == 32) { + cute::SM100_TMEM_LOAD_16dp256b32x::copy(tmem_start, data[Is]...); + } + }(cute::make_index_sequence{}); +#endif +} + +// Store into tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st) +template +__device__ __forceinline__ +void tmem_st_32dp32bNx(uint32_t tmem_start, void const* data_) { + uint32_t const* data = (uint32_t const*)data_; + static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements"); +#ifndef __VSCODE_IDE__ + [&](cute::index_sequence) { + if constexpr (kNumElements == 1) { + cute::SM100_TMEM_STORE_32dp32b1x::copy(data[Is]..., tmem_start); + } else if constexpr (kNumElements == 2) { + cute::SM100_TMEM_STORE_32dp32b2x::copy(data[Is]..., tmem_start); + } else if constexpr (kNumElements == 4) { + cute::SM100_TMEM_STORE_32dp32b4x::copy(data[Is]..., tmem_start); + } else if constexpr (kNumElements == 8) { + cute::SM100_TMEM_STORE_32dp32b8x::copy(data[Is]..., tmem_start); + } else if constexpr (kNumElements == 16) { + cute::SM100_TMEM_STORE_32dp32b16x::copy(data[Is]..., tmem_start); + } else if constexpr (kNumElements == 32) { + cute::SM100_TMEM_STORE_32dp32b32x::copy(data[Is]..., tmem_start); + } else if constexpr (kNumElements == 64) { + cute::SM100_TMEM_STORE_32dp32b64x::copy(data[Is]..., tmem_start); + } else if constexpr (kNumElements == 128) { + cute::SM100_TMEM_STORE_32dp32b128x::copy(data[Is]..., tmem_start); + } + }(cute::make_index_sequence{}); +#endif +} + +} diff --git a/csrc/sm100/tma_cta_group2_nosplit.h b/csrc/kerutils/include/kerutils/device/sm100/tma_cta_group2_nosplit.cuh similarity index 95% rename from csrc/sm100/tma_cta_group2_nosplit.h rename to csrc/kerutils/include/kerutils/device/sm100/tma_cta_group2_nosplit.cuh index 045456d..b731e34 100644 --- a/csrc/sm100/tma_cta_group2_nosplit.h +++ b/csrc/kerutils/include/kerutils/device/sm100/tma_cta_group2_nosplit.cuh @@ -2,15 +2,16 @@ #include +#include + namespace cute { // Extensions to CuTe -// CuTe's SM100_TMA_2SM_LOAD_1D requires two threads to perform this operation cooperatively (using ThrID = Layout<_2>;), which doesn't fit our use case. +// CuTe's built-in SM100_TMA_2SM_LOAD_1D series requires the number of participating threads to be 2 (using ThrID = Layout<_2>;) and also splits the data, which is really annoying to use, so we modified our own version. Additionally, to keep it consistent with other parts that use SM90 TMA, we made it accept TMA::CacheHintSm90 instead of TMA::CacheHintSm100. //////////////////////////////////////////////////////////////////////////////////////////////////// /// TMA_LOAD : Initiates a TMA copy from global memory to shared memory //////////////////////////////////////////////////////////////////////////////////////////////////// - struct SM100_TMA_2SM_LOAD_1D_NOSPLIT { CUTE_HOST_DEVICE static void @@ -36,7 +37,6 @@ struct SM100_TMA_2SM_LOAD_1D_NOSPLIT #endif } }; - struct SM100_TMA_2SM_LOAD_2D_NOSPLIT { CUTE_HOST_DEVICE static void @@ -62,7 +62,6 @@ struct SM100_TMA_2SM_LOAD_2D_NOSPLIT #endif } }; - struct SM100_TMA_2SM_LOAD_3D_NOSPLIT { CUTE_HOST_DEVICE static void @@ -88,7 +87,6 @@ struct SM100_TMA_2SM_LOAD_3D_NOSPLIT #endif } }; - struct SM100_TMA_2SM_LOAD_4D_NOSPLIT { CUTE_HOST_DEVICE static void @@ -114,7 +112,6 @@ struct SM100_TMA_2SM_LOAD_4D_NOSPLIT #endif } }; - struct SM100_TMA_2SM_LOAD_5D_NOSPLIT { CUTE_HOST_DEVICE static void @@ -140,7 +137,6 @@ struct SM100_TMA_2SM_LOAD_5D_NOSPLIT #endif } }; - struct SM100_TMA_2SM_LOAD_NOSPLIT { CUTE_HOST_DEVICE static void @@ -178,14 +174,9 @@ struct SM100_TMA_2SM_LOAD_NOSPLIT { return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); } - using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; }; - - - struct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {}; - // The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar // Use .with(tma_mbar) to construct an executable version template @@ -198,19 +189,16 @@ struct Copy_Traits using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - // SM100_TMA_2SM_LOAD_NOSPLIT arguments TmaDescriptor tma_desc_; using AuxParams = AuxParams_; AuxParams aux_params_; - // Return TmaDescriptor/TensorMap CUTE_HOST_DEVICE constexpr TmaDescriptor const* get_tma_descriptor() const { return &tma_desc_; } - // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar CUTE_HOST_DEVICE constexpr Copy_Traits @@ -221,7 +209,6 @@ struct Copy_Traits // We accept multicast_mask here to keep the API for both atoms consistent return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; } - // Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) CUTE_HOST_DEVICE constexpr Copy_Traits @@ -233,16 +220,14 @@ struct Copy_Traits // We accept multicast_mask here to keep the API for both atoms consistent return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; } - // Generate the TMA coord tensor template CUTE_HOST_DEVICE constexpr auto get_tma_tensor(GShape const& g_shape) const { static_assert(is_congruent::value); - return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); } - // Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with() template @@ -251,7 +236,6 @@ struct Copy_Traits Tensor const& src, Tensor & dst) = delete; }; - // The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar template struct Copy_Traits @@ -264,18 +248,15 @@ struct Copy_Traits using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - // SM100_TMA_2SM_LOAD_NOSPLIT arguments tuple< TmaDescriptor const*, uint64_t*, // smem mbarrier uint64_t // cache hint > const opargs_; - CUTE_HOST_DEVICE Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) : opargs_(desc, mbar, cache) {} }; - } diff --git a/csrc/kerutils/include/kerutils/device/sm80/helpers.cuh b/csrc/kerutils/include/kerutils/device/sm80/helpers.cuh new file mode 100644 index 0000000..551b0f1 --- /dev/null +++ b/csrc/kerutils/include/kerutils/device/sm80/helpers.cuh @@ -0,0 +1,55 @@ +#pragma once + +#include "kerutils/device/common.h" +#include "kerutils/device/sm80/intrinsics.cuh" + +namespace kerutils { + +// Retrieve the value of `%smid` and check its range +CUTE_DEVICE +uint32_t get_sm_id_with_range_check(uint32_t num_physical_sms) { + uint32_t sm_id = get_sm_id(); + if (!(sm_id < num_physical_sms)) { + trap(); + } + return sm_id; +} + +#ifndef KU_TRAP_ONLY_DEVICE_ASSERT +#define KU_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +// Construct a `float2` from a single `float` by duplicating the value +CUTE_DEVICE +float2 float2float2(const float &x) { + return float2 {x, x}; +} + +CUTE_DEVICE +void st_shared(void* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +CUTE_DEVICE +void st_shared(void* ptr, float4 val) { + st_shared(ptr, *(__int128_t*)&val); +} + +CUTE_DEVICE +__int128_t ld_shared(void* ptr) { + __int128_t val; + asm volatile("ld.shared.b128 %0, [%1];" : "=q"(val) : "l"(__cvta_generic_to_shared(ptr))); + return val; +} + +CUTE_DEVICE +float4 ld_shared_float4(void* ptr) { + __int128_t temp = ld_shared(ptr); + return *(float4*)&temp; +} + +} diff --git a/csrc/kerutils/include/kerutils/device/sm80/intrinsics.cuh b/csrc/kerutils/include/kerutils/device/sm80/intrinsics.cuh new file mode 100644 index 0000000..7039a0b --- /dev/null +++ b/csrc/kerutils/include/kerutils/device/sm80/intrinsics.cuh @@ -0,0 +1,146 @@ +#pragma once + +#include "kerutils/device/common.h" + +namespace kerutils { + +// cp.async.cg (cache global) with prefetch and predicate (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async) +template +CUTE_DEVICE +void cp_async_cacheglobal(const void* src, void* dst, bool pred=true) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); + if constexpr (PREFETCH_SIZE == PrefetchSize::B64) { + asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16, %2;\n" + :: "r"(dst_addr), + "l"(src), + "r"(pred?16:0)); + } else if constexpr (PREFETCH_SIZE == PrefetchSize::B128) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16, %2;\n" + :: "r"(dst_addr), + "l"(src), + "r"(pred?16:0)); + } else if constexpr (PREFETCH_SIZE == PrefetchSize::B256) { + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16, %2;\n" + :: "r"(dst_addr), + "l"(src), + "r"(pred?16:0)); + } else { + static_assert(PREFETCH_SIZE == PrefetchSize::B64 || + PREFETCH_SIZE == PrefetchSize::B128 || + PREFETCH_SIZE == PrefetchSize::B256, + "Unsupported prefetch size for cp_async_cacheglobal."); + } +} + +// Create fraction-based cache policy (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-createpolicy) +template +CUTE_DEVICE +int64_t create_fraction_based_cache_policy(float fraction = 1.0f) { + int64_t result; + #define EMIT(PRIMARY_PRIORITY_STR, SECONDARY_PRIORITY_STR) \ + asm volatile( \ + "createpolicy.fractional.L2::" PRIMARY_PRIORITY_STR ".L2::" SECONDARY_PRIORITY_STR ".b64 %0, %1;\n" \ + : "=l"(result) \ + : "f"(fraction) \ + ); + #define EMIT2(PRIMARY_PRIORITY_STR) \ + { \ + if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_FIRST) { \ + EMIT(PRIMARY_PRIORITY_STR, "evict_first") \ + } else if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED) { \ + EMIT(PRIMARY_PRIORITY_STR, "evict_unchanged") \ + } else { \ + static_assert(SECONDARY_PRIORITY == CacheHint::EVICT_FIRST || \ + SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED, \ + "Unsupported secondary cache hint for create_fraction_based_cache_policy."); \ + } \ + } + if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_FIRST) { + EMIT2("evict_first"); + } else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL) { + EMIT2("evict_normal"); + } else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_LAST) { + EMIT2("evict_last"); + } else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED) { + EMIT2("evict_unchanged"); + } else { + static_assert(PRIMARY_PRIORITY == CacheHint::EVICT_FIRST || + PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL || + PRIMARY_PRIORITY == CacheHint::EVICT_LAST || + PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED, + "Unsupported primary cache hint for create_fraction_based_cache_policy."); + } + #undef EMIT + #undef EMIT2 + return result; +} + +// Create a simple cache policy (equivalent to create_fraction_based_cache_policy(1.0f)) +// The same as cute::TMA::CacheHintSmXX +template +CUTE_DEVICE +constexpr int64_t create_simple_cache_policy() { + if constexpr (CACHE_HINT == CacheHint::EVICT_FIRST) { + return 0x12F0000000000000; // Result of createpolicy.fractional.L2::evict_first.b64 + } else if constexpr (CACHE_HINT == CacheHint::EVICT_NORMAL) { + return 0x1000000000000000; // Copied from CuTe. Unsure about the exact meaning. (TODO Change to 0x16F0000000000000?) + } else if constexpr (CACHE_HINT == CacheHint::EVICT_LAST) { + return 0x14F0000000000000; // Result of createpolicy.fractional.L2::evict_last.b64 + } else { + static_assert(CACHE_HINT == CacheHint::EVICT_FIRST || + CACHE_HINT == CacheHint::EVICT_NORMAL || + CACHE_HINT == CacheHint::EVICT_LAST, + "Unsupported cache hint for create_simple_cache_policy."); + } +} + +// AtomicAdd (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red) +CUTE_DEVICE +void atomicadd_f32_with_policy_and_pred(void* global_addr, const float &data, int64_t cache_policy, uint32_t pred = true) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %3, 1;\n\t" + "@p red.relaxed.gpu.global.add.L2::cache_hint.f32 [%1], %0, %2; \n\t" + "}" + : + : "f"(data), + "l"((int64_t)global_addr), "l"(cache_policy), "r"(pred) + ); +} + +// Get the id of the current SM +// About %smid (https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-smid): PTX document says that %smid ranges from 0 to %nsmid-1, while "The SM identifier numbering is not guaranteed to be contiguous, so %nsmid may be larger than the physical number of SMs in the device.". However, result shows that, at least for sm90 and sm100f, %nsmid is the number of physical SMs - 1. For the sake of safety, I recommend you to check the return of get_sm_id manually or call `get_sm_id_with_range_check()` defined in `device/sm80/helpers.cuh`. +// Besides, PTX document also says that this number may change due to preemption, but currently this never happens according to [DATEN GELÖSCHT] +CUTE_DEVICE +uint32_t get_sm_id() { + uint32_t ret; + asm volatile("mov.u32 %0, %%smid;\n" : "=r"(ret)); + return ret; +} + +// trap (https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-trap) +CUTE_DEVICE +void trap() { + asm volatile("trap;\n"); +} + +// LDG.128 or LDG.128 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld) +// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function +// NC_STR should be either "" or ".nc" +// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate" +// L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B" +// L2 cache hint is not supported since it's only supported for LDG.256 +#define KU_LDG_128(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \ + { \ + static_assert(std::is_pointer_v || std::is_array_v, "`global_addr` must be a pointer"); \ + static_assert(std::is_pointer_v || std::is_array_v, "`result` must be a pointer"); \ + uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \ + asm volatile( \ + "ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v2.u64 {%0, %1}, [%2];\n" \ + : "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]) \ + : "l"(global_addr) \ + ); \ + } + +} diff --git a/csrc/kerutils/include/kerutils/device/sm90/helpers.cuh b/csrc/kerutils/include/kerutils/device/sm90/helpers.cuh new file mode 100644 index 0000000..3e7bb14 --- /dev/null +++ b/csrc/kerutils/include/kerutils/device/sm90/helpers.cuh @@ -0,0 +1,110 @@ +#pragma once + +#include + +#include "kerutils/device/common.h" + +namespace kerutils { + +template< + typename TMA, + typename Tensor0, + typename Tensor1 +> +CUTE_DEVICE +void launch_tma_copy( + const TMA &tma_copy, + Tensor0 src, + Tensor1 dst, + transac_bar_t &bar, + const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL +) { + auto thr_tma = tma_copy.get_slice(cute::_0{}); + cute::copy( + tma_copy.with(reinterpret_cast(bar), 0, cache_hint), + thr_tma.partition_S(src), + thr_tma.partition_D(dst) + ); +} + +// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx +// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a +CUTE_DEVICE +int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { + int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); + return row_idx; +} + +// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in some rows. This function converts the local_elem_idx to the actual col_idx +// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a +CUTE_DEVICE +int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) { + int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1); + return col_idx; +} + +template +CUTE_DEVICE +void wgmma(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC, bool zero_init) { + using namespace cute; + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + warpgroup_arrive(); + tiled_mma.accumulate_ = zero_init ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +template +CUTE_DEVICE +void wgmma_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { + using namespace cute; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor sA_frag = thr_mma.partition_fragment_A(sA); + Tensor sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(sA_frag) == size<2>(sB_frag)); + + warpgroup_fence_operand(rC_frag); + warpgroup_arrive(); + tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(sA_frag); ++k) { + cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_fence_operand(rC_frag); +} + +template +CUTE_DEVICE +void wgmma_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { + using namespace cute; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor sB_frag = thr_mma.partition_fragment_B(sB); + static_assert(size<2>(rA_frag) == size<2>(sB_frag)); + + warpgroup_fence_operand(const_cast(rA_frag)); + warpgroup_fence_operand(rC_frag); + warpgroup_arrive(); + tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(rA_frag); ++k) { + cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_fence_operand(rC_frag); + warpgroup_fence_operand(const_cast(rA_frag)); +} + +} diff --git a/csrc/kerutils/include/kerutils/device/sm90/intrinsics.cuh b/csrc/kerutils/include/kerutils/device/sm90/intrinsics.cuh new file mode 100644 index 0000000..07c5f9a --- /dev/null +++ b/csrc/kerutils/include/kerutils/device/sm90/intrinsics.cuh @@ -0,0 +1,107 @@ +#pragma once + +#include "kerutils/device/common.h" + +namespace kerutils { + +// st.async (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-async) +template +CUTE_DEVICE +static void st_async(void* dst_ptr, const T& data, transac_bar_t &mbar) { + static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async."); + long2 data_long2 = *reinterpret_cast(&data); + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar); + asm volatile ( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" + : + : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) + ); +} + +static constexpr int PEER_ADDR_MASK = 16777216; + +// Given an address in the current CTA, return the corresponding address in the peer CTA +template +CUTE_DEVICE +T* get_peer_addr(const T* p) { + return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); +} + +// Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0) +template +CUTE_DEVICE +T* get_cta0_addr(const T* p) { + constexpr int CTA0_ADDR_MASK = 0xFEFFFFFF; + return (T*)((int64_t)(p) & CTA0_ADDR_MASK); +} + +// TMA bulk reduce add (cp.reduce.async.bulk), shared to global, float32, add. (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk) +CUTE_DEVICE +void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) { + uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(src_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" + : + : "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +} + +// Cluster barrier arrive with .release modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster) +CUTE_DEVICE +void barrier_cluster_arrive_release() { + asm volatile("barrier.cluster.arrive.release;" : : : "memory"); +} + +// Cluster barrier arrive with .relaxed modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster) +CUTE_DEVICE +void barrier_cluster_arrive_relaxed() { + asm volatile("barrier.cluster.arrive.relaxed;" : : :); +} + +// Cluster barrier wait with .acquire modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster) +CUTE_DEVICE +void barrier_cluster_wait_acquire() { + asm volatile("barrier.cluster.wait.acquire;" : : : "memory"); +} + +// mbarrier.arrive with .relaxed.cluster qualifier (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-arrive) +CUTE_DEVICE +void mbarrier_arrive_relaxed_cluster(transac_bar_t &mbar) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&mbar); + asm volatile( + "{\n\t" + "mbarrier.arrive.relaxed.cluster.shared::cta.b64 _, [%0];\n\t" + "}" + : + : "r"(smem_addr)); +} + +// AtomicAdd with v4.f32 type (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red) +CUTE_DEVICE +void atomicadd_f32x4_with_policy_and_pred(void* global_addr, const float4 &data, int64_t cache_policy, uint32_t pred = true) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %6, 1;\n\t" + "@p red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t" + "}" + : + : "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w), + "l"((int64_t)global_addr), "l"(cache_policy), "r"(pred) + ); +} + +// cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk) +CUTE_DEVICE +void cp_async_bulk_shared_cta_to_shared_cluster(void* dst_ptr, const void* src_ptr, int32_t load_bytes, transac_bar_t &mbar) { + uint32_t dst_smem_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + uint32_t src_smem_addr = cute::cast_smem_ptr_to_uint(src_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar); + asm volatile( + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n" + : + : "r"(dst_smem_addr), "r"(src_smem_addr), "r"(load_bytes), "r"(mbar_addr) + ); +} + +} diff --git a/csrc/kerutils/include/kerutils/host/host.h b/csrc/kerutils/include/kerutils/host/host.h new file mode 100644 index 0000000..3bdd124 --- /dev/null +++ b/csrc/kerutils/include/kerutils/host/host.h @@ -0,0 +1,155 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include + +#include "kerutils/common/common.h" + +namespace kerutils { + +class KUException final : public std::exception { + std::string message = {}; + +public: + template + explicit KUException(const char *name, const char* file, const int line, Args&&... args) { + std::ostringstream oss; + + oss << name << " error (" << file << ":" << line << "): "; + (oss << ... << args); + message = oss.str(); + } + + const char *what() const noexcept override { + return message.c_str(); + } +}; + +#define THROW_KU_EXCEPTION(name, ...) \ + throw kerutils::KUException(name, __FILE__, __LINE__, __VA_ARGS__) + +#define KU_CUDA_CHECK(call) \ +do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + THROW_KU_EXCEPTION("CUDA", "CUDA error: ", cudaGetErrorString(status_)); \ + } \ +} while(0) + +#define KU_CUTLASS_CHECK(call) \ +do { \ + cutlass::Status status_ = call; \ + if (status_ != cutlass::Status::kSuccess) { \ + fprintf(stderr, "CUTLASS error (%s:%d): %d\n", __FILE__, __LINE__, static_cast(status_)); \ + THROW_KU_EXCEPTION("CUTLASS", "CUTLASS error: ", static_cast(status_)); \ + } \ +} while(0) + +// This `KU_ASSERT` is triggered no matter if the code is compiled with `-DNDEBUG` or not. +#define KU_ASSERT(cond, ...) \ + do { \ + if (not (cond)) { \ + fprintf(stderr, "Assertion `%s` failed (%s:%d): ", #cond, __FILE__, __LINE__); \ + if constexpr (sizeof(#__VA_ARGS__) > 1) { \ + fprintf(stderr, ", " __VA_ARGS__); \ + } \ + fprintf(stderr, "\n"); \ + THROW_KU_EXCEPTION("Assertion", "Assertion `", #cond, "` failed."); \ + } \ + } while(0) + +#define KU_CHECK_KERNEL_LAUNCH() KU_CUDA_CHECK(cudaGetLastError()) + +template +inline __host__ __device__ constexpr T ceil_div(const T &a, const T &b) { + return (a + b - 1) / b; +} + +template +inline __host__ __device__ constexpr T ceil(const T &a, const T &b) { + return (a + b - 1) / b * b; +} + +// A wrapper for make_tensor_map +static inline CUtensorMap make_tensor_map( + const std::vector &size, + const std::vector &strides, // PAY ATTENTION: In BYTES + const std::vector &box_size, + void* global_ptr, + CUtensorMapDataType data_type, + CUtensorMapSwizzle swizzle_mode, + CUtensorMapL2promotion l2_promotion, + CUtensorMapInterleave interleave_mode = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapFloatOOBfill oob_fill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, + const std::vector &element_strides_ = {} +) { + int dim = size.size(); + KU_ASSERT(dim >= 1); + + std::vector element_strides; + if (element_strides_.empty()) { + for (int i = 0; i < dim; ++i) + element_strides.push_back(1); + } else { + element_strides = element_strides_; + } + KU_ASSERT(strides.size() == (uint32_t)dim-1 && box_size.size() == (uint32_t)dim && element_strides.size() == (uint32_t)dim); + + CUtensorMap result; + CUresult ret_code = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &result, + data_type, + dim, + global_ptr, + size.data(), + strides.data(), + box_size.data(), + element_strides.data(), + interleave_mode, + swizzle_mode, + l2_promotion, + oob_fill + ); + if (ret_code != CUresult::CUDA_SUCCESS) { + auto print_vector = [&](auto t, const char* fmt, const char end='\n') { + for (auto elem : t) { + printf(fmt, elem); + } + printf("%c", end); + }; + fprintf(stderr, "Failed to create tensormap\n"); + fprintf(stderr, "Dim: %d\n", dim); + printf("size: "); print_vector(size, "%lu "); + printf("strides: "); print_vector(strides, "%lu "); + printf("box_size: "); print_vector(box_size, "%u "); + printf("element_strides: "); print_vector(element_strides, "%u "); + printf("global ptr: 0x%lx\n", (int64_t)global_ptr); + printf("data_type: %d\n", (int)data_type); + printf("swizzle_mode: %d\n", (int)swizzle_mode); + printf("l2_promotion: %d\n", (int)l2_promotion); + printf("interleave_mode: %d\n", (int)interleave_mode); + printf("oob_fill: %d\n", (int)oob_fill); + KU_ASSERT(false); + } + return result; +} + +// Given strides (in number of elements), this function converts their datatype in uint64_t and then multiplies by elem_size +template +static inline std::vector make_stride_helper(const std::vector &strides_in_elems, size_t elem_size) { + std::vector res; + for (auto stride : strides_in_elems) { + res.push_back(((uint64_t)stride) * elem_size); + } + return res; +} + +} \ No newline at end of file diff --git a/csrc/kerutils/include/kerutils/kerutils.cuh b/csrc/kerutils/include/kerutils/kerutils.cuh new file mode 100644 index 0000000..bc818f0 --- /dev/null +++ b/csrc/kerutils/include/kerutils/kerutils.cuh @@ -0,0 +1,4 @@ +#pragma once + +#include "host/host.h" +#include "device/device.cuh" diff --git a/csrc/kerutils/include/kerutils/supplemental/torch_tensors.h b/csrc/kerutils/include/kerutils/supplemental/torch_tensors.h new file mode 100644 index 0000000..5b4e564 --- /dev/null +++ b/csrc/kerutils/include/kerutils/supplemental/torch_tensors.h @@ -0,0 +1,71 @@ +#pragma once + +#include + +#include + +#include "kerutils/common/common.h" + +namespace kerutils { + +// Check whether the given tensor or optional tensor satisfies the given condition +// If tensor_or_opt is a tensor, check_fn is applied directly +// If tensor_or_opt is an optional tensor, check_fn is applied only when the optional has value +template +static inline bool _check_optional_tensor(const T& tensor_or_opt, const std::function& check_fn) { + if constexpr (std::is_same::value) { + return check_fn(tensor_or_opt); + } else { + if (tensor_or_opt.has_value()) { + return check_fn(tensor_or_opt.value()); + } else { + return true; + } + } +} + +// Get the pointer of the given tensor +// Return (PtrT*)tensor.data_ptr() if the tensor has a backend storage, nullptr otherwise +template +static inline PtrT* get_tensor_ptr(const at::Tensor& tensor) { + if (tensor.has_storage()) { + return (PtrT*)tensor.data_ptr(); + } else { + return nullptr; + } +} + +// Get the pointer of the given tensor or optional tensor +// Return (PtrT*)tensor.data_ptr() if tensor_or_opt has value and points to a valid tensor, return nullptr otherwise +template +static inline PtrT* get_optional_tensor_ptr(const T& tensor_or_opt) { + if constexpr (std::is_same::value) { + return get_tensor_ptr(tensor_or_opt); + } else { + if (tensor_or_opt.has_value()) { + return get_tensor_ptr(*tensor_or_opt); + } else { + return nullptr; + } + } +} + +} + +// Check whether the given tensor (or optional) is on cuda +#define KU_CHECK_DEVICE(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_cuda(); }), #tensor " must be on CUDA") + +// Check whether the given tensor (or optional) has the given number of dimensions +#define KU_CHECK_NDIM(tensor, ndim) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.dim() == (ndim); }), #tensor " must have " #ndim " dimensions") + +// Check whether the given tensor (or optional) has the given shape +#define KU_CHECK_SHAPE(tensor, ...) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.sizes() == torch::IntArrayRef({__VA_ARGS__}); }), #tensor " must have shape (" #__VA_ARGS__ ")") + +// Check whether the given tensor (or optional) is contiguous +#define KU_CHECK_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_contiguous(); }), #tensor " must be contiguous") + +// Check whether the last dimention of the given tensor (or optional) +#define KU_CHECK_LAST_DIM_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.size(-1) == 1 || t.stride(-1) == 1; }), #tensor " must have contiguous last dimension") + +// Check whether the given tensor (or optional) has the specified dtype +#define KU_CHECK_DTYPE(tensor, target_dtype) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.dtype() == (target_dtype); }), #tensor " must have dtype " #target_dtype) diff --git a/csrc/params.h b/csrc/params.h index baa2f7f..4433e8d 100644 --- a/csrc/params.h +++ b/csrc/params.h @@ -2,7 +2,21 @@ #include "cutlass/bfloat16.h" -struct DecodingParams { +enum class ModelType { + V32, + MODEL1 +}; + +struct __align__(4*8) DecodingSchedMeta { + int begin_req_idx, end_req_idx; // Both inclusive + int begin_block_idx, end_block_idx; // Inclusive, exclusive + int begin_split_idx; + int is_first_req_splitted, is_last_req_splitted; + int _pad[1]; +}; +static constexpr int DecodingSchedMetaSize = sizeof(DecodingSchedMeta); + +struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams using index_t = int64_t; int b; // batch size @@ -14,13 +28,11 @@ struct DecodingParams { int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k bool is_causal; float scale_softmax, scale_softmax_log2; - int topk; void *__restrict__ q_ptr; void *__restrict__ k_ptr; void *__restrict__ o_ptr; - void *__restrict__ softmax_lse_ptr; - int *__restrict__ indices_ptr; + float *__restrict__ softmax_lse_ptr; index_t q_batch_stride; index_t k_batch_stride; @@ -31,38 +43,106 @@ struct DecodingParams { index_t q_head_stride; index_t k_head_stride; index_t o_head_stride; - index_t indices_batch_stride; - index_t indices_row_stride; int *__restrict__ block_table; index_t block_table_batch_stride; int page_block_size; int *__restrict__ seqlens_k_ptr; - int *__restrict__ tile_scheduler_metadata_ptr; + DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr; int num_sm_parts; int *__restrict__ num_splits_ptr; int total_num_splits; - void *__restrict__ softmax_lseaccum_ptr; - void *__restrict__ oaccum_ptr; + float *__restrict__ softmax_lseaccum_ptr; + float *__restrict__ oaccum_ptr; + + cudaStream_t stream; }; -static constexpr int TileSchedulerMetaDataSize = 8; -// [begin_idx (inclusive), begin_block_idx (inclusive), end_idx (inclusive), end_block_idx (exclusive), begin_n_split_idx, _, _, _] +struct SparseAttnDecodeParams { + int b, s_q; + int h_q, h_kv; + int d_qk, d_v; + float sm_scale, sm_scale_div_log2; + int num_blocks, page_block_size, topk; + ModelType model_type; -struct GetDecodingMetadataParams { - int *__restrict__ seqlens_k_ptr; - int *__restrict__ tile_scheduler_metadata_ptr; - int *__restrict__ num_splits_ptr; - int batch_size; + cutlass::bfloat16_t* __restrict__ q; // [b, s_q, h_q, d_qk] + cutlass::bfloat16_t* __restrict__ kv; // [num_blocks, page_block_size, d_qk] + int* __restrict__ indices; // [b, s_q, topk] + int* __restrict__ topk_length; // [b], may be nullptr + float* __restrict__ attn_sink; // [h_q], may be nullptr + + float* __restrict__ lse; // [b, s_q, h_q] + cutlass::bfloat16_t* __restrict__ out; // [b, s_q, h_q, d_v] + + int extra_num_blocks, extra_page_block_size, extra_topk; + cutlass::bfloat16_t* __restrict__ extra_kv; // [extra_num_blocks, extra_page_block_size, d_qk] + int* __restrict__ extra_indices; // [b, s_q, extra_topk] + int* __restrict__ extra_topk_length; // [b], may be nullptr + + int stride_q_b, stride_q_s_q, stride_q_h_q; + int stride_kv_block, stride_kv_row; + int stride_indices_b, stride_indices_s_q; + int stride_lse_b, stride_lse_s_q; + int stride_o_b, stride_o_s_q, stride_o_h_q; + int stride_extra_kv_block, stride_extra_kv_row; + int stride_extra_indices_b, stride_extra_indices_s_q; + + cudaStream_t stream; + + // SplitKV-related parameters + float* __restrict__ lse_accum; // [num_splits, s_q, h_q] + float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v] + int stride_lse_accum_split, stride_lse_accum_s_q; + int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q; + DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous + int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous + int num_sm_parts; +}; + +struct CombineParams { + int b, s_q, h_q, d_v; + + float* __restrict__ lse; // [b, s_q, h_q] + void* __restrict__ out; // [b, s_q, h_q, d_v] + int stride_lse_b, stride_lse_s_q; + int stride_o_b, stride_o_s_q, stride_o_h_q; + + float* __restrict__ lse_accum; // [num_splits, s_q, h_q] + float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v] + int stride_lse_accum_split, stride_lse_accum_s_q; + int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q; + + DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous + int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous + int num_sm_parts; + + float* attn_sink; // [h_q], may be nullptr + + cudaStream_t stream; +}; + +struct GetDecodeSchedMetaParams { + int b; // batch size + int s_q; int block_size_n; int fixed_overhead_num_blocks; + + int topk, extra_topk; // -1 if sparse attention (or extra topk) is disabled + int *__restrict__ topk_length, *__restrict__ extra_topk_length; + + int *__restrict__ seqlens_k_ptr; // Only necessary for dense attention + + DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr; + int *__restrict__ num_splits_ptr; int num_sm_parts; - int topk; + + cudaStream_t stream; }; -struct SparsePrefillParams { +struct SparseAttnFwdParams { int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk; float sm_scale, sm_scale_div_log2; @@ -70,7 +150,10 @@ struct SparsePrefillParams { cutlass::bfloat16_t* __restrict__ q; // [s_q, h_q, d_qk] cutlass::bfloat16_t* __restrict__ kv; // [s_kv, h_kv, d_qk] int* __restrict__ indices; // [s_q, h_kv, topk] + float* __restrict__ attn_sink; // [h_q], may be nullptr + int* __restrict__ topk_length; // [s_q], may be nullptr + // Strides int stride_q_s_q; int stride_q_h_q; int stride_kv_s_kv; int stride_kv_h_kv; int stride_indices_s_q; int stride_indices_h_kv; @@ -80,5 +163,18 @@ struct SparsePrefillParams { float* __restrict__ max_logits; // [s_q, h_q] float* __restrict__ lse; // [s_q, h_q] + int num_sm; cudaStream_t stream; }; + +// We have some kernels that implement both prefill and decode modes in a single kernel (with different template instantiations). The following enum helps to distinguish the modes. +enum class SparseAttnFwdMode { + Prefill, // Normal prefill mode + DecodeWithSplitKV, // To trigger decoding mode for kernels that support both prefill and decode +}; + +template +inline constexpr bool is_decode_v = std::bool_constant::value; + +template +using SparseFwdArgT = std::conditional_t, SparseAttnDecodeParams, SparseAttnFwdParams>; diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp deleted file mode 100644 index 13541d4..0000000 --- a/csrc/pybind.cpp +++ /dev/null @@ -1,472 +0,0 @@ -// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#include -#include -#include -#include - -#include - -#include "params.h" -#include "smxx/get_mla_metadata.h" -#include "smxx/mla_combine.h" -#include "sm90/decode/dense/splitkv_mla.h" -#include "sm90/decode/sparse_fp8/splitkv_mla.h" -#include "sm90/prefill/sparse/fwd.h" -#include "sm100/decode/sparse_fp8/splitkv_mla.h" -#include "sm100/prefill/dense/interface.h" -#include "sm100/prefill/sparse/fwd.h" - -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -struct Arch { - int major; - int minor; - - bool is_sm90() const { - return major == 9 && minor == 0; - } - - bool is_sm100() const { - return major == 10; - } - - void assert_is_supported() const { - TORCH_CHECK(is_sm90() || is_sm100(), "Only SM90 and SM100 are supported"); - } -}; - -// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. SM90 Dense BF16, SM90 Sparse FP8, etc.) -struct DecodingAttnImplMeta { - int num_sm_parts; - int fixed_overhead_num_blocks; - int k_block_size; -}; - -DecodingAttnImplMeta get_attn_impl_meta( - Arch arch, - int sm_count, - int num_q_tokens_per_head_k, - int h_k, - std::optional h_q_, - bool is_fp8_kvcache, - bool is_sparse_attn -) { - if (arch.is_sm90()) { - if (is_sparse_attn) { - if (is_fp8_kvcache) { - TORCH_CHECK(h_q_.has_value()); - int h_q = h_q_.value(); - TORCH_CHECK(h_q % h_k == 0); - int s_q = num_q_tokens_per_head_k * h_k / h_q; - // FP8 + Sparse MLA - return { - std::max((sm_count/2) / h_k / (cutlass::ceil_div(h_q/h_k, 2*64) * s_q), 1), - 5, - 64 - }; - } else { - // Sparse BF16 MLA - TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM90"); - } - } else { - if (is_fp8_kvcache) { - // Dense FP8 MLA - TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); - } else { - // Dense BF16 MLA - return { - std::max(sm_count / h_k / cutlass::ceil_div(num_q_tokens_per_head_k, 64), 1), - 5, - 64 - }; - } - } - } else if (arch.is_sm100()) { - if (is_sparse_attn) { - if (is_fp8_kvcache) { - TORCH_CHECK(h_q_.has_value()); - int h_q = h_q_.value(); - TORCH_CHECK(h_q % h_k == 0); - int s_q = num_q_tokens_per_head_k * h_k / h_q; - // FP8 + Sparse MLA - return { - std::max(sm_count / h_k / (cutlass::ceil_div(h_q/h_k, 64) * s_q), 1), - 5, - 64 - }; - } else { - // Sparse BF16 MLA - TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM100"); - } - } else { - if (is_fp8_kvcache) { - // FP8 MLA - TORCH_CHECK(false, "FP8 Dence MLA is not supported on SM100"); - } else { - // Normal BF16 MLA - TORCH_CHECK(false, "BF16 Dence MLA is not supported on SM100"); - } - } - } else { - TORCH_CHECK(false, "Unsupported GPU architecture"); - } -} - - -std::vector -get_mla_decoding_metadata( - at::Tensor &seqlens_k, - const int num_q_tokens_per_head_k, - const int h_k, - const std::optional h_q, - const bool is_fp8_kvcache, - const std::optional topk -) { - bool is_sparse_attn = topk.has_value(); - CHECK_DEVICE(seqlens_k); - TORCH_CHECK(seqlens_k.is_contiguous()); - TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); - if (is_sparse_attn) - TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided"); - - int batch_size = seqlens_k.size(0); - int *seqlens_k_ptr = seqlens_k.data_ptr(); - auto options = seqlens_k.options(); - - auto dprops = at::cuda::getCurrentDeviceProperties(); - int sm_count = dprops->multiProcessorCount; - Arch arch = {dprops->major, dprops->minor}; - arch.assert_is_supported(); - DecodingAttnImplMeta attn_impl_meta = get_attn_impl_meta(arch, sm_count, num_q_tokens_per_head_k, h_k, h_q, is_fp8_kvcache, is_sparse_attn); - - auto tile_scheduler_metadata = torch::empty({attn_impl_meta.num_sm_parts, TileSchedulerMetaDataSize}, options); - auto num_splits = torch::empty({batch_size + 1}, options); - int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); - int *num_splits_ptr = num_splits.data_ptr(); - - at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - GetDecodingMetadataParams params = {}; - params.seqlens_k_ptr = seqlens_k_ptr; - params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; - params.num_splits_ptr = num_splits_ptr; - params.batch_size = batch_size; - params.block_size_n = attn_impl_meta.k_block_size; - params.fixed_overhead_num_blocks = attn_impl_meta.fixed_overhead_num_blocks; - params.num_sm_parts = attn_impl_meta.num_sm_parts; - params.topk = is_sparse_attn ? topk.value() : -1; - run_get_mla_metadata_kernel(params, stream); - - return {tile_scheduler_metadata, num_splits}; -} - -std::vector -fwd_kvcache_mla( - at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) - const int head_size_v, - const at::Tensor &seqlens_k, // batch_size - const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq - const float softmax_scale, - bool is_causal, - const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize - const at::Tensor &num_splits, // batch_size + 1 - const bool &is_fp8, - const std::optional &indices // None, or batch_size x seqlen_q x topk -) { - bool is_sparse_attn = indices.has_value(); - int topk = is_sparse_attn ? indices->size(-1) : -1; - - // Check the architecture - auto dprops = at::cuda::getCurrentDeviceProperties(); - Arch arch = {dprops->major, dprops->minor}; - arch.assert_is_supported(); - - // Check data types - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); - - if (!is_fp8) { - TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); - } else { - TORCH_CHECK(kcache.dtype() == torch::kFloat8_e4m3fn || kcache.dtype() == torch::kInt8 || kcache.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn or int8 or uint8"); - } - TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); - TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); - TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); - TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); - TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32"); - - // Check device - CHECK_DEVICE(q); - CHECK_DEVICE(kcache); - CHECK_DEVICE(seqlens_k); - CHECK_DEVICE(block_table); - CHECK_DEVICE(tile_scheduler_metadata); - CHECK_DEVICE(num_splits); - if (is_sparse_attn) CHECK_DEVICE(indices.value()); - - // Check layout - TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension"); - TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension"); - CHECK_CONTIGUOUS(seqlens_k); - TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); - CHECK_CONTIGUOUS(tile_scheduler_metadata); - CHECK_CONTIGUOUS(num_splits); - TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension"); - - const auto sizes = q.sizes(); - const int batch_size = sizes[0]; - const int seqlen_q_ori = sizes[1]; - const int num_heads_q = sizes[2]; - const int head_size_k = sizes[3]; - TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); - TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); - - const int max_num_blocks_per_seq = block_table.size(1); - const int num_blocks = kcache.size(0); - const int page_block_size = kcache.size(1); - const int num_heads_k = kcache.size(2); - TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64"); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - if (seqlen_q_ori == 1) { is_causal = false; } - - const int num_q_heads_per_hk = num_heads_q / num_heads_k; - const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; - const int num_heads = num_heads_k; - q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) - .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); - - CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); - if (!is_fp8) { - CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); - } else { - int bytes_per_token = 512 + 64*2 + (512/128)*4; - CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, bytes_per_token); - TORCH_CHECK(num_heads_k == 1, "Currently the number of k heads must be 1 when is_fp8_kvcache is True"); - TORCH_CHECK(kcache.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True"); - } - CHECK_SHAPE(seqlens_k, batch_size); - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); - TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); - CHECK_SHAPE(num_splits, batch_size+1); - if (is_sparse_attn) CHECK_SHAPE(indices.value(), batch_size, seqlen_q_ori, topk); - - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - - auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); - at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); - CHECK_CONTIGUOUS(softmax_lse); - - DecodingParams params = {}; - // Set the sizes. - params.b = batch_size; - params.s_q = seqlen_q_ori; - params.q_seq_per_hk = q_seq_per_hk; - params.seqlens_k_ptr = seqlens_k.data_ptr(); - params.h_q = num_heads_q; - params.h_k = num_heads_k; - params.num_blocks = num_blocks; - params.q_head_per_hk = num_q_heads_per_hk; - params.is_causal = is_causal; - params.d = head_size_k; - params.d_v = head_size_v; - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); - params.topk = topk; - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = kcache.data_ptr(); - params.o_ptr = out.data_ptr(); - params.indices_ptr = is_sparse_attn ? indices->data_ptr() : nullptr; - params.softmax_lse_ptr = softmax_lse.data_ptr(); - // All stride are in elements, not bytes. - params.q_batch_stride = q.stride(0); - params.k_batch_stride = kcache.stride(0); - params.o_batch_stride = out.stride(0); - params.q_row_stride = q.stride(-3); - params.k_row_stride = kcache.stride(1); - params.o_row_stride = out.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = kcache.stride(2); - params.o_head_stride = out.stride(-2); - params.indices_batch_stride = is_sparse_attn ? indices->stride(0) : 0; - params.indices_row_stride = is_sparse_attn ? indices->stride(1) : 0; - - params.block_table = block_table.data_ptr(); - params.block_table_batch_stride = block_table.stride(0); - params.page_block_size = page_block_size; - - params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); - params.num_sm_parts = tile_scheduler_metadata.size(0); - params.num_splits_ptr = num_splits.data_ptr(); - - const int total_num_splits = batch_size + params.num_sm_parts; - at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); - at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); - CHECK_CONTIGUOUS(softmax_lse_accum); - CHECK_CONTIGUOUS(out_accum); - params.total_num_splits = total_num_splits; - params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); - params.oaccum_ptr = out_accum.data_ptr(); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(head_size_k == 576); - - if (q_dtype == torch::kHalf) { -#ifdef FLASH_MLA_DISABLE_FP16 - TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA."); -#endif - } - - if (arch.is_sm90()) { - if (is_sparse_attn) { - if (is_fp8) { - TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90"); - sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream); - } else { - TORCH_CHECK(false, "Only FP8 kvcahe is supported for sparse MLA on SM90"); - } - } else { - if (is_fp8) { - TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); - } else { - if (q_dtype == torch::kBFloat16) { - sm90::run_flash_splitkv_mla_kernel(params, stream); - } else if (q_dtype == torch::kHalf) { -#ifndef FLASH_MLA_DISABLE_FP16 - sm90::run_flash_splitkv_mla_kernel(params, stream); -#endif - } else { - TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90"); - } - } - } - } else if (arch.is_sm100()) { - TORCH_CHECK(is_fp8 && is_sparse_attn, "Only FP8 + Sparse attention is supported on SM100"); - sm100::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream); - } else { - TORCH_CHECK(false, "Unsupported GPU architecture"); - } - - if (q_dtype == torch::kBFloat16) { - run_flash_mla_combine_kernel(params, stream); - } else if (q_dtype == torch::kHalf) { -#ifndef FLASH_MLA_DISABLE_FP16 - run_flash_mla_combine_kernel(params, stream); -#endif - } else { - TORCH_CHECK(false, "Unsupported tensor dtype for query"); - } - - out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) - .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); - softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) - .reshape({batch_size, num_heads_q, seqlen_q_ori}); - - return {out, softmax_lse}; -} - - -inline int int64_stride_to_int(int64_t orig_stride) { - if (orig_stride > std::numeric_limits::max()) { - TORCH_CHECK(false, "[Sparse TopK Attention] Stride exceeds int32 limit: ", orig_stride); - } - return static_cast(orig_stride); -} - -std::vector sparse_prefill_fwd( - const at::Tensor &q, - const at::Tensor &kv, - const at::Tensor &indices, - float sm_scale, - int d_v -) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm90 = dprops->major == 9; - bool is_sm100 = dprops->major == 10; - TORCH_CHECK(is_sm90 || is_sm100, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 or SM100 architectures"); - - CHECK_DEVICE(q); - CHECK_DEVICE(kv); - CHECK_DEVICE(indices); - - TORCH_CHECK(q.dtype() == torch::kBFloat16); - TORCH_CHECK(kv.dtype() == torch::kBFloat16); - TORCH_CHECK(indices.dtype() == torch::kInt32); - - int s_q = q.size(0); - int s_kv = kv.size(0); - int h_q = q.size(1); - int h_kv = kv.size(1); - int d_qk = q.size(2); - int topk = indices.size(2); - - CHECK_SHAPE(q, s_q, h_q, d_qk); - CHECK_SHAPE(kv, s_kv, h_kv, d_qk); - CHECK_SHAPE(indices, s_q, h_kv, topk); - - TORCH_CHECK(q.stride(-1) == 1); - TORCH_CHECK(kv.stride(-1) == 1); - TORCH_CHECK(indices.stride(-1) == 1); - - at::cuda::CUDAGuard device_guard{(char)q.get_device()}; - auto opts = q.options(); - at::Tensor out = torch::empty({s_q, h_q, d_v}, opts); - CHECK_CONTIGUOUS(out); - - at::Tensor buf_attn_score, max_logits, lse, p_sum; - max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); - lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); - CHECK_CONTIGUOUS(max_logits); - CHECK_CONTIGUOUS(lse); - - SparsePrefillParams params = { - s_q, s_kv, h_q, h_kv, d_qk, d_v, topk, - sm_scale, sm_scale * 1.44269504f, - - (cutlass::bfloat16_t*)q.data_ptr(), - (cutlass::bfloat16_t*)kv.data_ptr(), - (int*)indices.data_ptr(), - - int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), - int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)), - int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)), - - (cutlass::bfloat16_t*)out.data_ptr(), - (float*)max_logits.data_ptr(), - (float*)lse.data_ptr(), - - at::cuda::getCurrentCUDAStream().stream() - }; - - if (is_sm90) { - sm90::run_fwd_kernel(params); - } else if (is_sm100) { - sm100::run_fwd_kernel(params); - } else { - TORCH_CHECK(false, "Unknown architecture"); - } - - return {out, max_logits, lse}; -} - - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashMLA"; - m.def("get_mla_decoding_metadata", &get_mla_decoding_metadata); - m.def("fwd_kvcache_mla", &fwd_kvcache_mla); - m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun); - m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun); - m.def("sparse_prefill_fwd", &sparse_prefill_fwd); -} diff --git a/csrc/sm100/decode/head128/README.md b/csrc/sm100/decode/head128/README.md new file mode 100644 index 0000000..6cd9062 --- /dev/null +++ b/csrc/sm100/decode/head128/README.md @@ -0,0 +1 @@ +Head128 decoding kernels are located at `csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu` (for k_dim = 512) or simulated using 2x head64 kernel \ No newline at end of file diff --git a/csrc/sm100/decode/head64/config.h b/csrc/sm100/decode/head64/config.h new file mode 100644 index 0000000..401f3ac --- /dev/null +++ b/csrc/sm100/decode/head64/config.h @@ -0,0 +1,212 @@ +#pragma once + +#include "kernel.h" + +#include +#include +#include + +#include + +#include "defines.h" +#include "params.h" + +namespace sm100::decode::head64 { + +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::NamedBarrier; +using e8m0 = __nv_fp8_e8m0; +using e4m3 = cutlass::float_e4m3_t; +using namespace cute; + +enum NamedBarriers : uint32_t { + main_loop_sync = 0, + wg0_sync = 1, + wg0_warp02_sync = 2, + wg0_warp13_sync = 3, + everyone_sync = 4 +}; + +template +struct KernelTemplate { + +static constexpr int D_Q = MODEL_TYPE == ModelType::V32 ? 576 : 512; +static constexpr int D_K = D_Q; +static constexpr int D_V = 512; +static constexpr int D_NOPE = MODEL_TYPE == ModelType::V32 ? 512 : 448; +static constexpr int D_ROPE = 64; +static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64; +static constexpr bool V_HAVE_ROPE = MODEL_TYPE == ModelType::V32 ? false : true; +static constexpr int NUM_SCALES_EACH_TOKEN = MODEL_TYPE == ModelType::V32 ? 4 : 8; // Padding is included +static constexpr int TMA_K_STRIDE = MODEL_TYPE == ModelType::V32 ? D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE) : D_NOPE+2*D_ROPE; // Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks. +static_assert(D_NOPE + D_ROPE == D_Q); +static_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V)); + +static constexpr int B_H = 64; +static constexpr int B_TOPK = 64; +static constexpr int NUM_BUFS = 2; +static constexpr int NUM_INDEX_BUFS = 4; // Number of buffers for indices (tma_coords) & is_token_valid & scales +static constexpr int NUM_THREADS = 128*3; // 128 exp + 1/32 utcmma + 1/32 raw KV producer + 1/32 rope producer + 32 index+scale+valid_mask producer + 128 dequant +static constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN + +static constexpr int D_Q_SW128 = 512; +static constexpr int D_Q_SW64 = MODEL_TYPE == ModelType::V32 ? 64 : 0; +static_assert(D_Q_SW128 + D_Q_SW64 == D_Q); +static constexpr int K_ROPE_SW = MODEL_TYPE == ModelType::V32 ? 64 : 128; // RoPE part stored in SW64 (for V32) or SW128 (for MODEL1), in bytes + +template< + typename Shape_Q_SW128, typename TMA_Q_SW128, + typename Shape_O, typename TMA_O +> +struct TmaParams { + Shape_Q_SW128 shape_Q_SW128; TMA_Q_SW128 tma_Q_SW128; + Shape_O shape_O; TMA_O tma_O; + CUtensorMap tensor_map_q_sw64; // Invalid if D_Q_SW64 == 0 + CUtensorMap tensor_map_kv_nope; + CUtensorMap tensor_map_kv_rope; + CUtensorMap tensor_map_extra_kv_nope; + CUtensorMap tensor_map_extra_kv_rope; +}; + +// Tensor memory columns +struct tmem_cols { + // 0 ~ 256: output + // 256 ~ 256 + 64*D_Q/256: Q + // 400 ~ 464: P + static constexpr int O = 0; + static constexpr int Q = 256; + static constexpr int Q_Tail = 256 + B_H*D_NOPE/2/128; + static constexpr int P = 400; +}; + +template +using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutQ_SW128 = SmemLayoutQTiles; + +using SmemLayoutOBuf = decltype(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} +)); + +using SmemLayoutOBuf_TMA = decltype(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64>>{} +)); // A TMA tile + +static_assert(D_V == 512); +using SmemLayoutOAccumBuf = Layout< + Shape, Int>, + Stride, _1> // We use stride = 520 here to avoid bank conflict +>; + +using SmemLayoutS = decltype(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +)); + +template +using SmemLayoutKTiles_SW128 = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTiles_DualGemm_SW128 = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTilesTransposed_SW128 = decltype(composition( + SmemLayoutKTiles_SW128{}, + Layout< + Shape, Int>, + Stride, _1> + >{} +)); + +template +using SmemLayoutKTiles_SW64 = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW64_Atom{}, + Shape, Int<32*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTiles_DualGemm_SW64 = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW64_Atom{}, + Shape, Int<32*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTilesTransposed_SW64 = decltype(composition( + SmemLayoutKTiles_SW64{}, + Layout< + Shape, Int>, + Stride, _1> + >{} +)); + +struct SharedMemoryPlan { + union { + struct { + array_aligned> q; + bf16 q_sw64[B_H*D_Q_SW64]; // NOTE D_Q_SW64 may be 0 but array_aligned will have a size of 16, so we use array here. The former tensor (`q`) promises its alignment. + union { + array_aligned> o_buf; + array_aligned> o_accum_buf; + } o; + } qo; + struct { + struct { + array_aligned nope; // NoPE part, dequantized + array_aligned rope; // RoPE part, dequantized. SW64 in v32 mode, SW128 in MODEL1 mode + } dequant[NUM_BUFS]; + static_assert(sizeof(dequant) >= sizeof(bf16) * (B_H*D_Q)); // So that Q does not covers raw_nope + array_aligned raw_nope[NUM_BUFS]; // Raw (quantized) NoPE part + } kv; + } u; + union { + float4 p_exchange_buf[4][16 * B_TOPK / 4]; + array_aligned> s; + } s_p; + CUTE_ALIGNAS(16) float rowwise_max_buf[128]; + char is_token_valid[NUM_INDEX_BUFS][B_TOPK/8]; + int tma_coord[NUM_INDEX_BUFS][B_TOPK]; + e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN]; + array_aligned tmem_start_addr; + transac_bar_t bar_last_store_done; + transac_bar_t bar_q_tma, bar_q_utccp; + transac_bar_t bar_rope_ready[NUM_BUFS]; + transac_bar_t bar_nope_ready[NUM_BUFS]; + transac_bar_t bar_raw_ready[NUM_BUFS], bar_raw_free[NUM_BUFS]; + transac_bar_t bar_valid_coord_scale_ready[NUM_INDEX_BUFS], bar_valid_coord_scale_free[NUM_INDEX_BUFS]; + transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS], bar_sv_done[NUM_BUFS]; +}; + +using TiledMMA_P = decltype(make_tiled_mma( + SM100_MMA_F16BF16_WS_TS_NOELECT{} +)); // *2 for dual gemm + +using TiledMMA_O = decltype(make_tiled_mma( + SM100_MMA_F16BF16_WS_SS_NOELECT{} +)); + +template +static __device__ void +flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams ¶ms, const TmaParam &tma_params); + +static void run(const SparseAttnDecodeParams ¶ms); + +}; + +} \ No newline at end of file diff --git a/csrc/sm100/decode/head64/instantiations/model1.cu b/csrc/sm100/decode/head64/instantiations/model1.cu new file mode 100644 index 0000000..868ff0c --- /dev/null +++ b/csrc/sm100/decode/head64/instantiations/model1.cu @@ -0,0 +1,8 @@ +#include "../kernel.cuh" + +namespace sm100::decode::head64 { + +template +void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); + +} diff --git a/csrc/sm100/decode/head64/instantiations/v32.cu b/csrc/sm100/decode/head64/instantiations/v32.cu new file mode 100644 index 0000000..08ce093 --- /dev/null +++ b/csrc/sm100/decode/head64/instantiations/v32.cu @@ -0,0 +1,8 @@ +#include "../kernel.cuh" + +namespace sm100::decode::head64 { + +template +void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); + +} diff --git a/csrc/sm100/decode/head64/kernel.cuh b/csrc/sm100/decode/head64/kernel.cuh new file mode 100644 index 0000000..7c46921 --- /dev/null +++ b/csrc/sm100/decode/head64/kernel.cuh @@ -0,0 +1,968 @@ +#include "kernel.h" + +#include +#include +#include +#include +#include +#include + +#include "kerutils/kerutils.cuh" + +#include "utils.h" +#include "sm100/helpers.h" + +#include "config.h" + +namespace sm100::decode::head64 { + +template +template +__device__ void +KernelTemplate +::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams ¶ms, const TmaParam &tma_params) { +#if defined(KERUTILS_ENABLE_SM100A) + const int s_q_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int idx_in_warpgroup = threadIdx.x % 128; + const int warp_idx = cutlass::canonical_warp_idx_sync(); + const int lane_idx = threadIdx.x % 32; + + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + + if (warp_idx == 0 && elect_one_sync()) { + cute::prefetch_tma_descriptor(tma_params.tma_Q_SW128.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + cute::prefetch_tma_descriptor(&tma_params.tensor_map_q_sw64); + cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope); + cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope); + } + + if (warp_idx == 0) { + if (elect_one_sync()) { + plan.bar_last_store_done.init(128); + plan.bar_q_tma.init(1); + plan.bar_q_utccp.init(1); + for (int i = 0; i < NUM_BUFS; ++i) { + plan.bar_rope_ready[i].init(1); + plan.bar_nope_ready[i].init(128); + plan.bar_raw_ready[i].init(1); + plan.bar_raw_free[i].init(128); + plan.bar_qk_done[i].init(1); + plan.bar_so_ready[i].init(128); + plan.bar_sv_done[i].init(1); + } + for (int i = 0; i < NUM_INDEX_BUFS; ++i) { + plan.bar_valid_coord_scale_ready[i].init(32); + plan.bar_valid_coord_scale_free[i].init(128+128+1+1); + } + cutlass::arch::fence_barrier_init(); + } + cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data()); + KU_TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); + cute::TMEM::Allocator1Sm().release_allocation_lock(); + } + __syncthreads(); + + struct MainLoopArgs { + int batch_idx, start_block_idx, end_block_idx; + bool is_no_split; int n_split_idx; + bool bar_phase_batch_rel; // Bar phase of barriers that are used once per batch + int topk_length, extra_topk_length, num_orig_kv_blocks; + bool is_last_batch; + }; + + auto run_main_loop = [&](auto f) { + // NOTE Putting the following code outside the warpgroup specialization switch results in register spilling. + // [[maybe_unused]] int begin_req_idx, end_req_idx, sched_begin_block_idx, sched_end_block_idx, begin_n_split_idx, is_first_req_splitted, is_last_req_splitted; + DecodingSchedMeta sched_meta; + KU_LDG_256( + params.tile_scheduler_metadata_ptr + partition_idx, + &sched_meta, + ".nc", + "no_allocate", + "evict_normal", + "256B" + ); + + if (sched_meta.begin_req_idx >= params.b) { + return; + } + + bool bar_phase_batch_rel = 0; + #pragma unroll 1 + for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx, bar_phase_batch_rel ^= 1) { + int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk; + int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK); + int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk; + int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0 + int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; + int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK; + bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false); + int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx); + + MainLoopArgs args = { + batch_idx, start_block_idx, end_block_idx, + !is_split, n_split_idx, + bar_phase_batch_rel, + topk_length, extra_topk_length, + orig_topk_padded / B_TOPK, + batch_idx == sched_meta.end_req_idx + }; + + f(args); + NamedBarrier(NUM_THREADS, NamedBarriers::everyone_sync).arrive_and_wait_unaligned(); + } + }; + + struct RingState { + int buf_idx = 0; + bool bar_phase = 0; + int index_buf_idx = 0; + bool index_bar_phase = 0; + CUTE_DEVICE void update() { + bar_phase ^= (buf_idx == NUM_BUFS-1); + buf_idx = (buf_idx+1) % NUM_BUFS; + index_bar_phase ^= (index_buf_idx == NUM_INDEX_BUFS-1); + index_buf_idx = (index_buf_idx+1) % NUM_INDEX_BUFS; + } + }; + RingState rs; + + if (warpgroup_idx == 0) { + // Scale & Exp warpgroup + // The same technique (and highly similar code) as the sm100 sparse prefill head64 kernel + cutlass::arch::warpgroup_reg_alloc<224>(); + + constexpr int B_EPI = 64; // Must be equal to the size of the swizzle atom + Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_buf.data()), SmemLayoutOBuf{}); + bf16* sO_bases[B_EPI/8]; // 64 is the size of the swizzle atom (in number of elements) while 8 is the width of each write + CUTE_UNROLL + for (int i = 0; i < B_EPI/8; ++i) + sO_bases[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*128 + i*8); + + const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2}; + bf16* sS_base = plan.s_p.s.data() + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2); + + float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg((float*)params.attn_sink + (idx_in_warpgroup%64)) * CUDART_L2E_F; + + run_main_loop([&](const MainLoopArgs &args) { + cute::tma_store_wait<0>(); + plan.bar_last_store_done.arrive(); + + float mi = MAX_INIT_VAL; + float li = 0.0f; + float real_mi = -CUDART_INF_F; + + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { + NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // Make sure all intermediate buffers (including p_exchange_buf, rowwise max_buf) are free + plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); // Put the barrier wait here for more code reordering space + plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase); + ku::tcgen05_after_thread_sync(); + + // Load P + float p[B_TOPK/2], p_peer[B_TOPK/2]; + if (warp_idx < 2) { + ku::tmem_ld_32dp32bNx(tmem_cols::P, p); + ku::tmem_ld_32dp32bNx(tmem_cols::P+32, p_peer); + } else { + ku::tmem_ld_32dp32bNx(tmem_cols::P, p_peer); + ku::tmem_ld_32dp32bNx(tmem_cols::P+32, p); + } + cutlass::arch::fence_view_async_tmem_load(); + ku::tcgen05_before_thread_sync(); + + // Reduce within shared mem + { + // Store + // Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/4; ++i) + plan.s_p.p_exchange_buf[warp_idx^2][i*32 + lane_idx] = *(float4*)(p_peer + i*4); + NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1)); // Synchronize between warp 0 and warp 2, as well as warp 1 - warp 3 + // Load + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/4; ++i) { + float2 t[2]; + *(float4*)t = plan.s_p.p_exchange_buf[warp_idx][i*32 + lane_idx]; + float2* cur_p = (float2*)(p + i*4); + cur_p[0] = ku::float2_add(cur_p[0], t[0]); + cur_p[1] = ku::float2_add(cur_p[1], t[1]); + } + } + + // Since dual gemm is utilized, the layout of P in register now look like: + // + // 32 32 + // +-------+-------+ + // | | | + // 32 | Warp0 | Warp2 | + // | | | + // +-------+-------+ + // | | | + // 32 | Warp1 | Warp3 | + // | | | + // +-------+-------+ + + // Mask + uint32_t valid_mask = *((uint32_t*)plan.is_token_valid[rs.index_buf_idx] + (idx_in_warpgroup>=64?1:0)); + CUTE_UNROLL + for (int i = 0; i < B_TOPK/2; i += 1) { + if (!(valid_mask>>i&1)) + p[i] = -CUDART_INF_F; + } + + // Get rowwise max of Pi + float cur_pi_max = -CUDART_INF_F; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2); i += 1) { + cur_pi_max = max(cur_pi_max, p[i]); + } + cur_pi_max *= params.sm_scale_div_log2; + + plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max; + NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // This also separates "reading p_exchange_buf" and "writing S" + plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive(); + cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]); + real_mi = max(real_mi, cur_pi_max); + bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f); + // By this point: + // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...) + // - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127) + + // Calc scale factor, and scale li + float new_max, scale_for_old; + if (!should_scale_o) { + // Don't scale O + scale_for_old = 1.0f; + new_max = mi; + } else { + new_max = max(cur_pi_max, mi); + scale_for_old = exp2f(mi - new_max); + } + mi = new_max; // mi is still identical within each row + + // Calculate S + __nv_bfloat162 s[(B_TOPK/2)/2]; + float2 neg_new_max = float2 {-new_max, -new_max}; + float2 cur_sum = float2 {0.0f, 0.0f}; + CUTE_UNROLL + for (int i = 0; i < (B_TOPK/2)/2; i += 1) { + float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale, neg_new_max); + d.x = exp2f(d.x); + d.y = exp2f(d.y); + cur_sum = ku::float2_add(cur_sum, d); + s[i] = __float22bfloat162_rn(d); + } + li = fma(li, scale_for_old, (cur_sum.x + cur_sum.y)); + + // Write S + CUTE_UNROLL + for (int i = 0; i < B_TOPK/2/8; i += 1) { + *(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4); + } + + // Scale O + if (block_idx != args.start_block_idx && should_scale_o) { + float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old}; + ku::tcgen05_after_thread_sync(); + + static constexpr int CHUNK_SIZE = 64; + float2 o[CHUNK_SIZE/2]; + CUTE_UNROLL + for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) { + // Load O + ku::tmem_ld_32dp32bNx(tmem_cols::O + chunk_idx*CHUNK_SIZE, o); + cutlass::arch::fence_view_async_tmem_load(); + + // Mult + for (int i = 0; i < CHUNK_SIZE/2; ++i) { + o[i] = ku::float2_mul(o[i], scale_for_old_float2); + } + + // Store O + ku::tmem_st_32dp32bNx(tmem_cols::O + chunk_idx*CHUNK_SIZE, o); + cutlass::arch::fence_view_async_tmem_store(); + } + ku::tcgen05_before_thread_sync(); + } + + fence_view_async_shared(); + plan.bar_so_ready[rs.buf_idx].arrive(); + + if (block_idx != args.end_block_idx-1) { + rs.update(); // Don't update rs for the last round since we want to wait for the last SV gemm + } + } + + if (real_mi == -CUDART_INF_F) { + // real_mi == -CUDART_INF_F <=> No valid TopK indices + // We set li to 0 to fit the definition that li := exp(x[i] - mi) + li = 0.0f; + mi = -CUDART_INF_F; + } + + // Exchange li + plan.rowwise_max_buf[idx_in_warpgroup] = li; + NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); + li += plan.rowwise_max_buf[idx_in_warpgroup^64]; + + // Store li + if (idx_in_warpgroup < B_H) { + if (args.is_no_split) { + float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li)); + cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse; + float* gSoftmaxLse = (float*)params.lse + args.batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + idx_in_warpgroup; + *gSoftmaxLse = cur_lse; + } else { + float cur_lse = log2f(li) + mi; + float* gSoftmaxLseAccum = (float*)params.lse_accum + args.n_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + idx_in_warpgroup; + *gSoftmaxLseAccum = cur_lse; + } + } + + plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase); + rs.update(); + ku::tcgen05_after_thread_sync(); + + if (args.is_last_batch) { + cudaTriggerProgrammaticLaunchCompletion(); + } + + if (args.is_no_split) { + Tensor tma_gO = flat_divide( + tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, args.batch_idx), + Shape, Int<64>>{} + )(_, _, _0{}, _); + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + Tensor tma_sO = flat_divide( + sO, + Shape, Int<64>>{} + )(_, _, _0{}, _); + + float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li + exp2f(attn_sink - mi)); + float2 o_scale_float2 = {o_scale, o_scale}; + float2 o[B_EPI/2]; + __nv_bfloat162 o_bf16[B_EPI/2]; + CUTE_UNROLL + for (int i = 0; i < (D_V/2) / B_EPI; ++i) { + // Load + ku::tmem_ld_32dp32bNx(tmem_cols::O + i*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + // Scale & Convert + CUTE_UNROLL + for (int j = 0; j < B_EPI/2; ++j) { + o[j] = ku::float2_mul(o[j], o_scale_float2); + o_bf16[j] = __float22bfloat162_rn(o[j]); + } + // Store + int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); + CUTE_UNROLL + for (int j = 0; j < B_EPI / 8; ++j) + *(__int128_t*)(sO_bases[j] + col_base*B_H) = *(__int128_t*)(&o_bf16[j*4]); + // Sync + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); + // S -> G + if (warp_idx == 0 && elect_one_sync()) { + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(tma_sO(_, _, col_base/64)), + thr_tma.partition_D(tma_gO(_, _, col_base/64)) + ); + } + if (warp_idx == 1 && elect_one_sync()) { + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(tma_sO(_, _, col_base/64 + (D_V/4)/64)), + thr_tma.partition_D(tma_gO(_, _, col_base/64 + (D_V/4)/64)) + ); + } + } + cute::tma_store_arrive(); + } else { + float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li); // Here we leave attn_sink to the combine kernel, otherwise attn_sink will take effect for multiple times + float2 o_scale_float2 = {o_scale, o_scale}; + constexpr int B_EPI = 64; + float2 o[B_EPI/2]; + Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_accum_buf.data()), SmemLayoutOAccumBuf{}); + CUTE_UNROLL + for (int i = 0; i < (D_V/2) / B_EPI; ++i) { + // Load + ku::tmem_ld_32dp32bNx(tmem_cols::O + i*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + // Scale & Convert + CUTE_UNROLL + for (int j = 0; j < B_EPI/2; ++j) + o[j] = ku::float2_mul(o[j], o_scale_float2); + // Store + int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); + CUTE_UNROLL + for (int j = 0; j < B_EPI / 4; ++j) + *(__int128_t*)&sO(idx_in_warpgroup%64, col_base + j*4) = *(__int128_t*)(&o[j*2]); + } + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); + if (elect_one_sync()) { + CUTE_UNROLL + for (int local_row = 0; local_row < B_H/4; ++local_row) { + int smem_row = local_row*4 + warp_idx; + SM90_BULK_COPY_S2G::copy( + &sO(smem_row, _0{}), + (float*)params.o_accum + args.n_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + smem_row*params.stride_o_accum_h_q, + D_V*sizeof(float) + ); + } + cute::tma_store_arrive(); + } + } + }); + + if (warp_idx == 0) { + cute::TMEM::Allocator1Sm().free(0, 512); + } + } else if (warpgroup_idx == 1) { + cutlass::arch::warpgroup_reg_dealloc<72>(); + const int warp_idx = cutlass::canonical_warp_idx_sync(); // Missing this leads to reg spilling + + if (warp_idx == 4 && elect_one_sync()) { + + // MMA Warp + run_main_loop([&](const MainLoopArgs &args) { + if (args.start_block_idx >= args.end_block_idx) { + ku::trap(); + } + // Issue Q (SW128) G->S + { + Tensor gQ = tma_params.tma_Q_SW128.get_tma_tensor(tma_params.shape_Q_SW128)(_, _, s_q_idx, args.batch_idx); + Tensor sQ = make_tensor(make_smem_ptr(plan.u.qo.q.data()), SmemLayoutQ_SW128{}); + ku::launch_tma_copy( + tma_params.tma_Q_SW128, + gQ, + sQ, + plan.bar_q_tma, + TMA::CacheHintSm90::EVICT_FIRST + ); + } + // Issue Q (SW64) G -> S + if constexpr (D_Q_SW64 > 0) { + cute::SM90_TMA_LOAD_5D::copy( + &tma_params.tensor_map_q_sw64, + (uint64_t*)&plan.bar_q_tma, + (uint64_t)TMA::CacheHintSm90::EVICT_FIRST, + plan.u.qo.q_sw64, + 0, 0, 0, + s_q_idx, args.batch_idx + ); + } + plan.bar_q_tma.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16)); + plan.bar_q_tma.wait(args.bar_phase_batch_rel); + ku::tcgen05_after_thread_sync(); + // Issue Q (SW128) UTCCP + { + UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc( + make_tensor( + make_smem_ptr(plan.u.qo.q.data()), + tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64>>{} // *2 to leverage dual GEMM + ) + ) + ); + static_assert(D_Q_SW128%128 == 0); + CUTE_UNROLL + for (int tile_idx = 0; tile_idx < D_Q_SW128/128; ++tile_idx) { + // Each tile: 64 x (64*2) logically, 128 x 64 bf16 on TMEM + CUTE_UNROLL + for (int subtile_idx = 0; subtile_idx < 64/16; ++subtile_idx) { + // Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM + SM100_UTCCP_128dp256bit_1cta::copy( + sQ_desc + (tile_idx*(B_H*128) + subtile_idx*16) * 2 / 16, + tmem_cols::Q + tile_idx*32 + subtile_idx*8 + ); + } + } + } + // Issue Q (SW64) UTCCP + if constexpr (D_Q_SW64 > 0) { + UMMA::SmemDescriptor sQ_SW64_desc = UMMA::make_umma_desc( + make_tensor( + make_smem_ptr(plan.u.qo.q_sw64), + tile_to_shape( + UMMA::Layout_K_SW64_Atom{}, + Shape, Int<32>>{} // *2 to leverage dual GEMM + ) + ) + ); + static_assert(D_Q_SW64%64 == 0); + CUTE_UNROLL + for (int tile_idx = 0; tile_idx < D_Q_SW64/64; ++tile_idx) { + // Each tile: 64 x (32*2) logically, 128 x 32 bf16 on TMEM + CUTE_UNROLL + for (int subtile_idx = 0; subtile_idx < 32/16; ++subtile_idx) { + // Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM + SM100_UTCCP_128dp256bit_1cta::copy( + sQ_SW64_desc + (tile_idx*(B_H*64) + subtile_idx*16) * 2 / 16, + tmem_cols::Q + (B_H*D_Q_SW128/2/128) + tile_idx*16 + subtile_idx*8 + ); + } + } + } + ku::umma_arrive_noelect(plan.bar_q_utccp); + + // Allocate tmem tensors + TiledMMA tiled_mma_P = TiledMMA_P{}; + TiledMMA tiled_mma_O = TiledMMA_O{}; + // NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm) + Tensor tP = partition_fragment_C(tiled_mma_P, Shape, _128>{}); + Tensor tO = partition_fragment_C(tiled_mma_O, Shape, Int>{}); + tP.data().get() = tmem_cols::P; + tO.data().get() = tmem_cols::O; + + // Wait for UTCCP + plan.bar_q_utccp.wait(args.bar_phase_batch_rel); + ku::tcgen05_after_thread_sync(); + + // Mainloop + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { + if constexpr (MODEL_TYPE == ModelType::V32) { + // V3.2: RoPE behaves like an extra block with size 64, so we can do RoPE first + // QK RoPE + plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase); + ku::tcgen05_after_thread_sync(); + Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A( + partition_shape_A(tiled_mma_P, Shape, Int>{}) + ); + tQ_rope.data().get() = tmem_cols::Q_Tail; + Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].rope.data()), SmemLayoutKTiles_DualGemm_SW64<2/2>{}); + ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true); + + // QK NoPE + plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase); + ku::tcgen05_after_thread_sync(); + Tensor tQ_nope = tiled_mma_P.get_slice(_0{}).make_fragment_A( + partition_shape_A(tiled_mma_P, Shape, Int>{}) + ); + tQ_nope.data().get() = tmem_cols::Q; + Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{}); + ku::utcmma_ts(tiled_mma_P, tQ_nope, sK_nope, tP, false); + } else { + // MODEL1: RoPE is the last 64 dims within the full 512 dim, which couples with the last 64 dim from the NoPE part when performing dual GEMM. i.e. + // + // logical view: |0|1|2|3|4|5|6|7| (where 7 is the RoPE part) + // dual gemm's view: + // |0|2|4|6| + // |1|3|5|7| + // + // So we must wait for both the NoPE and the RoPE part, and then perform dual GEMM + plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase); + plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase); + ku::tcgen05_after_thread_sync(); + + Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A( + partition_shape_A(tiled_mma_P, Shape, Int>{}) + ); + tQ.data().get() = tmem_cols::Q; + Tensor sK = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{}); + ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true); + } + ku::umma_arrive_noelect(plan.bar_qk_done[rs.buf_idx]); + + // SV + plan.bar_so_ready[rs.buf_idx].wait(rs.bar_phase); + ku::tcgen05_after_thread_sync(); + Tensor sS = make_tensor(make_smem_ptr(plan.s_p.s.data()), SmemLayoutS{}); + Tensor sV = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTilesTransposed_SW128{}); // NOTE: For MODEL1, it "expands" to the RoPE part. + ku::utcmma_ss(tiled_mma_O, sS, sV, tO, block_idx == args.start_block_idx); + ku::umma_arrive_noelect(plan.bar_sv_done[rs.buf_idx]); + + rs.update(); + } + }); + } else if (warp_idx == 5 && elect_one_sync()) { + // Raw KV NoPE retrieval warp + run_main_loop([&](const MainLoopArgs &args) { + plan.bar_q_utccp.wait(args.bar_phase_batch_rel); + plan.bar_last_store_done.wait(args.bar_phase_batch_rel); + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { + plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); + plan.bar_raw_free[rs.buf_idx].wait(rs.bar_phase^1); + int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0); + int4 nxt_cur_indices; + CUTE_UNROLL + for (int row = 0; row < B_TOPK; row += 4) { + if (row+4 < B_TOPK) + nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4); + ku::tma_gather4( + block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope, + plan.bar_raw_ready[rs.buf_idx], + plan.u.kv.raw_nope[rs.buf_idx].data() + D_NOPE*row, + 0, + cur_indices, + (int64_t)TMA::CacheHintSm90::EVICT_LAST + ); + cur_indices = nxt_cur_indices; + } + plan.bar_raw_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_NOPE*sizeof(e4m3)); + plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive(); + rs.update(); + } + }); + } else if (warp_idx == 6 && elect_one_sync()) { + // KV RoPE retrieval warp + run_main_loop([&](const MainLoopArgs &args) { + plan.bar_q_utccp.wait(args.bar_phase_batch_rel); + plan.bar_last_store_done.wait(args.bar_phase_batch_rel); + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { + plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); + if constexpr (MODEL_TYPE == ModelType::V32) { + plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase^1); + } else { + plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1); + } + int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0); + int4 nxt_cur_indices; + CUTE_UNROLL + for (int row = 0; row < B_TOPK; row += 4) { + if (row+4 < B_TOPK) + nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4); + CUTE_UNROLL + for (int t = 0; t < D_ROPE/(K_ROPE_SW/2); ++t) { + ku::tma_gather4( + block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope, + plan.bar_rope_ready[rs.buf_idx], + plan.u.kv.dequant[rs.buf_idx].rope.data() + (K_ROPE_SW/2)*row + t*B_TOPK*(K_ROPE_SW/2), + t*(K_ROPE_SW/2), + cur_indices, + (int64_t)TMA::CacheHintSm90::EVICT_LAST + ); + } + cur_indices = nxt_cur_indices; + } + plan.bar_rope_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16)); + plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive(); + rs.update(); + } + }); + } else if (warp_idx == 7) { + // Indices transformation warp + // Responsible for generating: TMA coordinates, scale factors, and valid masks + static_assert(B_TOPK == 64); + static constexpr int tma_coords_step_per_token = MODEL_TYPE == ModelType::V32 ? 656/TMA_K_STRIDE : 576/TMA_K_STRIDE; + int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE > 512 + int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE; + uint8_t* k_scales_ptr = + MODEL_TYPE == ModelType::V32 ? + (uint8_t*)params.kv + D_NOPE : + (uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE); + uint8_t* extra_k_scales_ptr = + MODEL_TYPE == ModelType::V32 ? + (uint8_t*)params.extra_kv + D_NOPE : + (uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE); + + run_main_loop([&](const MainLoopArgs &args) { + int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*s_q_idx; + int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*s_q_idx; + + struct IsOrigBlock {}; + struct IsExtraBlock {}; + auto process_one_block = [&](int block_idx, auto is_extra_block_t) { + static constexpr bool IS_EXTRA_BLOCK = std::is_same_v; + int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size; + int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block; + [[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row; + uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr; + int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block; + + int abs_pos, my_indices[2]; + if (!IS_EXTRA_BLOCK) { + abs_pos = block_idx*B_TOPK + lane_idx*2; + *(int2*)my_indices = __ldg((int2*)(indices + abs_pos)); + } else { + abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2; + *(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos)); + } + plan.bar_valid_coord_scale_free[rs.index_buf_idx].wait(rs.index_bar_phase^1); + + int tma_coords[2]; + e8m0 scales[2*NUM_SCALES_EACH_TOKEN]; + char valid_mask = 0; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + int block_idx, idx_in_block; + block_idx = (unsigned int)my_indices[i] / cur_block_size; + idx_in_block = (unsigned int)my_indices[i] % cur_block_size; + bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length)); + valid_mask |= is_token_valid << i; + tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN. + if constexpr (MODEL_TYPE == ModelType::V32) { + int64_t offset = is_token_valid ? block_idx*cur_k_block_stride + idx_in_block*cur_k_row_stride : 0; + float4 cur_scale_fp32 = __ldg((float4*)(cur_k_scales_ptr + offset)); + e8m0 res[4]; + *(__nv_fp8x2_storage_t*)(res+0) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.x, cur_scale_fp32.y}, __NV_NOSAT, cudaRoundZero); + *(__nv_fp8x2_storage_t*)(res+2) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.z, cur_scale_fp32.w}, __NV_NOSAT, cudaRoundZero); + if (!is_token_valid) *(uint32_t*)res = (uint32_t)0; + *(uint32_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = *(uint32_t*)(res); + } else { + int64_t offset = block_idx*cur_k_block_stride + idx_in_block*8; // Each token has 7 scale factors with an extra 1B padding + uint64_t scalesx8 = is_token_valid ? __ldg((uint64_t*)(cur_k_scales_ptr + offset)) : 0; + *(uint64_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = scalesx8; + } + } + valid_mask <<= lane_idx%4*2; + valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1); + valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2); + if constexpr (MODEL_TYPE == ModelType::V32) { + *(uint64_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(uint64_t*)scales; + } else { + *(__int128_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(__int128_t*)scales; + } + *(int2*)(plan.tma_coord[rs.index_buf_idx] + lane_idx*2) = *(int2*)tma_coords; + if (lane_idx%4 == 0) + plan.is_token_valid[rs.index_buf_idx][lane_idx/4] = valid_mask; + + plan.bar_valid_coord_scale_ready[rs.index_buf_idx].arrive(); + rs.update(); + }; + + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) { + process_one_block(block_idx, IsOrigBlock{}); + } + + CUTE_NO_UNROLL + for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) { + process_one_block(block_idx, IsExtraBlock{}); + } + }); + } else { + run_main_loop([&](const MainLoopArgs &args) {}); + } + } else { + // Dequant warpgroup + cutlass::arch::warpgroup_reg_alloc<208>(); + + // 8 threads per token + constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = D_NOPE/(GROUP_SIZE*8); + int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE; + Tensor nope0 = make_tensor(make_smem_ptr(plan.u.kv.dequant[0].nope.data()), SmemLayoutKTiles_SW128{}); + bf16* nope0_base = &nope0(group_idx, idx_in_group*8); + bf16* nope1_base = nope0_base + (plan.u.kv.dequant[1].nope.data() - plan.u.kv.dequant[0].nope.data()); + e4m3* raw_nope0_base = plan.u.kv.raw_nope[rs.buf_idx].data() + group_idx*D_NOPE + idx_in_group*8; + e4m3* raw_nope1_base = raw_nope0_base + B_H*D_NOPE; + + run_main_loop([&](const MainLoopArgs &args) { + // plan.bar_last_store_done.wait(args.bar_phase_batch_rel); // No need to wait since the raw nope producer must wait + plan.bar_q_utccp.wait(args.bar_phase_batch_rel); + + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { + plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); + plan.bar_raw_ready[rs.buf_idx].wait(rs.bar_phase); + plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1); + uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(rs.buf_idx == 0 ? nope0_base : nope1_base); + e4m3* raw_nope_base = rs.buf_idx == 0 ? raw_nope0_base : raw_nope1_base; + auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) { + asm volatile ("st.weak.shared::cta.b128 [%0], %1;\n" + : + : "r"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), "q"(data) // 2 for sizeof(bf16) + ); // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS + }; + auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t { + return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*D_NOPE + local_col_idx*(GROUP_SIZE*8)); + }; + // The following code suffers from a 2-way bank conflict when reading from SMEM. + if constexpr (MODEL_TYPE == ModelType::V32) { + CUTE_UNROLL + for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) { + int row_idx = local_row_idx*NUM_GROUPS + group_idx; + bf16 scales[4]; + e8m0 scales_e8m0[4]; + *(uint32_t*)scales_e8m0 = *(uint32_t*)plan.scales[rs.index_buf_idx][row_idx]; + *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0)); + *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2)); + + uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0); + CUTE_UNROLL + for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) { + ku::nve4m3x2 data_fp8[4]; + ku::nvbf16x2 data_bf16[4]; + *(uint64_t*)data_fp8 = cur_data_fp8x8; + if (local_col_idx+1 < COLS_PER_GROUP) + cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1); + bf16 scale = scales[local_col_idx / (D_NOPE/(GROUP_SIZE*8)/4)]; + CUTE_UNROLL + for (int i = 0; i < 4; ++i) { + data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale)); + } + st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16); + } + } + } else { + CUTE_UNROLL + for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) { + int row_idx = local_row_idx*NUM_GROUPS + group_idx; + bf16 scales[8]; + e8m0 scales_e8m0[8]; + *(uint64_t*)scales_e8m0 = *(uint64_t*)plan.scales[rs.index_buf_idx][row_idx]; + *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0)); + *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2)); + *(__nv_bfloat162_raw*)(scales+4) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+4)); + *(__nv_bfloat162_raw*)(scales+6) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+6)); + + uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0); + CUTE_UNROLL + for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) { + ku::nve4m3x2 data_fp8[4]; + ku::nvbf16x2 data_bf16[4]; + *(uint64_t*)data_fp8 = cur_data_fp8x8; + if (local_col_idx+1 < COLS_PER_GROUP) + cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1); + bf16 scale = scales[local_col_idx]; + CUTE_UNROLL + for (int i = 0; i < 4; ++i) { + data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale)); + } + st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16); + } + } + } + cutlass::arch::fence_view_async_shared(); + plan.bar_nope_ready[rs.buf_idx].arrive(); + plan.bar_raw_free[rs.buf_idx].arrive(); + plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive(); + rs.update(); + } + }); + } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100 ~ sm119"); + } +#endif +} + +template +__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1) +flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) { + Kernel::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(params, tma_params); +} + +template +void KernelTemplate::run(const SparseAttnDecodeParams ¶ms) { + KU_ASSERT(params.topk % B_TOPK == 0, "topk (%d) mod B_TOPK (%d) must be 0", params.topk, B_TOPK); + KU_ASSERT(params.extra_topk % B_TOPK == 0, "extra_topk (%d) mod B_TOPK (%d) must be 0", params.extra_topk, B_TOPK); + KU_ASSERT(params.h_q == B_H); + KU_ASSERT(params.h_kv == 1); + KU_ASSERT(params.d_qk == D_Q); + KU_ASSERT(params.d_v == D_V); + if constexpr (MODEL_TYPE == ModelType::MODEL1) { + constexpr int BYTES_PER_TOKEN = D_NOPE + 2*D_ROPE + 8; + KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, "Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous + } + + auto shape_Q_SW128 = make_shape(B_H, D_Q, params.s_q, params.b); + auto tma_Q_SW128 = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((bf16*)params.q), + make_layout( + shape_Q_SW128, + make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b) + ) + ), + SmemLayoutQ_SW128{} + ); + + auto shape_O = make_shape(B_H, D_V, params.s_q, params.b); + auto tma_O = cute::make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + make_gmem_ptr((bf16*)params.out), + make_layout( + shape_O, + make_stride(params.stride_o_h_q, _1{}, params.stride_o_s_q, params.stride_o_b) + ) + ), + SmemLayoutOBuf_TMA{} + ); + + CUtensorMap tensor_map_q_sw64{}; + if constexpr (D_Q_SW64 > 0) { + tensor_map_q_sw64 = ku::make_tensor_map( + {D_Q_SW64, (uint64_t)params.h_q, D_Q_SW64/32, (uint64_t)params.s_q, (uint64_t)params.b}, + ku::make_stride_helper(std::vector{params.stride_q_h_q, (int64_t)32, params.stride_q_s_q, params.stride_q_b}, sizeof(bf16)), + {32, B_H, D_Q_SW64/32, 1, 1}, + (bf16*)params.q + D_Q_SW128, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B + ); + } + + auto get_nope_rope_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t k_batch_stride) -> std::pair { + static_assert(D_NOPE%8 == 0); + KU_ASSERT((int64_t)k_ptr % 16 == 0, "The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f", is_extra?"extra_":"", k_ptr); + KU_ASSERT(k_batch_stride % TMA_K_STRIDE == 0, "%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary", is_extra?"extra_":"", k_batch_stride, TMA_K_STRIDE); + CUtensorMap tensor_map_kv_nope = ku::make_tensor_map( + {D_NOPE/8, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)}, + {TMA_K_STRIDE}, + {D_NOPE/8, 1}, + k_ptr, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT64, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B + ); // NOTE We combine 8 float8 into 1 int64 since boxdim cannot > 256 + CUtensorMap tensor_map_kv_rope = ku::make_tensor_map( + {D_ROPE, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)}, + {TMA_K_STRIDE}, + {K_ROPE_SW/2, 1}, + (uint8_t*)k_ptr + (MODEL_TYPE == ModelType::V32 ? (D_NOPE+16) : D_NOPE), + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + K_ROPE_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B : CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B + ); + return {tensor_map_kv_nope, tensor_map_kv_rope}; + }; + + auto [tensor_map_kv_nope, tensor_map_kv_rope] = get_nope_rope_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block); + CUtensorMap tensor_map_extra_kv_nope{}, tensor_map_extra_kv_rope{}; + if (params.extra_topk > 0) { + std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_nope_rope_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block); + } + + TmaParams< + decltype(shape_Q_SW128), decltype(tma_Q_SW128), + decltype(shape_O), decltype(tma_O) + > tma_params = { + shape_Q_SW128, tma_Q_SW128, + shape_O, tma_O, + tensor_map_q_sw64, + tensor_map_kv_nope, + tensor_map_kv_rope, + tensor_map_extra_kv_nope, + tensor_map_extra_kv_rope + }; + auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel, decltype(tma_params)>; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + static_assert(smem_size < 227*1024); + KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // NOTE Don't use PDL because of potential compiler bugs! + mla_kernel<<>>(params, tma_params); + KU_CHECK_KERNEL_LAUNCH(); +} + +template +void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms) { + KernelTemplate::run(params); +} + +} diff --git a/csrc/sm100/decode/head64/kernel.h b/csrc/sm100/decode/head64/kernel.h new file mode 100644 index 0000000..0b3c63c --- /dev/null +++ b/csrc/sm100/decode/head64/kernel.h @@ -0,0 +1,11 @@ +#pragma once + +#include "params.h" + +namespace sm100::decode::head64 { + +template +void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); + +} + diff --git a/csrc/sm100/decode/sparse_fp8/dequant.h b/csrc/sm100/decode/sparse_fp8/dequant.h deleted file mode 100644 index 3ed46e1..0000000 --- a/csrc/sm100/decode/sparse_fp8/dequant.h +++ /dev/null @@ -1,61 +0,0 @@ -#pragma once - -#include -#include - -#include "sm100/defines.h" - -namespace sm100 { - -struct fp8x8 { - __nv_fp8x4_e4m3 lo; - __nv_fp8x4_e4m3 hi; -}; - -struct fp8x32 { - fp8x8 a0, a1, a2, a3; -}; - -struct fp8x16 { - fp8x8 a0, a1; -}; - -__device__ __forceinline__ -bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) { - __nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale); - - #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \ - { \ - float4 fp32x4 = (float4)(FP8x4); \ - OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \ - OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \ - } - - bf16x8 result; - DEQUANT_FP8x4(result.a01, result.a23, inputs.lo); - DEQUANT_FP8x4(result.a45, result.a67, inputs.hi); - - return result; -} - -__device__ __forceinline__ -fp8x32 ldg_256_fp8x32(void* src_ptr) { - int32x8_t val; - asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" - : "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3), - "=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7) - : "l"(src_ptr) - ); - return *reinterpret_cast(&val); -} - -__device__ __forceinline__ -fp8x16 ldg_128_fp8x16(void* src_ptr) { - int4 ret; - asm volatile("ld.global.nc.L1::evict_first.v4.s32 {%0, %1, %2, %3}, [%4];" - : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) - : "l"(src_ptr)); - return *reinterpret_cast(&ret); -} - -} diff --git a/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu b/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu deleted file mode 100644 index 068e9fd..0000000 --- a/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu +++ /dev/null @@ -1,592 +0,0 @@ -#include "splitkv_mla.h" - -#include -#include -#include -#include -#include - -#include "utils.h" -#include "dequant.h" -#include "sm100/defines.h" -#include "sm100/helpers.h" -#include "sm100/intrinsics.h" -#include "sm100/ws_gemm.h" - -namespace sm100 { - -using cutlass::arch::fence_view_async_shared; -using cutlass::arch::NamedBarrier; -using namespace cute; - -constexpr int B_H = 64; -constexpr int B_TOPK = 64; -constexpr int D_K = 576; -constexpr int D_V = 512; -constexpr int NUM_BUFS = 2; -constexpr int NUM_THREADS = 128*3; -constexpr int NUM_WORKING_THREADS = 128 + 128 + 32; -constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN - -template< - typename Shape_Q, typename TMA_Q, - typename Shape_O, typename TMA_O -> -struct TmaParams { - Shape_Q shape_Q; TMA_Q tma_Q; - Shape_O shape_O; TMA_O tma_O; -}; - -namespace tmem_addr { - constexpr int o = 0; // o: [0, 256] - constexpr int p = 256; // p: [256, 288] -}; - -using SmemLayoutQ = decltype(coalesce(tile_to_shape( - UMMA::Layout_K_SW128_Atom{}, - Shape, Int>{}, - Step<_1, _2>{} -), Shape<_1, _1>{})); - -using SmemLayoutOBuf = decltype(tile_to_shape( - UMMA::Layout_K_INTER_Atom{}, // TODO This may lead to TMA double traffic - Shape, Int>{} -)); - -using SmemLayoutOAccumBuf = Layout< - Shape, Int>, - Stride, _1> // We use stride = 520 here to avoid bank conflict ->; - -using SmemLayoutS = decltype(tile_to_shape( - UMMA::Layout_K_INTER_Atom{}, - Shape, Int>{}, - Step<_1, _2>{} -)); - -template -using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( - UMMA::Layout_K_INTER_Atom{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -), Shape<_1, _1>{})); - -template -using SmemLayoutKTilesTransposed = decltype(composition( - SmemLayoutKTiles{}, - Layout< - Shape, Int>, - Stride, _1> - >{} -)); - -using SmemLayoutK = SmemLayoutKTiles<9>; -using SmemLayoutV = SmemLayoutKTilesTransposed<8>; - -struct SharedMemoryPlan { - array_aligned> q; - union { - array_aligned> o_buf; - array_aligned> o_accum_buf; - array_aligned> k[NUM_BUFS]; - } u; - array_aligned> s; - transac_bar_t bar_q; - transac_bar_t bar_k_ready[NUM_BUFS], bar_k_free[NUM_BUFS]; - transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS]; - float rowwise_max_buf[128], rowwise_li_buf[128]; - bool is_token_valid[NUM_BUFS][B_TOPK]; - array_aligned tmem_start_addr; -}; - -using TiledMMA_QK = decltype(make_tiled_mma( - SM100_MMA_F16BF16_WS_SS_NOELECT{}, - Layout>{} -)); // TODO Use TS? - -using TiledMMA_SV = decltype(make_tiled_mma( - SM100_MMA_F16BF16_WS_SS_NOELECT{}, - Layout>{}, - Tile, Int>{} -)); - -template -CUTE_DEVICE -void store_128b(void* smem_ptr, const T &data) { - static_assert(sizeof(T) == 16); - *(__int128*)smem_ptr = *(__int128*)&data; -} - -template -__global__ void __launch_bounds__(NUM_THREADS, 1, 1) -flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) { -#if IS_SM100 - const int head_block_idx = blockIdx.x; - const int s_q_idx = blockIdx.y; - const int partition_idx = blockIdx.z; - const int warpgroup_idx = cutlass::canonical_warp_group_idx(); - const int idx_in_warpgroup = threadIdx.x % 128; - const int warp_idx = cutlass::canonical_warp_idx_sync(); - - // Define shared tensors - extern __shared__ char wksp_buf[]; - SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); - Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{}); - - if (warp_idx == 0 && elect_one_sync()) { - cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); - cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); - } - - if (warp_idx == 0) { - if (elect_one_sync()) { - plan.bar_q.init(1); - for (int i = 0; i < NUM_BUFS; ++i) { - plan.bar_k_ready[i].init(128); - plan.bar_k_free[i].init(1); - plan.bar_qk_done[i].init(1); - plan.bar_so_ready[i].init(128); - } - cutlass::arch::fence_barrier_init(); - } - cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data()); - TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); - cute::TMEM::Allocator1Sm().release_allocation_lock(); - } - __syncthreads(); - - int bar_phase_k = 0; - - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); - int begin_idx = tile_scheduler_metadata.x; - int sched_begin_block_idx = tile_scheduler_metadata.y; - int end_idx = tile_scheduler_metadata.z; - int sched_end_block_idx = tile_scheduler_metadata.w; - if (begin_idx >= params.b) { - if (warp_idx == 0) { - cute::TMEM::Allocator1Sm().free(0, 512); - } - return; - } - - auto get_cur_req_info = [&](int batch_idx) -> std::tuple { - int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0; - int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : params.topk / B_TOPK; - bool is_no_split = start_block_idx == 0 && end_block_idx == params.topk / B_TOPK; - return {start_block_idx, end_block_idx, is_no_split}; - }; - - if (warpgroup_idx == 0) { - // Producer warpgroup - - #pragma unroll 1 - for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { - auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); - int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1) - - constexpr int GROUP_SIZE = 4, NUM_GROUPS = 128 / GROUP_SIZE; - constexpr int ROWS_PER_GROUP = B_TOPK / NUM_GROUPS; - int group_idx = idx_in_warpgroup / GROUP_SIZE; - int idx_in_group = idx_in_warpgroup % GROUP_SIZE; - - NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1); - - CUTE_NO_UNROLL - for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { - int buf_idx = block_idx % NUM_BUFS; - - // Wait for buffer to be available - plan.bar_k_free[buf_idx].wait(bar_phase_k>>buf_idx&1^1); - - // Load - Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); - - CUTE_UNROLL - for (int local_row = 0; local_row < ROWS_PER_GROUP; ++local_row) { - int smem_row = group_idx + local_row*NUM_GROUPS; - int token_index = __ldg(gIndices + block_idx*B_TOPK + smem_row); - bool is_token_invalid = token_index == -1; - if (idx_in_group == 0) - plan.is_token_valid[buf_idx][smem_row] = !is_token_invalid; - if (is_token_invalid) { - uint128_t zeros = uint128_t{}; - CUTE_UNROLL - for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) { - int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16; - store_128b(&sK(smem_row, col_base ), zeros); - store_128b(&sK(smem_row, col_base+8), zeros); - } - CUTE_UNROLL - for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) { - int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8; - store_128b(&sK(smem_row, D_V+col_base), zeros); - } - } else { - int block_index = token_index/B_TOPK; - int rel_idx_in_block = (token_index+B_TOPK) % B_TOPK; // NOTE When token_index is -1, -1/B_TOPK = 0 and (-1+B_TOPK)%B_TOPK = 63, so there will be no illegal-memory-access error. However, masking is necessary to prevent NaN (TODO Skip some rows instead?) TODO Masking - fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride; - float4 scales = __ldg((float4*)(gK_base + D_V)); - - CUTE_UNROLL - for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) { - int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16; - fp8x16 cur_fp8s = ldg_128_fp8x16(gK_base + col_base); - float cur_scale = local_col < (256/(GROUP_SIZE*16)) ? - (local_col < (128/(GROUP_SIZE*16)) ? scales.x : scales.y) : - (local_col < (384/(GROUP_SIZE*16)) ? scales.z : scales.w); - store_128b(&sK(smem_row, col_base ), cvt_fp8x8_bf16x8(cur_fp8s.a0, cur_scale)); - store_128b(&sK(smem_row, col_base+8), cvt_fp8x8_bf16x8(cur_fp8s.a1, cur_scale)); - } - - CUTE_UNROLL - for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) { - int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8; - fp8x16 cur_k_rope_fp8s = ldg_128_fp8x16(gK_base + D_V + 4*sizeof(float) + col_base*sizeof(bf16)); - bf16x8 cur_k_rope = *reinterpret_cast(&cur_k_rope_fp8s); - store_128b(&sK(smem_row, D_V+col_base), cur_k_rope); - } - } - } - - fence_view_async_shared(); - - // Signal - plan.bar_k_ready[buf_idx].arrive(); - - bar_phase_k ^= 1<(); - - int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); - - #pragma unroll 1 - for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { - auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); - - NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1); - - float li = 0.0f; - float mi = MAX_INIT_VAL; - - CUTE_NO_UNROLL - for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { - int buf_idx = block_idx % NUM_BUFS; - - // Wait for P - plan.bar_qk_done[buf_idx].wait(bar_phase_k>>buf_idx&1); - tcgen05_after_thread_sync(); - - // Load P from TMEM - float p[B_TOPK/2]; - float2* p_float2 = reinterpret_cast(p); - tmem_ld_32dp32bNx(tmem_addr::p, p); - cutlass::arch::fence_view_async_tmem_load(); - - // Get rowwise max - float cur_max = -INFINITY; - CUTE_UNROLL - for (int i = 0; i < B_TOPK/2; ++i) { - if (!plan.is_token_valid[buf_idx][(idx_in_warpgroup/64)*(B_TOPK/2)+i]) p[i] = -INFINITY; - cur_max = max(cur_max, p[i]); - } - cur_max *= params.scale_softmax_log2; - - NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers - plan.rowwise_max_buf[idx_in_warpgroup] = cur_max; - NamedBarrier::arrive_and_wait(128, 0); - cur_max = max(cur_max, plan.rowwise_max_buf[idx_in_warpgroup ^ 64]); - - float new_max = max(mi, cur_max); - float scale_for_old = exp2f(mi - new_max); - float2 scale_for_old_float2 = {scale_for_old, scale_for_old}; - - // Get S - float2 scale_softmax_log2_float2 = {params.scale_softmax_log2, params.scale_softmax_log2}; - float2 neg_new_max_float2 = {-new_max, -new_max}; - bf16 s[B_TOPK/2]; - float2 cur_sum = {0.0f, 0.0f}; - CUTE_UNROLL - for (int i = 0; i < (B_TOPK/2)/2; ++i) { - float2 t = float2_fma(p_float2[i], scale_softmax_log2_float2, neg_new_max_float2); - t.x = exp2(t.x); - t.y = exp2(t.y); - *(__nv_bfloat162*)&s[i*2] = __float22bfloat162_rn(t); - cur_sum = float2_add(cur_sum, t); - } - - // Save S - // NOTE We don't need a barrier here, since the current QK^T has finished implies that the previous SV has finished - bf16* sS_base = plan.s.data() + (idx_in_warpgroup/64)*(B_H*B_TOPK/2) + (idx_in_warpgroup%64) * 8; - CUTE_UNROLL - for (int i = 0; i < (B_TOPK/2)/8; i += 1) { - store_128b(sS_base + i*8*B_H, *((bf16x8*)s + i)); - } - fence_view_async_shared(); - - // Rescale O - if (block_idx != start_block_idx) { - constexpr int B_SCALE_O = 64; - float2 o[B_SCALE_O/2]; - CUTE_UNROLL - for (int b = 0; b < (D_V/2)/B_SCALE_O; ++b) { - tmem_ld_32dp32bNx(tmem_addr::o + b*B_SCALE_O, o); - cutlass::arch::fence_view_async_tmem_load(); - CUTE_UNROLL - for (int i = 0; i < B_SCALE_O/2; ++i) - o[i] = float2_mul(o[i], scale_for_old_float2); - tmem_st_32dp32bNx(tmem_addr::o + b*B_SCALE_O, o); - cutlass::arch::fence_view_async_tmem_store(); - } - } - plan.bar_so_ready[buf_idx].arrive(); - - // Update mi and li - mi = new_max; - li = li * scale_for_old + cur_sum.x + cur_sum.y; - - bar_phase_k ^= 1<>((end_block_idx-1)%NUM_BUFS)&1^1); - tcgen05_after_thread_sync(); - - // Save O - float o_scale = li == 0.0f ? 0.0f : 1.0f / li; - float2 o_scale_float2 = {o_scale, o_scale}; - if (is_no_split) { - constexpr int B_EPI = 32; - float2 o[B_EPI/2]; - __nv_bfloat162 o_bf16[B_EPI/2]; - Tensor sO = make_tensor(make_smem_ptr(plan.u.o_buf.data()), SmemLayoutOBuf{}); - bf16* sO_base = plan.u.o_buf.data() + ((idx_in_warpgroup/64)*128)*B_H + (idx_in_warpgroup%64)*8; - CUTE_UNROLL - for (int i = 0; i < (D_V/2) / B_EPI; ++i) { - // Load - tmem_ld_32dp32bNx(tmem_addr::o + i*B_EPI, o); - cutlass::arch::fence_view_async_tmem_load(); - // Scale & Convert - CUTE_UNROLL - for (int j = 0; j < B_EPI/2; ++j) { - o[j] = float2_mul(o[j], o_scale_float2); - o_bf16[j] = __float22bfloat162_rn(o[j]); - } - // Store - int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); - CUTE_UNROLL - for (int j = 0; j < B_EPI / 8; ++j) - store_128b(sO_base + (col_base+j*8)*B_H, *reinterpret_cast(&o_bf16[j*4])); - } - fence_view_async_shared(); - NamedBarrier::arrive_and_wait(128, 0); - if (warp_idx == 4 && elect_one_sync()) { - Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx); - auto thr_tma = tma_params.tma_O.get_slice(_0{}); - Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, head_block_idx, _0{}); - cute::copy( - tma_params.tma_O, - thr_tma.partition_S(sO), - thr_tma.partition_D(my_tma_gO) - ); - cute::tma_store_arrive(); - } - } else { - constexpr int B_EPI = 64; - float2 o[B_EPI/2]; - Tensor sO = make_tensor(make_smem_ptr(plan.u.o_accum_buf.data()), SmemLayoutOAccumBuf{}); - CUTE_UNROLL - for (int i = 0; i < (D_V/2) / B_EPI; ++i) { - // Load - tmem_ld_32dp32bNx(tmem_addr::o + i*B_EPI, o); - cutlass::arch::fence_view_async_tmem_load(); - // Scale & Convert - CUTE_UNROLL - for (int j = 0; j < B_EPI/2; ++j) - o[j] = float2_mul(o[j], o_scale_float2); - // Store - int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4)); - CUTE_UNROLL - for (int j = 0; j < B_EPI / 4; ++j) - store_128b(&sO(idx_in_warpgroup%64, col_base + j*4), *reinterpret_cast(&o[j*2])); - } - fence_view_async_shared(); - NamedBarrier::arrive_and_wait(128, 0); - if (elect_one_sync()) { - CUTE_UNROLL - for (int local_row = 0; local_row < B_H/4; ++local_row) { - int smem_row = local_row*4 + (warp_idx-4); - if (smem_row < num_valid_heads) { - SM90_BULK_COPY_S2G::copy( - &sO(smem_row, _0{}), - (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx + smem_row)*D_V, - D_V*sizeof(float) - ); - } - } - cute::tma_store_arrive(); - } - } - - cute::tma_store_wait<0>(); - } - - if (warp_idx == 4) { - cute::TMEM::Allocator1Sm().free(0, 512); - } - } else { - cutlass::arch::warpgroup_reg_dealloc<96>(); - if (warp_idx == 8) { - // UTCMMA warp - - bool bar_phase_q = 0; - TiledMMA tiled_mma_qk = TiledMMA_QK{}; - TiledMMA tiled_mma_sv = TiledMMA_SV{}; - Tensor tP = partition_fragment_C(tiled_mma_qk, Shape, Int>{}); - Tensor tO = partition_fragment_C(tiled_mma_sv, Shape, Int>{}); - tO.data().get() = tmem_addr::o; - tP.data().get() = tmem_addr::p; - Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); - - #pragma unroll 1 - for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { - auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); - - if (elect_one_sync()) { - // Copy Q - Tensor gQ = flat_divide( - tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx), - Tile, Int>{} - )(_, _, head_block_idx, _0{}); - launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); - plan.bar_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16)); - } - - NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1); - - if (elect_one_sync()) { - // Wait for Q - plan.bar_q.wait(bar_phase_q); - bar_phase_q ^= 1; - tcgen05_after_thread_sync(); - - CUTE_NO_UNROLL - for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { - int buf_idx = block_idx % NUM_BUFS; - - // Wait for K - plan.bar_k_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); - tcgen05_after_thread_sync(); - Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); - - // Issue P = Q @ K^T - utcmma_ss(tiled_mma_qk, sQ, sK, tP, true); - umma_arrive_noelect(plan.bar_qk_done[buf_idx]); - - // Wait for S - plan.bar_so_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); - tcgen05_after_thread_sync(); - Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutV{}); - - // Issue O += S @ V - utcmma_ss(tiled_mma_sv, sS, sV, tO, block_idx == start_block_idx); - umma_arrive_noelect(plan.bar_k_free[buf_idx]); - - bar_phase_k ^= 1< tma_params = { - shape_Q, tma_Q, - shape_O, tma_O - }; - auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel; - - constexpr size_t smem_size = sizeof(SharedMemoryPlan); - CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - const int num_m_blocks = cute::ceil_div(params.q_head_per_hk, B_H); - // NOTE Don't use PDL because of potential compiler bugs! - mla_kernel<<>>(params, tma_params); - CHECK_CUDA_KERNEL_LAUNCH(); -} - -} \ No newline at end of file diff --git a/csrc/sm100/decode/sparse_fp8/splitkv_mla.h b/csrc/sm100/decode/sparse_fp8/splitkv_mla.h deleted file mode 100644 index cc8c6da..0000000 --- a/csrc/sm100/decode/sparse_fp8/splitkv_mla.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -#include "params.h" - -namespace sm100 { - -void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams ¶ms, cudaStream_t stream); - -} - diff --git a/csrc/sm100/helpers.h b/csrc/sm100/helpers.h index 9195b33..a566625 100644 --- a/csrc/sm100/helpers.h +++ b/csrc/sm100/helpers.h @@ -1,97 +1,35 @@ #pragma once #include +#include +#include + #include "defines.h" namespace sm100 { using namespace cute; -using _72 = Int<72>; -using _576 = Int<576>; - -template< - typename TMA, - typename Tensor0, - typename Tensor1 -> CUTE_DEVICE -void launch_tma_copy( - const TMA &tma_copy, - Tensor0 src, - Tensor1 dst, - transac_bar_t &bar, - const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL -) { - auto thr_tma = tma_copy.get_slice(_0{}); - cute::copy( - tma_copy.with(reinterpret_cast(bar), 0, cache_hint), - thr_tma.partition_S(src), - thr_tma.partition_D(dst) - ); +int int4_max(int4 t) { + return max(max(t.x, t.y), max(t.z, t.w)); } -template< - typename TiledMMA, - typename TensorA, - typename TensorB, - typename TensorFragC -> CUTE_DEVICE -void utcmma_ss( - TiledMMA &tiled_mma, - TensorA sA, - TensorB sB, - TensorFragC tC_frag, - bool clear_accum -) { - tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; - ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter - auto sA_frag = thr_mma.partition_fragment_A(sA); - auto sB_frag = thr_mma.partition_fragment_B(sB); - static_assert(size<2>(sA_frag) == size<2>(sB_frag)); - static_assert(size<1>(sA_frag) == size<1>(tC_frag)); - static_assert(size<1>(sB_frag) == size<2>(tC_frag)); - CUTE_UNROLL - for (int k = 0; k < size<2>(sA_frag); ++k) { - cute::gemm( - tiled_mma, - sA_frag(_, _, k), - sB_frag(_, _, k), - tC_frag - ); - tiled_mma.accumulate_ = UMMA::ScaleOut::One; - } +int int4_min(int4 t) { + return min(min(t.x, t.y), min(t.z, t.w)); } -template< - typename TiledMMA, - typename TensorA, - typename TensorB, - typename TensorFragC -> +// Convert 2x fp8_e4m3 to 2x bf16 with scaling CUTE_DEVICE -void utcmma_ts( - TiledMMA &tiled_mma, - TensorA tA_frag, - TensorB sB, - TensorFragC tC_frag, - bool clear_accum -) { - tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; - ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter - auto sB_frag = thr_mma.partition_fragment_B(sB); - static_assert(size<2>(tA_frag) == size<2>(sB_frag)); - CUTE_UNROLL - for (int k = 0; k < size<2>(tA_frag); ++k) { - cute::gemm( - tiled_mma, - tA_frag(_, _, k), - sB_frag(_, _, k), - tC_frag - ); - tiled_mma.accumulate_ = UMMA::ScaleOut::One; - } +nv_bfloat162 fp8x2_to_bf16x2_with_scale(__nv_fp8x2_e4m3 data, nv_bfloat16 scale) { + // TODO Use native conversion for CUDA >= 13.1 + float2 data_float2 = (float2)data; + nv_bfloat162 data_bf16x2 = __float22bfloat162_rn(data_float2); + return nv_bfloat162 { + data_bf16x2.x * scale, + data_bf16x2.y * scale + }; } } diff --git a/csrc/sm100/intrinsics.h b/csrc/sm100/intrinsics.h deleted file mode 100644 index c2402ee..0000000 --- a/csrc/sm100/intrinsics.h +++ /dev/null @@ -1,461 +0,0 @@ -#pragma once - -#include -#include - -#include "defines.h" - -namespace sm100 { - -using namespace cute; - -__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { - uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); - asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" - :: "r"(dst_addr), - "l"(src), - "n"(16)); -} - -CUTE_DEVICE -int64_t createpolicy_evict_last() { - int64_t res; - asm volatile( - "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" - : "=l"(res) - : - ); - return res; -} - -template -CUTE_DEVICE -static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { - static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b."); - long2 data_long2 = *reinterpret_cast(&data); - uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); - uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); - asm volatile ( - "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" - : - : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) - ); -} - - -__device__ __forceinline__ void tcgen05_before_thread_sync() { - asm volatile("tcgen05.fence::before_thread_sync;"); -} - -__device__ __forceinline__ void tcgen05_after_thread_sync() { - asm volatile("tcgen05.fence::after_thread_sync;"); -} - -CUTE_DEVICE -void umma_arrive_multicast_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) { - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); - asm volatile( - "{\n\t" - "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" - "}" - : - :"r"(bar_intptr), "h"(cta_mask)); -} - -CUTE_DEVICE -void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) { - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); - asm volatile( - "{\n\t" - "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" - "}" - : - :"r"(bar_intptr), "h"(cta_mask)); -} - -CUTE_DEVICE -void umma_arrive_noelect(transac_bar_t &smem_ptr) { - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); - asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" - : - :"r"(bar_intptr)); -} - -CUTE_DEVICE -void umma_arrive_2x1SM_noelect(transac_bar_t &smem_ptr) { - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr); - asm volatile("tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];" - : - :"r"(bar_intptr)); -} - -CUTE_DEVICE -float2 float2_add(const float2 &a, const float2 &b) { - float2 res; - cute::add(res, a, b); - return res; -} - -CUTE_DEVICE -float2 float2_mul(const float2 &a, const float2 &b) { - float2 res; - cute::mul(res, a, b); - return res; -} - -CUTE_DEVICE -float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) { - // return a*b+c - float2 res; - cute::fma(res, a, b, c); - return res; -} - -CUTE_DEVICE -float2 float2_neg(const float2 &a) { - float2 t = {-1.0f, -1.0f}; - return float2_mul(a, t); -} - -template -CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) { - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); - if constexpr (USE_CTA0_MBAR) { - mbar_addr &= Sm100MmaPeerBitMask; - } - asm volatile( - "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" - : - : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), - "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), - "r"(mbar_addr), "l"(uint64_t(cache_hint)) - : "memory" - ); -} - -// 32 data path lanes, 32-bit pattern, repeated N times -template -CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); - uint32_t* dst_ptr = reinterpret_cast(dst_ptr_); - - if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32" - "{%0}," - "[%1];\n" - : "=r"(dst_ptr[0]) - : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32" - "{%0, %1}," - "[%2];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) - : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32" - "{%0, %1, %2, %3}," - "[%4];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]) - : "r"(src_addr)); - } else if constexpr (N == 8) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "[%8];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) - : "r"(src_addr)); - } else if constexpr (N == 16) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15}," - "[%16];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]) - : "r"(src_addr)); - } else if constexpr (N == 32) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " - "%26, %27, %28, %29, %30, %31}," - "[%32];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) - : "r"(src_addr)); - } else if constexpr (N == 64) { - asm volatile( - "tcgen05.ld.sync.aligned.32x32b.x64.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " - "%57, %58, %59, %60, %61, %62, %63}," - "[%64];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]) - : "r"(src_addr)); - } else if constexpr (N == 128) { - asm volatile( - "tcgen05.ld.sync.aligned.32x32b.x128.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " - "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " - "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " - "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " - "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " - "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " - "%121, %122, %123, %124, %125, %126, %127}," - "[%128];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), - "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), - "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), - "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), - "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), - "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), - "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), - "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), - "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), - "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), - "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), - "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), - "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), - "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), - "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), - "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), - "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), - "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), - "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), - "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), - "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), - "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) - : "r"(src_addr)); - } else { - asm volatile ("trap"); - } -} - -// 32 data path lanes, 32-bit pattern, repeated N times -template -CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); - uint32_t* src_ptr = reinterpret_cast(src_ptr_); - - if constexpr (N == 1) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32" - "[%1], {%0};\n" - : - : "r"(src_ptr[0]), - "r"(dst_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32" - "[%2], {%0, %1};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), - "r"(dst_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32" - "[%4], {%0, %1, %2, %3};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), - "r"(dst_addr)); - } else if constexpr (N == 8) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32" - "[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), - "r"(dst_addr)); - } else if constexpr (N == 16) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32" - "[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), - "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), - "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), - "r"(src_ptr[15]), - "r"(dst_addr)); - } else if constexpr (N == 32) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32" - "[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " - "%26, %27, %28, %29, %30, %31};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), - "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), - "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), - "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), - "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), - "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), - "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), - "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), - "r"(src_ptr[30]), "r"(src_ptr[31]), - "r"(dst_addr)); - } else if constexpr (N == 64) { - asm volatile( - "tcgen05.st.sync.aligned.32x32b.x64.b32" - "[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " - "%57, %58, %59, %60, %61, %62, %63};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), - "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), - "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), - "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), - "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), - "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), - "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), - "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), - "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), - "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), - "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), - "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), - "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), - "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), - "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), - "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), - "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), - "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), - "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), - "r"(src_ptr[63]), - "r"(dst_addr)); - } else if constexpr (N == 128) { - asm volatile( - "tcgen05.st.sync.aligned.32x32b.x128.b32" - "[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " - "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " - "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " - "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " - "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " - "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " - "%121, %122, %123, %124, %125, %126, %127};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), - "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), - "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), - "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), - "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), - "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), - "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), - "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), - "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), - "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), - "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), - "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), - "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), - "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), - "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), - "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), - "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), - "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), - "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), - "r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]), - "r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]), - "r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]), - "r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]), - "r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]), - "r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]), - "r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]), - "r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]), - "r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]), - "r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]), - "r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]), - "r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]), - "r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]), - "r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]), - "r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]), - "r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]), - "r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]), - "r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]), - "r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]), - "r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]), - "r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]), - "r"(src_ptr[126]), "r"(src_ptr[127]), - "r"(dst_addr)); - } else { - asm volatile ("trap"); - } -} - -static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 -template -CUTE_DEVICE -T* get_peer_addr(const T* p) { - return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); -} - - -} diff --git a/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp index 9a25ff3..4034b1f 100644 --- a/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp +++ b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_convert.hpp @@ -34,7 +34,7 @@ #include "cutlass/cutlass.h" #include "cute/layout.hpp" -#include "utils.h" // for IS_SM100 +#include // for KERUTILS_ENABLE_SM100A namespace cutlass::fmha::kernel { @@ -139,7 +139,7 @@ struct FmhaKernelBwdConvert { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { -#if IS_SM100 +#if defined(KERUTILS_ENABLE_SM100A) if (params.ptr_src_dQ != nullptr) { copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); } diff --git a/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp index 07ae4f2..cb98b27 100644 --- a/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp +++ b/csrc/sm100/prefill/dense/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -34,7 +34,7 @@ #include "cutlass/cutlass.h" #include "cute/layout.hpp" -#include "utils.h" // for IS_SM100 +#include // for KERUTILS_ENABLE_SM100A namespace cutlass::fmha::kernel { @@ -105,7 +105,7 @@ struct FmhaKernelBwdSumOdO { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { -#if IS_SM100 +#if defined(KERUTILS_ENABLE_SM100A) auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); diff --git a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index c34713b..95cc33c 100644 --- a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -41,7 +41,7 @@ #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "utils.h" // for IS_SM100 +#include // for KERUTILS_ENABLE_SM100A #include "../collective/fmha_common.hpp" #include @@ -949,8 +949,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { TensorC const& coord, TensorShape const& tensor_shape) { - // TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version. - // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); auto copy_op = make_cotiled_copy( Copy_Atom, Element>{}, @@ -960,23 +959,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { auto thr_copy = copy_op.get_slice(_0{}); Tensor quantized_regs = quantize(regs); - auto tCg = thr_copy.partition_D(gmem); - auto tCr = thr_copy.partition_S(quantize(regs)); - auto tCc = thr_copy.partition_D(coord); - - - constexpr int R = decltype(tCr.layout())::rank; - auto tCg_v = group_modes<1, R>(tCg); - auto tCr_v = group_modes<1, R>(tCr); - auto tCc_v = group_modes<1, R>(tCc); - auto tCp_v = make_tensor(shape<1>(tCc_v)); - - for (int i = 0; i < size(tCp_v); ++i) { - tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); - } - - copy_if(copy_op, tCp_v, tCr_v, tCg_v); + Tensor tCr = thr_copy.partition_S(quantized_regs); + Tensor tCg = thr_copy.partition_D(gmem); + Tensor tPc = thr_copy.partition_D(preds); + copy_if(copy_op, tPc, tCr, tCg); } @@ -1500,7 +1487,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem) { -#if IS_SM100 +#if defined(KERUTILS_ENABLE_SM100A) int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); diff --git a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp index c25d638..7a3b944 100644 --- a/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -41,7 +41,7 @@ #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "utils.h" // for IS_SM100 +#include // for KERUTILS_ENABLE_SM100A #include "../collective/fmha_common.hpp" #include @@ -954,8 +954,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { TensorC const& coord, TensorShape const& tensor_shape) { - // TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version. - // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); auto copy_op = make_cotiled_copy( Copy_Atom, Element>{}, @@ -965,23 +964,11 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { auto thr_copy = copy_op.get_slice(_0{}); Tensor quantized_regs = quantize(regs); - auto tCg = thr_copy.partition_D(gmem); - auto tCr = thr_copy.partition_S(quantize(regs)); - auto tCc = thr_copy.partition_D(coord); - - - constexpr int R = decltype(tCr.layout())::rank; - auto tCg_v = group_modes<1, R>(tCg); - auto tCr_v = group_modes<1, R>(tCr); - auto tCc_v = group_modes<1, R>(tCc); - auto tCp_v = make_tensor(shape<1>(tCc_v)); - - for (int i = 0; i < size(tCp_v); ++i) { - tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); - } - - copy_if(copy_op, tCp_v, tCr_v, tCg_v); - + Tensor tCr = thr_copy.partition_S(quantized_regs); + Tensor tCg = thr_copy.partition_D(gmem); + Tensor tPc = thr_copy.partition_D(preds); + + copy_if(copy_op, tPc, tCr, tCg); } @@ -1494,7 +1481,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { CUTLASS_DEVICE void operator()(Params const& params, char* smem) { -#if IS_SM100 +#if defined(KERUTILS_ENABLE_SM100A) int warp_idx = cutlass::canonical_warp_idx_sync(); auto role = warp_idx_to_role(warp_idx); uint32_t lane_predicate = cute::elect_one_sync(); diff --git a/csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index ef75280..dc5ad45 100644 --- a/csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/csrc/sm100/prefill/dense/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -37,7 +37,7 @@ #include "cutlass/pipeline/pipeline.hpp" #include "cute/arch/tmem_allocator_sm100.hpp" -#include "utils.h" // for IS_SM100 +#include // for KERUTILS_ENABLE_SM100A #include "../kernel/fmha_options.hpp" #include "../kernel/fmha_tile_scheduler.hpp" #include "../kernel/fmha_causal_tile_scheduler.hpp" @@ -252,7 +252,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { } CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { -#if IS_SM100 +#if defined(KERUTILS_ENABLE_SM100A) TileScheduler tile_scheduler{params.tile_scheduler}; diff --git a/csrc/sm100/prefill/sparse/common_subroutine.h b/csrc/sm100/prefill/sparse/common_subroutine.h new file mode 100644 index 0000000..36ddab4 --- /dev/null +++ b/csrc/sm100/prefill/sparse/common_subroutine.h @@ -0,0 +1,208 @@ +#pragma once + +#include +#include + +namespace sm100 { + +/* +Load K/V indices from global memory, and generate validity mask +Each thread loads 8 indices +Should be called by lanes 0 ~ (BLOCK_TOPK/8) +*/ +CUTE_DEVICE +char load_indices_and_generate_mask( + int lane_idx, + int* gIndices, + int s_kv, + int abs_pos_start, + int topk_length +) { + int indices[8]; + KU_LDG_256( + gIndices + lane_idx*8, + indices, + ".nc", + "no_allocate", + "evict_normal", + "256B" + ); + auto is_valid = [&](int rel_pos_in_lane, int index) -> char { + int abs_pos = abs_pos_start + lane_idx*8 + rel_pos_in_lane; + return index >= 0 && index < s_kv && abs_pos < topk_length; + }; + char is_ks_valid_mask = \ + is_valid(7, indices[7]) << 7 | + is_valid(6, indices[6]) << 6 | + is_valid(5, indices[5]) << 5 | + is_valid(4, indices[4]) << 4 | + is_valid(3, indices[3]) << 3 | + is_valid(2, indices[2]) << 2 | + is_valid(1, indices[1]) << 1 | + is_valid(0, indices[0]) << 0; + return is_ks_valid_mask; +} + + +/* +Get P from Tensor Memory, reduce P within shared memory, perform masking, and store back if necessary + +Initially, since dual gemm is used, we have two P pieces in Tensor Memory, one occupying rows 0 ~ 63 while the other occupying rows 64 ~ 127. We'd like to have them reduced into one single P piece, stored in registers with layout: + + N N --- (topk) + +-------+-------+ + | | | +32 | Warp0 | Warp2 | + | | | + +-------+-------+ + | | | +32 | Warp1 | Warp3 | + | | | + +-------+-------+ +| +(head) + +where N = NUM_ELEMS_PER_THREAD +*/ +template< + int NUM_ELEMS_PER_THREAD, + int TMEM_COL_START, + int BARRIER_WARP02_SYNC_ID, + int BARRIER_WARP13_SYNC_ID, + bool STORE_BACK_P +> +CUTE_DEVICE +void retrieve_mask_and_reduce_p( + char* k_validness_base, + int local_warp_idx, + int lane_idx, + auto slot_bar_P_empty_arrival, + float p_exchange_buf[4][32*NUM_ELEMS_PER_THREAD], + float p[NUM_ELEMS_PER_THREAD] +) { + using namespace cute; + using cutlass::arch::NamedBarrier; + static_assert(BARRIER_WARP13_SYNC_ID == BARRIER_WARP02_SYNC_ID+1); + + float p_peer[NUM_ELEMS_PER_THREAD]; + if (local_warp_idx < 2) { + ku::tmem_ld_32dp32bNx(TMEM_COL_START, p); + ku::tmem_ld_32dp32bNx(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p_peer); + } else { + ku::tmem_ld_32dp32bNx(TMEM_COL_START, p_peer); + ku::tmem_ld_32dp32bNx(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p); + } + cutlass::arch::fence_view_async_tmem_load(); + ku::tcgen05_before_thread_sync(); + slot_bar_P_empty_arrival(); + + // Mask invalid tokens + // We put masking before reduction, since (-inf) + anything (except nan and +inf) is (-inf), which guarantees correctness, and this can overlap with smem load + static_assert(NUM_ELEMS_PER_THREAD == 32); + uint32_t is_k_valid = *(uint32_t*)(k_validness_base + (local_warp_idx>=2?NUM_ELEMS_PER_THREAD/8:0)); + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_PER_THREAD; i += 1) { + if (!(is_k_valid >> i & 1)) + p[i] = -CUDART_INF_F; + } + + // Reduce P within the cluster + { + // Store + // Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) { + ku::st_shared(&p_exchange_buf[local_warp_idx^2][i*32*4 + lane_idx*4], *(float4*)(p_peer + i*4)); + } + NamedBarrier::arrive_and_wait(64, BARRIER_WARP02_SYNC_ID + (local_warp_idx&1)); + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) { + float2 t[2]; + *(float4*)t = *(float4*)(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4]); + float2* cur_p = (float2*)(p + i*4); + cur_p[0] = ku::float2_add(cur_p[0], t[0]); + cur_p[1] = ku::float2_add(cur_p[1], t[1]); + } + } + + if constexpr (STORE_BACK_P) { + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) { + ku::st_shared(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4], *(float4*)(p+i*4)); + } + } +} + +/* +Rescale O in Tensor Memory. + +O should occupy 128 rows x (D_V/2) columns in Tensor Memory. +*/ +template< + int D_V, + int CHUNK_SIZE, + int TMEM_COL_START +> +CUTE_DEVICE +void rescale_O( + float scale_factor +) { + float2 scale_factor_float2 = {scale_factor, scale_factor}; + float2 o[CHUNK_SIZE/2]; + + CUTE_UNROLL + for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) { + // Load O + ku::tmem_ld_32dp32bNx(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o); + cutlass::arch::fence_view_async_tmem_load(); + + // Mult + for (int i = 0; i < CHUNK_SIZE/2; ++i) { + o[i] = ku::float2_mul(o[i], scale_factor_float2); + } + + // Store O + ku::tmem_st_32dp32bNx(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o); + cutlass::arch::fence_view_async_tmem_store(); + } +} + +template +CUTE_DEVICE +float get_max( + float p[NUM_ELEMS_PER_THREAD] +) { + float local_max = -CUDART_INF_F; + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_PER_THREAD; ++i) { + local_max = max(local_max, p[i]); + } + return local_max; +} + +/* +Calculate s := exp2f(p*scale - new_max) and its sum +*/ +template +CUTE_DEVICE +float get_s_from_p( + nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2], + float p[NUM_ELEMS_PER_THREAD], + float scale, + float new_max +) { + float2 cur_sum = float2 {0.0f, 0.0f}; + float2 neg_new_max_float2 = float2 {-new_max, -new_max}; + float2 scale_float2 = float2 {scale, scale}; + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_PER_THREAD/2; i += 1) { + float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale_float2, neg_new_max_float2); + d.x = exp2f(d.x); + d.y = exp2f(d.y); + cur_sum = ku::float2_add(cur_sum, d); + s[i] = __float22bfloat162_rn(d); + } + return cur_sum.x + cur_sum.y; +} + +} diff --git a/csrc/sm100/prefill/sparse/fwd.h b/csrc/sm100/prefill/sparse/fwd.h deleted file mode 100644 index 6558e80..0000000 --- a/csrc/sm100/prefill/sparse/fwd.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include "params.h" - -namespace sm100 { - -void run_fwd_kernel(const SparsePrefillParams& params); - -} diff --git a/csrc/sm100/prefill/sparse/fwd/head128/config.h b/csrc/sm100/prefill/sparse/fwd/head128/config.h new file mode 100644 index 0000000..6c846bb --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd/head128/config.h @@ -0,0 +1,140 @@ +#pragma once + +#include +#include +#include + +#include "params.h" +#include "defines.h" + +namespace sm100::fwd::head128 { + +using namespace cute; + +template< + typename Shape_Q, typename TMA_Q, + typename Shape_O, typename TMA_O +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + Shape_O shape_O; TMA_O tma_O; + CUtensorMap tensor_map_kv; +}; + +struct float2x2 { + float2 lo, hi; +}; + +template +struct KernelTemplate { + +static constexpr int D_Q = D_QK; +static constexpr int D_K = D_QK; +static constexpr int D_V = 512; +static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan + +static constexpr int B_H = 128; // For 2 CTAs +static constexpr int B_TOPK = 128; // For 2 CTAs +static constexpr int NUM_BUFS = 2; +static constexpr int NUM_THREADS = 256 + 128 + 128; // 128 scale & exp threads, 128x2 TMA threads, 32 UTCMMA threads + + +static constexpr int D_tQ = 384, NUM_tQ_TILES = D_tQ / 64; +static constexpr int D_sQ = D_QK-D_tQ, NUM_sQ_TILES = D_sQ / 64; +static_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q); + +// Tensor memory columns +struct tmem_cols { + // 0 ~ 256: output + // 256 ~ 320: P + // 320 ~ 512: Q[D_QK-D_tQ:] + static constexpr int o = 0; + static constexpr int p = 256; + static constexpr int q = 512 - D_tQ/2; + static_assert(p+64 <= q); +}; + +template +using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutO = SmemLayoutOTiles<8>; + +template +using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutV = decltype(coalesce(tile_to_shape( + UMMA::Layout_MN_SW128_Atom{}, + Shape, Int>{}, + Step<_2, _1>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutSTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +struct SharedMemoryPlan { + union { + array_aligned>> q_full; + struct { + array_aligned>> sq; + array_aligned> v; + // NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q + static_assert(cosize_v> <= cosize_v> + cosize_v); + array_aligned>> k; + } s; + array_aligned> o; + } u; + array_aligned>> s; + float p[(B_H/2)*B_TOPK]; + char is_k_valid[NUM_BUFS][B_TOPK/8]; + transac_bar_t bar_prologue_q, bar_prologue_utccp; + transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS]; // Pi = QKi^T done (i.e. Ki free) + transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. Vi free) + transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS]; + transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS]; // Vi is ready + transac_bar_t bar_p_free[NUM_BUFS]; + transac_bar_t bar_so_ready[NUM_BUFS]; // S and O are ready + transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS]; + array_aligned tmem_start_addr; + float rowwise_max_buf[128], rowwise_li_buf[128]; +}; + +using TiledMMA_P_tQ = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_TS_NOELECT{} +)); + +using TiledMMA_P_sQ = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_SS_NOELECT{} +)); + +using TiledMMA_O = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_SS_NOELECT{}, + Layout>{}, + Tile, Layout, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512] +)); + +template +static __device__ void +sparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams ¶ms, const TmaParams &tma_params); + +}; + +} diff --git a/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu b/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu new file mode 100644 index 0000000..5dcec83 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu @@ -0,0 +1,8 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm100::fwd::head128 { + +template void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu b/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu new file mode 100644 index 0000000..bfd01ce --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu @@ -0,0 +1,8 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm100::fwd::head128 { + +template void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/fwd.cu b/csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh similarity index 72% rename from csrc/sm100/prefill/sparse/fwd.cu rename to csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh index 963ac78..ec1192b 100644 --- a/csrc/sm100/prefill/sparse/fwd.cu +++ b/csrc/sm100/prefill/sparse/fwd/head128/phase1.cuh @@ -1,4 +1,5 @@ -#include "fwd.h" +#pragma once +#include "phase1.h" #include #include @@ -9,12 +10,11 @@ #include "params.h" #include "utils.h" -#include "sm100/ws_gemm.h" #include "sm100/helpers.h" -#include "sm100/intrinsics.h" -#include "sm100/tma_cta_group2_nosplit.h" -namespace sm100 { +#include "config.h" + +namespace sm100::fwd::head128 { using namespace cute; @@ -28,120 +28,6 @@ CUTE_DEVICE int32x8_t ldg_256_indices(void* src_ptr) { return val; } -template< - typename Shape_Q, typename TMA_Q, - typename Shape_O, typename TMA_O -> -struct TmaParams { - Shape_Q shape_Q; TMA_Q tma_Q; - Shape_O shape_O; TMA_O tma_O; - CUtensorMap tensor_map_kv; -}; - -struct float2x2 { - float2 lo, hi; -}; - -constexpr int D_Q = 576; -constexpr int D_K = 576; -constexpr int D_V = 512; -constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan - -constexpr int B_H = 128; // For 2 CTAs -constexpr int B_TOPK = 128; // For 2 CTAs -constexpr int NUM_BUFS = 2; -constexpr int NUM_THREADS = 256 + 128 + 128; // 128 TMA threads, 128 scale & exp threads, 32 UTCMMA threads - -constexpr int D_sQ = 256, NUM_sQ_TILES = D_sQ / 64; -constexpr int D_tQ = D_Q - D_sQ, NUM_tQ_TILES = D_tQ / 64; -static_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q); - -// Tensor memory columns -namespace tmem_cols { - // 0 ~ 256: output - // 256 ~ 320: P - // 320 ~ 512: Q[192:576] - constexpr int o = 0; - constexpr int p = 256; - constexpr int q = 512 - D_tQ/2; - static_assert(p+64 <= q); -} - -template -using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( - UMMA::Layout_K_SW128_Atom{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -), Shape<_1, _1>{})); - -template -using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( - UMMA::Layout_K_SW128_Atom{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -), Shape<_1, _1>{})); - -using SmemLayoutO = SmemLayoutOTiles<8>; - -template -using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( - UMMA::Layout_K_SW128_Atom{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -), Shape<_1, _1>{})); - -using SmemLayoutV = decltype(coalesce(tile_to_shape( - UMMA::Layout_MN_SW128_Atom{}, - Shape, Int>{}, - Step<_2, _1>{} -), Shape<_1, _1>{})); - -template -using SmemLayoutSTiles = decltype(coalesce(tile_to_shape( - UMMA::Layout_K_INTER_Atom{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -), Shape<_1, _1>{})); - -struct SharedMemoryPlan { - union { - array_aligned>> q_full; - struct { - array_aligned>> sq; - array_aligned> v; - // NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q - array_aligned>> k; - } s; - array_aligned> o; - } u; - array_aligned>> s; - char is_k_valid[NUM_BUFS][B_TOPK/8]; - transac_bar_t bar_prologue_q, bar_prologue_utccp; - transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS]; // Pi = QKi^T done (i.e. Ki free) - transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. Vi free) - transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS]; - transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS]; // Vi is ready - transac_bar_t bar_p_free[NUM_BUFS]; - transac_bar_t bar_so_ready[NUM_BUFS]; // S and O are ready - transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS]; - array_aligned tmem_start_addr; - float rowwise_max_buf[128], rowwise_li_buf[128]; -}; - -using TiledMMA_P_tQ = decltype(make_tiled_mma( - SM100_MMA_F16BF16_2x1SM_TS_NOELECT{} -)); - -using TiledMMA_P_sQ = decltype(make_tiled_mma( - SM100_MMA_F16BF16_2x1SM_SS_NOELECT{} -)); - -using TiledMMA_O = decltype(make_tiled_mma( - SM100_MMA_F16BF16_2x1SM_SS_NOELECT{}, - Layout>{}, - Tile, Layout, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512] -)); - /* Pipeline Overview: @@ -176,15 +62,17 @@ V(n-1) scale(O) w.r.t P(n-1) O += S(n-1)V(n-1) */ +template template -__global__ void __launch_bounds__(NUM_THREADS, 1, 2) -sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) { -#if IS_SM100 +__device__ void +KernelTemplate::sparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams ¶ms, const TmaParams &tma_params) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) const int cta_idx = blockIdx.x % 2; const int s_q_idx = blockIdx.x / 2; const int warp_idx = cutlass::canonical_warp_idx_sync(); const int lane_idx = threadIdx.x % 32; - const int num_k_blocks = params.topk / B_TOPK; + const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk; + const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1 const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); const int idx_in_warpgroup = threadIdx.x % 128; @@ -198,7 +86,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri // Define shared tensors extern __shared__ char wksp_buf[]; SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); - Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles<9>{}); + Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles{}); int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk] @@ -248,17 +136,17 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx), Tile>{} )(_, cta_idx, _); - launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST); + ku::launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST); } // Initialize TMEM - // We put this before cluster_arrive to make sure that the TMEM allocation is done before UTCCP cute::TMEM::Allocator2Sm().allocate(512, plan.tmem_start_addr.data()); TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); cute::TMEM::Allocator2Sm().release_allocation_lock(); - __syncwarp(); } + __syncthreads(); // Wait for TMEM allocation + if (warpgroup_idx == 0) { cutlass::arch::warpgroup_reg_alloc<144>(); // Scale & Exp warps @@ -276,18 +164,19 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2}; uint128_t* sS_base = (uint128_t*)plan.s.data() + idx_in_warpgroup%64 + 64*((idx_in_warpgroup/64)*8); + float* sP_base = plan.p + idx_in_warpgroup%64*4 + (idx_in_warpgroup/64)*((B_H/2)*(B_TOPK/2)); CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { // Wait for P plan.bar_qk_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1); - tcgen05_after_thread_sync(); + ku::tcgen05_after_thread_sync(); // Load P float2 p[(B_TOPK/2)/2]; - tmem_ld_32dp32bNx(tmem_cols::p, p); + ku::tmem_ld_32dp32bNx(tmem_cols::p, p); cutlass::arch::fence_view_async_tmem_load(); - tcgen05_before_thread_sync(); + ku::tcgen05_before_thread_sync(); plan.bar_p_free[k%NUM_BUFS].arrive(0u); // Mask @@ -330,6 +219,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...) // - should_scale_o is identical among threads 0~31+64~95; and is identical among threads 32~63+96~127 + // Calc scale factor, and scale li float new_max, scale_for_old; if (!should_scale_o) { @@ -348,10 +238,10 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri float2 neg_new_max = float2 {-new_max, -new_max}; CUTE_UNROLL for (int i = 0; i < (B_TOPK/2)/2; i += 1) { - float2 d = float2_fma(p[i], scale, neg_new_max); + float2 d = ku::float2_fma(p[i], scale, neg_new_max); d.x = exp2f(d.x); d.y = exp2f(d.y); - li += d.x + d.y; // NOTE Theorically we can have use FFMA2 here but actually this is faster... + li += d.x + d.y; // NOTE: Theoretically we could use FFMA2 here but actually this is faster... s[i] = __float22bfloat162_rn(d); } @@ -367,27 +257,27 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri // Scale O if (k > 0 && should_scale_o) { float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old}; - // plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE We have waited for last SV gemm before - tcgen05_after_thread_sync(); + // plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE: We have waited for last SV gemm before + ku::tcgen05_after_thread_sync(); static constexpr int CHUNK_SIZE = 32; float2 o[CHUNK_SIZE/2]; CUTE_UNROLL for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) { // Load O - tmem_ld_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); + ku::tmem_ld_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); cutlass::arch::fence_view_async_tmem_load(); // Mult for (int i = 0; i < CHUNK_SIZE/2; ++i) { - o[i] = float2_mul(o[i], scale_for_old_float2); + o[i] = ku::float2_mul(o[i], scale_for_old_float2); } // Store O - tmem_st_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); + ku::tmem_st_32dp32bNx(tmem_cols::o + chunk_idx*CHUNK_SIZE, o); cutlass::arch::fence_view_async_tmem_store(); } - tcgen05_before_thread_sync(); + ku::tcgen05_before_thread_sync(); } fence_view_async_shared(); @@ -411,17 +301,19 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri // Store mi and li if (idx_in_warpgroup < 64) { int global_index = s_q_idx*params.h_q + cta_idx*(B_H/2) + idx_in_warpgroup; - float cur_lse = log2f(li) + mi; - params.max_logits[global_index] = real_mi; + float cur_lse = logf(li) + mi*CUDART_LN2_F; + cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse; + params.max_logits[global_index] = real_mi*CUDART_LN2_F; params.lse[global_index] = cur_lse; } // Wait for the last GEMM plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1); - tcgen05_after_thread_sync(); + ku::tcgen05_after_thread_sync(); // Store O - float output_scale = __fdividef(1.0f, li); + float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + cta_idx*B_H/2 + (idx_in_warpgroup%64))*CUDART_L2E_F; + float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi)); Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{}); constexpr int B_EPI = 64; Tensor tma_gO = flat_divide( @@ -435,7 +327,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri auto thr_tma = tma_params.tma_O.get_slice(_0{}); float2 o[B_EPI/2]; - bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during tmem_ld + bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld if (!have_valid_indices) { // If there are no valid indices, we set o[i] to 0 and don't load from TMEM CUTE_UNROLL @@ -450,7 +342,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri for (int k = 0; k < (D_V/2)/B_EPI; ++k) { // Load O from tO if (have_valid_indices) { - tmem_ld_32dp32bNx(tmem_cols::o + k*B_EPI, o); + ku::tmem_ld_32dp32bNx(tmem_cols::o + k*B_EPI, o); cutlass::arch::fence_view_async_tmem_load(); } @@ -460,7 +352,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri __nv_bfloat162 o_bf16[4]; CUTE_UNROLL for (int j = 0; j < 4; ++j) { - float2 d = float2_mul(o[i*4+j], output_scale_float2); + float2 d = ku::float2_mul(o[i*4+j], output_scale_float2); o_bf16[j] = __float22bfloat162_rn(d); } int smem_row = idx_in_warpgroup % 64; @@ -503,22 +395,28 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { int4 indices[NUM_LOCAL_ROWS_PER_WARP]; + int max_indices = -1, min_indices = params.s_kv; CUTE_UNROLL - for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) + for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK + cta_idx*(B_TOPK/2)) + local_row*NUM_WARPS + warp_idx); + max_indices = max(max_indices, int4_max(indices[local_row])); + min_indices = min(min_indices, int4_min(indices[local_row])); + } + bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1; + bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS; - auto load_part_ki = [&](transac_bar_t* bar, int local_col_start, int local_col_end) { + auto load_part_ki = [&](transac_bar_t &bar, int local_col_start, int local_col_end) { CUTE_UNROLL for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { CUTE_UNROLL for (int local_col = local_col_start; local_col < local_col_end; ++local_col) - tma_gather4( + ku::tma_gather4_cta_group_2( &(tma_params.tensor_map_kv), bar, sK_base + local_row*(4*NUM_WARPS)*64 + local_col*((B_TOPK/2)*64), local_col*64, indices[local_row], - TMA::CacheHintSm90::EVICT_LAST + (int64_t)TMA::CacheHintSm90::EVICT_LAST ); } }; @@ -527,12 +425,23 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri if (k > 0) { plan.bar_qk_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } - load_part_ki(plan.bar_k_part0_ready+cur_buf, 0, D_sQ/64); + if (!should_skip_tma) { + load_part_ki(plan.bar_k_part0_ready[cur_buf], 0, D_sQ/64); + } else { + // NOTE: TMA has performance issues when all indices are the same (even if those indices are invalid), so we detect whether all indices in our block are invalid (by inspecting their MIN and MAX, for performance reasons), and skip the copy if all indices are invalid. + // NOTE: We can also skip the initial zero-fill procedure (which prevents NaN from appearing in K/V buf if the first TMA copy is skipped) by disabling skipping on the first NUM_BUFS TMAs. + // NOTE: We only do this for K to save some checking overhead, since after doing this for K, cases where topk indices are all invalid are faster than the other cases + plan.bar_k_part0_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_sQ*sizeof(bf16), 1u); + } if (k > 0) { plan.bar_qk_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } - load_part_ki(plan.bar_k_part1_ready+cur_buf, D_sQ/64, D_K/64); + if (!should_skip_tma) { + load_part_ki(plan.bar_k_part1_ready[cur_buf], D_sQ/64, D_K/64); + } else { + plan.bar_k_part1_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_tQ*sizeof(bf16), 1u); + } } } } else if (warpgroup_idx == 2) { @@ -549,19 +458,19 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks; ++k) { - auto load_part_vi = [&](transac_bar_t* bar, int local_row_start, int local_row_end) { + auto load_part_vi = [&](transac_bar_t &bar, int local_row_start, int local_row_end) { CUTE_UNROLL for (int local_row = local_row_start; local_row < local_row_end; ++local_row) { int4 token_idxs = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx); CUTE_UNROLL for (int local_col = 0; local_col < (D_V/2)/64; ++local_col) - tma_gather4( + ku::tma_gather4_cta_group_2( &(tma_params.tensor_map_kv), bar, sV_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64), local_col*64 + (cta_idx?256:0), token_idxs, - TMA::CacheHintSm90::EVICT_LAST + (int64_t)TMA::CacheHintSm90::EVICT_LAST ); } }; @@ -570,12 +479,12 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri if (k > 0) { plan.bar_sv_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } - load_part_vi(plan.bar_v_part0_ready+cur_buf, 0, (B_TOPK/2)/4/NUM_WARPS); + load_part_vi(plan.bar_v_part0_ready[cur_buf], 0, (B_TOPK/2)/4/NUM_WARPS); if (k > 0) { plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } - load_part_vi(plan.bar_v_part1_ready+cur_buf, (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS); + load_part_vi(plan.bar_v_part1_ready[cur_buf], (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS); } } } else { @@ -595,7 +504,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri ); plan.bar_prologue_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16)); plan.bar_prologue_q.wait(0); - tcgen05_after_thread_sync(); + ku::tcgen05_after_thread_sync(); CUTE_UNROLL for (int tile_idx = 0; tile_idx < NUM_tQ_TILES; ++tile_idx) { // A tile is 64 rows * 64 cols (128B) @@ -608,7 +517,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri ); } } - umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2); + ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2); CUTE_NO_UNROLL for (int k = 0; k < num_k_blocks+1; ++k) { @@ -625,18 +534,18 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri if (k > 0) { plan.bar_p_free[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); } - tcgen05_after_thread_sync(); + ku::tcgen05_after_thread_sync(); - utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true); - umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2); + ku::utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true); + ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2); // Wait for K (part1) plan.bar_k_part1_ready[cur_buf].arrive_and_expect_tx(B_TOPK*(D_K-D_sQ)*sizeof(bf16)); plan.bar_k_part1_ready[cur_buf].wait((k/NUM_BUFS)&1); - tcgen05_after_thread_sync(); + ku::tcgen05_after_thread_sync(); - utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false); - umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2); + ku::utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false); + ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2); } if (k > 0) { // O += S(i-1)V(i-1) @@ -653,17 +562,17 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri // Wait for V (part0), and issue O += sS @ sV plan.bar_v_part0_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16)); plan.bar_v_part0_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); - tcgen05_after_thread_sync(); + ku::tcgen05_after_thread_sync(); - utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1); - umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2); + ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1); + ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2); // Wait for V (part1), and issue O += sS @ sV plan.bar_v_part1_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16)); plan.bar_v_part1_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1); - tcgen05_after_thread_sync(); - utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false); - umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2); + ku::tcgen05_after_thread_sync(); + ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false); + ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2); } } } else if (warp_idx == 13) { @@ -674,18 +583,19 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri for (int k = 0; k < num_k_blocks; ++k) { int cur_buf = k%NUM_BUFS; int32x8_t indices = ldg_256_indices(gIndices + k*B_TOPK + lane_idx*8); - auto is_valid = [&](int index) -> char { - return index >= 0 && index < params.s_kv; + auto is_valid = [&](int rel_pos_in_lane, int index) -> char { + int abs_pos = k*B_TOPK + lane_idx*8 + rel_pos_in_lane; + return index >= 0 && index < params.s_kv && abs_pos < topk_length; }; char is_ks_valid_mask = \ - is_valid(indices.a7) << 7 | - is_valid(indices.a6) << 6 | - is_valid(indices.a5) << 5 | - is_valid(indices.a4) << 4 | - is_valid(indices.a3) << 3 | - is_valid(indices.a2) << 2 | - is_valid(indices.a1) << 1 | - is_valid(indices.a0) << 0; + is_valid(7, indices.a7) << 7 | + is_valid(6, indices.a6) << 6 | + is_valid(5, indices.a5) << 5 | + is_valid(4, indices.a4) << 4 | + is_valid(3, indices.a3) << 3 | + is_valid(2, indices.a2) << 2 | + is_valid(1, indices.a1) << 1 | + is_valid(0, indices.a0) << 0; plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1); plan.is_k_valid[cur_buf][lane_idx] = is_ks_valid_mask; @@ -695,6 +605,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri } } + #else if (cute::thread0()) { CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100"); @@ -702,10 +613,21 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri #endif } -void run_fwd_kernel(const SparsePrefillParams& params) { - FLASH_ASSERT(params.h_kv == 1); - FLASH_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings - FLASH_ASSERT(params.h_q == B_H); // To save some calculation +template +__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2) +sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) { + Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params); +} + +template +void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { + static_assert(D_QK == 576 || D_QK == 512); + using Kernel = KernelTemplate; + + KU_ASSERT(params.h_kv == 1); + KU_ASSERT(params.topk % Kernel::B_TOPK == 0); // To save some boundry checkings + KU_ASSERT(params.h_q == Kernel::B_H); // To save some calculation + KU_ASSERT(params.d_qk == D_QK); auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); auto tma_Q = cute::make_tma_copy( @@ -717,7 +639,7 @@ void run_fwd_kernel(const SparsePrefillParams& params) { make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) ) ), - SmemLayoutQTiles<9>{} + (typename Kernel::template SmemLayoutQTiles){} ); auto shape_O = make_shape(params.h_q, params.d_v, params.s_q); @@ -730,12 +652,12 @@ void run_fwd_kernel(const SparsePrefillParams& params) { make_stride(params.d_v, _1{}, params.h_q*params.d_v) ) ), - SmemLayoutOTiles<1>{} + (typename Kernel::template SmemLayoutOTiles<1>){} ); CUtensorMap tensor_map_kv; { - uint64_t size[2] = {D_K, (unsigned long)params.s_kv}; + uint64_t size[2] = {D_QK, (unsigned long)params.s_kv}; uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)}; uint32_t box_size[2] = {64, 1}; uint32_t elem_stride[2] = {1, 1}; @@ -753,7 +675,7 @@ void run_fwd_kernel(const SparsePrefillParams& params) { CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE ); - FLASH_ASSERT(res == CUresult::CUDA_SUCCESS); + KU_ASSERT(res == CUresult::CUDA_SUCCESS); } TmaParams< @@ -764,22 +686,21 @@ void run_fwd_kernel(const SparsePrefillParams& params) { shape_O, tma_O, tensor_map_kv }; - auto kernel = &sparse_attn_fwd_kernel; + auto kernel = &sparse_attn_fwd_kernel; - constexpr size_t smem_size = sizeof(SharedMemoryPlan); - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + constexpr size_t smem_size = sizeof(typename Kernel::SharedMemoryPlan); + KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); cutlass::ClusterLaunchParams launch_params = { dim3(2*params.s_q, 1, 1), - dim3(NUM_THREADS, 1, 1), + dim3(Kernel::NUM_THREADS, 1, 1), dim3(2, 1, 1), smem_size, params.stream }; - cutlass::launch_kernel_on_cluster( + KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster( launch_params, (void*)kernel, params, tma_params - ); - CHECK_CUDA_KERNEL_LAUNCH(); + )); } } diff --git a/csrc/sm100/prefill/sparse/fwd/head128/phase1.h b/csrc/sm100/prefill/sparse/fwd/head128/phase1.h new file mode 100644 index 0000000..b105780 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd/head128/phase1.h @@ -0,0 +1,10 @@ +#pragma once + +#include "params.h" + +namespace sm100::fwd::head128 { + +template +void run_fwd_phase1_kernel(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/fwd/head64/config.h b/csrc/sm100/prefill/sparse/fwd/head64/config.h new file mode 100644 index 0000000..8d6eb77 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd/head64/config.h @@ -0,0 +1,157 @@ +#pragma once + +#include +#include + +#include "defines.h" + +namespace sm100::fwd::head64 { + +using namespace cute; + +template< + typename Shape_Q_NoPE, typename TMA_Q_NoPE, + typename Shape_Q_RoPE, typename TMA_Q_RoPE, + typename Shape_O, typename TMA_O +> +struct TmaParams { + Shape_Q_NoPE shape_Q_nope; TMA_Q_NoPE tma_Q_nope; + Shape_Q_RoPE shape_Q_rope; TMA_Q_RoPE tma_Q_rope; + Shape_O shape_O; TMA_O tma_O; + CUtensorMap tensor_map_kv_nope; +}; + +struct float2x2 { + float2 lo, hi; +}; + +constexpr int D_Q = 576; +constexpr int D_K = 576; +constexpr int D_V = 512; +constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan + +constexpr int B_H = 64; +constexpr int B_TOPK = 64; +constexpr int NUM_BUFS = 3; +constexpr int NUM_THREADS = 128 + 128 + 128; // 128 scale & exp threads, 128 TMA threads, 32 UTCMMA threads + + +// Tensor memory columns +namespace tmem_cols { + // 0 ~ 256: output + // 256 ~ 400: Q + // 400 ~ 464: P + constexpr int O = 0; + constexpr int Q = 256; + constexpr int Q_RoPE = 256 + 128; + constexpr int P = 400; +} + +using SmemLayoutQNoPE = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutQRoPE = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW64_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutO = SmemLayoutOTiles<8>; + +template +using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutKNoPE = SmemLayoutKTiles<8>; +using SmemLayoutV = decltype(coalesce( + composition( + SmemLayoutKNoPE{}, + Layout, Int>, Stride, _1>>{} + ) +, Shape<_1, _1>{})); + +using SmemLayoutKRoPE = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW64_Atom{}, + Shape, Int<64>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutKNoPE_TiledMMA = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); // Re-view K-NoPE as B_TOPK*2 x D_V/2 for dual gemm + +using SmemLayoutKRoPE_TiledMMA = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_SW64_Atom{}, + Shape, Int<64/2>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +using SmemLayoutS = decltype(coalesce(tile_to_shape( + UMMA::Layout_K_INTER_Atom{}, + Shape, Int>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + + +struct SharedMemoryPlan { + union { + struct { + array_aligned> _k_rope_pad; + array_aligned> _k_pad[2]; // So that q_nope covers k[2] + array_aligned> q_nope; + } q_full; + struct { + array_aligned> k_rope; + array_aligned> k_nope[NUM_BUFS]; + } k; + array_aligned> o; + } u; + float p_exchange_buf[4][32 * (B_TOPK/2)]; + union { + bf16 s[B_H*B_TOPK]; + array_aligned> q_rope; + } s_q_rope; + char is_k_valid[NUM_BUFS][B_TOPK/8]; + transac_bar_t bar_prologue_q_nope, bar_prologue_q_rope, bar_prologue_utccp_nope, bar_prologue_utccp_rope; + transac_bar_t bar_qk_nope_done[NUM_BUFS], bar_qk_rope_done; // Pi = QKi^T (the nope part) done + transac_bar_t bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. O, Si and Vi are free) + transac_bar_t bar_kv_nope_ready[NUM_BUFS][2], bar_kv_rope_ready; + transac_bar_t bar_p_free; + transac_bar_t bar_so_ready; // S and O are ready + transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS]; + array_aligned tmem_start_addr; + float rowwise_max_buf[128], rowwise_li_buf[128]; +}; + +using TiledMMA_P = decltype(make_tiled_mma( + SM100_MMA_F16BF16_WS_TS_NOELECT{} // Here we use N = 128 = 2*B_TOPK since we're going to use implicit dual gemm: +)); + +using TiledMMA_O = decltype(make_tiled_mma( + SM100_MMA_F16BF16_WS_SS_NOELECT{} +)); + +enum NamedBarriers : int { + wg0_sync = 0, + wg0_warp02_sync = 1, + wg0_warp13_sync = 2, + pepi_sync = 3, +}; + + +} diff --git a/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu b/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu new file mode 100644 index 0000000..e1c87be --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu @@ -0,0 +1,8 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm100::fwd::head64 { + +template void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu b/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu new file mode 100644 index 0000000..1bd214e --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu @@ -0,0 +1,8 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm100::fwd::head64 { + +template void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh b/csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh new file mode 100644 index 0000000..b510b27 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd/head64/phase1.cuh @@ -0,0 +1,673 @@ +#pragma once +#include "phase1.h" + +#include +#include +#include +#include +#include + +#include + +#include "params.h" +#include "utils.h" +#include "sm100/helpers.h" +#include "sm100/prefill/sparse/common_subroutine.h" +#include "config.h" + +namespace sm100::fwd::head64 { + +using namespace cute; + +/* +Pipeline Overview: + +| Copy | MMA | Scale & Exp | + +KV0 +KV1 +KV2 + P0 = QK0^T + S0 = exp(P0) + scale(O) w.r.t P0 + P1 = QK1^T + S1 = exp(P1) + O += S0V0 +KV3 scale(O) w.r.t P1 + P2 = QK2^T + S2 = exp(P2) + O += S1V1 +KV4 scale(O) w.r.t P2 + P3 = QK3^T + S3 = exp(P3) + O += S2V2 +KV5 scale(O) w.r.t P3 + +... + + O += S(n-3)V(n-3) + scale(O) w.r.t P(n-2) + P(n-1) = QK(n-1)^T + S(n-1) = exp(P(n-1)) + O += S(n-2)V(n-2) + scale(O) w.r.t P(n-1) + O += S(n-1)V(n-1) +*/ + +using FwdMode = SparseAttnFwdMode; + +template +__global__ void __launch_bounds__(NUM_THREADS, 1, 1) +sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) + // Grid shape: [s_q, 1, 1] + + const int s_q_idx = blockIdx.x; + const int warp_idx = cutlass::canonical_warp_idx_sync(); + const int lane_idx = threadIdx.x % 32; + const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const int idx_in_warpgroup = threadIdx.x % 128; + const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk; + const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1 + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + + int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk] + + // Allocate tmem tensors + TiledMMA tiled_mma_P = TiledMMA_P{}; + TiledMMA tiled_mma_O = TiledMMA_O{}; + // NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm) + Tensor tP = partition_fragment_C(tiled_mma_P, Shape, _128>{}); + Tensor tQ_nope_part0 = tiled_mma_P.get_slice(_0{}).make_fragment_A( + partition_shape_A(tiled_mma_P, Shape, Int<(D_V/2)/2>>{}) + ); + Tensor tQ_nope_part1 = tiled_mma_P.get_slice(_0{}).make_fragment_A( + partition_shape_A(tiled_mma_P, Shape, Int<(D_V/2)/2>>{}) + ); + Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A( + partition_shape_A(tiled_mma_P, Shape, Int<64/2>>{}) + ); + Tensor tO = partition_fragment_C(tiled_mma_O, Shape, Int>{}); + tP.data().get() = tmem_cols::P; + tQ_nope_part0.data().get() = tmem_cols::Q; + tQ_nope_part1.data().get() = tmem_cols::Q + 64; + tQ_rope.data().get() = tmem_cols::Q_RoPE; + tO.data().get() = tmem_cols::O; + + if (warp_idx == 0) { + if (elect_one_sync()) { + // Copy Q + if constexpr (HAVE_ROPE) { + cute::prefetch_tma_descriptor(tma_params.tma_Q_rope.get_tma_descriptor()); + } + cute::prefetch_tma_descriptor(tma_params.tma_Q_nope.get_tma_descriptor()); + + plan.bar_prologue_q_nope.init(1); + plan.bar_prologue_q_rope.init(1); + fence_barrier_init(); + + if constexpr (HAVE_ROPE) { + Tensor gQ_rope = tma_params.tma_Q_rope.get_tma_tensor(tma_params.shape_Q_rope)(_, _, s_q_idx); + Tensor sQ_rope = make_tensor(make_smem_ptr(plan.s_q_rope.q_rope.data()), SmemLayoutQRoPE{}); + ku::launch_tma_copy(tma_params.tma_Q_rope, gQ_rope, sQ_rope, plan.bar_prologue_q_rope, TMA::CacheHintSm90::EVICT_FIRST); + } + + Tensor gQ_nope = tma_params.tma_Q_nope.get_tma_tensor(tma_params.shape_Q_nope)(_, _, s_q_idx); + Tensor sQ_nope = make_tensor(make_smem_ptr(plan.u.q_full.q_nope.data()), SmemLayoutQNoPE{}); + ku::launch_tma_copy(tma_params.tma_Q_nope, gQ_nope, sQ_nope, plan.bar_prologue_q_nope, TMA::CacheHintSm90::EVICT_FIRST); + + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv_nope)); + + // Initialize other barriers + plan.bar_prologue_utccp_rope.init(1); + plan.bar_prologue_utccp_nope.init(1); + CUTE_UNROLL + for (int i = 0; i < NUM_BUFS; ++i) { + plan.bar_qk_nope_done[i].init(1); + plan.bar_sv_done[i].init(1); + plan.bar_kv_nope_ready[i][0].init(1); + plan.bar_kv_nope_ready[i][1].init(1); + plan.bar_k_valid_ready[i].init(B_TOPK/8); + plan.bar_k_valid_free[i].init(128); + } + plan.bar_p_free.init(128); + plan.bar_so_ready.init(128); + plan.bar_qk_rope_done.init(1); + plan.bar_kv_rope_ready.init(64); + fence_barrier_init(); + } + + // Initialize TMEM + cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data()); + TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0); + cute::TMEM::Allocator1Sm().release_allocation_lock(); + } + + __syncthreads(); + + if (warpgroup_idx == 0) { + // Scale & Exp warps + + // The following three numbers are + // - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V) + // - li: sumexp, i.e. li := sum(exp(Pi*scale - mi)) + // - real_mi: real max logits, i.e. real_mi := max(Pi*scale) + // where Pi is the i-th row of P, P := QK^T + // mi and real_mi are always consistent within the two threads that + // controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update + float mi = MAX_INIT_VAL; + float li = 0.0f; + float real_mi = -CUDART_INF_F; + + bf16* sS_base = plan.s_q_rope.s + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2); + static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2; + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + // Wait for P + NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1)); + plan.bar_qk_nope_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1); + plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1); // Put the barrier wait here for more code reordering space + ku::tcgen05_after_thread_sync(); + + // Load P + float p[NUM_ELEMS_PER_THREAD]; + retrieve_mask_and_reduce_p< + NUM_ELEMS_PER_THREAD, + tmem_cols::P, + NamedBarriers::wg0_warp02_sync, + NamedBarriers::wg0_warp13_sync, + false + >( + plan.is_k_valid[k%NUM_BUFS], + warp_idx, lane_idx, + [&]() {plan.bar_p_free.arrive();}, + plan.p_exchange_buf, + p + ); + plan.bar_k_valid_free[k%NUM_BUFS].arrive(); + + // Get rowwise max of Pi + float cur_pi_max = get_max(p); + cur_pi_max *= params.sm_scale_div_log2; + + plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max; + NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); + cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]); + real_mi = max(real_mi, cur_pi_max); + bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f); + // By this point: + // - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...) + // - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127) + + + // Calc scale factor, and scale li + float new_max, scale_for_old; + if (!should_scale_o) { + // Don't scale O + scale_for_old = 1.0f; + new_max = mi; + } else { + new_max = max(cur_pi_max, mi); + scale_for_old = exp2f(mi - new_max); + } + mi = new_max; // mi is still identical within each row + + // Calculate S + nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2]; + float cur_sum = get_s_from_p(s, p, params.sm_scale_div_log2, new_max); + li = fma(li, scale_for_old, cur_sum); + + // Wait for last SV gemm, write S + if (k > 0) { + plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); + } + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; i += 1) { + *(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4); + } + + // Scale O + if (k > 0 && should_scale_o) { + // plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE We have waited for last SV gemm before + ku::tcgen05_after_thread_sync(); + rescale_O(scale_for_old); + ku::tcgen05_before_thread_sync(); + } + + fence_view_async_shared(); + plan.bar_so_ready.arrive(); + } + + // Epilogue + + if (real_mi == -CUDART_INF_F) { + // real_mi == -CUDART_INF_F <=> No valid TopK indices + // We set li to 0 to fit the definition that li := exp(x[i] - mi) + li = 0.0f; + mi = -CUDART_INF_F; + } + + // Exchange li + plan.rowwise_li_buf[idx_in_warpgroup] = li; + NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); + li += plan.rowwise_li_buf[idx_in_warpgroup^64]; + + // Store mi and li + if (idx_in_warpgroup < 64) { + int global_index = s_q_idx*params.h_q + idx_in_warpgroup; + float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li)); + cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse; + params.max_logits[global_index] = real_mi*CUDART_LN2_F; + params.lse[global_index] = cur_lse; + } + + // Wait for the last GEMM + plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1); + ku::tcgen05_after_thread_sync(); + + // Fetch dO if necessary + + // Store O + float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + (idx_in_warpgroup%64))*CUDART_L2E_F; + float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi)); + Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{}); + constexpr int B_EPI = 64; + Tensor tma_gO = flat_divide( + tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx), + Shape, Int>{} + )(_, _, _0{}, _); + Tensor sO_divided = flat_divide( + sO, + Shape, Int>{} + )(_, _, _0{}, _); + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + + float2 o[B_EPI/2]; + bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld + if (!have_valid_indices) { + // If there are no valid indices, we set o[i] to 0 and don't load from TMEM + CUTE_UNROLL + for (int i = 0; i < B_EPI/2; ++i) + o[i].x = o[i].y = 0.0f; + output_scale = 1.0f; + } + + float2 output_scale_float2 = make_float2(output_scale, output_scale); + + bf16* sO_addrs[8]; + CUTE_UNROLL + for (int i = 0; i < B_EPI/8; ++i) { + sO_addrs[i] = &sO(idx_in_warpgroup%64, i*8); + } + + CUTE_UNROLL + for (int c = 0; c < 2; ++c) { + // Each tile: 64 x 256 + CUTE_UNROLL + for (int k = 0; k < (D_V/4)/B_EPI; ++k) { + // Load O from tO + if (have_valid_indices) { + ku::tmem_ld_32dp32bNx(tmem_cols::O + c*128 + k*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + } + + // Convert and store + CUTE_UNROLL + for (int i = 0; i < B_EPI/8; ++i) { + nv_bfloat162 o_bf16[4]; + CUTE_UNROLL + for (int j = 0; j < 4; ++j) { + o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2); + o_bf16[j] = __float22bfloat162_rn(o[i*4+j]); + } + *(uint128_t*)(sO_addrs[i] + (c*(D_V/2) + (idx_in_warpgroup/64)*(D_V/4) + k*B_EPI)*64) = *(uint128_t*)(o_bf16); + } + + // Sync + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); + + if (warp_idx == 0 && elect_one_sync()) { + int epi_chunk_idx = c*(D_V/2/B_EPI) + k; + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)), + thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx)) + ); + } + if (warp_idx == 1 && elect_one_sync()) { + int epi_chunk_idx = c*(D_V/2/B_EPI) + (D_V/B_EPI/4) + k; + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)), + thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx)) + ); + } + } + } + + + if (warp_idx == 0) { + cute::TMEM::Allocator1Sm().free(0, 512); + } + } else if (warpgroup_idx == 1) { + // Producer warp for KV + int warp_idx = cutlass::canonical_warp_idx_sync() - 4; + constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/4)/NUM_WARPS; + if (elect_one_sync()) { + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + int4 indices[NUM_LOCAL_ROWS_PER_WARP]; + int max_indices = -1, min_indices = params.s_kv; + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { + indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx); + max_indices = max(max_indices, int4_max(indices[local_row])); + min_indices = min(min_indices, int4_min(indices[local_row])); + } + bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1; + bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS; + + if (k == 2) { + plan.bar_prologue_utccp_nope.wait(0); // Since q_nope coincidences with k[2] + } + + // Copy NoPE + int cur_buf = k%NUM_BUFS; + plan.bar_sv_done[cur_buf].wait((k/NUM_BUFS)&1^1); + bf16* sK_nope_base = plan.u.k.k_nope[cur_buf].data() + warp_idx*4*64; + + auto load_kv_nope_part = [&](int part_idx) { + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) { + CUTE_UNROLL + for (int local_col = part_idx*(D_V/2/64); local_col < (part_idx+1)*(D_V/2/64); ++local_col) { + ku::tma_gather4( + &(tma_params.tensor_map_kv_nope), + plan.bar_kv_nope_ready[cur_buf][part_idx], + sK_nope_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64), + local_col*64, + indices[local_row], + (int64_t)TMA::CacheHintSm90::EVICT_LAST + ); + } + } + }; + + if (!should_skip_tma) { + load_kv_nope_part(0); + load_kv_nope_part(1); + } else { + // NOTE See head128/phase1.cuh for this TMA skipping technique + CUTE_UNROLL + for (int part_idx = 0; part_idx < 2; ++part_idx) + plan.bar_kv_nope_ready[cur_buf][part_idx].complete_transaction(NUM_LOCAL_ROWS_PER_WARP*4*D_V/2*sizeof(bf16)); + } + } + } + } else { + // MMA warp + if (warp_idx == 8 && elect_one_sync()) { + // S -> T copy for Q + UMMA::SmemDescriptor sQ_nope_desc = UMMA::make_umma_desc( + make_tensor( + make_smem_ptr(plan.u.q_full.q_nope.data()), + tile_to_shape( + UMMA::Layout_K_SW128_Atom{}, + Shape, Int<64>>{} // We use this shape for dual gemm (TODO Link) + ) + ) + ); + UMMA::SmemDescriptor sQ_rope_desc = UMMA::make_umma_desc( + make_tensor( + make_smem_ptr(plan.s_q_rope.q_rope.data()), + tile_to_shape( + UMMA::Layout_K_SW64_Atom{}, + Shape, Int<32>>{} + ) + ) + ); + + if constexpr (HAVE_ROPE) { + // Copy the RoPE tile: 128 rows * 32 cols (64B) (in UTCCP's view), or 64 rows * 64 cols (in our view) + plan.bar_prologue_q_rope.arrive_and_expect_tx(B_H*(D_Q-D_V)*sizeof(bf16)); + plan.bar_prologue_q_rope.wait(0); + ku::tcgen05_after_thread_sync(); + CUTE_UNROLL + for (int subtile_idx = 0; subtile_idx < 2; ++subtile_idx) { + // A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view) + SM100_UTCCP_128dp256bit_1cta::copy( + sQ_rope_desc + (subtile_idx*32) / 16, + tmem_cols::Q_RoPE + subtile_idx*8 + ); + } + ku::umma_arrive_noelect(plan.bar_prologue_utccp_rope); + } + + plan.bar_prologue_q_nope.arrive_and_expect_tx(B_H*D_V*sizeof(bf16)); + plan.bar_prologue_q_nope.wait(0); + ku::tcgen05_after_thread_sync(); + CUTE_UNROLL + for (int tile_idx = 0; tile_idx < D_V/64/2; ++tile_idx) { + // A tile is 128 rows * 64 cols (128B) (in UTCCP's view), or 64 rows * 128 cols (in our view) + CUTE_UNROLL + for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) { + // A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view) + SM100_UTCCP_128dp256bit_1cta::copy( + sQ_nope_desc + (tile_idx*(B_H*128*2) + subtile_idx*32) / 16, // Remember that 4 LSBs are not included + tmem_cols::Q + tile_idx*32 + subtile_idx*8 + ); + } + } + ku::umma_arrive_noelect(plan.bar_prologue_utccp_nope); + + if constexpr (HAVE_ROPE) { + plan.bar_prologue_utccp_rope.wait(0); + } + + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks+1; ++k) { + if (k < num_k_blocks) { + // Pi = QKi^T + int cur_buf = k%NUM_BUFS; + Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutKNoPE_TiledMMA{}); + Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE_TiledMMA{}); + + plan.bar_p_free.wait(k&1^1); + ku::tcgen05_after_thread_sync(); + + // Wait for K (RoPE) + // P = Q(rope) @ K(rope)^T + if constexpr (HAVE_ROPE) { + plan.bar_kv_rope_ready.wait(k&1); + ku::tcgen05_after_thread_sync(); + ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true); + ku::umma_arrive_noelect(plan.bar_qk_rope_done); + } + + // Wait for K (NoPE) + if (k == 0) { + plan.bar_prologue_utccp_nope.wait(0); + } + Tensor sK_nope_divided = flat_divide(sK_nope, Tile, Int>{})(_, _, _0{}, _); + CUTE_UNROLL + for (int kv_nope_part_idx = 0; kv_nope_part_idx < 2; ++kv_nope_part_idx) { + plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].arrive_and_expect_tx(B_TOPK*D_V/2*sizeof(bf16)); + plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].wait((k/NUM_BUFS)&1); + ku::tcgen05_after_thread_sync(); + + // P += Q(nope) @ K(nope)^T + bool clear_accum = (!HAVE_ROPE) && kv_nope_part_idx == 0; + ku::utcmma_ts(tiled_mma_P, kv_nope_part_idx ? tQ_nope_part1 : tQ_nope_part0, sK_nope_divided(_, _, kv_nope_part_idx), tP, clear_accum); + } + ku::umma_arrive_noelect(plan.bar_qk_nope_done[cur_buf]); + } + if (k > 0) { + // O += S(i-1)V(i-1) + int cur_buf = (k-1)%NUM_BUFS; + + Tensor sS = make_tensor(make_smem_ptr(plan.s_q_rope.s), SmemLayoutS{}); + Tensor sV = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutV{}); + + // Wait for S(i-1) and O to be scaled + plan.bar_so_ready.wait((k-1)&1); + ku::tcgen05_after_thread_sync(); + + // O += sS @ sV + ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == 1); + ku::umma_arrive_noelect(plan.bar_sv_done[cur_buf]); + } + } + } else if (warp_idx == 9) { + // KV valid loading warp + if (lane_idx < B_TOPK/8) { + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + char k_validness_mask = load_indices_and_generate_mask( + lane_idx, + gIndices + k*B_TOPK, + params.s_kv, + k*B_TOPK, + topk_length + ); + + int cur_buf = k%NUM_BUFS; + plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1); + plan.is_k_valid[cur_buf][lane_idx] = k_validness_mask; + plan.bar_k_valid_ready[cur_buf].arrive(); + } + } + } else if (warp_idx == 10 || warp_idx == 11) { + if constexpr (HAVE_ROPE) { + int thread_idx = threadIdx.x - 10*32; + constexpr int GROUP_SIZE = 8, NUM_GROUPS = 64/GROUP_SIZE, ROWS_PER_THREAD = B_TOPK/NUM_GROUPS; + int group_idx = thread_idx / GROUP_SIZE, idx_in_group = thread_idx % GROUP_SIZE; + Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE{}); + bf16* sK_rope_base = &sK_rope(group_idx, idx_in_group*8); + CUTE_NO_UNROLL + for (int k = 0; k < num_k_blocks; ++k) { + int indices[ROWS_PER_THREAD]; + CUTE_UNROLL + for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row) + indices[local_row] = __ldg(gIndices + k*B_TOPK + group_idx + local_row*NUM_GROUPS); + plan.bar_qk_rope_done.wait(k&1^1); + CUTE_UNROLL + for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row) { + int index = indices[local_row]; + ku::cp_async_cacheglobal( + params.kv + (int64_t)index*params.stride_kv_s_kv + 512 + idx_in_group*8, + sK_rope_base + local_row*NUM_GROUPS*32, + index >= 0 && index < params.s_kv + ); // NOTE Using cp.async instead of TMA is faster here + // NOTE Here we only consider the range of `index` instead of also checking against topk_length, as it's noted that under this scenario (i.e. there exists a valid index among indices[topk_length: ] that points to a token who has NaN inside) + } + cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)&(plan.bar_kv_rope_ready)); + } + } + } + } + + +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100"); + } +#endif +} + +template +void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { + KU_ASSERT(params.h_kv == 1); + KU_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings + KU_ASSERT(params.h_q == B_H); // To save some calculation + KU_ASSERT(params.d_qk == D_QK); + static_assert(D_QK == 576 || D_QK == 512); + + auto shape_Q_nope = make_shape(params.h_q, D_V, params.s_q); + auto tma_Q_nope = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((bf16*)params.q), + make_layout( + shape_Q_nope, + make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) + ) + ), + SmemLayoutQNoPE{} + ); + + auto shape_Q_rope = make_shape(params.h_q, D_Q-D_V, params.s_q); + auto tma_Q_rope = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((bf16*)params.q + D_V), + make_layout( + shape_Q_rope, + make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) + ) + ), + SmemLayoutQRoPE{} + ); + + auto shape_O = make_shape(params.h_q, params.d_v, params.s_q); + auto tma_O = cute::make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + make_gmem_ptr((bf16*)params.out), + make_layout( + shape_O, + make_stride(params.d_v, _1{}, params.h_q*params.d_v) + ) + ), + SmemLayoutOTiles<1>{} + ); + + + CUtensorMap tensor_map_kv_nope; + { + uint64_t size[2] = {D_V, (unsigned long)params.s_kv}; + uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)}; + uint32_t box_size[2] = {64, 1}; + uint32_t elem_stride[2] = {1, 1}; + CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tensor_map_kv_nope, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + 2, + params.kv, + size, + stride, + box_size, + elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + KU_ASSERT(res == CUresult::CUDA_SUCCESS); + } + + TmaParams< + decltype(shape_Q_nope), decltype(tma_Q_nope), + decltype(shape_Q_rope), decltype(tma_Q_rope), + decltype(shape_O), decltype(tma_O) + > tma_params = { + shape_Q_nope, tma_Q_nope, + shape_Q_rope, tma_Q_rope, + shape_O, tma_O, + tensor_map_kv_nope + }; + auto kernel = &sparse_attn_fwd_kernel; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + kernel<<>>(params, tma_params); + KU_CHECK_KERNEL_LAUNCH(); +} + +} diff --git a/csrc/sm100/prefill/sparse/fwd/head64/phase1.h b/csrc/sm100/prefill/sparse/fwd/head64/phase1.h new file mode 100644 index 0000000..2962389 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd/head64/phase1.h @@ -0,0 +1,10 @@ +#pragma once + +#include "params.h" + +namespace sm100::fwd::head64 { + +template +void run_fwd_phase1_kernel(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/config.h b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/config.h new file mode 100644 index 0000000..e488007 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/config.h @@ -0,0 +1,140 @@ +#pragma once +#include "phase1.h" + +#include +#include +#include +#include + +#include "defines.h" +#include "params.h" + +namespace sm100::fwd_for_small_topk::head128 { + +using namespace cute; + +template +struct KernelTemplate { + +using ArgT = SparseFwdArgT; +static constexpr bool IS_DECODE = is_decode_v; +static constexpr bool IS_PREFILL = !IS_DECODE; +using fp8_e4m3 = cutlass::float_e4m3_t; +using fp8_e8m0 = __nv_fp8_e8m0; + +struct TmaParamsForPrefill { + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_o; +}; + +struct TmaParamsForDecode { + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_o; + CUtensorMap tensor_map_o_accum; + CUtensorMap tensor_map_kv_nope; + CUtensorMap tensor_map_kv_rope; + CUtensorMap tensor_map_extra_kv_nope; // Only available if extra_kv is enabled + CUtensorMap tensor_map_extra_kv_rope; +}; + +using TmaParams = std::conditional_t< + IS_DECODE, + TmaParamsForDecode, + TmaParamsForPrefill +>; + +static_assert(D_QK == 512); + +static constexpr int D_Q = D_QK; +static constexpr int D_K = D_QK; +static constexpr int D_V = 512; +static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan + +static constexpr int H_Q = 128; // For 2 CTAs +static constexpr int B_TOPK = 64; // For 2 CTAs +static constexpr int NUM_THREADS = 128*4; +static constexpr int NUM_WORKER_THREADS = IS_PREFILL ? (128 + 4 + (B_TOPK/8) + 1 + 128)*2 + 1 : (128 + 128 + 1 + 32 + 2 + 128)*2; + +// For non-decode mode, we have 4 (half-)KV buffers +// For decode mode, we have 3 (half-)KV buffers with two raw KV buffers +static constexpr int NUM_K_BUFS = IS_DECODE ? 3 : 4; +static constexpr int NUM_RAW_K_BUFS = IS_DECODE ? 2 : 0; +static constexpr int NUM_INDEX_BUFS = IS_DECODE ? 4 : 4; + +static constexpr int D_NOPE = 448; +static constexpr int D_ROPE = 64; +static constexpr int TMA_K_STRIDE_FOR_DECODING = D_NOPE + 2*D_ROPE; +static constexpr int NUM_SCALES_EACH_TOKEN = 8; // 7 scales + 1 padding + +static constexpr int B_EPI = 64; // Epilogue block size for normal case (i.e. prefill or non-splitkv decoding) +static constexpr int B_EPI_SPLITKV = 32; // Epilogue block size for splitkv decoding +static constexpr int NUM_EPI_SPLITKV_BUFS = 4; // The number of epilogue buffers for splitkv decoding +static_assert((H_Q/2)*D_Q*sizeof(bf16) >= NUM_EPI_SPLITKV_BUFS*(H_Q/2)*(B_EPI_SPLITKV*2)*sizeof(float)); + +// Tensor memory columns +struct tmem_cols { + // 0 ~ 256: Output accumulator + // 256 ~ 384: Q + // 384 ~ 448: P + static constexpr int O = 0; + static constexpr int Q = 256; + static constexpr int P = 384; +}; + +struct SharedMemoryPlan { + array_aligned Q; // Will be output for epilogue + array_aligned K[NUM_K_BUFS]; + array_aligned K_raw[NUM_RAW_K_BUFS]; + array_aligned S; + float P_exchange[4][(H_Q/2/2)*(B_TOPK/2)]; + float rowwise_max_buf[128], rowwise_li_buf[128]; + + CUTE_ALIGNAS(16) char is_k_valid[NUM_INDEX_BUFS][B_TOPK/8]; + CUTE_ALIGNAS(16) int tma_coord[NUM_INDEX_BUFS][B_TOPK]; + CUTE_ALIGNAS(16) fp8_e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN/2]; + + transac_bar_t bar_sQ_full, bar_tQ_empty, bar_tQ_full; + transac_bar_t bar_tOut_full, bar_tOut_empty; + transac_bar_t bar_KV_full[NUM_K_BUFS], bar_KV_empty[NUM_K_BUFS]; + transac_bar_t bar_P_empty; + transac_bar_t bar_QK_done, bar_SV_done; + transac_bar_t bar_S_O_full; + transac_bar_t bar_li_full, bar_li_empty; + + // The following barriers are prefill-only + transac_bar_t bar_clc_full, bar_clc_empty; + + // The following barriers are decode-only + transac_bar_t bar_raw_KV_full[NUM_RAW_K_BUFS], bar_raw_KV_empty[NUM_RAW_K_BUFS]; + transac_bar_t bar_valid_coord_scales_full[NUM_INDEX_BUFS], bar_valid_coord_scales_empty[NUM_INDEX_BUFS]; + + ku::CLCResponseObj clc_response_obj; + array_aligned tmem_start_addr; +}; + +using TiledMMA_P = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_TS_NOELECT{} +)); // *2 for dual gemm + +using TiledMMA_O = decltype(make_tiled_mma( + SM100_MMA_F16BF16_2x1SM_SS_NOELECT{}, + Layout>{}, + Tile, Layout, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512] +)); + +struct barrier_ids { + static constexpr int WG0_SYNC = 0; + static constexpr int WG2_SYNC = 1; + static constexpr int WG2_WARP02_SYNC = 2; + static constexpr int WG2_WARP13_SYNC = 3; +}; + +static __device__ void +sparse_attn_fwd_kernel_devfunc(const ArgT ¶ms, const TmaParams &tma_params); + +static void run(const ArgT& params); + +}; + +} diff --git a/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu new file mode 100644 index 0000000..a8b4895 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu @@ -0,0 +1,8 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm100::fwd_for_small_topk::head128 { + +template void run_fwd_for_small_topk_phase1_kernel(const SparseAttnDecodeParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu new file mode 100644 index 0000000..2f17fed --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu @@ -0,0 +1,8 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm100::fwd_for_small_topk::head128 { + +template void run_fwd_for_small_topk_phase1_kernel(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh new file mode 100644 index 0000000..388abc8 --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.cuh @@ -0,0 +1,1107 @@ +#pragma once +#include "phase1.h" + +#include +#include +#include +#include +#include + +#include "params.h" +#include "utils.h" +#include "sm100/prefill/sparse/common_subroutine.h" +#include "sm100/helpers.h" + +#include "config.h" + +namespace sm100::fwd_for_small_topk::head128 { + +using namespace cute; +using FwdMode = SparseAttnFwdMode; + +template +__device__ void +KernelTemplate::sparse_attn_fwd_kernel_devfunc(const ArgT ¶ms, const TmaParams &tma_params) { +#ifdef KERUTILS_ENABLE_SM100A + // Grid shape: [2*s_q, 1, 1] for prefilling, [2*s_q, num_sm_parts, 1] for decoding + // Cluster shape: [2, 1, 1] + const int warp_idx = cutlass::canonical_warp_idx_sync(); + const int lane_idx = threadIdx.x % 32; + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int idx_in_warpgroup = threadIdx.x % 128; + const int cta_idx = block_id_in_cluster().x; + + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &smem = *reinterpret_cast(wksp_buf); + + if (warp_idx == 0 && elect_one_sync()) { + cute::prefetch_tma_descriptor(&tma_params.tensor_map_q); + cute::prefetch_tma_descriptor(&tma_params.tensor_map_o); + if constexpr (IS_DECODE) { + cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope); + cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope); + } else { + cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv); + } + } else if (warp_idx == 1 && elect_one_sync()) { + smem.bar_sQ_full.init(1); + smem.bar_tQ_empty.init(1); + smem.bar_tQ_full.init(1); + smem.bar_tOut_full.init(1); + smem.bar_tOut_empty.init(256); + smem.bar_P_empty.init(256); + smem.bar_QK_done.init(1); + smem.bar_SV_done.init(1); + smem.bar_S_O_full.init(256); + smem.bar_li_full.init(H_Q/2); + smem.bar_li_empty.init(128); + if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) { + smem.bar_clc_full.init(1); + smem.bar_clc_empty.init(NUM_WORKER_THREADS); + } + fence_barrier_init(); + } else if (warp_idx == 2) { + cute::TMEM::Allocator2Sm().allocate(512, smem.tmem_start_addr.data()); + KU_TRAP_ONLY_DEVICE_ASSERT(smem.tmem_start_addr.data()[0] == 0); + cute::TMEM::Allocator2Sm().release_allocation_lock(); + } else if (warp_idx == 3 && elect_one_sync()) { + CUTE_UNROLL + for (int i = 0; i < NUM_K_BUFS; ++i) { + smem.bar_KV_full[i].init(IS_PREFILL ? 1 : (128/32)*2+1); + smem.bar_KV_empty[i].init(1); + } + CUTE_UNROLL + for (int i = 0; i < NUM_INDEX_BUFS; ++i) { + smem.bar_valid_coord_scales_full[i].init(IS_PREFILL ? B_TOPK/8 : 32); + smem.bar_valid_coord_scales_empty[i].init(IS_PREFILL ? 128 : (128 + (cta_idx==1) + 2 + 128)); + } + if constexpr (IS_DECODE) { + CUTE_UNROLL + for (int i = 0; i < NUM_RAW_K_BUFS; ++i) { + smem.bar_raw_KV_full[i].init(1); + smem.bar_raw_KV_empty[i].init(128); + } + } + fence_barrier_init(); + } + + ku::barrier_cluster_arrive_relaxed(); + ku::barrier_cluster_wait_acquire(); + + struct OuterloopArgs { + bool outer_loop_phase; + int batch_idx, s_q_idx; + int start_block_idx, end_block_idx; + int topk_length; + + int extra_topk_length, num_orig_kv_blocks; // extra-KV related + bool is_no_split; int n_split_idx; // splitkv related + }; + + auto run_outer_loop = [&](auto loop_body) -> bool { + int outer_loop_phase = false; + if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) { + int s_q_idx = blockIdx.x / 2; + DecodingSchedMeta sched_meta; + KU_LDG_256( + params.tile_scheduler_metadata_ptr + blockIdx.y, + &sched_meta, + ".nc", + "no_allocate", + "evict_normal", + "256B" + ); + if (sched_meta.begin_req_idx >= params.b) { + return 0; + } + + #pragma unroll 1 + for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { + int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk; + int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK); + int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk; + int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0 + int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; + int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK; + bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false); + int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx); + + // start_block_idx = 0; + // end_block_idx = total_topk_padded / B_TOPK; + // is_split = false; + // n_split_idx = 0; + + OuterloopArgs args = { + (bool)outer_loop_phase, + batch_idx, s_q_idx, + start_block_idx, end_block_idx, + topk_length, + + extra_topk_length, orig_topk_padded / B_TOPK, + !is_split, n_split_idx + }; + + loop_body(args); + outer_loop_phase ^= 1; + } + } else { + // Prefill mode. Use CLC to allocate different s_q (for decoding, different batches + s_q) to different workers + ku::CLCResult next_job = {true, (int)blockIdx.x, IS_PREFILL ? 0 : (int)blockIdx.y, 0}; + CUTE_NO_UNROLL + while (next_job.is_valid) { + int s_q_idx = next_job.x / 2; + int batch_idx = IS_PREFILL ? 0 : next_job.y; + int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + (IS_PREFILL?s_q_idx:batch_idx)) : params.topk; + + if constexpr (IS_PREFILL) { + int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1 + OuterloopArgs args = { + (bool)outer_loop_phase, + 0, s_q_idx, + 0, num_k_blocks, + topk_length + }; + loop_body(args); + } else { + int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK); + int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk; + int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0 + + OuterloopArgs args = { + (bool)outer_loop_phase, + batch_idx, s_q_idx, + 0, total_topk_padded / B_TOPK, + topk_length, + + extra_topk_length, orig_topk_padded / B_TOPK, + false, 0 + }; + loop_body(args); + } + + smem.bar_clc_full.wait(outer_loop_phase); + next_job = ku::get_clc_query_response(smem.clc_response_obj); + smem.bar_clc_empty.arrive(0u); + + outer_loop_phase ^= 1; + } + } + return outer_loop_phase; + }; + + if (warpgroup_idx == 0) { + // Q fetching and O writing back warpgroup + cutlass::arch::warpgroup_reg_alloc<176>(); + + bf16* sO_addrs[B_EPI/8]; + CUTE_UNROLL + for (int i = 0; i < B_EPI/8; ++i) { + Tensor sO = make_tensor(make_smem_ptr(smem.Q.data()), ku::make_umma_canonical_k_major_layout()); + sO_addrs[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*(D_V/2) + i*8); + } + + float* sO_accum_addrs[B_EPI_SPLITKV/4]; + if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) { + // If split-KV is enabled, we need to store back O in float32 + // We view Q buffer (with shape 64 x 512, bf16) as 4 buffers with shape (H_Q/2) x (B_EPI_SPLITKV*2), float32 + Tensor sO_accum = make_tensor(make_smem_ptr((float*)smem.Q.data()), ku::make_umma_canonical_k_major_layout()); + CUTE_UNROLL + for (int i = 0; i < B_EPI_SPLITKV/4; ++i) { + sO_accum_addrs[i] = &sO_accum(idx_in_warpgroup%64, i*4) + (idx_in_warpgroup >= 64 ? (H_Q/2)*B_EPI_SPLITKV : 0); + } + } + + auto perform_o_copy_out = [&](const OuterloopArgs &args, bool is_last_o) { + // outer_loop_phase is the loop phase corresponding to s_q_idx + + // Get li (output_scale actually) + smem.bar_li_full.wait(args.outer_loop_phase); + float output_scale = smem.rowwise_li_buf[idx_in_warpgroup%64]; + float2 output_scale_float2 = float2 {output_scale, output_scale}; + smem.bar_li_empty.arrive(); + + // Retrieve and store O, and calculate delta := sum(O*dO, dim=-1) if FWD_MODE is Recompute + smem.bar_tOut_full.wait(args.outer_loop_phase); + if (is_last_o && elect_one_sync()) { + cudaTriggerProgrammaticLaunchCompletion(); + } + + if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) { + CUTE_UNROLL + for (int k = 0; k < (D_V/2)/B_EPI; ++k) { + float2 o[B_EPI/2]; + ku::tmem_ld_32dp32bNx(tmem_cols::O + k*B_EPI, o); + cutlass::arch::fence_view_async_tmem_load(); + if (k == (D_V/2)/B_EPI-1) { + smem.bar_tOut_empty.arrive(0u); + } + CUTE_UNROLL + for (int i = 0; i < B_EPI/8; ++i) { + nv_bfloat162 o_bf16[4]; + CUTE_UNROLL + for (int j = 0; j < 4; ++j) { + o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2); + o_bf16[j] = __float22bfloat162_rn(o[i*4+j]); + } + bf16* o_do_addr = sO_addrs[i] + k*B_EPI*(H_Q/2); + if (k == 0 && i == 0) { + smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o); // Wait for sQ's availability + } + ku::st_shared(o_do_addr, *(__int128_t*)o_bf16); + } + } + + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); + if (warp_idx == 0 && elect_one_sync()) { + SM90_TMA_STORE_5D::copy( + &tma_params.tensor_map_o, + smem.Q.data(), + 0, cta_idx*(H_Q/2), 0, args.s_q_idx, IS_DECODE ? args.batch_idx : 0 + ); + cute::tma_store_arrive(); + } + } else { + CUTE_UNROLL + for (int k = 0; k < (D_V/2)/B_EPI_SPLITKV; ++k) { + int cur_buf_idx = k % NUM_EPI_SPLITKV_BUFS; + if (k == 0) { + cute::tma_store_wait<0>(); + } else { + cute::tma_store_wait(); + } + NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); + + float o[B_EPI_SPLITKV]; + ku::tmem_ld_32dp32bNx(tmem_cols::O + k*B_EPI_SPLITKV, o); + cutlass::arch::fence_view_async_tmem_load(); + if (k == (D_V/2)/B_EPI_SPLITKV-1) { + smem.bar_tOut_empty.arrive(0u); + } + CUTE_UNROLL + for (int i = 0; i < B_EPI_SPLITKV/4; ++i) { + CUTE_UNROLL + for (int j = 0; j < 4; j += 2) { + *(float2*)(o + i*4 + j) = ku::float2_mul(float2 {o[i*4+j], o[i*4+j+1]}, output_scale_float2); + } + if (k == 0 && i == 0) { + smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o); // Wait for sQ's availability + } + ku::st_shared( + sO_accum_addrs[i] + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2), + *(__int128_t*)(o + i*4) + ); + } + + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); + if constexpr (IS_DECODE) { // Otherwise nvcc complains about `tma_params` doesn't have `tensor_map_o_accum` + float* cur_buf_base = (float*)smem.Q.data() + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2); + if (warp_idx == 0 && elect_one_sync()) { + SM90_TMA_STORE_5D::copy( + &tma_params.tensor_map_o_accum, + cur_buf_base, + 0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32), args.s_q_idx, args.n_split_idx + ); + cute::tma_store_arrive(); + } else if (warp_idx == 1 && elect_one_sync()) { + SM90_TMA_STORE_5D::copy( + &tma_params.tensor_map_o_accum, + cur_buf_base + (H_Q/2)*B_EPI_SPLITKV, + 0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32) + (D_V/2)/32, args.s_q_idx, args.n_split_idx + ); + cute::tma_store_arrive(); + } + } + } + } + }; + + OuterloopArgs last_args; + last_args.batch_idx = -1; + + bool final_outer_loop_phase = \ + run_outer_loop([&](const OuterloopArgs &args) { + // Copy Q for this round + if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) { + cute::tma_store_wait<0>(); + NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); // Since we use two warps to issue TMA during FwdMode::DecodeWithSplitKV + } + if (warp_idx == 0 && elect_one_sync()) { + // Wait for sQ to become empty, and issue G -> S copy for Q + if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) { + cute::tma_store_wait<0>(); // This thread must be the same one as o copy out thread (since `elect_one_sync()` always returns the same thread for the same `mask`, according to PTX document) + } + int stride_q_b_div_stride_q_s_q = 0; + if constexpr (IS_DECODE) { + stride_q_b_div_stride_q_s_q = params.stride_q_b / params.stride_q_s_q; + } + SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy( + &tma_params.tensor_map_q, + (uint64_t*)&smem.bar_sQ_full, + (uint64_t)TMA::CacheHintSm90::EVICT_FIRST, + smem.Q.data(), + 0, cta_idx*(H_Q/2), 0, 0, (IS_DECODE ? args.batch_idx*stride_q_b_div_stride_q_s_q : 0) + args.s_q_idx + ); + + // Wait for sQ to be ready, and issue S -> T copy for Q + if (cta_idx == 0) { + smem.bar_sQ_full.arrive_and_expect_tx(H_Q*D_Q*sizeof(bf16)); + smem.bar_sQ_full.wait(args.outer_loop_phase); + + smem.bar_tQ_empty.wait(args.outer_loop_phase^1); + ku::tcgen05_after_thread_sync(); + UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc( + make_tensor( + make_smem_ptr(smem.Q.data()), + ku::make_umma_canonical_k_major_layout<(H_Q/2)*2, 64, 128>() + ) + ); + CUTE_UNROLL + for (int tile_idx = 0; tile_idx < D_Q/64/2; ++tile_idx) { + // A tile is 128 rows * 64 cols in UTCCP's view, or 64 rows * 128 cols in our view + CUTE_UNROLL + for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) { + // A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view) + // NOTE Using `sQ_desc+((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4)` leads to IMA, doesn't know why + UMMA::SmemDescriptor cur_sQ_desc = sQ_desc; + cur_sQ_desc.lo += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4); + // uint64_t cur_sQ_desc = sQ_desc; + // cur_sQ_desc += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4); + SM100_UTCCP_128dp256bit_2cta::copy( + cur_sQ_desc, + tmem_cols::Q + tile_idx*32 + subtile_idx*8 + ); + } + } + ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tQ_full, 1|2); + } + } + + if (last_args.batch_idx != -1) { + perform_o_copy_out(last_args, false); + } else { + smem.bar_tQ_full.wait(args.outer_loop_phase); // To prevent double arrive + } + last_args = args; + }); + if (last_args.batch_idx != -1) { + cute::tma_store_wait<0>(); + NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); + perform_o_copy_out(last_args, true); + } + + if (warp_idx == 0) { + cute::TMEM::Allocator2Sm().free(0, 512); + } + } else if (warpgroup_idx == 1) { + // KV fetching threads for prefill, dequant threads for decoding + cutlass::arch::warpgroup_reg_dealloc<80>(); + RingBufferState rs; + + if constexpr (!IS_DECODE) { + const int warp_idx = cutlass::canonical_warp_idx(); // Using `warp_idx` without `__shfl_sync` is faster + if (elect_one_sync()) { + // KV fetching threads + run_outer_loop([&](const OuterloopArgs &args) { + int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q; + int64_t cache_hint = ku::create_simple_cache_policy(); + + static constexpr int NUM_ROWS_PER_THREAD = B_TOPK / 4; + + CUTE_NO_UNROLL + for (int k = args.start_block_idx; k < args.end_block_idx; ++k) { + auto [k_buf_idx, k_bar_phase] = rs.get(); + + int cur_indices[NUM_ROWS_PER_THREAD]; + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/8; local_row += 1) { + int row = local_row*(4*8) + (warp_idx-4)*8; + KU_LDG_256( + gIndices + k*B_TOPK + row, + cur_indices + local_row*8, + ".nc", + "no_allocate", + "evict_first", + "256B" + ); + } + smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1); + + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/4; local_row += 1) { + int row = (warp_idx-4)*8 + (local_row/2)*(4*8) + (local_row%2)*4; + int4 indices = *(int4*)(cur_indices+local_row*4); + static_assert(D_K == 512); + CUTE_UNROLL + for (int local_col = 0; local_col < (D_K/64)/2; ++local_col) { + ku::tma_gather4_cta_group_2( + &tma_params.tensor_map_kv, + smem.bar_KV_full[k_buf_idx], + smem.K[k_buf_idx].data() + row*64 + local_col*64*B_TOPK, + local_col*64 + cta_idx*(D_K/2), + indices, + cache_hint + ); + } + } + rs.update(); + } + }); + } + + } else { + // 8 threads per token + struct IsCTA0 {}; + struct IsCTA1 {}; + + auto launch_dequant_wg = [&](auto cta_id_t) { + static constexpr bool IS_CTA1 = std::is_same::value; + constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = (IS_CTA1 ? 256-64 : 256) / (GROUP_SIZE*8); + int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE; + Tensor nope0 = make_tensor(make_smem_ptr(smem.K[0].data()), ku::make_umma_canonical_k_major_layout()); + bf16* nope0_base = &nope0(group_idx, idx_in_group*8); + fp8_e4m3* raw_nope0_base = smem.K_raw[0].data() + group_idx*(D_K/2) + idx_in_group*8; + run_outer_loop([&](const OuterloopArgs &args) { + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { + auto [k_buf_idx, k_bar_phase] = rs.get(); + auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get(); + auto [index_buf_idx, index_bar_phase] = rs.get(); + fp8_e4m3* raw_nope_base = raw_nope0_base + raw_k_buf_idx * (B_TOPK*(D_K/2)); + auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t { + return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*(D_K/2) + local_col_idx*(GROUP_SIZE*8)); + }; + bf16* nope_base = nope0_base + k_buf_idx * (B_TOPK*(D_K/2)); + uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(nope_base); + auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) { + asm volatile ("st.weak.shared::cta.b128 [%0], %1;\n" + : + : "r"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), "q"(data) // 2 for sizeof(bf16) + ); // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS + }; + + smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase); + smem.bar_raw_KV_full[raw_k_buf_idx].wait(raw_k_bar_phase); + + CUTE_UNROLL + for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) { + int row_idx = local_row_idx*NUM_GROUPS + group_idx; + bf16 scales[4]; + fp8_e8m0 scales_e8m0[4]; + *(uint32_t*)scales_e8m0 = *(uint32_t*)(smem.scales[index_buf_idx][row_idx]); + *(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0)); + *(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2)); + + uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0); + CUTE_UNROLL + for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) { + ku::nve4m3x2 data_fp8[4]; + ku::nvbf16x2 data_bf16[4]; + *(uint64_t*)data_fp8 = cur_data_fp8x8; + if (local_col_idx+1 < COLS_PER_GROUP) + cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1); + bf16 scale = scales[local_col_idx]; + CUTE_UNROLL + for (int i = 0; i < 4; ++i) { + data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale)); + } + if (local_row_idx == 0 && local_col_idx == 0) { + smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1); + } + st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16); + } + } + + fence_view_async_shared(); // NOTE Should we use shared::cluster here? + __syncwarp(); + smem.bar_valid_coord_scales_empty[index_buf_idx].arrive(); + smem.bar_raw_KV_empty[raw_k_buf_idx].arrive(); + if (elect_one_sync()) { + smem.bar_KV_full[k_buf_idx].arrive(0u); + } + rs.update(); + } + }); + }; + if (cta_idx == 0) { + launch_dequant_wg(IsCTA0{}); + } else { + launch_dequant_wg(IsCTA1{}); + } + } + } else if (warpgroup_idx == 2) { + cutlass::arch::warpgroup_reg_dealloc<80>(); + + RingBufferState rs; + if (warp_idx == 8 && cta_idx == 0 && elect_one_sync()) { + // UMMA thread + TiledMMA tiled_mma_P = TiledMMA_P{}; + TiledMMA tiled_mma_O = TiledMMA_O{}; + Tensor tP = partition_fragment_C(tiled_mma_P, Shape, Int>{}); + Tensor tO = partition_fragment_C(tiled_mma_O, Shape, Int>{}); + Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A( + partition_shape_A(tiled_mma_P, Shape, Int>{}) + ); + tP.data().get() = tmem_cols::P; + tO.data().get() = tmem_cols::O; + tQ.data().get() = tmem_cols::Q; + + run_outer_loop([&](const OuterloopArgs &args) { + smem.bar_tQ_full.wait(args.outer_loop_phase); + + // Issue P = Q K^T + auto issue_P = [&](int k, int rs_offset) { + auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get(); + auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>(); + smem.bar_P_empty.wait(bar_phase^1); + if constexpr (IS_PREFILL) { + smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_K*sizeof(bf16)); + } else { + // RoPE only + smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16)); + } + smem.bar_KV_full[k_buf_idx].wait(k_bar_phase); + ku::tcgen05_after_thread_sync(); + Tensor sK = make_tensor( + make_smem_ptr(smem.K[k_buf_idx].data()), + ku::make_umma_canonical_k_major_layout() + ); + ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true); + ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_QK_done, 1|2); + }; + + // Issue O += S V + auto issue_O = [&](int k, int rs_offset) { + auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get(); + auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>(); + smem.bar_S_O_full.wait(bar_phase); + if (k == args.start_block_idx) { + smem.bar_tOut_empty.wait(args.outer_loop_phase^1); + } + ku::tcgen05_after_thread_sync(); + Tensor sS = make_tensor( + make_smem_ptr(smem.S.data()), + ku::make_umma_canonical_k_major_layout() + ); + Tensor sV = make_tensor( + make_smem_ptr(smem.K[k_buf_idx].data()), + ku::make_umma_canonical_mn_major_layout() + ); + ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == args.start_block_idx); + ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_SV_done, 1|2); + ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_KV_empty[k_buf_idx], 1|2); + }; + + CUTE_NO_UNROLL + for (int k = args.start_block_idx; k < args.end_block_idx+1; ++k) { + if (k < args.end_block_idx) { + issue_P(k, 0); + } + if (k == args.end_block_idx-1) { + ku::umma_arrive_2x1SM_noelect(smem.bar_tQ_empty); + } + + if (k > args.start_block_idx) { + issue_O(k-1, -1); + } + + if (k != args.end_block_idx) { + rs.update(); + } + } + ku::tcgen05_before_thread_sync(); + ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tOut_full, 1|2); + }); + } else if (warp_idx == 8 && cta_idx == 1 && elect_one_sync()) { + if constexpr (IS_DECODE) { + // KV RoPE fetching warp + run_outer_loop([&](const OuterloopArgs &args) { + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { + auto [index_buf_idx, index_bar_phase] = rs.get(); + auto [k_buf_idx, k_bar_phase] = rs.get(); + smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase); + smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1); + CUTE_UNROLL + for (int row = 0; row < B_TOPK; row += 4) { + int4 cur_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row); + ku::tma_gather4_cta_group_2( + block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope, + smem.bar_KV_full[k_buf_idx], + smem.K[k_buf_idx].data() + (D_NOPE-D_K/2)*B_TOPK + row*D_ROPE, + 0, + cur_indices, + (int64_t)TMA::CacheHintSm90::EVICT_LAST + ); + } + smem.bar_valid_coord_scales_empty[index_buf_idx].arrive(); + rs.update(); + } + }); + } + } else if (warp_idx == 9) { + // KV validness loading warp (for prefill), Indices transformation warp (for decode, Responsible for generating: TMA coordinates, scale factors, and valid masks) + if constexpr (IS_PREFILL) { + if (lane_idx < B_TOPK/8) { + run_outer_loop([&](const OuterloopArgs &args) { + int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q; + CUTE_NO_UNROLL + for (int k = args.start_block_idx; k < args.end_block_idx; ++k) { + char k_validness_mask = load_indices_and_generate_mask( + lane_idx, + gIndices + k*B_TOPK, + params.s_kv, + k*B_TOPK, + args.topk_length + ); + + auto [indices_buf_idx, indices_bar_phase] = rs.get(); + smem.bar_valid_coord_scales_empty[indices_buf_idx].wait(indices_bar_phase^1); + smem.is_k_valid[indices_buf_idx][lane_idx] = k_validness_mask; + smem.bar_valid_coord_scales_full[indices_buf_idx].arrive(); + + rs.update(); + } + }); + } + } else { + static_assert(B_TOPK == 64); + // Each thread is responsible for 2 tokens + static constexpr int tma_coords_step_per_token = 576/TMA_K_STRIDE_FOR_DECODING; + int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE_FOR_DECODING; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE_FOR_DECODING > 512 + int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE_FOR_DECODING; + uint8_t* k_scales_ptr = (uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE); + uint8_t* extra_k_scales_ptr = (uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE); + + run_outer_loop([&](const OuterloopArgs &args) { + int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*args.s_q_idx; + int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*args.s_q_idx; + + struct IsOrigBlock {}; + struct IsExtraBlock {}; + auto process_one_block = [&](int block_idx, auto is_extra_block_t) { + auto [index_buf_idx, index_bar_phase] = rs.get(); + static constexpr bool IS_EXTRA_BLOCK = std::is_same_v; + int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size; + int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block; + [[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row; + uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr; + int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block; + + int abs_pos, my_indices[2]; + if (!IS_EXTRA_BLOCK) { + abs_pos = block_idx*B_TOPK + lane_idx*2; + *(int2*)my_indices = __ldg((int2*)(indices + abs_pos)); + } else { + abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2; + *(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos)); + } + smem.bar_valid_coord_scales_empty[index_buf_idx].wait(index_bar_phase^1); + + int tma_coords[2]; + fp8_e8m0 scales[2*(NUM_SCALES_EACH_TOKEN/2)]; + char valid_mask = 0; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + int block_idx, idx_in_block; + block_idx = (unsigned int)my_indices[i] / cur_block_size; + idx_in_block = (unsigned int)my_indices[i] % cur_block_size; + bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length)); + valid_mask |= is_token_valid << i; + tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN. + + int64_t offset = block_idx*cur_k_block_stride + (idx_in_block*8 + (cta_idx == 1 ? 4 : 0)); // Each token has 7 scale factors with an extra 1B padding + uint32_t scalesx4 = is_token_valid ? __ldg((uint32_t*)(cur_k_scales_ptr + offset)) : 0; + *(uint32_t*)(scales+i*(NUM_SCALES_EACH_TOKEN/2)) = scalesx4; + } + valid_mask <<= lane_idx%4*2; + valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1); + valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2); + *(uint64_t*)(smem.scales[index_buf_idx] + lane_idx*2) = *(uint64_t*)scales; + *(int2*)(smem.tma_coord[index_buf_idx] + lane_idx*2) = *(int2*)tma_coords; + if (lane_idx%4 == 0) + smem.is_k_valid[index_buf_idx][lane_idx/4] = valid_mask; + + smem.bar_valid_coord_scales_full[index_buf_idx].arrive(); + rs.update(); + }; + + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) { + process_one_block(block_idx, IsOrigBlock{}); + } + + CUTE_NO_UNROLL + for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) { + process_one_block(block_idx, IsExtraBlock{}); + } + }); + } + } else if (warp_idx >= 10 && elect_one_sync()) { + if constexpr (IS_PREFILL) { + if (warp_idx == 10) { + // CLC Producer thread + run_outer_loop([&](const OuterloopArgs &args) { + if (cta_idx == 0) { + smem.bar_clc_empty.wait(args.outer_loop_phase^1); + ku::issue_clc_query_multicast_cluster_all(smem.bar_clc_full, smem.clc_response_obj); + } + smem.bar_clc_full.arrive_and_expect_tx(sizeof(smem.clc_response_obj)); + }); + } + } else { + // Raw KV NoPE Producer thread + run_outer_loop([&](const OuterloopArgs &args) { + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { + auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get(); + auto [index_buf_idx, index_bar_phase] = rs.get(); + smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase); + smem.bar_raw_KV_empty[raw_k_buf_idx].wait(raw_k_bar_phase^1); + + int4 nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + (warp_idx == 10 ? 0 : 4)); + CUTE_UNROLL + for (int row = (warp_idx == 10 ? 0 : 4); row < B_TOPK; row += 8) { + int4 cur_indices = nxt_indices; + if (row+8 < B_TOPK) + nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row + 8); + ku::tma_gather4( + block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope, + smem.bar_raw_KV_full[raw_k_buf_idx], + smem.K_raw[raw_k_buf_idx].data() + row*(D_K/2), + cta_idx*(D_K/2), + cur_indices, + (int64_t)TMA::CacheHintSm90::EVICT_LAST + ); + } + if (warp_idx == 10) { + smem.bar_raw_KV_full[raw_k_buf_idx].arrive_and_expect_tx(B_TOPK*(D_K/2)*sizeof(fp8_e4m3)); + } + smem.bar_valid_coord_scales_empty[index_buf_idx].arrive(); + rs.update(); + } + }); + } + } + } else { + // Scale & Exp threads + cutlass::arch::warpgroup_reg_alloc<176>(); + + int local_warp_idx = warp_idx - 12; + bf16* sS_base = smem.S.data() + (local_warp_idx >= 2 ? (H_Q/2)*(B_TOPK/2) : 0) + (idx_in_warpgroup%64)*8; + + RingBufferState rs; + run_outer_loop([&](const OuterloopArgs &args) { + // For definition and consistency about `mi`, `li`, and `real_mi`, plz refer to head64 prefill + float mi = MAX_INIT_VAL; + float li = 0.0f; + float real_mi = -CUDART_INF_F; + static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2; + + CUTE_NO_UNROLL + for (int k = args.start_block_idx; k < args.end_block_idx; ++k) { + auto [k_buf_idx, k_bar_phase] = rs.get(); + auto [indices_buf_idx, indices_bar_phase] = rs.get(); + auto [_, bar_phase] = rs.get<1>(); + // NOTE We don't need to sync for Prefill mode, since we have two synchronizations inside the loop body (one for p_exchange_buf sync, another one for rowwise_max_buf sync). The latter one guarantees the emptyness of p_exchange_buf and the former one guarantees the emptyness of rowwise_max_buf + smem.bar_valid_coord_scales_full[indices_buf_idx].wait(indices_bar_phase); + + // Get P from TMEM + float p[NUM_ELEMS_PER_THREAD]; + smem.bar_QK_done.wait(bar_phase); + ku::tcgen05_after_thread_sync(); + retrieve_mask_and_reduce_p< + NUM_ELEMS_PER_THREAD, + tmem_cols::P, + barrier_ids::WG2_WARP02_SYNC, + barrier_ids::WG2_WARP13_SYNC, + false + >( + smem.is_k_valid[indices_buf_idx], + local_warp_idx, + lane_idx, + [&]() {smem.bar_P_empty.arrive(0u);}, + smem.P_exchange, + p + ); + + // Get rowwise max of P + float cur_pi_max = get_max(p); + cur_pi_max *= params.sm_scale_div_log2; + + smem.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max; + NamedBarrier::arrive_and_wait(64, barrier_ids::WG2_WARP02_SYNC + (local_warp_idx&1)); + cur_pi_max = max(cur_pi_max, smem.rowwise_max_buf[idx_in_warpgroup^64]); + real_mi = max(real_mi, cur_pi_max); + bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f); + + + // Calc scale factor, and scale li + float new_max, scale_for_old; + if (!should_scale_o) { + // Don't scale O + scale_for_old = 1.0f; + new_max = mi; + } else { + new_max = max(cur_pi_max, mi); + scale_for_old = exp2f(mi - new_max); + } + mi = new_max; // mi is still identical within each row + + // Calculate S + nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2]; + float cur_sum = get_s_from_p(s, p, params.sm_scale_div_log2, new_max); + li = fmaf(li, scale_for_old, cur_sum); + + // Store S + smem.bar_SV_done.wait(bar_phase^1); + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; ++i) { + ku::st_shared(sS_base + i*8*(H_Q/2), *(__int128_t*)(s + i*4)); + } + + // Rescale O + if (k > 0 && should_scale_o) { + ku::tcgen05_after_thread_sync(); + rescale_O(scale_for_old); + ku::tcgen05_before_thread_sync(); + } + + fence_view_async_shared(); + smem.bar_S_O_full.arrive(0u); + smem.bar_valid_coord_scales_empty[indices_buf_idx].arrive(); + + rs.update(); + } + + if (real_mi == -CUDART_INF_F) { + // real_mi == -CUDART_INF_F <=> No valid TopK indices + // We set li to 0 to fit the definition that li := exp(x[i] - mi) + li = 0.0f; + mi = -CUDART_INF_F; + } + + // Reduce li + smem.bar_li_empty.wait(args.outer_loop_phase^1); + smem.rowwise_li_buf[idx_in_warpgroup^64] = li; + NamedBarrier::arrive_and_wait(128, barrier_ids::WG2_SYNC); + li += smem.rowwise_li_buf[idx_in_warpgroup]; + + if (idx_in_warpgroup < H_Q/2) { + // Calculate output_scale and save + int head_idx = cta_idx*(H_Q/2) + idx_in_warpgroup; + float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + head_idx); + float output_scale; + if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) { + output_scale = __fdividef(1.0f, li + exp2f(fmaf(attn_sink, CUDART_L2E_F, -mi))); + } else { + output_scale = __fdividef(1.0f, li); + } + smem.rowwise_li_buf[idx_in_warpgroup] = li == 0.0f ? 0.0f : output_scale; + smem.bar_li_full.arrive(); + + float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li)); + cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse; + if constexpr (IS_PREFILL) { + int global_index = args.s_q_idx*params.h_q + head_idx; + params.max_logits[global_index] = real_mi*CUDART_LN2_F; + params.lse[global_index] = cur_lse; + } else { + if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) { + params.lse[args.batch_idx*params.stride_lse_b + args.s_q_idx*params.stride_lse_s_q + head_idx] = cur_lse; + } else { + float cur_lse_2base = log2f(li) + mi; + params.lse_accum[args.n_split_idx*params.stride_lse_accum_split + args.s_q_idx*params.stride_lse_accum_s_q + head_idx] = cur_lse_2base; + } + } + + } + }); + } + + ku::barrier_cluster_arrive_relaxed(); + ku::barrier_cluster_wait_acquire(); + +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100"); + } +#endif +} + +// We have two launchers with different kernel names to distinguish prefill and decode + +template +static __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2) +sparse_attn_fwd_for_small_topk_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) { + Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params); +} + +template +static __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2) +flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) { + Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params); +} + +template +void KernelTemplate::run(const ArgT& params) { + static_assert(D_QK == 576 || D_QK == 512); + + KU_ASSERT(params.h_kv == 1); + KU_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings + KU_ASSERT(params.h_q == H_Q); // To save some calculation + KU_ASSERT(params.d_qk == D_QK); + + static_assert(D_Q == 512); + CUtensorMap tensor_map_q; + if constexpr (IS_DECODE) { + KU_ASSERT(params.stride_q_b % params.stride_q_s_q == 0, "In decode mode for MODEL1 sparse fp8 decoding on sm100f, q.stride(0) (on the batch dimension) must be divisible by q.stride(1) (on the sequence dimension)."); + tensor_map_q = ku::make_tensor_map( + {64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.b * (params.stride_q_b / params.stride_q_s_q)}, + ku::make_stride_helper({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)), + {64, H_Q/2, 2, (D_Q/64)/2, 1}, + params.q, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CU_TENSOR_MAP_SWIZZLE_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + ); + } else { + tensor_map_q = ku::make_tensor_map( + {64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.s_q}, + ku::make_stride_helper({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)), + {64, H_Q/2, 2, (D_Q/64)/2, 1}, + params.q, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CU_TENSOR_MAP_SWIZZLE_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + ); // We use this layout to group Q[0:64] and Q[256:256+64] together, for UTCCP for dual gemm + } + + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_nope, tensor_map_kv_rope, tensor_map_extra_kv_nope = {}, tensor_map_extra_kv_rope = {}; + if constexpr (IS_DECODE) { + auto get_kv_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t stride_kv_block, int64_t stride_kv_row) -> std::pair { + KU_ASSERT((int64_t)k_ptr % 16 == 0, "The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f", is_extra?"extra_":"", k_ptr); + KU_ASSERT(stride_kv_block % TMA_K_STRIDE_FOR_DECODING == 0, "%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary", is_extra?"extra_":"", stride_kv_block, TMA_K_STRIDE_FOR_DECODING); + CUtensorMap tensor_map_kv_nope = ku::make_tensor_map( + {D_NOPE + D_ROPE*2, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)}, + {TMA_K_STRIDE_FOR_DECODING}, + {D_K/2, 1}, + k_ptr, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B + ); // NOTE: Here we use `D_NOPE+D_ROPE*2` as the box shape instead of D_NOPE because it's actually faster. I think that's because, if we use `D_NOPE+D_ROPE*2`, we can prefetch part of the RoPE part of the selected tokens. + CUtensorMap tensor_map_kv_rope = ku::make_tensor_map( + {D_ROPE, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)}, + {TMA_K_STRIDE_FOR_DECODING}, + {64, 1}, + (uint8_t*)k_ptr + D_NOPE, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B + ); + return {tensor_map_kv_nope, tensor_map_kv_rope}; + }; + std::tie(tensor_map_kv_nope, tensor_map_kv_rope) = get_kv_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block, params.stride_kv_row); + if (params.extra_topk > 0) + std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_kv_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block, params.stride_extra_kv_row); + } else { + tensor_map_kv = ku::make_tensor_map( + {D_QK, (unsigned long)params.s_kv}, + {(unsigned long)params.stride_kv_s_kv*sizeof(bf16)}, + {64, 1}, + params.kv, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CU_TENSOR_MAP_SWIZZLE_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + ); + } + + CUtensorMap tensor_map_o; + if constexpr (IS_DECODE) { + tensor_map_o = ku::make_tensor_map( + {64, H_Q, D_V/64, (unsigned long)params.s_q, (unsigned long)params.b}, + ku::make_stride_helper({params.stride_o_h_q, 64, params.stride_o_s_q, params.stride_o_b}, sizeof(bf16)), + {64, H_Q/2, D_V/64, 1, 1}, + params.out, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CU_TENSOR_MAP_SWIZZLE_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + ); + } else { + tensor_map_o = ku::make_tensor_map( + {64, H_Q, D_V/64, (unsigned long)params.s_q, 1ul}, + ku::make_stride_helper({D_V, 64, H_Q*D_V, H_Q*D_V}, sizeof(bf16)), + {64, H_Q/2, D_V/64, 1, 1}, + params.out, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + CU_TENSOR_MAP_SWIZZLE_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + ); + } + + + CUtensorMap tensor_map_o_accum = {}; + if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) { + tensor_map_o_accum = ku::make_tensor_map( + {32, H_Q, D_V/32, (unsigned long)params.s_q, (unsigned long)params.num_sm_parts + params.b}, + ku::make_stride_helper({params.stride_o_accum_h_q, 32, params.stride_o_accum_s_q, params.stride_o_accum_split}, sizeof(float)), + {32, H_Q/2, B_EPI_SPLITKV/32, 1, 1}, + params.o_accum, + CU_TENSOR_MAP_DATA_TYPE_FLOAT32, + CU_TENSOR_MAP_SWIZZLE_128B, + CU_TENSOR_MAP_L2_PROMOTION_L2_256B + ); + } + + TmaParams tma_params; + if constexpr (IS_DECODE) { + tma_params = { + tensor_map_q, + tensor_map_o, + tensor_map_o_accum, + tensor_map_kv_nope, + tensor_map_kv_rope, + tensor_map_extra_kv_nope, + tensor_map_extra_kv_rope + }; + } else { + tma_params = { + tensor_map_q, + tensor_map_kv, + tensor_map_o + }; + } + + auto kernel = IS_PREFILL ? &sparse_attn_fwd_for_small_topk_kernel> : &flash_fwd_splitkv_mla_fp8_sparse_kernel>; + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + dim3 grid_shape; + if constexpr (IS_DECODE) { + grid_shape = dim3(2*params.s_q, FWD_MODE == FwdMode::DecodeWithSplitKV ? params.num_sm_parts : params.b, 1); + } else { + grid_shape = dim3(2*params.s_q, 1, 1); + } + + cutlass::ClusterLaunchParams launch_params = { + grid_shape, + dim3(NUM_THREADS, 1, 1), + dim3(2, 1, 1), + smem_size, + params.stream + }; + KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster( + launch_params, (void*)kernel, params, tma_params + )); +} + +template +void run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT& params) { + using Kernel = KernelTemplate; + Kernel::run(params); +} + +} diff --git a/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h new file mode 100644 index 0000000..d1a092a --- /dev/null +++ b/csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h @@ -0,0 +1,10 @@ +#pragma once + +#include "params.h" + +namespace sm100::fwd_for_small_topk::head128 { + +template +void run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT& params); + +} diff --git a/csrc/sm100/prefill/sparse/helpers.h b/csrc/sm100/prefill/sparse/helpers.h deleted file mode 100644 index 991b40d..0000000 --- a/csrc/sm100/prefill/sparse/helpers.h +++ /dev/null @@ -1,104 +0,0 @@ -#pragma once - -#include -#include "sm100/defines.h" - -namespace sm100 { - -using namespace cute; - -using _72 = Int<72>; -using _576 = Int<576>; - -template< - typename TMA, - typename Tensor0, - typename Tensor1 -> -CUTE_DEVICE -void launch_tma_copy( - const TMA &tma_copy, - Tensor0 src, - Tensor1 dst, - transac_bar_t &bar, - const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL -) { - auto thr_tma = tma_copy.get_slice(_0{}); - cute::copy( - tma_copy.with(reinterpret_cast(bar), 0, cache_hint), - thr_tma.partition_S(src), - thr_tma.partition_D(dst) - ); -} - -template< - typename TiledMMA, - typename TensorA, - typename TensorB, - typename TensorFragC -> -CUTE_DEVICE -void utcmma( - TiledMMA &tiled_mma, - TensorA sA, - TensorB sB, - TensorFragC tC_frag, - bool clear_accum -) { - tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; - ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter - auto sA_frag = thr_mma.partition_fragment_A(sA); - auto sB_frag = thr_mma.partition_fragment_B(sB); - static_assert(size<2>(sA_frag) == size<2>(sB_frag)); - static_assert(size<1>(sA_frag) == size<1>(tC_frag)); - static_assert(size<1>(sB_frag) == size<2>(tC_frag)); - CUTE_UNROLL - for (int k = 0; k < size<2>(sA_frag); ++k) { - cute::gemm( - tiled_mma, - sA_frag(_, _, k), - sB_frag(_, _, k), - tC_frag - ); - tiled_mma.accumulate_ = UMMA::ScaleOut::One; - } -} - -template< - typename TiledMMA, - typename TensorA, - typename TensorB, - typename TensorFragC -> -CUTE_DEVICE -void utcmma_ts( - TiledMMA &tiled_mma, - TensorA tA_frag, - TensorB sB, - TensorFragC tC_frag, - bool clear_accum -) { - tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; - ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter - auto sB_frag = thr_mma.partition_fragment_B(sB); - static_assert(size<2>(tA_frag) == size<2>(sB_frag)); - CUTE_UNROLL - for (int k = 0; k < size<2>(tA_frag); ++k) { - cute::gemm( - tiled_mma, - tA_frag(_, _, k), - sB_frag(_, _, k), - tC_frag - ); - tiled_mma.accumulate_ = UMMA::ScaleOut::One; - } -} - -struct bf16x8 { - __nv_bfloat162 a01; - __nv_bfloat162 a23; - __nv_bfloat162 a45; - __nv_bfloat162 a67; -}; - -} diff --git a/csrc/sm100/prefill/sparse/intrinsics.h b/csrc/sm100/prefill/sparse/intrinsics.h deleted file mode 100644 index 85a8203..0000000 --- a/csrc/sm100/prefill/sparse/intrinsics.h +++ /dev/null @@ -1,638 +0,0 @@ -#pragma once - -#include -#include "defines.h" - -namespace sm100 { - -using namespace cute; - -struct int32x8_t { - int a0, a1, a2, a3, a4, a5, a6, a7; -}; - -struct float8 { - float2 a01, a23, a45, a67; -}; - -__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { - uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); - asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" - :: "r"(dst_addr), - "l"(src), - "n"(16)); -} - -template -CUTE_DEVICE -static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { - static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b."); - long2 data_long2 = *reinterpret_cast(&data); - uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); - uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); - asm volatile ( - "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" - : - : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) - ); -} - -CUTE_DEVICE -void umma_arrive_multicast_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) { - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" - "}" - : - :"r"(bar_intptr), "h"(cta_mask)); -} - -CUTE_DEVICE -void umma_arrive_multicast_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) { - umma_arrive_multicast_noelect((uint64_t*)smem_ptr, cta_mask); -} - -CUTE_DEVICE -void umma_arrive_multicast_2x1SM_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) { - uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr); - asm volatile( - "{\n\t" - "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t" - "}" - : - :"r"(bar_intptr), "h"(cta_mask)); -} - -CUTE_DEVICE -void umma_arrive_multicast_2x1SM_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) { - umma_arrive_multicast_2x1SM_noelect((uint64_t*)smem_ptr, cta_mask); -} - -CUTE_DEVICE -int64_t createpolicy_evict_last() { - int64_t res; - asm volatile( - "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" - : "=l"(res) - : - ); - return res; -} - -CUTE_DEVICE -void atomicadd_f32x4_with_policy(void* global_addr, const float4 &data, int64_t cache_policy) { - asm volatile( - "red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t" - : - : "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w), - "l"((int64_t)global_addr), "l"(cache_policy) - ); -} - -CUTE_DEVICE -void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) { - uint32_t smem_int_ptr = cast_smem_ptr_to_uint(src_ptr); - asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" - : - : "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes) - : "memory"); -} - -CUTE_DEVICE -float2 float2_add(const float2 &a, const float2 &b) { - float2 res; - cute::add(res, a, b); - return res; -} - -CUTE_DEVICE -float2 float2_mul(const float2 &a, const float2 &b) { - float2 res; - cute::mul(res, a, b); - return res; -} - -CUTE_DEVICE -float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) { - // return a*b+c - float2 res; - cute::fma(res, a, b, c); - return res; -} - -CUTE_DEVICE -float2 float2_neg(const float2 &a) { - float2 t = {-1.0f, -1.0f}; - return float2_mul(a, t); -} - -__device__ __forceinline__ void tcgen05_before_thread_sync() { - asm volatile("tcgen05.fence::before_thread_sync;"); -} - -__device__ __forceinline__ void tcgen05_after_thread_sync() { - asm volatile("tcgen05.fence::after_thread_sync;"); -} - -template -CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) { - uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); - uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); - if constexpr (USE_CTA0_MBAR) { - mbar_addr &= Sm100MmaPeerBitMask; - } - asm volatile( - "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" - : - : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), - "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), - "r"(mbar_addr), "l"(uint64_t(cache_hint)) - : "memory" - ); -} - -// 32 data path lanes, 32-bit pattern, repeated N times -template -CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); - uint32_t* dst_ptr = reinterpret_cast(dst_ptr_); - - if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32" - "{%0}," - "[%1];\n" - : "=r"(dst_ptr[0]) - : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32" - "{%0, %1}," - "[%2];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) - : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32" - "{%0, %1, %2, %3}," - "[%4];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]) - : "r"(src_addr)); - } else if constexpr (N == 8) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "[%8];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) - : "r"(src_addr)); - } else if constexpr (N == 16) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15}," - "[%16];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]) - : "r"(src_addr)); - } else if constexpr (N == 32) { - asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " - "%26, %27, %28, %29, %30, %31}," - "[%32];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) - : "r"(src_addr)); - } else if constexpr (N == 64) { - asm volatile( - "tcgen05.ld.sync.aligned.32x32b.x64.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " - "%57, %58, %59, %60, %61, %62, %63}," - "[%64];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]) - : "r"(src_addr)); - } else if constexpr (N == 128) { - asm volatile( - "tcgen05.ld.sync.aligned.32x32b.x128.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " - "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " - "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " - "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " - "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " - "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " - "%121, %122, %123, %124, %125, %126, %127}," - "[%128];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), - "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), - "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), - "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), - "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), - "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), - "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), - "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), - "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), - "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), - "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), - "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), - "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), - "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), - "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), - "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), - "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), - "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), - "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), - "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), - "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), - "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) - : "r"(src_addr)); - } else { - asm volatile ("trap"); - } -} - -// 16 data path lanes, 256-bit pattern, repeated N times -template -CUTE_DEVICE void tmem_ld_16dp256bNx(uint32_t const &src_addr, T* dst_ptr_) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, - "N must be a power of 2 and lies between 1 ~ 32"); - uint32_t* dst_ptr = reinterpret_cast(dst_ptr_); - - if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" - "{%0, %1, %2, %3}," - "[%4];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]) - : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "[%8];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) - : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x4.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15}," - "[%16];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]) - : "r"(src_addr)); - } else if constexpr (N == 8) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x8.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " - "%26, %27, %28, %29, %30, %31}," - "[%32];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) - : "r"(src_addr)); - } else if constexpr (N == 16) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x16.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63}," - "[%64];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]) - : "r"(src_addr)); - } else if constexpr (N == 32) { - asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x32.b32" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " - "%28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " - "%42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " - "%56, " - "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " - "%70, " - "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " - "%84, " - "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " - "%98, " - "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " - "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " - "%121, %122, %123, %124, %125, %126, %127}," - "[%128];\n" - : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), - "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), - "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), - "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), - "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), - "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), - "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), - "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), - "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), - "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), - "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), - "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), - "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), - "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), - "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), - "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), - "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), - "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), - "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), - "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), - "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), - "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), - "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), - "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), - "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), - "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), - "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), - "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), - "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), - "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), - "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), - "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), - "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), - "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), - "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), - "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), - "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), - "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), - "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), - "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), - "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), - "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), - "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) - : "r"(src_addr)); - } else { - asm volatile("trap"); - } -} - - -// 32 data path lanes, 32-bit pattern, repeated N times -template -CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128"); - uint32_t* src_ptr = reinterpret_cast(src_ptr_); - - if constexpr (N == 1) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32" - "[%1], {%0};\n" - : - : "r"(src_ptr[0]), - "r"(dst_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32" - "[%2], {%0, %1};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), - "r"(dst_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32" - "[%4], {%0, %1, %2, %3};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), - "r"(dst_addr)); - } else if constexpr (N == 8) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32" - "[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), - "r"(dst_addr)); - } else if constexpr (N == 16) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32" - "[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), - "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), - "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), - "r"(src_ptr[15]), - "r"(dst_addr)); - } else if constexpr (N == 32) { - asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32" - "[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " - "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " - "%26, %27, %28, %29, %30, %31};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), - "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), - "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), - "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), - "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), - "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), - "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), - "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), - "r"(src_ptr[30]), "r"(src_ptr[31]), - "r"(dst_addr)); - } else if constexpr (N == 64) { - asm volatile( - "tcgen05.st.sync.aligned.32x32b.x64.b32" - "[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " - "%57, %58, %59, %60, %61, %62, %63};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), - "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), - "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), - "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), - "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), - "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), - "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), - "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), - "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), - "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), - "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), - "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), - "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), - "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), - "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), - "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), - "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), - "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), - "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), - "r"(src_ptr[63]), - "r"(dst_addr)); - } else if constexpr (N == 128) { - asm volatile( - "tcgen05.st.sync.aligned.32x32b.x128.b32" - "[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " - "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, " - "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, " - "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, " - "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, " - "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, " - "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, " - "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " - "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " - "%121, %122, %123, %124, %125, %126, %127};\n" - : - : "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]), - "r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]), - "r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]), - "r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]), - "r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]), - "r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]), - "r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]), - "r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]), - "r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]), - "r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]), - "r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]), - "r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]), - "r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]), - "r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]), - "r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]), - "r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]), - "r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]), - "r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]), - "r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]), - "r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]), - "r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]), - "r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]), - "r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]), - "r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]), - "r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]), - "r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]), - "r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]), - "r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]), - "r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]), - "r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]), - "r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]), - "r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]), - "r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]), - "r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]), - "r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]), - "r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]), - "r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]), - "r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]), - "r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]), - "r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]), - "r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]), - "r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]), - "r"(src_ptr[126]), "r"(src_ptr[127]), - "r"(dst_addr)); - } else { - asm volatile ("trap"); - } -} - - -static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 -template -CUTE_DEVICE -T* get_peer_addr(const T* p) { - return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); -} - -} diff --git a/csrc/sm100/prefill/sparse/ws_gemm.h b/csrc/sm100/prefill/sparse/ws_gemm.h deleted file mode 100644 index 78c9005..0000000 --- a/csrc/sm100/prefill/sparse/ws_gemm.h +++ /dev/null @@ -1,328 +0,0 @@ -#pragma once - -#include - -namespace cute { - -// Extensions to CuTe -// CuTe don't support UTCMMA with .ws, so we add it here - -template -struct SM100_MMA_F16BF16_WS_SS_NOELECT -{ - static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA."); - static_assert(N == 64 || N == 128 || N == 256, - "SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128"); - - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[1]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t const& tmem_c, - uint32_t const& scaleC, - uint64_t const& idescE) - { - asm volatile( - "{\n\t" - ".reg .pred p;\n\t" - "setp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" - "}\n" - : - : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC)); - } -}; - -template -struct MMA_Traits> -{ - using ValTypeD = c_type; - using ValTypeA = a_type; - using ValTypeB = b_type; - using ValTypeC = c_type; - - static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types"); - - using FrgTypeA = UMMA::smem_desc; - using FrgTypeB = UMMA::smem_desc; - using FrgTypeC = UMMA::tmem_frg_ws_1sm; - - // Logical shape-K is always 256bits, transform to units of elements - static constexpr int K = 256 / cute::sizeof_bits::value; - - using Shape_MNK = Shape,Int,Int>; - using ThrID = Layout<_1>; - using ALayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using BLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - using CLayout = Layout,Int>>, - Stride<_0,Stride< _1,Int>>>; - - UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< - a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); - - // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] - UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; - - template - CUTE_HOST_DEVICE constexpr friend - void - mma_unpack(MMA_Traits const& traits, - Tensor & D, - Tensor const& A, - Tensor const& B, - Tensor const& C) - { - static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); - static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - - uint64_t desc_a = A[0]; - uint64_t desc_b = B[0]; - uint32_t tmem_c = raw_pointer_cast(D.data()); - uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); - - SM100_MMA_F16BF16_WS_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); - } -}; - - -template -struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT -{ - static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); - static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed"); - - using DRegisters = void; - using ARegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[1]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& tmem_a, - uint64_t const& desc_b, - uint32_t const& tmem_c, - uint32_t const& scaleC, - uint64_t const& idescE) - { -#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) - uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - asm volatile( - "{\n\t" - ".reg .pred p;\n\t" - "setp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" - "}\n" - : - : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), - "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), - "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); -#endif - } -}; - - -template -struct MMA_Traits> -{ - using ValTypeD = c_type; - using ValTypeA = a_type; - using ValTypeB = b_type; - using ValTypeC = c_type; - static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types"); - - using FrgTypeA = UMMA::tmem_frg_2sm; - using FrgTypeB = UMMA::smem_desc; - using FrgTypeC = UMMA::tmem_frg_2sm; - - // Size of instructions' K extent is always 256 bits; convert to units of element - constexpr static int K = 256 / cute::sizeof_bits::value; - - using Shape_MNK = Shape,Int,Int>; - using ThrID = Layout<_2>; - using ALayout = Layout,Int>>, - Stride,Stride< _1,Int>>>; - using BLayout = Layout,Int>>, - Stride,Stride< _1,Int>>>; - using CLayout = Layout,Int>>, - Stride,Stride< _1,Int>>>; - - // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] - UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; - - UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< - a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); - - template - CUTE_HOST_DEVICE constexpr friend - void - mma_unpack(MMA_Traits const& traits, - Tensor & D, - Tensor const& A, - Tensor const& B, - Tensor const& C) - { - static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); - static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - - uint64_t tmem_a = raw_pointer_cast(A.data()); - uint64_t desc_b = B[0]; - uint32_t tmem_c = raw_pointer_cast(D.data()); - uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); - - SM100_MMA_F16BF16_2x1SM_TS_NOELECT::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); - } -}; - - - -// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync() -template -struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT -{ - static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256."); - - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[1]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t const& tmem_c, - uint32_t const& scaleC, - uint64_t const& idescE) - { -#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) - uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - asm volatile( - "{\n\t" - ".reg .pred p;\n\t" - "setp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" - "}\n" - : - : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), - "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), - "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); -#endif - } -}; - -// template -// struct MMA_Traits> : MMA_Traits> {}; -template -struct MMA_Traits> -{ - using ValTypeD = c_type; - using ValTypeA = a_type; - using ValTypeB = b_type; - using ValTypeC = c_type; - static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types"); - - using FrgTypeA = UMMA::smem_desc; - using FrgTypeB = UMMA::smem_desc; - using FrgTypeC = UMMA::tmem_frg_2sm; - - // Size of instructions's K extent is always 256bits, convert to units of element - constexpr static int K = 256 / cute::sizeof_bits::value; - - using Shape_MNK = Shape,Int,Int>; - using ThrID = Layout<_2>; - using ALayout = Layout,Int>>, - Stride,Stride< _1,Int>>>; - using BLayout = Layout,Int>>, - Stride,Stride< _1,Int>>>; - using CLayout = Layout,Int>>, - Stride,Stride< _1,Int>>>; - - UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< - a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); - - // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] - UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; - - template - CUTE_HOST_DEVICE constexpr friend - void - mma_unpack(MMA_Traits const& traits, - Tensor & D, - Tensor const& A, - Tensor const& B, - Tensor const& C) - { - static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); - static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); - static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); - - uint64_t desc_a = A[0]; - uint64_t desc_b = B[0]; - uint32_t tmem_c = raw_pointer_cast(D.data()); - uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); - - SM100_MMA_F16BF16_2x1SM_SS_NOELECT::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); - } -}; - -} \ No newline at end of file diff --git a/csrc/sm90/decode/dense/instantiations/bf16.cu b/csrc/sm90/decode/dense/instantiations/bf16.cu new file mode 100644 index 0000000..3a1dce9 --- /dev/null +++ b/csrc/sm90/decode/dense/instantiations/bf16.cu @@ -0,0 +1,8 @@ +#include "../splitkv_mla.cuh" +#include "../splitkv_mla.h" + +namespace sm90 { + +template void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms); + +} diff --git a/csrc/sm90/decode/dense/instantiations/fp16.cu b/csrc/sm90/decode/dense/instantiations/fp16.cu new file mode 100644 index 0000000..bc6cd64 --- /dev/null +++ b/csrc/sm90/decode/dense/instantiations/fp16.cu @@ -0,0 +1,10 @@ +#include "../splitkv_mla.cuh" +#include "../splitkv_mla.h" + +namespace sm90 { + +#ifndef FLASH_MLA_DISABLE_FP16 +template void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms); +#endif + +} diff --git a/csrc/sm90/decode/dense/splitkv_mla.cu b/csrc/sm90/decode/dense/splitkv_mla.cuh similarity index 95% rename from csrc/sm90/decode/dense/splitkv_mla.cu rename to csrc/sm90/decode/dense/splitkv_mla.cuh index cb2e476..cdd5441 100644 --- a/csrc/sm90/decode/dense/splitkv_mla.cu +++ b/csrc/sm90/decode/dense/splitkv_mla.cuh @@ -758,7 +758,7 @@ __forceinline__ __device__ void wg0_subroutine( TMABarrier barriers_K1[9], bool &cur_phase_K0, const TMAParams &tma_params, - const DecodingParams ¶ms, + const DenseAttnDecodeParams ¶ms, int* block_table_ptr, int seqlen_k, int block_idx, @@ -870,7 +870,7 @@ __forceinline__ __device__ void wg1_subroutine( TMABarrier barriers_K1[9], bool &cur_phase_K1, const TMAParams &tma_params, - const DecodingParams ¶ms, + const DenseAttnDecodeParams ¶ms, int* block_table_ptr, int seqlen_k, int block_idx, @@ -945,7 +945,7 @@ __forceinline__ __device__ void wg1_subroutine( } // A helper function for determining the length of the causal mask for one q token -__forceinline__ __device__ int get_mask_len(const DecodingParams ¶ms, int m_block_idx, int local_seq_q_idx) { +__forceinline__ __device__ int get_mask_len(const DenseAttnDecodeParams ¶ms, int m_block_idx, int local_seq_q_idx) { int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx; if (global_seq_q_idx < params.q_seq_per_hk) { int s_q_idx = global_seq_q_idx / params.q_head_per_hk; @@ -958,7 +958,7 @@ __forceinline__ __device__ int get_mask_len(const DecodingParams ¶ms, int m_ template __global__ void __launch_bounds__(T::NUM_THREADS, 1, 1) -flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) { +flash_fwd_splitkv_mla_kernel(__grid_constant__ const DenseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) { // grid shape: [ // num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))), // num_kv_heads, @@ -968,7 +968,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr // If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx]) // For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file). -#if IS_SM90 +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) const int m_block_idx = blockIdx.x; const int k_head_idx = blockIdx.y; const int partition_idx = blockIdx.z; @@ -1016,30 +1016,21 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr __syncthreads(); bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0; - // Programmatic Dependent Launch: Wait for the previous kernel to finish - cudaGridDependencySynchronize(); - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - // We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race. - int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); - int begin_idx = tile_scheduler_metadata.x; - int sched_begin_block_idx = tile_scheduler_metadata.y; - int end_idx = tile_scheduler_metadata.z; - int sched_end_block_idx = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; - int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); + DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx]; + if (sched_meta.begin_req_idx >= params.b) return; // Copy the first Q - launch_q_copy(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q); + launch_q_copy(tma_params, sched_meta.begin_req_idx, m_block_idx, k_head_idx, sQ, barrier_Q); #pragma unroll 1 - for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { constexpr int kBlockN = T::PAGE_BLOCK_SIZE; - const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; + const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0; int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); - const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0; - int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(seqlen_k, kBlockN); - const bool is_no_split = __ldg(params.num_splits_ptr + batch_idx + 1) - __ldg(params.num_splits_ptr + batch_idx) == 1; + const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; + int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN); + const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true); int rRightBorderForQSeq[2]; if (params.is_causal) { @@ -1061,7 +1052,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr // NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1); int last_block_in_seq = cute::ceil_div(seqlen_k-common_mask_len, kBlockN); - end_block_idx = batch_idx == end_idx ? min(sched_end_block_idx, last_block_in_seq) : last_block_in_seq; + end_block_idx = batch_idx == sched_meta.end_req_idx ? min(sched_meta.end_block_idx, last_block_in_seq) : last_block_in_seq; CUTLASS_PRAGMA_UNROLL for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { @@ -1127,7 +1118,9 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr cur_phase_K0 ^= 1; // Issue P0 = Q @ K0^T, wait - warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); + if (start_block_idx-16777216 < end_block_idx) { // NOTE We use this `if` to prevent register spilling + warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); + } // We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0 NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized); cute::warpgroup_wait<0>(); @@ -1225,7 +1218,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i]; // Copy Q for the next batch - if (batch_idx+1 <= end_idx) { + if (batch_idx+1 <= sched_meta.end_req_idx) { launch_q_copy(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q); } else { // Allow the next kernel (the combine kernel) to launch @@ -1268,7 +1261,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr cute::tma_store_wait<0>(); } - if (batch_idx != end_idx) + if (batch_idx != sched_meta.end_req_idx) __syncthreads(); } #else @@ -1280,7 +1273,10 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr template -void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream) { +void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms) { + FLASH_ASSERT(params.d == Config::HEAD_DIM_K); + FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V); + using T = Traits; auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); auto tma_Q = cute::make_tma_copy( @@ -1348,7 +1344,7 @@ void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream) { dim3(num_m_block, params.h_k, params.num_sm_parts), dim3(T::NUM_THREADS, 1, 1), smem_size, - stream, + params.stream, mla_kernel_attributes, 1 }; @@ -1356,10 +1352,4 @@ void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); } -template void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream); - -#ifndef FLASH_MLA_DISABLE_FP16 -template void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream); -#endif - } diff --git a/csrc/sm90/decode/dense/splitkv_mla.h b/csrc/sm90/decode/dense/splitkv_mla.h index 6d45cfa..b2c50c8 100644 --- a/csrc/sm90/decode/dense/splitkv_mla.h +++ b/csrc/sm90/decode/dense/splitkv_mla.h @@ -5,6 +5,6 @@ namespace sm90 { template -void run_flash_splitkv_mla_kernel(DecodingParams ¶ms, cudaStream_t stream); +void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms); } diff --git a/csrc/sm90/decode/sparse_fp8/components/config.h b/csrc/sm90/decode/sparse_fp8/components/config.h index bdba0b8..f38915b 100644 --- a/csrc/sm90/decode/sparse_fp8/components/config.h +++ b/csrc/sm90/decode/sparse_fp8/components/config.h @@ -3,119 +3,19 @@ #include #include #include - -using bf16 = cutlass::bfloat16_t; -using fp8 = cutlass::float_e4m3_t; -using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; +#include "defines.h" using namespace cute; -static constexpr int NUM_THREADS = 128*3; -static constexpr int BLOCK_M = 64; -static constexpr int TOPK_BLOCK_SIZE = 64; -static constexpr int PAGE_BLOCK_SIZE = 64; -static constexpr int QUANT_TILE_SIZE = 128; +namespace sm90::decode::sparse_fp8 { static constexpr int HEAD_DIM_K = 576; static constexpr int HEAD_DIM_V = 512; static constexpr int HEAD_DIM_NOPE = HEAD_DIM_V; static constexpr int HEAD_DIM_ROPE = HEAD_DIM_K - HEAD_DIM_V; +static constexpr int QUANT_TILE_SIZE = 128; static constexpr int NUM_SCALES = HEAD_DIM_NOPE / QUANT_TILE_SIZE; static constexpr int NUM_BYTES_PER_TOKEN = HEAD_DIM_NOPE + NUM_SCALES*sizeof(float) + HEAD_DIM_ROPE*sizeof(bf16); +static constexpr int PAGE_BLOCK_SIZE = 64; -static constexpr int NUM_K_BUFS = 2; - -using SmemLayoutQTile = decltype(tile_to_shape( - GMMA::Layout_SW128_Atom{}, - Shape, Int<64>>{} -)); - -template -using SmemLayoutQTiles = decltype(tile_to_shape( - SmemLayoutQTile{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -)); - -using SmemLayoutQ = SmemLayoutQTiles<9>; - -using SmemLayoutKTile = decltype(tile_to_shape( - GMMA::Layout_INTER_Atom{}, - Shape, _64>{}, - Step<_1, _2>{} -)); - -template -using SmemLayoutKTiles = decltype(tile_to_shape( - SmemLayoutKTile{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -)); - -template -using SmemLayoutKTilesTransposed = decltype(composition( - SmemLayoutKTiles{}, - Layout, Int>, Stride, _1>>{} -)); - -using SmemLayoutOBuf = decltype(tile_to_shape( - GMMA::Layout_K_SW128_Atom{}, - Shape, Int>{} -)); - -using SmemLayoutOAccumBuf = Layout< - Shape, Int>, - Stride, _1> // We use stride = 520 here to avoid bank conflict ->; - -using SmemLayoutK = SmemLayoutKTiles<9>; -using SmemLayoutV = SmemLayoutKTilesTransposed<8>; -using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>; - -using SmemLayoutS = decltype(tile_to_shape( - GMMA::Layout_K_SW128_Atom{}, - Shape, Int>{} -)); - -struct SharedMemoryPlan { - array_aligned> q; - union { - array_aligned> k[NUM_K_BUFS]; - array_aligned> oBuf; - array_aligned> oAccumBuf; - } u; - array_aligned> s; - bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE]; - - float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M]; - transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS]; -}; - -template< - typename Shape_Q, typename TMA_Q, - typename Shape_O, typename TMA_O -> -struct TmaParams { - Shape_Q shape_Q; TMA_Q tma_Q; - Shape_O shape_O; TMA_O tma_O; -}; - -using TiledMMA_QK = decltype(make_tiled_mma( - GMMA::MMA_64x64x16_F32BF16BF16_SS{}, - Layout>{} -)); - -using TiledMMA_QK_rQ = decltype(make_tiled_mma( - GMMA::MMA_64x64x16_F32BF16BF16_RS{}, - Layout>{} -)); - -using TiledMMA_PV_LocalP = decltype(make_tiled_mma( - GMMA::MMA_64x256x16_F32BF16BF16_RS{}, - Layout>{} -)); - -using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( - GMMA::MMA_64x256x16_F32BF16BF16_SS{}, - Layout>{} -)); +} \ No newline at end of file diff --git a/csrc/sm90/decode/sparse_fp8/components/dequant.h b/csrc/sm90/decode/sparse_fp8/components/dequant.h index c3efc05..0c4022d 100644 --- a/csrc/sm90/decode/sparse_fp8/components/dequant.h +++ b/csrc/sm90/decode/sparse_fp8/components/dequant.h @@ -3,6 +3,10 @@ #include #include +#include "defines.h" + +namespace sm90::decode::sparse_fp8 { + struct fp8x8 { __nv_fp8x4_e4m3 lo; __nv_fp8x4_e4m3 hi; @@ -13,14 +17,8 @@ struct fp8x16 { fp8x8 hi; }; -struct bf16x8 { - __nv_bfloat162 a, b, c, d; -}; - __device__ __forceinline__ -bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) { - __nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale); - +bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale_bf162) { #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \ { \ float4 fp32x4 = (float4)(FP8x4); \ @@ -29,8 +27,8 @@ bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) { } bf16x8 result; - DEQUANT_FP8x4(result.a, result.b, inputs.lo); - DEQUANT_FP8x4(result.c, result.d, inputs.hi); + DEQUANT_FP8x4(result.a01, result.a23, inputs.lo); + DEQUANT_FP8x4(result.a45, result.a67, inputs.hi); return result; } @@ -86,3 +84,44 @@ T load_128b_from_gmem(const void* addr) { #undef DISPATCH_L2 return *reinterpret_cast(&ret); } + +template< + typename T, + L1CacheHint l1_cache_hint, + L2PrefetchHint l2_prefetch_hint +> +__device__ __forceinline__ +T load_64b_from_gmem(const void* addr) { + static_assert(sizeof(T) == 64/8); + int2 ret; + + #define EXEC(L1_HINT_STR, L2_HINT_STR) { \ + asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v2.s32 {%0, %1}, [%2];" \ + : "=r"(ret.x), "=r"(ret.y) \ + : "l"(addr)); \ + } + + #define DISPATCH_L2(L1_HINT_STR) { \ + if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \ + EXEC(L1_HINT_STR, "64B") \ + else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \ + EXEC(L1_HINT_STR, "128B") \ + else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \ + EXEC(L1_HINT_STR, "256B") \ + } + + if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE) + DISPATCH_L2("no_allocate") + else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST) + DISPATCH_L2("evict_first") + else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL) + DISPATCH_L2("evict_normal") + else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST) + DISPATCH_L2("evict_last") + + #undef EXEC + #undef DISPATCH_L2 + return *reinterpret_cast(&ret); +} + +} diff --git a/csrc/sm90/decode/sparse_fp8/components/epilogue.h b/csrc/sm90/decode/sparse_fp8/components/epilogue.h deleted file mode 100644 index 038cbfd..0000000 --- a/csrc/sm90/decode/sparse_fp8/components/epilogue.h +++ /dev/null @@ -1,87 +0,0 @@ -#pragma once - -#include "named_barriers.h" - -// Store O / OAccum -template< - bool IS_NO_SPLIT, - typename TMAParams, - typename Tensor0, - typename Tensor1, - typename Tensor2, - typename Tensor3 -> -__forceinline__ __device__ void store_o( - Tensor0 &rO, // ((2, 2, 32), 1, 1) - Tensor1 &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) - Tensor2 &sOutputBuf, - Tensor3 &sOutputAccumBuf, - float rL[2], - TMAParams &tma_params, - int batch_idx, - int s_q_idx, - int head_block_idx, - int num_valid_seq_q, - int warpgroup_idx, - int idx_in_warpgroup -) { - using cutlass::arch::NamedBarrier; - if constexpr (IS_NO_SPLIT) { - // Should convert the output to bfloat16 / float16, and save it to O - Tensor rOb = make_tensor_like(rO); - CUTLASS_PRAGMA_UNROLL - for (int idx = 0; idx < size(rO); ++idx) { - rOb(idx) = (bf16)(rO(idx) / rL[idx%4 >= 2]); - } - - Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); - TiledCopy r2s_tiled_copy = make_tiled_copy_C( - Copy_Atom{}, - TiledMMA_PV_LocalP{} - ); - ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup); - Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb); - Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf); - cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf); - cutlass::arch::fence_view_async_shared(); - - NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); - - if (threadIdx.x == 0) { - Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx); - auto thr_tma = tma_params.tma_O.get_slice(_0{}); - Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, head_block_idx, _0{}); - cute::copy( - tma_params.tma_O, - thr_tma.partition_S(sOutputBuf), - thr_tma.partition_D(my_tma_gO) - ); - cute::tma_store_arrive(); - } - } else { - // Should save the result to OAccum - CUTLASS_PRAGMA_UNROLL - for (int idx = 0; idx < size(rO); idx += 2) { - int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); - int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; - *(float2*)(&(sOutputAccumBuf(row, col))) = float2 { - rO(idx) / rL[idx%4 >= 2], - rO(idx+1) / rL[idx%4 >= 2], - }; - } - cutlass::arch::fence_view_async_shared(); - - NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); - - if (elect_one_sync()) { - CUTLASS_PRAGMA_UNROLL - for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) { - int row = local_row * (256/32) + (threadIdx.x / 32); - if (row < num_valid_seq_q) { - SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float)); - } - } - cute::tma_store_arrive(); - } - } -} diff --git a/csrc/sm90/decode/sparse_fp8/components/helpers.h b/csrc/sm90/decode/sparse_fp8/components/helpers.h index 8a336ea..d47e492 100644 --- a/csrc/sm90/decode/sparse_fp8/components/helpers.h +++ b/csrc/sm90/decode/sparse_fp8/components/helpers.h @@ -1,5 +1,14 @@ #pragma once +#include +#include + +#include "config.h" + +using namespace cute; + +namespace sm90::decode::sparse_fp8 { + // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { @@ -78,9 +87,23 @@ static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mba ); } -static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +CUTE_DEVICE +static void cp_async_bulk_shared_cta_shared_cluster(void* dst_ptr, void* src_ptr, int size, transac_bar_t* mbar_ptr) { + uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); + uint32_t src_addr = cute::cast_smem_ptr_to_uint(src_ptr); + uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); + asm volatile ( + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n" + : + : "r"(dst_addr), "r"(src_addr), "r"(size), "r"(mbar_addr) + ); +} + +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. template CUTE_DEVICE -T* get_peer_addr(const T* p) { +T* get_peer_addr(T* p) { return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); } + +} diff --git a/csrc/sm90/decode/sparse_fp8/components/named_barriers.h b/csrc/sm90/decode/sparse_fp8/components/named_barriers.h deleted file mode 100644 index b91cb22..0000000 --- a/csrc/sm90/decode/sparse_fp8/components/named_barriers.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -enum NamedBarriers : uint32_t { - sScale_and_sS_ready = 0, - sScale_and_sS_free = 1, - oBuf_free_and_sL_ready = 2, - epilogue_r2s_ready = 3, - batch_loop_sync = 4, - warpgroup0_sync = 5 -}; diff --git a/csrc/sm90/decode/sparse_fp8/config.h b/csrc/sm90/decode/sparse_fp8/config.h new file mode 100644 index 0000000..e5631f3 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/config.h @@ -0,0 +1,279 @@ +#pragma once + +#include +#include +#include +#include + +#include "defines.h" +#include "params.h" + +using namespace cute; + +namespace sm90::decode::sparse_fp8 { + +template +class KernelTemplate { +public: + +static_assert(NUM_HEADS == 64 || NUM_HEADS == 128); +static constexpr int NUM_M_BLOCKS = NUM_HEADS / 64; +static constexpr int CLUSTER_SIZE = NUM_M_BLOCKS; + +static constexpr int HEAD_DIM_K = MODEL_TYPE == ModelType::V32 ? 576 : 512; +static constexpr int HEAD_DIM_V = 512; +static constexpr int HEAD_DIM_ROPE = 64; +static constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE; + +static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64; +static constexpr int NUM_SCALES = MODEL_TYPE == ModelType::V32 ? 4 : 8; // For MODEL1: 7 fp8_e4m3 + 1 padding + +static constexpr int NUM_THREADS = 128*3; +static constexpr int BLOCK_M = 64; +static constexpr int TOPK_BLOCK_SIZE = 64; +static constexpr int NUM_K_BUFS = 2; + +using SmemLayoutQTile = decltype(tile_to_shape( + GMMA::Layout_SW128_Atom{}, + Shape, Int<64>>{} +)); + +template +using SmemLayoutQTiles = decltype(tile_to_shape( + SmemLayoutQTile{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +)); + +using SmemLayoutQ = SmemLayoutQTiles; + +using SmemLayoutKTile = decltype(tile_to_shape( + GMMA::Layout_INTER_Atom{}, + Shape, _64>{}, + Step<_1, _2>{} +)); + +template +using SmemLayoutKTiles = decltype(tile_to_shape( + SmemLayoutKTile{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +)); + +template +using SmemLayoutKTilesTransposed = decltype(composition( + SmemLayoutKTiles{}, + Layout, Int>, Stride, _1>>{} +)); + +static constexpr int OBUF_SW = 64; +using SmemLayoutOBufAtom = GMMA::Layout_K_SW128_Atom; +using SmemLayoutOBuf = decltype(tile_to_shape( + SmemLayoutOBufAtom{}, + Shape, Int>{}, + Step<_1, _2>{} +)); + +using SmemLayoutOAccumBuf = Layout< + Shape, Int>, + Stride, _1> // We use stride = 520 here to avoid bank conflict +>; + +using SmemLayoutK = SmemLayoutKTiles; +using SmemLayoutV = SmemLayoutKTilesTransposed; +using SmemLayoutHalfV = SmemLayoutKTilesTransposed; + +using SmemLayoutS = decltype(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} +)); + +struct SharedMemoryPlan { + array_aligned> q; + union { + array_aligned> k[NUM_K_BUFS]; + array_aligned> oBuf; + array_aligned> oAccumBuf; + } u; + CUTE_ALIGNAS(1024) array_aligned> s; + bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE]; + + float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M]; + transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS]; +}; + +template< + typename Shape_Q, typename TMA_Q +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + CUtensorMap tensor_map_o; +}; + +using TiledMMA_QK = decltype(make_tiled_mma( + GMMA::MMA_64x64x16_F32BF16BF16_SS{}, + Layout>{} +)); + +using TiledMMA_QK_rQ = decltype(make_tiled_mma( + GMMA::MMA_64x64x16_F32BF16BF16_RS{}, + Layout>{} +)); + +using TiledMMA_PV_LocalP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_RS{}, + Layout>{} +)); + +using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_SS{}, + Layout>{} +)); + + +enum NamedBarriers : uint32_t { + sScale_and_sS_ready = 0, + sScale_and_sS_free = 1, + oBuf_free_and_sL_ready = 2, + epilogue_r2s_ready = 3, + batch_loop_sync = 4, + warpgroup0_sync = 5 +}; + + +// Synchronize all threads within the cluster (which processes one q token) +static __forceinline__ __device__ void sync_all_threads_in_cluster() { + if constexpr (CLUSTER_SIZE == 1) { + __syncthreads(); + } else { + ku::barrier_cluster_arrive_relaxed(); + ku::barrier_cluster_wait_acquire(); + } +} + +// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction +template< + typename Tensor0, + typename Tensor1 +> +static __forceinline__ __device__ void save_rPb_to_sP( + Tensor0 const &rPb, + Tensor1 const &sP, + int idx_in_warpgroup +) { + auto r2s_copy = make_tiled_copy_C( + Copy_Atom{}, + TiledMMA_QK{} + ); + ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_rPb = thr_copy.retile_S(rPb); + Tensor thr_copy_sP = thr_copy.partition_D(sP); + cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); +} + + +template< + bool IS_NO_SPLIT, + typename TMAParams, + typename Tensor0, + typename Tensor1, + typename Tensor2, + typename Tensor3 +> +static __forceinline__ __device__ void store_o( + Tensor0 &rO, // ((2, 2, 32), 1, 1) + Tensor1 &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) + Tensor2 &sOutputBuf, + Tensor3 &sOutputAccumBuf, + SharedMemoryPlan &plan, + float o_scales[2], + TMAParams &tma_params, + int batch_idx, + int s_q_idx, + int head_block_idx, + int num_valid_seq_q, + int warpgroup_idx, + int idx_in_warpgroup +) { + using cutlass::arch::NamedBarrier; + if constexpr (IS_NO_SPLIT) { + // Should convert the output to bfloat16 / float16, and save it to O + // Here we don't pipeline STSM and tma store because it's slower + Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); + + // Calculate "base" ptrs in advance + // Each STSM fills a chunk of shape 16x16, while we are using SW-OBUF_SW, so we need OBUF_SW/16 base pointers + constexpr int NUM_CHUNKS_IN_SW_ATOM = OBUF_SW/16; + bf16* base_output_buf_ptrs[NUM_CHUNKS_IN_SW_ATOM]; + CUTE_UNROLL + for (int i = 0; i < NUM_CHUNKS_IN_SW_ATOM; ++i) { + base_output_buf_ptrs[i] = &sMyOutputBuf((idx_in_warpgroup/32)*16+idx_in_warpgroup%16, idx_in_warpgroup%32/16*8 + i*16); + } + + CUTE_UNROLL + for (int idx = 0; idx < (HEAD_DIM_V/2)/16; idx += 1) { + // In each iteration we deal with a chunk of shape 16x16 + using bf16x2 = __nv_bfloat162; + bf16x2 a01 = __float22bfloat162_rn(float2{rO(idx*8+0)*o_scales[0], rO(idx*8+1)*o_scales[0]}); + bf16x2 a23 = __float22bfloat162_rn(float2{rO(idx*8+2)*o_scales[1], rO(idx*8+3)*o_scales[1]}); + bf16x2 a45 = __float22bfloat162_rn(float2{rO(idx*8+4)*o_scales[0], rO(idx*8+5)*o_scales[0]}); + bf16x2 a67 = __float22bfloat162_rn(float2{rO(idx*8+6)*o_scales[1], rO(idx*8+7)*o_scales[1]}); + SM90_U32x4_STSM_N::copy( + *reinterpret_cast(&a01), + *reinterpret_cast(&a23), + *reinterpret_cast(&a45), + *reinterpret_cast(&a67), + *reinterpret_cast(base_output_buf_ptrs[idx%4] + (idx/4*4)*16*64) + ); + } + + cutlass::arch::fence_view_async_shared(); + NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); + + if (threadIdx.x == 0) { + SM90_TMA_STORE_5D::copy( + &tma_params.tensor_map_o, + plan.u.oBuf.data(), + 0, head_block_idx*64, 0, + s_q_idx, batch_idx + ); + cute::tma_store_arrive(); + } + } else { + // Should save the result to OAccum + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); idx += 2) { + int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); + int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; + *(float2*)(&(sOutputAccumBuf(row, col))) = float2 { + rO(idx) * o_scales[idx%4>=2], + rO(idx+1) * o_scales[idx%4>=2], + }; + } + cutlass::arch::fence_view_async_shared(); + + NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready); + + if (elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) { + int row = local_row * (256/32) + (threadIdx.x / 32); + if (row < num_valid_seq_q) { + SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float)); + } + } + cute::tma_store_arrive(); + } + } +} + + +template +static __device__ __forceinline__ void +devfunc(const SparseAttnDecodeParams ¶ms, const TMAParams &tma_params); + +static void run(const SparseAttnDecodeParams ¶ms); + +}; + +} diff --git a/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu b/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu new file mode 100644 index 0000000..af5058f --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu @@ -0,0 +1,7 @@ +#include "../splitkv_mla.cuh" + +namespace sm90::decode::sparse_fp8 { + +template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); + +} diff --git a/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu b/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu new file mode 100644 index 0000000..902a591 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu @@ -0,0 +1,8 @@ +#include "../splitkv_mla.cuh" + +namespace sm90::decode::sparse_fp8 { + +template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); + +} + diff --git a/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu b/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu new file mode 100644 index 0000000..9727642 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu @@ -0,0 +1,7 @@ +#include "../splitkv_mla.cuh" + +namespace sm90::decode::sparse_fp8 { + +template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); + +} diff --git a/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu b/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu new file mode 100644 index 0000000..f7a3f19 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu @@ -0,0 +1,7 @@ +#include "../splitkv_mla.cuh" + +namespace sm90::decode::sparse_fp8 { + +template void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); + +} diff --git a/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu deleted file mode 100644 index 3283413..0000000 --- a/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu +++ /dev/null @@ -1,614 +0,0 @@ -#include "splitkv_mla.h" - -#include -#include -#include -#include - -#include "utils.h" -#include "components/config.h" -#include "components/epilogue.h" -#include "components/helpers.h" -#include "components/named_barriers.h" -#include "components/dequant.h" -using namespace cute; - -namespace sm90 { - -static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan -using cutlass::arch::fence_view_async_shared; -using cutlass::arch::NamedBarrier; - -// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction -template< - typename Tensor0, - typename Tensor1 -> -__forceinline__ __device__ void save_rPb_to_sP( - Tensor0 const &rPb, - Tensor1 const &sP, - int idx_in_warpgroup -) { - auto r2s_copy = make_tiled_copy_C( - Copy_Atom{}, - TiledMMA_QK{} - ); - ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); - Tensor thr_copy_rPb = thr_copy.retile_S(rPb); - Tensor thr_copy_sP = thr_copy.partition_D(sP); - cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); -} - - -// Retrieve rPb (64x64, bfloat16) from sP using the ldmatrix instruction -template< - typename Tensor0, - typename Tensor1 -> -__forceinline__ __device__ void retrieve_rP_from_sP( - Tensor0 &rPb, - Tensor1 const &sP, - int idx_in_warpgroup -) { - TiledCopy s2r_copy = make_tiled_copy_A( - Copy_Atom{}, - TiledMMA_PV_LocalP{} - ); - ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup); - Tensor thr_copy_sP = thr_copy.partition_S(sP); - Tensor thr_copy_rPb = thr_copy.retile_D(rPb); - cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb); -} - - -template< - typename Tensor0, - typename Tensor1, - typename Tensor2 -> -__forceinline__ __device__ void scale_softmax( - Tensor0 &rP, - Tensor1 &rS, - Tensor2 &rO, - float scale_softmax_log2, - float sScale[], - float rM[2], - float rL[2], - bool is_kv_valid[], - int block_idx, - int idx_in_warpgroup -) { - float scale_for_olds[2]; - CUTE_UNROLL - for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { - Tensor cur_rP = flatten(rP(make_coord(_, local_row_idx, _), _, _)); - Tensor cur_rS = flatten(rS(make_coord(_, local_row_idx, _), _, _)); - Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); - - float cur_max = -INFINITY; - CUTE_UNROLL - for (int i = 0; i < size(cur_rP); ++i) { - if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2]) - cur_rP(i) = -INFINITY; - cur_max = max(cur_max, cur_rP(i)); - } - cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); - cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); - - cur_max *= scale_softmax_log2; - float old_max = rM[local_row_idx]; - rM[local_row_idx] = max(cur_max, old_max); - float scale_for_old = exp2f(old_max - rM[local_row_idx]); - scale_for_olds[local_row_idx] = scale_for_old; - - CUTE_UNROLL - for (int i = 0; i < size(cur_rO); ++i) { - cur_rO(i) *= scale_for_old; - } - - float cur_sum = 0; - CUTE_UNROLL - for (int i = 0; i < size(cur_rP); ++i) { - cur_rP(i) = exp2f(cur_rP(i)*scale_softmax_log2 - rM[local_row_idx]); - cur_rS(i) = (bf16)cur_rP(i); - cur_sum += cur_rP(i); - } - rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; - } - if (idx_in_warpgroup%4 == 0) - *(float2*)(sScale + 2*(idx_in_warpgroup/4)) = *(float2*)(scale_for_olds); -} - -template -__global__ void __launch_bounds__(NUM_THREADS, 1, 2) -flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) { -#if IS_SM90 - const int head_block_idx = blockIdx.x; - const int s_q_idx = blockIdx.y; - const int partition_idx = blockIdx.z; - const int idx_in_cluster = head_block_idx % 2; - const int warpgroup_idx = cutlass::canonical_warp_group_idx(); - const int idx_in_warpgroup = threadIdx.x % 128; - const int warp_idx = cutlass::canonical_warp_idx_sync(); - - // Define shared tensors - extern __shared__ char wksp_buf[]; - SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); - Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{}); - Tensor sOBuf = make_tensor(make_smem_ptr(plan.u.oBuf.data()), SmemLayoutOBuf{}); - Tensor sOAccumBuf = make_tensor(make_smem_ptr(plan.u.oAccumBuf.data()), SmemLayoutOAccumBuf{}); - Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); - float* sM = plan.sM; - float* sL = plan.sL; - float* sScale = plan.sScale; - - // Prefetch TMA descriptors - if (warp_idx == 0 && elect_one_sync()) { - cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); - cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); - } - - // Initialize TMA barriers - if (warp_idx == 0 && elect_one_sync()) { - plan.bar_q.init(1); - CUTE_UNROLL - for (int i = 0; i < NUM_K_BUFS; ++i) { - plan.bar_k_local_ready[i].init(128); - plan.bar_k_remote_ready[i].init(1); - plan.bar_k_avail[i].init(4); - } - fence_view_async_shared(); - } - cute::cluster_arrive(); - - bool bar_phase_q = 0; - int bar_phase_k = 0; // Don't use array here to prevent using local memory - - // Programmatic Dependent Launch: Wait for the previous kernel to finish - // Don't use PDL because of compiler bugs! - // cudaGridDependencySynchronize(); - - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; - int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); - int begin_idx = tile_scheduler_metadata.x; - int sched_begin_block_idx = tile_scheduler_metadata.y; - int end_idx = tile_scheduler_metadata.z; - int sched_end_block_idx = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; - int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); - - if (warp_idx == 0 && elect_one_sync()) { - Tensor gQ = flat_divide( - tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, begin_idx), - Tile, Int>{} - )(_, _, head_block_idx, _0{}); - launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); - plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); - } - - cute::cluster_wait(); // Wait for barriers from the other CTA to be ready - - auto get_cur_req_info = [&](int batch_idx) -> std::tuple { - constexpr int kBlockN = TOPK_BLOCK_SIZE; - const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0; - // NOTE TopK attention has nothing to do with causal mask and sliding window - int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(params.topk, kBlockN); - const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(params.topk, kBlockN); - return {start_block_idx, end_block_idx, is_no_split}; - }; - - if (warpgroup_idx == 0) { - cutlass::arch::warpgroup_reg_alloc<192>(); - - TiledMMA tiled_mma_QK = TiledMMA_QK{}; - ThrMMA thr_mma_QK = tiled_mma_QK.get_slice(idx_in_warpgroup); - TiledMMA tiled_mma_PV = TiledMMA_PV_LocalP{}; - ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); - - float rL[2], rM[2]; - Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); - Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); - Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); - - #pragma unroll 1 - for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { - auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); - - rL[0] = rL[1] = 0.0f; - rM[0] = rM[1] = MAX_INIT_VAL; - cute::fill(rO, 0.); - - // Wait for Q - plan.bar_q.wait(bar_phase_q); - bar_phase_q ^= 1; - - CUTE_NO_UNROLL - for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { - int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS; - Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutHalfV{}); - - // Wait, issue WGMMA - plan.bar_k_local_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); - plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); - - gemm( - tiled_mma_QK, - thr_mma_QK.partition_fragment_A(sQ), - thr_mma_QK.partition_fragment_B(sK), - rP - ); - - bar_phase_k ^= 1<(); - - // Calculate S = softmax(mask(scale(P))) - if (block_idx != start_block_idx) - NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_free); // Make sure that sScale and sS is free - - // Since in our case TOPK_BLOCK_SIZE == BLOCK_M, so we only need to do OOB checking for the last 2 blocks - scale_softmax(rP, rS, rO, params.scale_softmax_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup); - - // Store S into shared, inform warpgroup 1 - save_rPb_to_sP(rS, sS, idx_in_warpgroup); - fence_view_async_shared(); - - // Issue O += S @ V - gemm( - tiled_mma_PV, - rS, - thr_mma_PV.partition_fragment_B(sV), - rO - ); - - NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_ready); - - cute::warpgroup_wait<0>(); - - plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); - plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); - } - - // Copy the next q - if (warp_idx == 0 && elect_one_sync()) { - if (batch_idx != end_idx) { - Tensor gQ = flat_divide( - tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx+1), - Tile, Int>{} - )(_, _, head_block_idx, _0{}); - launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); - plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); - } else { - cudaTriggerProgrammaticLaunchCompletion(); - } - } - - // Synchronize L and M across warpgroups - rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); - rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); - rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); - rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); - if (idx_in_warpgroup%4 == 0) { - CUTE_UNROLL - for (int i = 0; i < 2; ++i) { - int row = get_AorC_row_idx(i, idx_in_warpgroup); - sL[row] = rL[i]; - sM[row] = rM[i]; - } - } - - // This is a synchronization point for warpgroup 0/1. - // Warpgroup 0 should wait wg 1 for oBuf/oAccumBuf (overlapped with k) to be free - // Warpgroup 1 should wait wg 0 for sL to be ready - NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); - - CUTE_UNROLL - for (int i = 0; i < 2; ++i) - rL[i] = rL[i] == 0.0f ? 1.0f : rL[i]; - - int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M); - int start_seq_idx = s_q_idx*params.q_head_per_hk + head_block_idx*BLOCK_M; - if (is_no_split) { - bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1) - Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( - Shape, Int>{}, - make_stride(params.o_row_stride, _1{}) - )); - float* gSoftmaxLse = (float*)params.softmax_lse_ptr + batch_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1) - - store_o(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); - - int i = threadIdx.x; - if (i < num_valid_seq_q) { - float cur_L = sL[i]; - gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / (float)M_LOG2E; - } - - cute::tma_store_wait<0>(); - } else { - int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; - int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; - float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) - float* gSoftmaxLseAccum = (float*)params.softmax_lseaccum_ptr + split_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1) - Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< - Shape, Int>, - Stride, _1> - >{}); - store_o(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); - - int i = threadIdx.x; - if (i < num_valid_seq_q) { - float cur_L = sL[i]; - gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i]; - } - - cute::tma_store_wait<0>(); - } - - cute::cluster_sync(); // Must use arrive_and_wait here to prevent overwritting sL while WG1 is writing back its result - } - } else if (warpgroup_idx == 1) { - cutlass::arch::warpgroup_reg_dealloc<160>(); - - TiledMMA tiled_mma_PV = TiledMMA_PV_RemoteP{}; - ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); - Tensor rO = partition_fragment_C(tiled_mma_PV, Shape, Int>{}); - float rL[2]; - - #pragma unroll 1 - for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { - auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); - cute::fill(rO, 0.); - - CUTE_NO_UNROLL - for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { - int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS; - Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data() + (SmemLayoutV{})(_256{}, _0{})), SmemLayoutHalfV{}); - - // Wait for S and sScale - NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready); - - // Scale O - float cur_scales[2]; - *(float2*)cur_scales = *(float2*)(sScale + (idx_in_warpgroup/4)*2); - CUTE_UNROLL - for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { - Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); - CUTE_UNROLL - for (int i = 0; i < size(cur_rO); ++i) { - cur_rO(i) *= cur_scales[local_row_idx]; - } - } - - // Issue O += S @ V, and wait - gemm( - tiled_mma_PV, - thr_mma_PV.partition_fragment_A(sS), - thr_mma_PV.partition_fragment_B(sV), - rO - ); - cute::warpgroup_wait<0>(); - - plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); - plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); - - if (block_idx != end_block_idx-1) - NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_free); // Tell WG0 that sScale and sS are available - } - - NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); - CUTE_UNROLL - for (int i = 0; i < 2; ++i) { - int row = get_AorC_row_idx(i, idx_in_warpgroup); - rL[i] = sL[row]; - } - - CUTE_UNROLL - for (int i = 0; i < 2; ++i) - rL[i] = rL[i] == 0.0f ? 1.0f : rL[i]; - - int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M); - int start_seq_idx = s_q_idx*params.q_head_per_hk+head_block_idx*BLOCK_M; - if (is_no_split) { - bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1) - Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( - Shape, Int>{}, - make_stride(params.o_row_stride, _1{}) - )); - - store_o(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); - - cute::tma_store_wait<0>(); - } else { - int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; - int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; - float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) - Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< - Shape, Int>, - Stride, _1> - >{}); - store_o(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); - - cute::tma_store_wait<0>(); - } - - cute::cluster_sync(); // We must use arrive_and_wait instead of arrive here to create an order between "forall warp in WG1, warp has done written back O" and "warp 2 signals `bar_k_avail`" - } - } else { - // Producer warpgroup - cutlass::arch::warpgroup_reg_dealloc<152>(); - - int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0); // NOTE TPBNO - int lane_idx = idx_in_warpgroup % 32; - int my_token_idx = warp_idx*8 + lane_idx%8; - - CUTE_NO_UNROLL - for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { - auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx); - int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1) - - #define GET_TOKEN_INDEX(block_idx) __ldg(gIndices + (block_idx)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx) - int nxt_token_index = GET_TOKEN_INDEX(start_block_idx); - - CUTE_NO_UNROLL - for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) { - int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS; - - // Define shared and global tensors - bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE; - bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base); - - transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx])); - int token_index = nxt_token_index; - if (block_idx+1 != end_block_idx) - nxt_token_index = GET_TOKEN_INDEX(block_idx+1); - int block_index = token_index/PAGE_BLOCK_SIZE; - int rel_idx_in_block = (token_index+PAGE_BLOCK_SIZE) % PAGE_BLOCK_SIZE; // NOTE When token_index is -1, -1/PAGE_BLOCK_SIZE = 0 and (-1+PAGE_BLOCK_SIZE)%PAGE_BLOCK_SIZE = 63, so there will be no illegal-memory-access error - fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride; - float4 scales = load_128b_from_gmem((float*)(gK_base+HEAD_DIM_NOPE)); - - // Wait for the nope buffer to be available - plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1); - bar_phase_k ^= 1 << buf_idx; - - // Copy block #block_index - if (idx_in_warpgroup == 0) { - plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16)); - } - - // Collectively copy from global memory and dequant - // For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py - - fp8* gK_nope = gK_base + (lane_idx/8)*16; - if (token_index == -1) { - scales = {0.0f, 0.0f, 0.0f, 0.0f}; - } - CUTE_UNROLL - for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) { - fp8x16 cur_fp8x16 = load_128b_from_gmem(gK_nope + dim_idx*64); // We use EVICT_LAST here since gK_base may not be aligned to 32B - float scale = dim_idx < 4 ? (dim_idx < 2 ? scales.x : scales.y) : (dim_idx < 6 ? scales.z : scales.w); - auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) { - int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE; - bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, scale); - *(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; - st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); - }; - if (token_index == -1) - *(uint128_t*)(&cur_fp8x16) = uint128_t(); - dequant_and_save_bf16x8(cur_fp8x16.lo, 0); - dequant_and_save_bf16x8(cur_fp8x16.hi, 8); - } - - bf16* gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8; - bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE; - bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base); - - CUTE_UNROLL - for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) { - bf16x8 cur_bf16x8 = load_128b_from_gmem(gK_rope + dim_idx*32); - if (token_index == -1) - *(uint128_t*)(&cur_bf16x8) = uint128_t(); - int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE; - *(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; - st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); - } - - fence_view_async_shared(); - - if (idx_in_warpgroup < 32) { - // We put this after fence_view_async_shared() since this won't be read by async proxy - int2 indices = __ldg((int2*)(gIndices + block_idx*TOPK_BLOCK_SIZE + lane_idx*2)); - *(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = {indices.x != -1, indices.y != -1}; - } - - // Signal the barrier - plan.bar_k_local_ready[buf_idx].arrive(); - } - - cute::cluster_sync(); - } - } - - if (begin_idx > end_idx) { - cute::cluster_sync(); // Don't need a cluster_sync() when begin_idx <= end_idx, since the loop will execute at least once and the final statement is cluster_sync() - } -#else - if (cute::thread0()) { - CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); - } -#endif - -} - - -void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.h_k == 1); - FLASH_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0); - - auto shape_Q = make_shape(params.q_head_per_hk, params.d, params.s_q, params.b); - auto tma_Q = cute::make_tma_copy( - SM90_TMA_LOAD{}, - make_tensor( - make_gmem_ptr((bf16*)params.q_ptr), - make_layout( - shape_Q, - make_stride(params.q_row_stride, _1{}, params.q_head_per_hk*params.q_row_stride, params.q_batch_stride) - ) - ), - SmemLayoutQ{} - ); - - auto shape_O = make_shape(params.q_head_per_hk, params.d_v, params.s_q, params.b); - auto tma_O = cute::make_tma_copy( - SM90_TMA_STORE{}, - make_tensor( - make_gmem_ptr((bf16*)params.o_ptr), - make_layout( - shape_O, - make_stride(params.o_row_stride, _1{}, params.q_head_per_hk*params.o_row_stride, params.o_batch_stride) - ) - ), - SmemLayoutOBuf{} - ); - - TmaParams< - decltype(shape_Q), decltype(tma_Q), - decltype(shape_O), decltype(tma_O) - > tma_params = { - shape_Q, tma_Q, - shape_O, tma_O - }; - auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel; - - constexpr size_t smem_size = sizeof(SharedMemoryPlan); - CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - const int num_m_block = cute::ceil_div(params.q_head_per_hk, 2*BLOCK_M) * 2; - // NOTE Don't use PDL because of potential compiler bugs! - // cudaLaunchAttribute mla_kernel_attributes[1]; - // mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - // mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; - // cudaLaunchConfig_t mla_kernel_config = { - // dim3(num_m_block, params.h_k, params.num_sm_parts), - // dim3(NUM_THREADS, 1, 1), - // smem_size, - // stream, - // mla_kernel_attributes, - // 1 - // }; - // cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); - cutlass::ClusterLaunchParams launch_params = { - dim3(num_m_block, params.s_q, params.num_sm_parts), - dim3(NUM_THREADS, 1, 1), - dim3(2, 1, 1), - smem_size, - stream - }; - cutlass::launch_kernel_on_cluster( - launch_params, (void*)mla_kernel, params, tma_params - ); - CHECK_CUDA_KERNEL_LAUNCH(); -} - -} diff --git a/csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh new file mode 100644 index 0000000..9994568 --- /dev/null +++ b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh @@ -0,0 +1,787 @@ +#pragma once + +#include "splitkv_mla.h" + +#include +#include +#include +#include +#include +#include + +#include + +#include "utils.h" +#include "components/dequant.h" +#include "components/helpers.h" +#include "config.h" +using namespace cute; + +namespace sm90::decode::sparse_fp8 { + +static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan +using cutlass::arch::fence_view_async_shared; +using cutlass::arch::NamedBarrier; +using fp8_e8m0 = __nv_fp8_e8m0; + +template< + typename Tensor0, + typename Tensor1, + typename Tensor2 +> +__forceinline__ __device__ void scale_softmax( + Tensor0 &rP, + Tensor1 &rS, + Tensor2 &rO, + float scale_softmax_log2, + float sScale[], + float rM[2], + float rL[2], + bool is_kv_valid[], + int block_idx, + int idx_in_warpgroup +) { + float scale_for_olds[2]; + CUTE_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + Tensor cur_rP = flatten(rP(make_coord(_, local_row_idx, _), _, _)); + Tensor cur_rS = flatten(rS(make_coord(_, local_row_idx, _), _, _)); + Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); + + float cur_max = -INFINITY; + CUTE_UNROLL + for (int i = 0; i < size(cur_rP); ++i) { + if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2]) + cur_rP(i) = -INFINITY; + cur_max = max(cur_max, cur_rP(i)); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + + cur_max *= scale_softmax_log2; + float old_max = rM[local_row_idx]; + rM[local_row_idx] = max(cur_max, old_max); + float scale_for_old = exp2f(old_max - rM[local_row_idx]); + scale_for_olds[local_row_idx] = scale_for_old; + + CUTE_UNROLL + for (int i = 0; i < size(cur_rO); ++i) { + cur_rO(i) *= scale_for_old; + } + + float cur_sum = 0; + CUTE_UNROLL + for (int i = 0; i < size(cur_rP); ++i) { + cur_rP(i) = exp2f(cur_rP(i)*scale_softmax_log2 - rM[local_row_idx]); + cur_rS(i) = (bf16)cur_rP(i); + cur_sum += cur_rP(i); + } + + rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; + } + if (idx_in_warpgroup%4 == 0) + *(float2*)(sScale + 2*(idx_in_warpgroup/4)) = *(float2*)(scale_for_olds); +} + +template +template +__device__ void KernelTemplate::devfunc(const SparseAttnDecodeParams ¶ms, const TMAParams &tma_params) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) + const int head_block_idx = NUM_M_BLOCKS == 1 ? 0 : blockIdx.x; + const int s_q_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int idx_in_cluster = CLUSTER_SIZE == 1 ? 0 : head_block_idx % 2; + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int idx_in_warpgroup = threadIdx.x % 128; + const int warp_idx = cutlass::canonical_warp_idx_sync(); + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{}); + Tensor sOBuf = make_tensor(make_smem_ptr(plan.u.oBuf.data()), SmemLayoutOBuf{}); + Tensor sOAccumBuf = make_tensor(make_smem_ptr(plan.u.oAccumBuf.data()), SmemLayoutOAccumBuf{}); + Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); + float* sM = plan.sM; + float* sL = plan.sL; + float* sScale = plan.sScale; + + // Prefetch TMA descriptors + if (warp_idx == 0 && elect_one_sync()) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(&tma_params.tensor_map_o); + } + + // Initialize TMA barriers + if (warp_idx == 0 && elect_one_sync()) { + plan.bar_q.init(1); + if constexpr (CLUSTER_SIZE == 2) { + CUTE_UNROLL + for (int i = 0; i < NUM_K_BUFS; ++i) { + plan.bar_k_local_ready[i].init(128); + plan.bar_k_remote_ready[i].init(1); + plan.bar_k_avail[i].init(4); + } + } else { + CUTE_UNROLL + for (int i = 0; i < NUM_K_BUFS; ++i) { + plan.bar_k_local_ready[i].init(128); + plan.bar_k_avail[i].init(256); + } + } + cutlass::arch::fence_barrier_init(); + } + ku::barrier_cluster_arrive_relaxed(); + + int bar_phase_k = 0; // Don't use array here to prevent using local memory + + // Programmatic Dependent Launch: Wait for the previous kernel to finish + // Don't use PDL because of compiler bugs! + // cudaGridDependencySynchronize(); + + DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx]; + + if (sched_meta.begin_req_idx >= params.b) return; + + if (warp_idx == 0 && elect_one_sync()) { + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, sched_meta.begin_req_idx), + Tile, Int>{} + )(_, _, head_block_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); + } + + ku::barrier_cluster_wait_acquire(); + + struct MainloopArgs { + int start_block_idx, end_block_idx; + bool is_no_split; + + // The following fields are only valid for MODEL1 + int topk_length, extra_topk_length, num_orig_kv_blocks; + }; + auto get_cur_req_info = [&](int batch_idx) -> MainloopArgs { + MainloopArgs args; + int total_topk_padded; + if constexpr (MODEL_TYPE == ModelType::V32) { + total_topk_padded = params.topk; + } else { + int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk; + int orig_topk_padded = max(ku::ceil(topk_length, (int)TOPK_BLOCK_SIZE), (int)TOPK_BLOCK_SIZE); + int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk; + total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)TOPK_BLOCK_SIZE); + args.topk_length = topk_length; + args.extra_topk_length = extra_topk_length; + args.num_orig_kv_blocks = orig_topk_padded / TOPK_BLOCK_SIZE; + } + + args.start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; + args.end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / TOPK_BLOCK_SIZE; + args.is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true); + + return args; + }; + + if (warpgroup_idx == 0) { + cutlass::arch::warpgroup_reg_alloc<192>(); + + TiledMMA tiled_mma_QK = TiledMMA_QK{}; + ThrMMA thr_mma_QK = tiled_mma_QK.get_slice(idx_in_warpgroup); + TiledMMA tiled_mma_PV = TiledMMA_PV_LocalP{}; + ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); + + float rL[2], rM[2]; + Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); + Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); + Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); + + float rAttn_sink[2] = {-CUDART_INF_F, -CUDART_INF_F}; + if (params.attn_sink != nullptr) { + for (int i = 0; i < 2; ++i) { + int head_idx = head_block_idx*BLOCK_M + get_AorC_row_idx(i, idx_in_warpgroup); + rAttn_sink[i] = __ldg((float*)params.attn_sink + head_idx) * CUDART_L2E_F; + } + } + + #pragma unroll 1 + for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { + MainloopArgs args = get_cur_req_info(batch_idx); + + rL[0] = rL[1] = 0.0f; + rM[0] = rM[1] = MAX_INIT_VAL; + cute::fill(rO, 0.); + + // Wait for Q + plan.bar_q.wait((sched_meta.begin_req_idx-batch_idx)&1); + + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) { + int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS; + Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutHalfV{}); + + // Wait, issue WGMMA + plan.bar_k_local_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + if constexpr (CLUSTER_SIZE == 2) { + plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1); + } + + gemm( + tiled_mma_QK, + thr_mma_QK.partition_fragment_A(sQ), + thr_mma_QK.partition_fragment_B(sK), + rP + ); + + bar_phase_k ^= 1<(); + + // Calculate S = softmax(mask(scale(P))) + if (block_idx != args.start_block_idx) + NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_free); // Make sure that sScale and sS is free + + // Since in our case TOPK_BLOCK_SIZE == BLOCK_M, so we only need to do OOB checking for the last 2 blocks + scale_softmax(rP, rS, rO, params.sm_scale_div_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup); + + // Store S into shared, inform warpgroup 1 + save_rPb_to_sP(rS, sS, idx_in_warpgroup); + fence_view_async_shared(); + + // Issue O += S @ V + gemm( + tiled_mma_PV, + rS, + thr_mma_PV.partition_fragment_B(sV), + rO + ); + + NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_ready); + + cute::warpgroup_wait<0>(); + + if constexpr (CLUSTER_SIZE == 2) { + plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); + plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); + } else { + plan.bar_k_avail[buf_idx].arrive(); + } + } + + // Copy the next q + if (threadIdx.x/32 == 0 && elect_one_sync()) { + if (batch_idx != sched_meta.end_req_idx) { + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx+1), + Tile, Int>{} + )(_, _, head_block_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16)); + } else { + // This kernel is followed by the combine kernel, so we signal PDL here + cudaTriggerProgrammaticLaunchCompletion(); + } + } + + // Synchronize L and M across warpgroups + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + + if (idx_in_warpgroup%4 == 0) { + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + int row = get_AorC_row_idx(i, idx_in_warpgroup); + sL[row] = rL[i]; + sM[row] = rM[i]; + } + } + + float o_scales[2]; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + if (args.is_no_split) { + o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i] + exp2f(rAttn_sink[i] - rM[i])); + } else { + o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i]); + } + if (idx_in_warpgroup%4 == 0) { + int row = get_AorC_row_idx(i, idx_in_warpgroup); + plan.sOScale[row] = o_scales[i]; + } + } + + // This is a synchronization point for warpgroup 0/1. + // Warpgroup 0 should wait wg 1 for oBuf/oAccumBuf (overlapped with k) to be free + // Warpgroup 1 should wait wg 0 for sL to be ready + NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); + + CUTE_UNROLL + for (int i = 0; i < 2; ++i) + rL[i] = rL[i] == 0.0f ? 1.0f : rL[i]; + + int start_head_idx = head_block_idx*BLOCK_M; + int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M); + if (args.is_no_split) { + bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1) + Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( + Shape, Int>{}, + make_stride(params.stride_o_h_q, _1{}) + )); + float* gSoftmaxLse = (float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + start_head_idx; // (BLOCK_M) : (1) + + store_o(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL[i]; + gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / (float)M_LOG2E; + } + + cute::tma_store_wait<0>(); + } else { + int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0; + int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1) + float* gSoftmaxLseAccum = (float*)params.lse_accum + split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + start_head_idx; // (BLOCK_M) : (1) + Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout( + Shape, Int>{}, + make_stride(params.stride_o_accum_h_q, _1{}) + )); + store_o(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL[i]; + gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i]; + } + + cute::tma_store_wait<0>(); + } + + sync_all_threads_in_cluster(); + } + } else if (warpgroup_idx == 1) { + cutlass::arch::warpgroup_reg_dealloc<160>(); + + TiledMMA tiled_mma_PV = TiledMMA_PV_RemoteP{}; + ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup); + Tensor rO = partition_fragment_C(tiled_mma_PV, Shape, Int>{}); + + #pragma unroll 1 + for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { + MainloopArgs args = get_cur_req_info(batch_idx); + cute::fill(rO, 0.); + + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) { + int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS; + Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data() + (SmemLayoutV{})(_256{}, _0{})), SmemLayoutHalfV{}); + + // Wait for S and sScale + NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready); + + // Scale O + float cur_scales[2]; + *(float2*)cur_scales = *(float2*)(sScale + (idx_in_warpgroup/4)*2); + CUTE_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _)); + CUTE_UNROLL + for (int i = 0; i < size(cur_rO); ++i) { + cur_rO(i) *= cur_scales[local_row_idx]; + } + } + + // Issue O += S @ V, and wait + gemm( + tiled_mma_PV, + thr_mma_PV.partition_fragment_A(sS), + thr_mma_PV.partition_fragment_B(sV), + rO + ); + cute::warpgroup_wait<0>(); + + if constexpr (CLUSTER_SIZE == 2) { + plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32); + plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64); + } else { + plan.bar_k_avail[buf_idx].arrive(); + } + + if (block_idx != args.end_block_idx-1) + NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_free); // Tell WG0 that sScale and sS are available + } + + NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready); + + float o_scales[2]; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + int row = get_AorC_row_idx(i, idx_in_warpgroup); + o_scales[i] = plan.sOScale[row]; + } + + int start_head_idx = head_block_idx*BLOCK_M; + int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M); + if (args.is_no_split) { + bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1) + Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( + Shape, Int>{}, + make_stride(params.stride_o_h_q, _1{}) + )); + + store_o(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + cute::tma_store_wait<0>(); + } else { + int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0; + int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx; + float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1) + Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout( + Shape, Int>{}, + make_stride(params.stride_o_accum_h_q, _1{}) + )); + store_o(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + cute::tma_store_wait<0>(); + } + + sync_all_threads_in_cluster(); + } + } else { + // Producer warpgroup + cutlass::arch::warpgroup_reg_dealloc<152>(); + + static_assert(CLUSTER_SIZE == 1 || CLUSTER_SIZE == 2); + static constexpr int NUM_TOKENS_PER_THREAD = CLUSTER_SIZE == 1 ? 2 : 1; + static constexpr int NUM_TOKENS_PER_ROUND = 32; // If head is 128, each CTA is responsible for dequantizing 32 tokens (1 rounds); if head is 64, each CTA is responsible for dequantizing 64 tokens (2 rounds) + int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0); + int lane_idx = idx_in_warpgroup % 32; + int my_token_idx_base = warp_idx*8 + lane_idx%8; + + CUTE_NO_UNROLL + for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { + MainloopArgs args = get_cur_req_info(batch_idx); + int* gIndices = params.indices + batch_idx*params.stride_indices_b + s_q_idx*params.stride_indices_s_q; // (topk) : (1) + int* gExtraIndices = params.extra_indices + batch_idx*params.stride_extra_indices_b + s_q_idx*params.stride_extra_indices_s_q; // (extra_topk) : (1) + + int nxt_token_indexs[NUM_TOKENS_PER_THREAD]; + CUTE_UNROLL + for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) { + if (MODEL_TYPE == ModelType::V32 || args.start_block_idx < args.num_orig_kv_blocks) + nxt_token_indexs[round] = __ldg(gIndices + args.start_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + round*NUM_TOKENS_PER_ROUND + my_token_idx_base); + } + + struct IsOrigBlock {}; + struct IsExtraBlock {}; + + struct IsFirstExtraBlock {}; + struct IsNotFirstExtraBlock {}; + auto process_one_block = [&](int block_idx, auto is_extra_block_t, auto is_first_extra_block_t) { + static constexpr bool IS_EXTRA_BLOCK = std::is_same_v; + static constexpr bool IS_FIRST_EXTRA_BLOCK = std::is_same_v; + int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS; + + int* indices_base; + int page_block_size; + int64_t k_block_stride, k_row_stride; + fp8* k_ptr; + if constexpr (!IS_EXTRA_BLOCK) { + indices_base = gIndices + (block_idx)*TOPK_BLOCK_SIZE; + page_block_size = params.page_block_size; + k_block_stride = params.stride_kv_block; + k_row_stride = params.stride_kv_row; + k_ptr = (fp8*)params.kv; + } else { + indices_base = gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE; + page_block_size = params.extra_page_block_size; + k_block_stride = params.stride_extra_kv_block; + k_row_stride = params.stride_extra_kv_row; + k_ptr = (fp8*)params.extra_kv; + } + [[maybe_unused]] int topk_length = IS_EXTRA_BLOCK ? args.extra_topk_length : args.topk_length; + [[maybe_unused]] int rel_block_idx = IS_EXTRA_BLOCK ? (block_idx - args.num_orig_kv_blocks) : block_idx; + transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx])); + + CUTE_UNROLL + for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) { + int my_token_idx = my_token_idx_base + round*NUM_TOKENS_PER_ROUND; + bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE; + bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base); + + // Get prefetched token index + int token_index; + if constexpr (!IS_EXTRA_BLOCK) { + token_index = nxt_token_indexs[round]; + if (block_idx+1 != (MODEL_TYPE == ModelType::V32 ? args.end_block_idx : args.num_orig_kv_blocks)) + nxt_token_indexs[round] = __ldg(gIndices + (block_idx+1)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx); + } else { + if constexpr (IS_FIRST_EXTRA_BLOCK) { + token_index = __ldg(gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx); + } else { + token_index = nxt_token_indexs[round]; + } + if (block_idx+1 != args.end_block_idx) + nxt_token_indexs[round] = __ldg(gExtraIndices + (block_idx+1-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx); + } + + if constexpr (MODEL_TYPE == ModelType::MODEL1) { + // For MODEL1, we need to check whether the token_index is within topk_length + if (rel_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx >= topk_length) { + token_index = -1; // To prevent IMA when we have invalid (e.g. INT_MAX) topk indexes outside topk_length + } + } + + int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance + int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error + + fp8* gK_base; + bf16 scales[NUM_SCALES]; + if constexpr (MODEL_TYPE == ModelType::V32) { + static_assert(NUM_SCALES == 4); + gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*k_row_stride; + float scales_float[NUM_SCALES]; + *(float4*)(scales_float) = load_128b_from_gmem((float*)(gK_base+HEAD_DIM_NOPE)); + CUTE_UNROLL + for (int i = 0; i < NUM_SCALES; ++i) { + scales[i] = (bf16)scales_float[i]; + } + } else { + static_assert(NUM_SCALES == 8); + gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*(HEAD_DIM_NOPE + HEAD_DIM_ROPE*sizeof(bf16)); + fp8_e8m0* gK_scales_base = (fp8_e8m0*)(k_ptr + block_index*k_block_stride + page_block_size*(HEAD_DIM_NOPE+HEAD_DIM_ROPE*sizeof(bf16)) + rel_idx_in_block*NUM_SCALES*sizeof(fp8_e8m0)); + fp8_e8m0 scales_e8m0[NUM_SCALES]; + *(int64_t*)scales_e8m0 = __ldg((int64_t*)gK_scales_base); + CUTE_UNROLL + for (int i = 0; i < NUM_SCALES; i += 2) { + *(__nv_bfloat162_raw*)(scales+i) = __nv_cvt_e8m0x2_to_bf162raw(*(__nv_fp8x2_storage_t*)(scales_e8m0+i)); + } + } + + // Wait for the nope buffer to be available + if (round == 0) { + plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1); + } + + if (CLUSTER_SIZE == 2 && round == 0 && idx_in_warpgroup == 0) { + plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16)); + } + + // Collectively copy from global memory and dequant + // For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py + + fp8* gK_nope = gK_base + (lane_idx/8)*16; + if (token_index == -1) { + CUTE_UNROLL + for (int i = 0; i < NUM_SCALES; ++i) + scales[i] = (bf16)0.0f; + } + CUTE_UNROLL + for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) { + fp8x16 cur_fp8x16 = load_128b_from_gmem(gK_nope + dim_idx*64); // We use EVICT_LAST here since gK_base may not be aligned to 32B (for V3.2) and the performance is the best among all cache hints (for MODEL1) + bf16 scale = scales[MODEL_TYPE == ModelType::V32 ? dim_idx/2 : dim_idx]; + auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) { + int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE; + bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, __bfloat162bfloat162(*(__nv_bfloat16*)(&scale))); + *(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; + if constexpr (CLUSTER_SIZE == 2) { + st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); + } + }; + if (token_index == -1) + *(uint128_t*)(&cur_fp8x16) = uint128_t(); + dequant_and_save_bf16x8(cur_fp8x16.lo, 0); + dequant_and_save_bf16x8(cur_fp8x16.hi, 8); + } + + bf16* gK_rope; + if constexpr (MODEL_TYPE == ModelType::V32) { + gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8; + } else { + gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE) + (lane_idx/8)*8; + } + bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE; + bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base); + + CUTE_UNROLL + for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) { + bf16x8 cur_bf16x8 = load_128b_from_gmem(gK_rope + dim_idx*32); + if constexpr (MODEL_TYPE == ModelType::V32) { + // NOTE We do not need to mask the RoPE part for V3.2 since it isn't involved in the SV gemm + } else { + if (token_index == -1) + *(uint128_t*)(&cur_bf16x8) = uint128_t(); + } + int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE; + *(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8; + if constexpr (CLUSTER_SIZE == 2) { + st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready); + } + } + } + + fence_view_async_shared(); + + if (idx_in_warpgroup < 32) { + // We put this after fence_view_async_shared() since this won't be read by async proxy + auto is_index_valid = [&](int index, int offset_within_thread) -> bool { + if constexpr (MODEL_TYPE == ModelType::V32) { + return index != -1; + } else { + return index != -1 && rel_block_idx*TOPK_BLOCK_SIZE + lane_idx*2 + offset_within_thread < topk_length; + } + }; + int2 indices = __ldg((int2*)(indices_base + lane_idx*2)); + *(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = { + is_index_valid(indices.x, 0), + is_index_valid(indices.y, 1) + }; + } + + // Signal the barrier + plan.bar_k_local_ready[buf_idx].arrive(); + bar_phase_k ^= 1 << buf_idx; + }; + + if constexpr (MODEL_TYPE == ModelType::V32) { + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) { + process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{}); + } + } else { + CUTE_NO_UNROLL + for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) { + process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{}); + } + + if (args.num_orig_kv_blocks < args.end_block_idx) { + process_one_block(max(args.start_block_idx, args.num_orig_kv_blocks), IsExtraBlock{}, IsFirstExtraBlock{}); + } + CUTE_NO_UNROLL + for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks)+1; block_idx < args.end_block_idx; ++block_idx) { + process_one_block(block_idx, IsExtraBlock{}, IsNotFirstExtraBlock{}); + } + } + + sync_all_threads_in_cluster(); + } + } +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); + } +#endif + +} + +template +__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, Kernel::CLUSTER_SIZE) +flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TMAParams tma_params) { + Kernel::devfunc(params, tma_params); +} + +template +void KernelTemplate::run(const SparseAttnDecodeParams ¶ms) { + KU_ASSERT(params.h_kv == 1); + KU_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0); + KU_ASSERT(params.d_qk == HEAD_DIM_K); + KU_ASSERT(params.d_v == HEAD_DIM_V); + KU_ASSERT(params.h_q % BLOCK_M == 0); + if constexpr (MODEL_TYPE == ModelType::MODEL1) { + constexpr int BYTES_PER_TOKEN = HEAD_DIM_NOPE + 2*HEAD_DIM_ROPE + 8; + KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, "Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous + if (params.extra_kv != nullptr) { + KU_ASSERT(params.stride_extra_kv_row == BYTES_PER_TOKEN, "Each page block in extra KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous + } + } else { + KU_ASSERT(params.extra_kv == nullptr, "V3.2 does not support extra KV cache"); + KU_ASSERT(params.topk_length == nullptr, "V3.2 does not support dynamic topk length"); + KU_ASSERT(params.stride_kv_row == 656); // number of bytes per token (512 fp8 + 4 float32 + 64 bfloat16) + } + + auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q, params.b); + auto tma_Q = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((bf16*)params.q), + make_layout( + shape_Q, + make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b) + ) + ), + SmemLayoutQ{} + ); + + CUtensorMap tensor_map_o; + { + // Here we manually construct TMA descriptor to store O, in order to leverage 5D TMA + uint64_t size[5] = {OBUF_SW, (unsigned long)params.h_q, HEAD_DIM_V/OBUF_SW, (unsigned long)params.s_q, (unsigned long)params.b}; + uint64_t stride[4] = {params.stride_o_h_q*sizeof(bf16), OBUF_SW*sizeof(bf16), params.stride_o_s_q*sizeof(bf16), params.stride_o_b*sizeof(bf16)}; + uint32_t box_size[5] = {OBUF_SW, BLOCK_M, HEAD_DIM_V/OBUF_SW, 1, 1}; + uint32_t elem_stride[5] = {1, 1, 1, 1, 1}; + CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tensor_map_o, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + 5, + params.out, + size, + stride, + box_size, + elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + OBUF_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B : + OBUF_SW == 32 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B : + OBUF_SW == 16 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B : + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + KU_ASSERT(res == CUresult::CUDA_SUCCESS); + } + + TmaParams< + decltype(shape_Q), decltype(tma_Q) + > tma_params = { + shape_Q, tma_Q, + tensor_map_o + }; + auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel, decltype(tma_params)>; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // NOTE Don't use PDL because of potential compiler bugs! + // cudaLaunchAttribute mla_kernel_attributes[1]; + // mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + // mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; + // cudaLaunchConfig_t mla_kernel_config = { + // dim3(num_m_block, params.h_k, params.num_sm_parts), + // dim3(NUM_THREADS, 1, 1), + // smem_size, + // stream, + // mla_kernel_attributes, + // 1 + // }; + // cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); + cutlass::ClusterLaunchParams launch_params = { + dim3(NUM_M_BLOCKS, params.s_q, params.num_sm_parts), + dim3(NUM_THREADS, 1, 1), + dim3(CLUSTER_SIZE, 1, 1), + smem_size, + params.stream + }; + cutlass::launch_kernel_on_cluster( + launch_params, (void*)mla_kernel, params, tma_params + ); + KU_CHECK_KERNEL_LAUNCH(); +} + +template +void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms) { + KernelTemplate::run(params); +} + +} diff --git a/csrc/sm90/decode/sparse_fp8/splitkv_mla.h b/csrc/sm90/decode/sparse_fp8/splitkv_mla.h index daa21a3..13b659b 100644 --- a/csrc/sm90/decode/sparse_fp8/splitkv_mla.h +++ b/csrc/sm90/decode/sparse_fp8/splitkv_mla.h @@ -2,8 +2,10 @@ #include "params.h" -namespace sm90 { +namespace sm90::decode::sparse_fp8 { -void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams ¶ms, cudaStream_t stream); +template +void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams ¶ms); } + diff --git a/csrc/sm90/prefill/sparse/helpers.h b/csrc/sm90/helpers.h similarity index 90% rename from csrc/sm90/prefill/sparse/helpers.h rename to csrc/sm90/helpers.h index fd68c36..2bab337 100644 --- a/csrc/sm90/prefill/sparse/helpers.h +++ b/csrc/sm90/helpers.h @@ -1,17 +1,10 @@ #pragma once -#include -#include #include +#include namespace sm90 { -using bf16 = cutlass::bfloat16_t; -using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; -using cutlass::arch::fence_view_async_shared; -using cutlass::arch::fence_barrier_init; -using cutlass::arch::NamedBarrier; - __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" @@ -51,7 +44,7 @@ __forceinline__ __device__ int64_t createpolicy_evict_first() { __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { - // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx + // In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); return row_idx; @@ -99,7 +92,7 @@ __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, T if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } -// A simpiler version of gemm +// A simpler version of gemm template __forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { using namespace cute; @@ -142,11 +135,11 @@ __forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Te __forceinline__ __device__ uint32_t get_sm_id() { uint32_t ret; - asm("mov.u32 %0, %smid;" : "=r"(ret)); + asm("mov.u32 %0, %%smid;" : "=r"(ret)); return ret; } -static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字 +static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. Not sure if this number is the same on all GPUs. template CUTE_DEVICE T* get_peer_addr(const T* p) { @@ -163,12 +156,12 @@ void launch_tma_copy( const TMA &tma_copy, Tensor0 src, Tensor1 dst, - transac_bar_t &bar, + cutlass::arch::ClusterTransactionBarrier &bar, const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL ) { auto thr_tma = tma_copy.get_slice(cute::_0{}); cute::copy( - tma_copy.with(reinterpret_cast(bar), 0, cache_hint), + tma_copy.with(reinterpret_cast(bar), 0, cache_hint), thr_tma.partition_S(src), thr_tma.partition_D(dst) ); diff --git a/csrc/sm90/prefill/sparse/config.h b/csrc/sm90/prefill/sparse/config.h new file mode 100644 index 0000000..7500566 --- /dev/null +++ b/csrc/sm90/prefill/sparse/config.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "defines.h" +#include "params.h" + +namespace sm90::fwd { + +using namespace cute; + +template +class KernelTemplate { +public: + +static constexpr int D_Q = D_QK; +static constexpr int D_K = D_QK; +static constexpr int D_V = 512; + +static constexpr int B_H = 64; +static constexpr int B_TOPK = 64; // TopK block size +static constexpr int NUM_THREADS = 128*3; +static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) + +template +using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( + GMMA::Layout_SW128_Atom{}, + Shape, Int<64*NUM_TILES>>{}, + Step<_1, _2>{} +), Shape<_1, _1>{})); + +template +using SmemLayoutKTilesTransposed = decltype(composition( + SmemLayoutKTiles{}, + Layout, Int>, Stride, _1>>{} +)); + +using SmemLayoutQ = SmemLayoutQTiles; +using SmemLayoutO = SmemLayoutOTiles; +using SmemLayoutK = SmemLayoutKTiles; +using SmemLayoutV = SmemLayoutKTilesTransposed; +using SmemLayoutHalfV = SmemLayoutKTilesTransposed; + +using SmemLayoutS = decltype(coalesce(tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} +), Shape<_1, _1>{})); + +struct SharedMemoryPlan { + union { + array_aligned> q; + array_aligned> o; + } q_o; + array_aligned> k[2]; + array_aligned> s[D_QK == 576 ? 1 : 2]; // For V3.2 (whose D_QK is 576), we overlap sS[0] with k's RoPE part to save shared memory; For MODEL1 (whose D_QK is 512), we allocate two buffers + + bool is_kv_valid[2][B_TOPK]; + float2 sM[32]; + float2 sL[64]; // For reduction across WG0/1 in epilogue + float final_max_logits[64], final_lse[64]; + transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready; +}; + +using TiledMMA_QK = decltype(make_tiled_mma( + GMMA::MMA_64x64x16_F32BF16BF16_SS{}, + Layout>{} +)); + +using TiledMMA_PV_LocalP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_RS{}, + Layout>{} +)); + +using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( + GMMA::MMA_64x256x16_F32BF16BF16_SS{}, + Layout>{} +)); + +template< + typename Shape_Q, typename TMA_Q +> +struct TmaParams { + Shape_Q shape_Q; TMA_Q tma_Q; + CUtensorMap tensor_map_O; +}; + +enum NamedBarriers : uint32_t { + wg0_bunch_0_ready = 0, + wg1_bunch_0_ready = 1, + wg0_s0_ready = 2, + wg1_s1_ready = 3, + sL_ready = 4, + warpgroup0_sync = 5, + warpgroup1_sync = 6, + epilogue_sync = 7 +}; + +// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction +template< + typename Tensor0, + typename Tensor1 +> +static __forceinline__ __device__ void save_rS_to_sS( + Tensor0 const &rPb, + Tensor1 const &sP, + int idx_in_warpgroup +) { + auto r2s_copy = make_tiled_copy_C( + Copy_Atom{}, + TiledMMA_QK{} + ); + ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); + Tensor thr_copy_rPb = thr_copy.retile_S(rPb); + Tensor thr_copy_sP = thr_copy.partition_D(sP); + cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); +} + +template +static __device__ __forceinline__ void +devfunc(const SparseAttnFwdParams ¶ms, const TMAParams &tma_params); + +static void run(const SparseAttnFwdParams ¶ms); + +}; + + +}; diff --git a/csrc/sm90/prefill/sparse/fwd.cu b/csrc/sm90/prefill/sparse/fwd.cu index 084e0e2..eb62b5d 100644 --- a/csrc/sm90/prefill/sparse/fwd.cu +++ b/csrc/sm90/prefill/sparse/fwd.cu @@ -1,709 +1,30 @@ #include "fwd.h" -#include -#include -#include -#include -#include -#include +#include -#include "utils.h" -#include "helpers.h" +#include "phase1.h" namespace sm90 { -using namespace cute; +void run_fwd_kernel(const SparseAttnFwdParams& params) { + const bool have_topk_length = params.topk_length != nullptr; -constexpr int D_Q = 576; -constexpr int D_K = 576; -constexpr int D_V = 512; - -constexpr int B_H = 64; -constexpr int B_TOPK = 64; // TopK block size -constexpr int NUM_THREADS = 128*3; -static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) - -template -using SmemLayoutQTiles = decltype(coalesce(tile_to_shape( - GMMA::Layout_K_SW128_Atom{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -), Shape<_1, _1>{})); - -template -using SmemLayoutOTiles = decltype(coalesce(tile_to_shape( - GMMA::Layout_K_SW128_Atom{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -), Shape<_1, _1>{})); - -template -using SmemLayoutKTiles = decltype(coalesce(tile_to_shape( - GMMA::Layout_SW128_Atom{}, - Shape, Int<64*NUM_TILES>>{}, - Step<_1, _2>{} -), Shape<_1, _1>{})); - -template -using SmemLayoutKTilesTransposed = decltype(composition( - SmemLayoutKTiles{}, - Layout, Int>, Stride, _1>>{} -)); - -using SmemLayoutQ = SmemLayoutQTiles<9>; -using SmemLayoutO = SmemLayoutOTiles<8>; -using SmemLayoutK = SmemLayoutKTiles<9>; -using SmemLayoutV = SmemLayoutKTilesTransposed<8>; -using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>; - -using SmemLayoutS = decltype(coalesce(tile_to_shape( - GMMA::Layout_K_SW128_Atom{}, - Shape, Int>{} -), Shape<_1, _1>{})); - -struct SharedMemoryPlan { - union { - array_aligned> q; - array_aligned> o; - } q_o; - array_aligned> k[2]; - array_aligned> s; - - bool is_kv_valid[2][B_TOPK]; - float2 sM[32]; - float2 sL[64]; // For reduction across WG0/1 in epilogue - float final_max_logits[64], final_lse[64]; - transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready; -}; - -using TiledMMA_QK = decltype(make_tiled_mma( - GMMA::MMA_64x64x16_F32BF16BF16_SS{}, - Layout>{} -)); - -using TiledMMA_PV_LocalP = decltype(make_tiled_mma( - GMMA::MMA_64x256x16_F32BF16BF16_RS{}, - Layout>{} -)); - -using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( - GMMA::MMA_64x256x16_F32BF16BF16_SS{}, - Layout>{} -)); - -template< - typename Shape_Q, typename TMA_Q -> -struct TmaParams { - Shape_Q shape_Q; TMA_Q tma_Q; - CUtensorMap tensor_map_O; -}; - -enum NamedBarriers : uint32_t { - wg0_bunch_0_ready = 0, - wg1_bunch_0_ready = 1, - wg0_s0_ready = 2, - wg1_s1_ready = 3, - sL_ready = 4, - warpgroup0_sync = 5, - warpgroup1_sync = 6 -}; - -// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction -template< - typename Tensor0, - typename Tensor1 -> -__forceinline__ __device__ void save_rS_to_sS( - Tensor0 const &rPb, - Tensor1 const &sP, - int idx_in_warpgroup -) { - auto r2s_copy = make_tiled_copy_C( - Copy_Atom{}, - TiledMMA_QK{} - ); - ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup); - Tensor thr_copy_rPb = thr_copy.retile_S(rPb); - Tensor thr_copy_sP = thr_copy.partition_D(sP); - cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP); -} - - -template -__global__ void __launch_bounds__(NUM_THREADS, 1, 1) -sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) { - // NOTE This kernel uses a similar schedule to Flash MLA - 0422. For a detailed explanation, please refer to https://github.com/deepseek-ai/FlashMLA/blob/main/docs/20250422-new-kernel-deep-dive.md -#if IS_SM90 - const int q_h_idx = blockIdx.x % (params.h_q/B_H); - const int s_q_idx = blockIdx.x / (params.h_q/B_H); - const int warpgroup_idx = cutlass::canonical_warp_group_idx(); - const int warp_idx = cutlass::canonical_warp_idx_sync(); - const int idx_in_warpgroup = threadIdx.x % 128; - - // Define shared tensors - extern __shared__ char wksp_buf[]; - SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); - Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{}); - Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{}); - Tensor sS0 = make_tensor(make_smem_ptr(plan.k[0].data()+64*512), SmemLayoutS{}); // Overlap with sK0's RoPE part - Tensor sS1 = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{}); - - if (warp_idx == 0 && elect_one_sync()) { - // Prefetch TMA descriptors - cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); - cute::prefetch_tma_descriptor(&tma_params.tensor_map_O); - - // Initialize barriers - plan.bar_q.init(1); - CUTE_UNROLL - for (int i = 0; i < 2; ++i) { - plan.bar_k0_free[i].init(128); - plan.bar_k0_ready[i].init(128); - plan.bar_k1_free[i].init(128); - plan.bar_k1_ready[i].init(128); - } - plan.bar_is_kv_valid_ready.init(16); - fence_barrier_init(); - } - - __syncthreads(); - - const int num_topk_blocks = params.topk / B_TOPK; - if (warpgroup_idx == 0 || warpgroup_idx == 1) { - cutlass::arch::warpgroup_reg_alloc<216>(); - - if (warp_idx == 0 && elect_one_sync()) { - // Load Q - Tensor gQ = flat_divide( - tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx), - Tile, Int>{} - )(_, _, q_h_idx, _0{}); - launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); - plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16)); + // Dispatch based on d_qk dimension and presence of topk_length + if (params.d_qk == 512) { + if (have_topk_length) { + sm90::fwd::run_fwd_phase1_kernel<512, true>(params); + } else { + sm90::fwd::run_fwd_phase1_kernel<512, false>(params); } - - float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation - float rL[2] = {0.0f, 0.0f}; - Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); - Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); - Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); - cute::fill(rO, 0.0f); - - // Wait for Q - plan.bar_q.wait(0); - - bool cur_bar_wait_phase = 0; - - struct Warpgroup0 {}; - struct Warpgroup1 {}; - - auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) { - constexpr bool IS_WG1 = std::is_same_v; - TiledMMA tiled_mma_QK = TiledMMA_QK{}; - Tensor sQ_tile = flat_divide(sQ, Tile, Int<64>>{})(_, _, _0{}, tile_idx); - Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{}); - gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup); - }; - - auto mask_rP = [&](auto warpgroup_idx) { - constexpr bool IS_WG1 = std::is_same_v; - plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); - CUTE_UNROLL - for (int row_idx = 0; row_idx < 2; ++row_idx) { - CUTE_UNROLL - for (int i = row_idx*2; i < size(rP); i += 4) { - int col = 8*(i/4) + (idx_in_warpgroup%4)*2; - if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY; - if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY; - } - } - }; - - auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) { - plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); - constexpr bool IS_WG1 = std::is_same_v; - const float scale = params.sm_scale_div_log2; - float r_sM[2]; - if constexpr (IS_WG1) { - *(float2*)r_sM = plan.sM[idx_in_warpgroup/4]; - } - float new_maxs[2]; - CUTE_UNROLL - for (int row_idx = 0; row_idx < 2; ++row_idx) { - // Get rowwise max - float cur_max = -INFINITY; - CUTE_UNROLL - for (int i = row_idx*2; i < size(rP); i += 4) { - cur_max = max(cur_max, max(rP(i), rP(i+1))); - } - cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); - cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); - cur_max *= scale; - - // Get new max and scale - // For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round) - new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max); - - // Scale O - float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]); - CUTE_UNROLL - for (int i = row_idx*2; i < size(rO); i += 4) { - rO(i) *= scale_for_o; - rO(i+1) *= scale_for_o; - } - - // Get rS - float cur_sum = 0; - CUTE_UNROLL - for (int i = row_idx*2; i < size(rP); i += 4) { - rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]); - rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]); - rS(i) = (bf16)rP(i); - rS(i+1) = (bf16)rP(i+1); - cur_sum += rP(i) + rP(i+1); - } - rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum; - } - __syncwarp(); - if (idx_in_warpgroup%4 == 0) { - plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs; - } - rM[0] = new_maxs[0]; - rM[1] = new_maxs[1]; - }; - - auto reduce_L = [&]() { - // Reduce L - // For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131 - rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); - rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); - rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); - rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); - if (idx_in_warpgroup%4 == 0) - plan.sL[threadIdx.x/4] = *(float2*)(rL); - NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready); - float2 peer_L = plan.sL[(threadIdx.x/4)^32]; - rL[0] += peer_L.x; - rL[1] += peer_L.y; - }; - - auto store_O = [&]() { - float scale_factors[2]; - CUTE_UNROLL - for (int i = 0; i < 2; ++i) - scale_factors[i] = rL[i] == 0.0f ? 1.0f : 1.0f / rL[i]; - - Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{}); - bf16* stsm_addrs[4]; - int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16); - CUTE_UNROLL - for (int i = 0; i < 64/16; ++i) { - stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i); - } - bool s2g_pred = warp_idx%4 == 0 && elect_one_sync(); - - warpgroup_wait<0>(); - CUTE_UNROLL - for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) { - // Convert - constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128; // 64: tile size, 128: warpgroup size - bf16 cur_rOb[NUM_ELEMS_EACH_TILE]; - CUTE_UNROLL - for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) { - cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]); - } - // R -> S - CUTE_UNROLL - for (int i = 0; i < 64/16; ++i) { - SM90_U32x4_STSM_N::copy( - *reinterpret_cast(cur_rOb + i*8 + 0), - *reinterpret_cast(cur_rOb + i*8 + 2), - *reinterpret_cast(cur_rOb + i*8 + 4), - *reinterpret_cast(cur_rOb + i*8 + 6), - *reinterpret_cast(stsm_addrs[i] + tile_idx*(B_H*64)) - ); - } - fence_view_async_shared(); - NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync); - // S -> G - if (s2g_pred) { - int g_tile_idx = warpgroup_idx*4 + tile_idx; - SM90_TMA_STORE_3D::copy( - &tma_params.tensor_map_O, - plan.q_o.o.data() + g_tile_idx*(B_H*64), - g_tile_idx*64, - q_h_idx*B_H, - s_q_idx - ); - } - } - cute::tma_store_arrive(); - }; - - - if (warpgroup_idx == 0) { - // Warpgroup 0 - - auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) { - plan.bar_k0_ready[0].wait(cur_bar_wait_phase); - qkt_gemm_one_tile(Warpgroup0{}, 0, true); - qkt_gemm_one_tile(Warpgroup0{}, 1, false); - qkt_gemm_one_tile(Warpgroup0{}, 2, false); - qkt_gemm_one_tile(Warpgroup0{}, 3, false); - warpgroup_commit_batch(); - }; - - auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) { - plan.bar_k0_ready[1].wait(cur_bar_wait_phase); - qkt_gemm_one_tile(Warpgroup0{}, 4, false); - qkt_gemm_one_tile(Warpgroup0{}, 5, false); - qkt_gemm_one_tile(Warpgroup0{}, 6, false); - qkt_gemm_one_tile(Warpgroup0{}, 7, false); - qkt_gemm_one_tile(Warpgroup0{}, 8, false); - warpgroup_commit_batch(); - }; - - auto scale_rS = [&](float scales[2]) { - CUTE_UNROLL - for (int row = 0; row < 2; ++row) { - CUTE_UNROLL - for (int i = row*2; i < size(rP); i += 4) { - rS(i) = (bf16)(rP(i) * scales[row]); - rS(i+1) = (bf16)(rP(i+1) * scales[row]); - } - } - }; - - auto rescale_rO = [&](float scales[2]) { - CUTE_UNROLL - for (int row = 0; row < 2; ++row) { - CUTE_UNROLL - for (int i = row*2; i < size(rO); i += 4) { - rO(i) *= scales[row]; - rO(i+1) *= scales[row]; - } - rL[row] *= scales[row]; - } - }; - - CUTE_NO_UNROLL - for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { - Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{}); - Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{}); - - if (block_idx == 0) { - // NOTE We put these code here to avoid register spilling - pipelined_wait_and_qkt_gemm_l(); - pipelined_wait_and_qkt_gemm_r(); - warpgroup_wait<0>(); - } - - // Online softmax, inform WG1 - mask_rP(Warpgroup0{}); - - online_softmax_and_rescale_o(Warpgroup0{}); - NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready); - - // Issue rO0 += rS0 @ sV0l - gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup); - warpgroup_commit_batch(); - - // Mark V0L as free - warpgroup_wait<0>(); - plan.bar_k0_free[0].arrive(); - - // Wait for new sM, scale rS, save, inform WG1 - NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready); - float new_rM[2], scale_factors[2]; - *(float2*)new_rM = plan.sM[idx_in_warpgroup/4]; - CUTE_UNROLL - for (int i = 0; i < 2; ++i) { - scale_factors[i] = exp2f(rM[i] - new_rM[i]); - rM[i] = new_rM[i]; - } - scale_rS(scale_factors); - save_rS_to_sS(rS, sS0, idx_in_warpgroup); - fence_view_async_shared(); - NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready); - - // Wait for sS1 - NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready); - - // Rescale rO0, Issue rO0 += sS1 @ sV1L - rescale_rO(scale_factors); - gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup); - warpgroup_commit_batch(); - - cur_bar_wait_phase ^= 1; - - if (block_idx+2 < num_topk_blocks) { - // Launch the next QK^T GEMM - pipelined_wait_and_qkt_gemm_l(); - - // Mark V1L as free - warpgroup_wait<1>(); - plan.bar_k1_free[0].arrive(); - pipelined_wait_and_qkt_gemm_r(); - - // Wait for rP0 = sQ @ sK0 - warpgroup_wait<0>(); - } else { - // Mark V1L as free - warpgroup_wait<0>(); - plan.bar_k1_free[0].arrive(); - } - } - - reduce_L(); - store_O(); + } else if (params.d_qk == 576) { + if (have_topk_length) { + sm90::fwd::run_fwd_phase1_kernel<576, true>(params); } else { - // Warpgroup 1 - - auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) { - plan.bar_k1_ready[1].wait(cur_bar_wait_phase); - qkt_gemm_one_tile(Warpgroup1{}, 4, true); - qkt_gemm_one_tile(Warpgroup1{}, 5, false); - qkt_gemm_one_tile(Warpgroup1{}, 6, false); - qkt_gemm_one_tile(Warpgroup1{}, 7, false); - qkt_gemm_one_tile(Warpgroup1{}, 8, false); - plan.bar_k1_ready[0].wait(cur_bar_wait_phase); - qkt_gemm_one_tile(Warpgroup1{}, 0, false); - qkt_gemm_one_tile(Warpgroup1{}, 1, false); - qkt_gemm_one_tile(Warpgroup1{}, 2, false); - qkt_gemm_one_tile(Warpgroup1{}, 3, false); - warpgroup_commit_batch(); - }; - - CUTE_NO_UNROLL - for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { - Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{}); - Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{}); - - // Issue rP1 = sQ @ sK1, and wait - pipelined_wait_and_qkt_gemm(); - warpgroup_wait<0>(); - - mask_rP(Warpgroup1{}); - - // Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready) - NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready); - online_softmax_and_rescale_o(Warpgroup1{}); - NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready); - - - // Issue rO1 += rS1 @ sV1R - gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup); - warpgroup_commit_batch(); - - // Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R - save_rS_to_sS(rS, sS1, idx_in_warpgroup); // Put it here is faster - NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready); - gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup); - warpgroup_commit_batch(); - - // Save rS1, inform WG0 - fence_view_async_shared(); - NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready); - - // Wait for GEMM, and inform that sV1R is free - warpgroup_wait<1>(); - plan.bar_k1_free[1].arrive(); - - // Wait for GEMM, and inform that sV0R is free - warpgroup_wait<0>(); - plan.bar_k0_free[1].arrive(); - - cur_bar_wait_phase ^= 1; - } - - reduce_L(); - store_O(); - - // Save lse - if (idx_in_warpgroup%4 == 0) { - for (int row = 0; row < 2; ++row) { - int real_row = get_AorC_row_idx(row, idx_in_warpgroup); - bool is_no_valid_tokens = rL[row] == 0.0f; - plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row]; - plan.final_lse[real_row] = is_no_valid_tokens ? -INFINITY : log2f(rL[row]) + rM[row]; - } - fence_view_async_shared(); - } - - NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync); - if (idx_in_warpgroup == 0) { - int g_offset = s_q_idx*params.h_q + q_h_idx*B_H; - SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float)); - SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float)); - cute::tma_store_arrive(); - } + sm90::fwd::run_fwd_phase1_kernel<576, false>(params); } } else { - // Producer warpgroup - cutlass::arch::warpgroup_reg_dealloc<72>(); - - constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE; - constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS; - int idx_in_group = idx_in_warpgroup % GROUP_SIZE; - int group_idx = idx_in_warpgroup / GROUP_SIZE; - int* gIndices = params.indices + s_q_idx*params.topk; // [topk] - - bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8)); - bf16* my_gKV_base = params.kv + idx_in_group*8; - - int64_t token_indices[2][NUM_ROWS_PER_GROUP]; - bool is_token_valid[2][NUM_ROWS_PER_GROUP]; - auto load_token_indices = [&](int block_idx) { - CUTE_UNROLL - for (int buf_idx = 0; buf_idx < 2; ++buf_idx) { - CUTE_UNROLL - for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { - int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx; - int t = __ldg(gIndices + offs); - token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster - is_token_valid[buf_idx][local_row] = t >= 0 && t < params.s_kv; - } - } - }; - - int64_t cache_policy = createpolicy_evict_last(); - auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) { - // Copy some K/V tiles from global memory to shared memory - // A tile has a shape of 64 (B_TOPK) x 64 - // `buf_idx` is the index of the shared memory buffer, 0 or 1 - // `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8 - CUTE_UNROLL - for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { - int64_t token_index = token_indices[buf_idx][local_row]; - CUTE_UNROLL - for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) { - cp_async_cacheglobal_l2_prefetch_256B( - my_gKV_base + token_index + tile_idx*64, - my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64), - is_token_valid[buf_idx][local_row], - cache_policy - ); - } - } - }; - - auto commit_to_mbar = [&](transac_bar_t &bar) { - cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar)); - }; - - int cur_bar_wait_phase = 1; - - CUTE_NO_UNROLL - for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { - load_token_indices(block_idx); - - // V0L - plan.bar_k0_free[0].wait(cur_bar_wait_phase); - copy_tiles(block_idx+0, 0, 0, 4); - commit_to_mbar(plan.bar_k0_ready[0]); - - // V1R - plan.bar_k1_free[1].wait(cur_bar_wait_phase); - copy_tiles(block_idx+1, 1, 4, 9); - commit_to_mbar(plan.bar_k1_ready[1]); - - // V0R - plan.bar_k0_free[1].wait(cur_bar_wait_phase); - copy_tiles(block_idx+0, 0, 4, 9); - commit_to_mbar(plan.bar_k0_ready[1]); - - // V1L - plan.bar_k1_free[0].wait(cur_bar_wait_phase); - copy_tiles(block_idx+1, 1, 0, 4); - commit_to_mbar(plan.bar_k1_ready[0]); - - // Valid mask - // NOTE V1R's finish implies maskings of the last round have finished - if (idx_in_group == 0) { - CUTE_UNROLL - for (int buf_idx = 0; buf_idx < 2; ++buf_idx) - CUTE_UNROLL - for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) - plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row]; - plan.bar_is_kv_valid_ready.arrive(); - } - - cur_bar_wait_phase ^= 1; - } + throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel"); } -#else - if (cute::thread0()) { - CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); - } -#endif } - -void run_fwd_kernel(const SparsePrefillParams& params) { - FLASH_ASSERT(params.h_kv == 1); - FLASH_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings - FLASH_ASSERT(params.topk > 0); - FLASH_ASSERT(params.h_q % B_H == 0); - - auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); - auto tma_Q = cute::make_tma_copy( - SM90_TMA_LOAD{}, - make_tensor( - make_gmem_ptr((bf16*)params.q), - make_layout( - shape_Q, - make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) - ) - ), - SmemLayoutQ{} - ); - - CUtensorMap tensor_map_O; - { - uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q}; - uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)}; - uint32_t box_size[3] = {64, B_H, 1}; - uint32_t elem_stride[3] = {1, 1, 1}; - CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( - &tensor_map_O, - CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, - 3, - params.out, - size, - stride, - box_size, - elem_stride, - CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, - CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE - ); - FLASH_ASSERT(res == CUresult::CUDA_SUCCESS); - } - - TmaParams< - decltype(shape_Q), decltype(tma_Q) - > tma_params = { - shape_Q, tma_Q, - tensor_map_O - }; - auto kernel = &sparse_attn_fwd_kernel; - - constexpr size_t smem_size = sizeof(SharedMemoryPlan); - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - cutlass::ClusterLaunchParams launch_params = { - dim3((params.h_q/B_H)*params.s_q, 1, 1), // NOTE We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z) - dim3(NUM_THREADS, 1, 1), - dim3(1, 1, 1), - smem_size, - params.stream - }; - cutlass::launch_kernel_on_cluster( - launch_params, (void*)kernel, params, tma_params - ); - CHECK_CUDA_KERNEL_LAUNCH(); -} - -} +} // namespace sm90 diff --git a/csrc/sm90/prefill/sparse/fwd.h b/csrc/sm90/prefill/sparse/fwd.h index 60cb624..1c26d68 100644 --- a/csrc/sm90/prefill/sparse/fwd.h +++ b/csrc/sm90/prefill/sparse/fwd.h @@ -4,6 +4,6 @@ namespace sm90 { -void run_fwd_kernel(const SparsePrefillParams& params); +void run_fwd_kernel(const SparseAttnFwdParams& params); } diff --git a/csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu b/csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu new file mode 100644 index 0000000..046cfb3 --- /dev/null +++ b/csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu @@ -0,0 +1,10 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm90::fwd { + +// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH +// = true / false respectively, to compile them in parallel. +template void run_fwd_phase1_kernel<512, false>(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu b/csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu new file mode 100644 index 0000000..45da995 --- /dev/null +++ b/csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu @@ -0,0 +1,10 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm90::fwd { + +// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH +// = true / false respectively, to compile them in parallel. +template void run_fwd_phase1_kernel<512, true>(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu b/csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu new file mode 100644 index 0000000..f35db00 --- /dev/null +++ b/csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu @@ -0,0 +1,8 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm90::fwd { + +template void run_fwd_phase1_kernel<576, false>(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu b/csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu new file mode 100644 index 0000000..bd0f0ca --- /dev/null +++ b/csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu @@ -0,0 +1,8 @@ +#include "../phase1.h" +#include "../phase1.cuh" + +namespace sm90::fwd { + +template void run_fwd_phase1_kernel<576, true>(const SparseAttnFwdParams& params); + +} diff --git a/csrc/sm90/prefill/sparse/phase1.cuh b/csrc/sm90/prefill/sparse/phase1.cuh new file mode 100644 index 0000000..bf2fff8 --- /dev/null +++ b/csrc/sm90/prefill/sparse/phase1.cuh @@ -0,0 +1,646 @@ +#pragma once + +#include "config.h" + +#include "utils.h" +#include "../../helpers.h" + +namespace sm90::fwd { + +using namespace cute; + +CUTE_DEVICE void st_global_cs_128(float f0, float f1, float f2, float f3, void *dst_ptr) { + asm volatile("st.weak.global.cs.v4.f32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(dst_ptr), + "f"(f0), "f"(f1), "f"(f2), "f"(f3) + ); +} + +CUTE_DEVICE +float2 __shfl_xor_sync_float2( + uint32_t mask, float2 value, int offset +) { + float2 res; + *reinterpret_cast(&res) = __shfl_xor_sync( + mask, + *reinterpret_cast(&value), + offset + ); + return res; +} + +CUTE_DEVICE +void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) { + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(src_ptr); + asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" + : + : "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +} + +template +template +__device__ void KernelTemplate::devfunc(const SparseAttnFwdParams ¶ms, const TMAParams &tma_params) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) + const int q_h_idx = blockIdx.x % (params.h_q/B_H); + const int s_q_idx = blockIdx.x / (params.h_q/B_H); + const int warpgroup_idx = cutlass::canonical_warp_group_idx(); + const int warp_idx = cutlass::canonical_warp_idx_sync(); + const int idx_in_warpgroup = threadIdx.x % 128; + + // Define shared tensors + extern __shared__ char wksp_buf[]; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{}); + Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{}); + Tensor sS0 = make_tensor(make_smem_ptr(D_QK == 576 ? plan.k[0].data()+64*512 : plan.s[1].data()), SmemLayoutS{}); // Overlap with sK0's RoPE part for V3.2 + Tensor sS1 = make_tensor(make_smem_ptr(plan.s[0].data()), SmemLayoutS{}); + + if (warp_idx == 0 && elect_one_sync()) { + // Prefetch TMA descriptors + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(&tma_params.tensor_map_O); + + // Initialize barriers + plan.bar_q.init(1); + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + plan.bar_k0_free[i].init(128); + plan.bar_k0_ready[i].init(128); + plan.bar_k1_free[i].init(128); + plan.bar_k1_ready[i].init(128); + } + plan.bar_is_kv_valid_ready.init(16); + fence_barrier_init(); + } + + __syncthreads(); + + const int topk_length = HAVE_TOPK_LENGTH ? __ldg(params.topk_length + s_q_idx) : params.topk; + const int num_topk_blocks = HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK); + + if (warpgroup_idx == 0 || warpgroup_idx == 1) { + cutlass::arch::warpgroup_reg_alloc<216>(); + + if (warp_idx == 0 && elect_one_sync()) { + // Load Q + Tensor gQ = flat_divide( + tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx), + Tile, Int>{} + )(_, _, q_h_idx, _0{}); + launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST); + plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16)); + } + + float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation + float rL[2] = {0.0f, 0.0f}; + Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape, Int>{}); + Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape, Int>{}); + Tensor rS = make_tensor(partition_shape_A(TiledMMA_PV_LocalP{}, Shape, Int>{})); + cute::fill(rO, 0.0f); + + // Wait for Q + plan.bar_q.wait(0); + + bool cur_bar_wait_phase = 0; + + struct Warpgroup0 {}; + struct Warpgroup1 {}; + + auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) { + constexpr bool IS_WG1 = std::is_same_v; + TiledMMA tiled_mma_QK = TiledMMA_QK{}; + Tensor sQ_tile = flat_divide(sQ, Tile, Int<64>>{})(_, _, _0{}, tile_idx); + Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{}); + gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup); + }; + + auto mask_rP = [&](auto warpgroup_idx) { + constexpr bool IS_WG1 = std::is_same_v; + plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); + CUTE_UNROLL + for (int row_idx = 0; row_idx < 2; ++row_idx) { + CUTE_UNROLL + for (int i = row_idx*2; i < size(rP); i += 4) { + int col = 8*(i/4) + (idx_in_warpgroup%4)*2; + if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY; + if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY; + } + } + }; + + auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) { + plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase); + constexpr bool IS_WG1 = std::is_same_v; + const float scale = params.sm_scale_div_log2; + float r_sM[2]; + if constexpr (IS_WG1) { + *(float2*)r_sM = plan.sM[idx_in_warpgroup/4]; + } + float new_maxs[2]; + CUTE_UNROLL + for (int row_idx = 0; row_idx < 2; ++row_idx) { + // Get rowwise max + float cur_max = -INFINITY; + CUTE_UNROLL + for (int i = row_idx*2; i < size(rP); i += 4) { + cur_max = max(cur_max, max(rP(i), rP(i+1))); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + cur_max *= scale; + + // Get new max and scale + // For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round) + new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max); + + // Scale O + float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]); + CUTE_UNROLL + for (int i = row_idx*2; i < size(rO); i += 4) { + rO(i) *= scale_for_o; + rO(i+1) *= scale_for_o; + } + + // Get rS + float cur_sum = 0; + CUTE_UNROLL + for (int i = row_idx*2; i < size(rP); i += 4) { + rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]); + rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]); + rS(i) = (bf16)rP(i); + rS(i+1) = (bf16)rP(i+1); + cur_sum += rP(i) + rP(i+1); + } + rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum; + } + __syncwarp(); + if (idx_in_warpgroup%4 == 0) { + plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs; + } + rM[0] = new_maxs[0]; + rM[1] = new_maxs[1]; + }; + + auto reduce_L = [&]() { + // Reduce L + // For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131 + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + if (idx_in_warpgroup%4 == 0) + plan.sL[threadIdx.x/4] = *(float2*)(rL); + NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready); + float2 peer_L = plan.sL[(threadIdx.x/4)^32]; + rL[0] += peer_L.x; + rL[1] += peer_L.y; + }; + + auto store_O = [&]() { + float scale_factors[2]; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : params.attn_sink[q_h_idx*B_H + get_AorC_row_idx(i, idx_in_warpgroup)]*CUDART_L2E_F; + scale_factors[i] = 1.0f / (rL[i] + exp2f(attn_sink - rM[i])); + if (rL[i] == 0.0f) + scale_factors[i] = 0.0f; // The output should be 0 whatever attn_sink is + } + + Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{}); + bf16* stsm_addrs[4]; + int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16); + CUTE_UNROLL + for (int i = 0; i < 64/16; ++i) { + stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i); + } + bool s2g_pred = warp_idx%4 == 0 && elect_one_sync(); + + warpgroup_wait<0>(); + CUTE_UNROLL + for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) { + // Convert + constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128; // 64: tile size, 128: warpgroup size + bf16 cur_rOb[NUM_ELEMS_EACH_TILE]; + CUTE_UNROLL + for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) { + cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]); + } + // R -> S + CUTE_UNROLL + for (int i = 0; i < 64/16; ++i) { + SM90_U32x4_STSM_N::copy( + *reinterpret_cast(cur_rOb + i*8 + 0), + *reinterpret_cast(cur_rOb + i*8 + 2), + *reinterpret_cast(cur_rOb + i*8 + 4), + *reinterpret_cast(cur_rOb + i*8 + 6), + *reinterpret_cast(stsm_addrs[i] + tile_idx*(B_H*64)) + ); + } + fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync); + // S -> G + if (s2g_pred) { + int g_tile_idx = warpgroup_idx*4 + tile_idx; + SM90_TMA_STORE_3D::copy( + &tma_params.tensor_map_O, + plan.q_o.o.data() + g_tile_idx*(B_H*64), + g_tile_idx*64, + q_h_idx*B_H, + s_q_idx + ); + } + } + cute::tma_store_arrive(); + }; + + + if (warpgroup_idx == 0) { + // Warpgroup 0 + + auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) { + plan.bar_k0_ready[0].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup0{}, 0, true); + qkt_gemm_one_tile(Warpgroup0{}, 1, false); + qkt_gemm_one_tile(Warpgroup0{}, 2, false); + qkt_gemm_one_tile(Warpgroup0{}, 3, false); + warpgroup_commit_batch(); + }; + + auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) { + plan.bar_k0_ready[1].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup0{}, 4, false); + qkt_gemm_one_tile(Warpgroup0{}, 5, false); + qkt_gemm_one_tile(Warpgroup0{}, 6, false); + qkt_gemm_one_tile(Warpgroup0{}, 7, false); + if constexpr (D_QK == 576) { + qkt_gemm_one_tile(Warpgroup0{}, 8, false); + } + warpgroup_commit_batch(); + }; + + auto scale_rS = [&](float scales[2]) { + CUTE_UNROLL + for (int row = 0; row < 2; ++row) { + CUTE_UNROLL + for (int i = row*2; i < size(rP); i += 4) { + rS(i) = (bf16)(rP(i) * scales[row]); + rS(i+1) = (bf16)(rP(i+1) * scales[row]); + } + } + }; + + auto rescale_rO = [&](float scales[2]) { + CUTE_UNROLL + for (int row = 0; row < 2; ++row) { + CUTE_UNROLL + for (int i = row*2; i < size(rO); i += 4) { + rO(i) *= scales[row]; + rO(i+1) *= scales[row]; + } + rL[row] *= scales[row]; + } + }; + + CUTE_NO_UNROLL + for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{}); + Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{}); + + if (block_idx == 0) { + // NOTE: We put this code here to avoid register spilling + pipelined_wait_and_qkt_gemm_l(); + pipelined_wait_and_qkt_gemm_r(); + warpgroup_wait<0>(); + } + + // Online softmax, inform WG1 + mask_rP(Warpgroup0{}); + + + online_softmax_and_rescale_o(Warpgroup0{}); + NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready); + + // Issue rO0 += rS0 @ sV0l + gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + // Mark V0L as free + warpgroup_wait<0>(); + plan.bar_k0_free[0].arrive(); + + // Wait for new sM, scale rS, save, inform WG1 + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready); + float new_rM[2], scale_factors[2]; + *(float2*)new_rM = plan.sM[idx_in_warpgroup/4]; + CUTE_UNROLL + for (int i = 0; i < 2; ++i) { + scale_factors[i] = exp2f(rM[i] - new_rM[i]); + rM[i] = new_rM[i]; + } + scale_rS(scale_factors); + save_rS_to_sS(rS, sS0, idx_in_warpgroup); + fence_view_async_shared(); + NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready); + + // Wait for sS1 + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready); + + // Rescale rO0, Issue rO0 += sS1 @ sV1L + rescale_rO(scale_factors); + gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + cur_bar_wait_phase ^= 1; + + if (block_idx+2 < num_topk_blocks) { + // Launch the next QK^T GEMM + pipelined_wait_and_qkt_gemm_l(); + + // Mark V1L as free + warpgroup_wait<1>(); + plan.bar_k1_free[0].arrive(); + pipelined_wait_and_qkt_gemm_r(); + + // Wait for rP0 = sQ @ sK0 + warpgroup_wait<0>(); + } else { + // Mark V1L as free + warpgroup_wait<0>(); + plan.bar_k1_free[0].arrive(); + } + } + + reduce_L(); + store_O(); + } else { + // Warpgroup 1 + + auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) { + plan.bar_k1_ready[1].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup1{}, 4, true); + qkt_gemm_one_tile(Warpgroup1{}, 5, false); + qkt_gemm_one_tile(Warpgroup1{}, 6, false); + qkt_gemm_one_tile(Warpgroup1{}, 7, false); + if constexpr (D_QK == 576) { + qkt_gemm_one_tile(Warpgroup1{}, 8, false); + } + plan.bar_k1_ready[0].wait(cur_bar_wait_phase); + qkt_gemm_one_tile(Warpgroup1{}, 0, false); + qkt_gemm_one_tile(Warpgroup1{}, 1, false); + qkt_gemm_one_tile(Warpgroup1{}, 2, false); + qkt_gemm_one_tile(Warpgroup1{}, 3, false); + warpgroup_commit_batch(); + }; + + CUTE_NO_UNROLL + for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{}); + Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{}); + + // Issue rP1 = sQ @ sK1, and wait + pipelined_wait_and_qkt_gemm(); + warpgroup_wait<0>(); + + mask_rP(Warpgroup1{}); + + + // Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready) + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready); + online_softmax_and_rescale_o(Warpgroup1{}); + NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready); + + + // Issue rO1 += rS1 @ sV1R + gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + // Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R + save_rS_to_sS(rS, sS1, idx_in_warpgroup); // Put it here is faster + NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready); + gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup); + warpgroup_commit_batch(); + + // Save rS1, inform WG0 + fence_view_async_shared(); + NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready); + + // Wait for GEMM, and inform that sV1R is free + warpgroup_wait<1>(); + plan.bar_k1_free[1].arrive(); + + // Wait for GEMM, and inform that sV0R is free + warpgroup_wait<0>(); + plan.bar_k0_free[1].arrive(); + + cur_bar_wait_phase ^= 1; + } + + reduce_L(); + store_O(); + + // Save lse + if (idx_in_warpgroup%4 == 0) { + for (int row = 0; row < 2; ++row) { + int real_row = get_AorC_row_idx(row, idx_in_warpgroup); + bool is_no_valid_tokens = rL[row] == 0.0f; + plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row]*CUDART_LN2_F; + plan.final_lse[real_row] = is_no_valid_tokens ? +INFINITY : logf(rL[row]) + rM[row]*CUDART_LN2_F; + } + fence_view_async_shared(); + } + + NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync); + if (idx_in_warpgroup == 0) { + int g_offset = s_q_idx*params.h_q + q_h_idx*B_H; + SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float)); + SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float)); + cute::tma_store_arrive(); + } + } + } else { + // Producer warpgroup + cutlass::arch::warpgroup_reg_dealloc<72>(); + + constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE; + constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS; + int idx_in_group = idx_in_warpgroup % GROUP_SIZE; + int group_idx = idx_in_warpgroup / GROUP_SIZE; + int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk] + + bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8)); + bf16* my_gKV_base = params.kv + idx_in_group*8; + + int64_t token_indices[2][NUM_ROWS_PER_GROUP]; + bool is_token_valid[2][NUM_ROWS_PER_GROUP]; + auto load_token_indices = [&](int block_idx) { + CUTE_UNROLL + for (int buf_idx = 0; buf_idx < 2; ++buf_idx) { + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { + int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx; + int t = __ldg(gIndices + offs); + token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster + bool is_cur_token_valid = t >= 0 && t < params.s_kv; + if constexpr (HAVE_TOPK_LENGTH) { + is_cur_token_valid &= offs < topk_length; + } + is_token_valid[buf_idx][local_row] = is_cur_token_valid; + } + } + }; + + int64_t cache_policy = createpolicy_evict_last(); + auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) { + // Copy some K/V tiles from global memory to shared memory + // A tile has a shape of 64 (B_TOPK) x 64 + // `buf_idx` is the index of the shared memory buffer, 0 or 1 + // `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8 + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) { + int64_t token_index = token_indices[buf_idx][local_row]; + CUTE_UNROLL + for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) { + cp_async_cacheglobal_l2_prefetch_256B( + my_gKV_base + token_index + tile_idx*64, + my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64), + is_token_valid[buf_idx][local_row], + cache_policy + ); + } + } + }; + + auto commit_to_mbar = [&](transac_bar_t &bar) { + cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar)); + }; + + int cur_bar_wait_phase = 1; + + CUTE_NO_UNROLL + for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) { + load_token_indices(block_idx); + + // V0L + plan.bar_k0_free[0].wait(cur_bar_wait_phase); + copy_tiles(block_idx+0, 0, 0, 4); + commit_to_mbar(plan.bar_k0_ready[0]); + + // V1R + plan.bar_k1_free[1].wait(cur_bar_wait_phase); + copy_tiles(block_idx+1, 1, 4, D_K/64); + commit_to_mbar(plan.bar_k1_ready[1]); + + // V0R + plan.bar_k0_free[1].wait(cur_bar_wait_phase); + copy_tiles(block_idx+0, 0, 4, D_K/64); + commit_to_mbar(plan.bar_k0_ready[1]); + + // V1L + plan.bar_k1_free[0].wait(cur_bar_wait_phase); + copy_tiles(block_idx+1, 1, 0, 4); + commit_to_mbar(plan.bar_k1_ready[0]); + + // Valid mask + // NOTE: V1R's finish implies maskings of the last round have finished + if (idx_in_group == 0) { + CUTE_UNROLL + for (int buf_idx = 0; buf_idx < 2; ++buf_idx) + CUTE_UNROLL + for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) + plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row]; + plan.bar_is_kv_valid_ready.arrive(); + } + + cur_bar_wait_phase ^= 1; + } + } + + +#else + if (cute::thread0()) { + CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90"); + } +#endif +} + +template +__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1) +sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TMAParams tma_params) { + Kernel::devfunc(params, tma_params); +} + +template +void KernelTemplate::run(const SparseAttnFwdParams ¶ms) { + KU_ASSERT(params.h_kv == 1); + KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings + KU_ASSERT(params.topk > 0); + KU_ASSERT(params.h_q % B_H == 0); + + auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); + auto tma_Q = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((bf16*)params.q), + make_layout( + shape_Q, + make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q) + ) + ), + SmemLayoutQ{} + ); + + CUtensorMap tensor_map_O; + { + uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q}; + uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)}; + uint32_t box_size[3] = {64, B_H, 1}; + uint32_t elem_stride[3] = {1, 1, 1}; + CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tensor_map_O, + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + 3, + params.out, + size, + stride, + box_size, + elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + KU_ASSERT(res == CUresult::CUDA_SUCCESS); + } + + TmaParams< + decltype(shape_Q), decltype(tma_Q) + > tma_params = { + shape_Q, tma_Q, + tensor_map_O + }; + auto kernel = &sparse_attn_fwd_kernel, decltype(tma_params)>; + + constexpr size_t smem_size = sizeof(SharedMemoryPlan); + KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + cutlass::ClusterLaunchParams launch_params = { + dim3((params.h_q/B_H)*params.s_q, 1, 1), // NOTE: We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z) + dim3(NUM_THREADS, 1, 1), + dim3(1, 1, 1), + smem_size, + params.stream + }; + cutlass::launch_kernel_on_cluster( + launch_params, (void*)kernel, params, tma_params + ); + KU_CHECK_KERNEL_LAUNCH(); +} + +template +void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { + KernelTemplate::run(params); +} + +} diff --git a/csrc/sm90/prefill/sparse/phase1.h b/csrc/sm90/prefill/sparse/phase1.h new file mode 100644 index 0000000..c315b2b --- /dev/null +++ b/csrc/sm90/prefill/sparse/phase1.h @@ -0,0 +1,10 @@ +#pragma once + +#include "../../../params.h" + +namespace sm90::fwd { + +template +void run_fwd_phase1_kernel(const SparseAttnFwdParams& params); + +} diff --git a/csrc/smxx/mla_combine.cu b/csrc/smxx/decode/combine/combine.cu similarity index 55% rename from csrc/smxx/mla_combine.cu rename to csrc/smxx/decode/combine/combine.cu index ff609bf..283f936 100644 --- a/csrc/smxx/mla_combine.cu +++ b/csrc/smxx/decode/combine/combine.cu @@ -1,70 +1,71 @@ -#include "mla_combine.h" +#include "combine.h" +#include #include #include #include #include +#include + #include "params.h" #include "utils.h" using namespace cute; +namespace smxx::decode { + template __global__ void __launch_bounds__(NUM_THREADS) -flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { - // grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M] +flash_fwd_mla_combine_kernel(__grid_constant__ const CombineParams params) { + // grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M] // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m const int batch_idx = blockIdx.x; - const int m_block_idx = blockIdx.y; + const int s_q_idx = blockIdx.y; + const int h_block_idx = blockIdx.z; const int warp_idx = threadIdx.x / 32; const int lane_idx = threadIdx.x % 32; + int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx); + if (warp_idx >= num_valid_heads) { + return; + } + const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx); const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); const int my_num_splits = end_split_idx - start_split_idx; - FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); if (my_num_splits == 1) { return; } - const int num_q_seqs = params.q_seq_per_hk * params.h_k; - const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M); + FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); + Tensor gLseAccum = make_tensor( - make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M), + make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M), Shape, Int>{}, - make_stride(num_q_seqs, _1{}) + make_stride(params.stride_lse_accum_split, _1{}) ); Tensor gLse = make_tensor( - make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M), + make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M), Shape>{}, Stride<_1>{} ); - extern __shared__ float smem_buf[]; - Tensor sLseScale = make_tensor( - make_smem_ptr(smem_buf), - Shape, Int>{}, - Stride, _1>{} // +1 to avoid bank conflict - ); - + __shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS]; + // Wait for the previous kernel (the MLA kernel) to finish cudaGridDependencySynchronize(); - - // Read gLseAccum into sLseScale - { - #pragma unroll 4 - for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) { - int split_idx = elem_idx / BLOCK_SIZE_M; - int seq_idx = elem_idx % BLOCK_SIZE_M; - sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY; - } - __syncthreads(); - } - if (warp_idx >= num_cur_valid_q_seqs) - return; + // Prefetch + static_assert(HEAD_DIM_V % (32*4) == 0); + constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4); + float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q; + float4 datas[ELEMS_PER_THREAD]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ELEMS_PER_THREAD; ++i) { + datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL + } // Warp #i gathers LseAccum for seq #i { @@ -73,7 +74,7 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { const int split_idx = i*32 + lane_idx; - local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY; + local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY; } float max_lse = -INFINITY; @@ -93,14 +94,26 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); - float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse; + float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse; if (lane_idx == 0) gLse(warp_idx) = global_lse / (float)M_LOG2E; - + + if (params.attn_sink != nullptr) { + int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; + float attn_sink = __ldg(params.attn_sink + q_head_idx); + if (global_lse != INFINITY) { + // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf) + // If attn_sink is -inf, this has no effect on global_lse + global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse)); + } else { + // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf) + global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F; + } + } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) { const int split_idx = i*32 + lane_idx; - if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse); + smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse); } } @@ -108,45 +121,42 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { // Warp #i accumulates activation for seq #i { - const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V; - Tensor gOaccum = make_tensor( - make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), - Shape, Int>{}, - make_stride(num_q_seqs*HEAD_DIM_V, _1{}) - ); - - static_assert(HEAD_DIM_V % 32 == 0); - constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32; - float result[ELEMS_PER_THREAD]; + float4 result[ELEMS_PER_THREAD]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ELEMS_PER_THREAD; ++i) - result[i] = 0.0f; + result[i] = {0.0f, 0.0f, 0.0f, 0.0f}; - #pragma unroll 2 + #pragma unroll 1 for (int split = 0; split < my_num_splits; ++split) { - float lse_scale = sLseScale(warp_idx, split); - if (lse_scale != 0.f) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ELEMS_PER_THREAD; ++i) { - result[i] += lse_scale * gOaccum(split, lane_idx + i*32); + float lse_scale = smem_buf[warp_idx][split]; + // if (lse_scale != 0.f) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ELEMS_PER_THREAD; ++i) { + result[i].x += lse_scale * datas[i].x; + result[i].y += lse_scale * datas[i].y; + result[i].z += lse_scale * datas[i].z; + result[i].w += lse_scale * datas[i].w; + if (split != my_num_splits-1) { + datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128); } } + // } } - - cudaTriggerProgrammaticLaunchCompletion(); - const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx; - const int k_head_idx = q_seq_idx / params.q_seq_per_hk; - auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride; - Tensor gO = make_tensor( - make_gmem_ptr(o_ptr), - Shape>{}, - Stride<_1>{} - ); + const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx; + ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ELEMS_PER_THREAD; ++i) - gO(lane_idx+i*32) = (ElementT)result[i]; + for (int i = 0; i < ELEMS_PER_THREAD; ++i) { + float4 data = result[i]; + ElementT data_converted[4]; + data_converted[0] = (ElementT)(data.x); + data_converted[1] = (ElementT)(data.y); + data_converted[2] = (ElementT)(data.z); + data_converted[3] = (ElementT)(data.w); + static_assert(sizeof(ElementT) == 2); + *(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted; + } } } @@ -175,7 +185,7 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { template -void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream) { +void run_flash_mla_combine_kernel(CombineParams ¶ms) { static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA FLASH_ASSERT(params.d_v == HEAD_DIM_V); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] { @@ -189,20 +199,22 @@ void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream) { attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = 1; cudaLaunchConfig_t combine_kernel_config = { - dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1), + dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)), dim3(NUM_THREADS, 1, 1), - smem_size, - stream, + 0, + params.stream, attribute, 1 }; - cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params); + CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params)); }); CHECK_CUDA_KERNEL_LAUNCH(); } -template void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream); +template void run_flash_mla_combine_kernel(CombineParams ¶ms); #ifndef FLASH_MLA_DISABLE_FP16 -template void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream); -#endif \ No newline at end of file +template void run_flash_mla_combine_kernel(CombineParams ¶ms); +#endif + +} diff --git a/csrc/smxx/decode/combine/combine.h b/csrc/smxx/decode/combine/combine.h new file mode 100644 index 0000000..0ea21fd --- /dev/null +++ b/csrc/smxx/decode/combine/combine.h @@ -0,0 +1,10 @@ +#pragma once + +#include "params.h" + +namespace smxx::decode { + +template +void run_flash_mla_combine_kernel(CombineParams ¶ms); + +} diff --git a/csrc/smxx/get_mla_metadata.cu b/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu similarity index 55% rename from csrc/smxx/get_mla_metadata.cu rename to csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu index 9b5be62..083da60 100644 --- a/csrc/smxx/get_mla_metadata.cu +++ b/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu @@ -1,16 +1,19 @@ -#include "get_mla_metadata.h" +#include "get_decoding_sched_meta.h" #include #include +#include #include "utils.h" +namespace smxx::decode { + __global__ void __launch_bounds__(32, 1, 1) -get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params) { +get_mla_metadata_kernel(__grid_constant__ const GetDecodeSchedMetaParams params) { int *seqlens_k_ptr = params.seqlens_k_ptr; - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + DecodingSchedMeta *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; int *num_splits_ptr = params.num_splits_ptr; - int batch_size = params.batch_size; + int batch_size = params.b; int block_size_n = params.block_size_n; int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; int num_sm_parts = params.num_sm_parts; @@ -24,14 +27,25 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params int total_num_blocks = 0; for (int i = threadIdx.x; i < batch_size; i += 32) { - int cur_s_k = params.topk == -1 ? __ldg(seqlens_k_ptr + i) : params.topk; + int cur_s_k; + if (params.topk == -1) { + // Dense model, cur_s_k = actual s_k + cur_s_k = __ldg(seqlens_k_ptr + i); + } else { + // Sparse model, cur_s_k = topk (+ extra topk) + cur_s_k = params.topk_length ? __ldg(params.topk_length + i) : params.topk; + if (cur_s_k == 0) cur_s_k = 1; // Ensure the main loop will never be empty + if (params.extra_topk) { + cur_s_k = ku::ceil(cur_s_k, block_size_n); + cur_s_k += params.extra_topk_length ? __ldg(params.extra_topk_length + i) : params.extra_topk; + } + } seqlens_k_shared[i] = cur_s_k; int first_token_idx = 0; int last_token_idx = max(cur_s_k-1, 0); int cur_first_block_idx = first_token_idx / block_size_n; int cur_last_block_idx = last_token_idx / block_size_n; // NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx] - // NOTE Before clamping, first_token_idx <= last_token_idx always holds, so after clamping, first_token_idx <= last_token_idx still holds. // NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel. int num_blocks = cur_last_block_idx - cur_first_block_idx + 1; total_num_blocks += num_blocks + fixed_overhead_num_blocks; @@ -47,22 +61,23 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params if (threadIdx.x == 0) { int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; - int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + int now_req_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; num_splits_shared[0] = 0; for (int i = 0; i < num_sm_parts; ++i) { - int tile_scheduler_metadata0[4], tile_scheduler_metadata1; - tile_scheduler_metadata0[0] = now_idx; - tile_scheduler_metadata0[1] = now_block + first_block_idx_shared[now_idx]; - tile_scheduler_metadata1 = now_n_split_idx; + DecodingSchedMeta cur_meta; + cur_meta.begin_req_idx = now_req_idx; + cur_meta.begin_block_idx = now_block + first_block_idx_shared[now_req_idx]; + cur_meta.begin_split_idx = now_n_split_idx; + cur_meta.is_first_req_splitted = (now_block != 0); int remain_payload = payload; - while (now_idx < batch_size) { - int num_blocks = num_blocks_shared[now_idx]; + while (now_req_idx < batch_size) { + int num_blocks = num_blocks_shared[now_req_idx]; int now_remain_blocks = num_blocks - now_block; if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { cum_num_splits += now_n_split_idx + 1; - num_splits_shared[now_idx + 1] = cum_num_splits; + num_splits_shared[now_req_idx + 1] = cum_num_splits; remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; - ++now_idx; + ++now_req_idx; now_block = 0; now_n_split_idx = 0; } else { @@ -74,12 +89,15 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params break; } } - tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; - tile_scheduler_metadata0[3] = now_block > 0 ? now_block + first_block_idx_shared[now_idx] : (seqlens_k_shared[now_idx-1] == 0 ? 0 : last_block_idx_shared[now_idx-1] + 1); - *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); - tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + cur_meta.end_req_idx = now_block > 0 ? now_req_idx : now_req_idx - 1; + cur_meta.end_block_idx = now_block > 0 ? now_block + first_block_idx_shared[now_req_idx] : (seqlens_k_shared[now_req_idx-1] == 0 ? 0 : last_block_idx_shared[now_req_idx-1] + 1); + cur_meta.is_last_req_splitted = cur_meta.end_block_idx != last_block_idx_shared[cur_meta.end_req_idx] + 1 && seqlens_k_shared[cur_meta.end_req_idx] != 0; + if (cur_meta.begin_req_idx == cur_meta.end_req_idx) { + cur_meta.is_first_req_splitted = cur_meta.is_last_req_splitted = cur_meta.is_first_req_splitted || cur_meta.is_last_req_splitted; + } + tile_scheduler_metadata_ptr[i] = cur_meta; } - FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0); } __syncwarp(); @@ -88,9 +106,11 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params } } -void run_get_mla_metadata_kernel(GetDecodingMetadataParams ¶ms, cudaStream_t stream) { - int smem_size = sizeof(int) * (params.batch_size*5+1); +void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams ¶ms) { + int smem_size = sizeof(int) * (params.b*5+1); CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params); + get_mla_metadata_kernel<<<1, 32, smem_size, params.stream>>>(params); CHECK_CUDA_KERNEL_LAUNCH(); } + +} diff --git a/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h b/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h new file mode 100644 index 0000000..0b1c288 --- /dev/null +++ b/csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h @@ -0,0 +1,9 @@ +#pragma once + +#include "params.h" + +namespace smxx::decode { + +void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams ¶ms); + +} diff --git a/csrc/smxx/get_mla_metadata.h b/csrc/smxx/get_mla_metadata.h deleted file mode 100644 index 7a1d1c4..0000000 --- a/csrc/smxx/get_mla_metadata.h +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "params.h" - -void run_get_mla_metadata_kernel(GetDecodingMetadataParams ¶ms, cudaStream_t stream); diff --git a/csrc/smxx/mla_combine.h b/csrc/smxx/mla_combine.h deleted file mode 100644 index eca7501..0000000 --- a/csrc/smxx/mla_combine.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#include "params.h" - -template -void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream); diff --git a/csrc/utils.h b/csrc/utils.h index 571412f..8de676f 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -1,5 +1,7 @@ #pragma once +#include + #define CHECK_CUDA(call) \ do { \ cudaError_t status_ = call; \ @@ -44,23 +46,37 @@ do { \ } while (0) #endif -// For development, we define both IS_SM100 and IS_SM90 when using CLion or VSCode IDEs so code highlighting will be correct. -#if defined(__CLION_IDE__) || defined(__VSCODE_IDE__) -#define IS_SM100 1 -#define IS_SM90 1 -#else - -// We define the following macros to detect the CUDA architecture, so that we can enable/disable certains kernels that depends on specific architectures. -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000) -#define IS_SM100 1 -#else -#define IS_SM100 0 +#ifndef TRAP_ONLY_DEVICE_ASSERT +#define TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) #endif -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900) -#define IS_SM90 1 -#else -#define IS_SM90 0 -#endif -#endif // defined(__CLION_IDE__) || defined(__VSCODE_IDE__) \ No newline at end of file +struct RingBufferState { + uint32_t cur_block_idx = 0u; + + __device__ __forceinline__ + void update() { + cur_block_idx += 1; + } + + template + __device__ __forceinline__ + std::pair get() const { + uint32_t stage_idx = cur_block_idx % NUM_STAGES; + bool phase = (cur_block_idx / NUM_STAGES) & 1; + return {stage_idx, phase}; + } + + __device__ __forceinline__ + RingBufferState offset_by(const int offset) const { + // Must guarantee no underflow + uint32_t new_block_idx = static_cast(static_cast(cur_block_idx) + offset); + RingBufferState new_state; + new_state.cur_block_idx = new_block_idx; + return new_state; + } +}; diff --git a/flash_mla/__init__.py b/flash_mla/__init__.py index 66f1986..02a8bba 100644 --- a/flash_mla/__init__.py +++ b/flash_mla/__init__.py @@ -8,3 +8,12 @@ flash_attn_varlen_kvpacked_func, flash_mla_sparse_fwd ) + +__all__ = [ + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd" +] diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 4d27621..4fac685 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -1,82 +1,183 @@ from typing import Optional, Tuple +import dataclasses import torch import flash_mla.cuda as flash_mla_cuda +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + def get_mla_metadata( - cache_seqlens: torch.Tensor, - num_q_tokens_per_head_k: int, - num_heads_k: int, - num_heads_q: Optional[int] = None, - is_fp8_kvcache: bool = False, - topk: Optional[int] = None -) -> Tuple[torch.Tensor, torch.Tensor]: + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. - num_heads_k: The number of k heads. - num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. - Returns: - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32. + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. """ - return flash_mla_cuda.get_mla_decoding_metadata(cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk) + return FlashMLASchedMeta(), None def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], head_dim_v: int, - tile_scheduler_metadata: torch.Tensor, - num_splits: torch.Tensor, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, softmax_scale: Optional[float] = None, causal: bool = False, is_fp8_kvcache: bool = False, indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head dimension of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. - softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md - indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. - - Returns: + Different modes (including fp8/bf16, sparsity, and model version (i.e. V3.2 or MODEL1)) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + Besides, some kernels also have their own requirements on the layout of k cache, including: + - For sparse fp8 decoding kernel on F3, k_cache.stride(0) must be a multiple of 656B (for V32) or 576B (for MODEL1). Padding is needed sometimes. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. This is used to support MODEL1. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + For DeepSeek MODEL1: + head_dim should be 512 while head_dim_v is also 512. + + In FP8+sparse mode, every block can be divided into two parts. The first parts stores NoPE0, RoPE0, NoPE1, RoPE1, ... while the second part stores scale factors: 7xue8m0, 1Bpad, 7xue8m0, 1Bpad, ... + + Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - if indices is not None: - assert causal == False, "causal must be `false` if sparse attention is enabled." - out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( - q, - k_cache, - head_dim_v, - cache_seqlens, - block_table, - softmax_scale, - causal, - tile_scheduler_metadata, - num_splits, - is_fp8_kvcache, - indices - ) - return out, softmax_lse + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) def flash_mla_sparse_fwd( @@ -85,6 +186,8 @@ def flash_mla_sparse_fwd( indices: torch.Tensor, sm_scale: float, d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sparse attention prefill kernel @@ -95,16 +198,22 @@ def flash_mla_sparse_fwd( indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv sm_scale: float d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. Returns: (output, max_logits, lse) - About the definition of output, max_logits and lse, please refer to README.md + Please refer to tests/ref.py for the precise definitions of these parameters. - output: [s_q, h_q, d_v], bfloat16 - max_logits: [s_q, h_q], float - - lse: [s_q, h_q], float, 2-based log-sum-exp + - lse: [s_q, h_q], float, log-sum-exp of attention scores """ results = flash_mla_cuda.sparse_prefill_fwd( - q, kv, indices, sm_scale, d_v + q, kv, indices, sm_scale, d_v, attn_sink, topk_length ) return results diff --git a/setup.py b/setup.py index 15fa671..513b435 100644 --- a/setup.py +++ b/setup.py @@ -36,11 +36,11 @@ def get_arch_flags(): DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") if major < 12 or (major == 12 and minor <= 8): - assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." + assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." # TODO Implement this arch_flags = [] if not DISABLE_SM100: - arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) + arch_flags.extend(["-gencode", "arch=compute_100f,code=sm_100f"]) if not DISABLE_SM90: arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) return arch_flags @@ -54,31 +54,60 @@ def get_nvcc_thread_args(): this_dir = os.path.dirname(os.path.abspath(__file__)) if IS_WINDOWS: - cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"] + cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"] else: - cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"] + cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"] ext_modules = [] ext_modules.append( CUDAExtension( name="flash_mla.cuda", sources=[ - "csrc/pybind.cpp", - "csrc/smxx/get_mla_metadata.cu", - "csrc/smxx/mla_combine.cu", - "csrc/sm90/decode/dense/splitkv_mla.cu", - "csrc/sm90/decode/sparse_fp8/splitkv_mla.cu", + # API + "csrc/api/api.cpp", + + # Misc kernels for decoding + "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu", + "csrc/smxx/decode/combine/combine.cu", + + # sm90 dense decode + "csrc/sm90/decode/dense/instantiations/fp16.cu", + "csrc/sm90/decode/dense/instantiations/bf16.cu", + + # sm90 sparse decode + "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu", + "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu", + "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", + "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", + + # sm90 sparse prefill "csrc/sm90/prefill/sparse/fwd.cu", - "csrc/sm100/decode/sparse_fp8/splitkv_mla.cu", + "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu", + "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu", + "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu", + "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu", + + # sm100 dense prefill & backward "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", - "csrc/sm100/prefill/sparse/fwd.cu", + + # sm100 sparse prefill + "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu", + "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu", + "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu", + "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu", + "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu", + + # sm100 sparse decode + "csrc/sm100/decode/head64/instantiations/v32.cu", + "csrc/sm100/decode/head64/instantiations/model1.cu", + "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), "nvcc": [ "-O3", - "-std=c++17", + "-std=c++20", "-DNDEBUG", "-D_USE_MATH_DEFINES", "-Wno-deprecated-declarations", @@ -89,11 +118,14 @@ def get_nvcc_thread_args(): "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - "--ptxas-options=-v,--register-usage-level=10" + "--ptxas-options=-v,--register-usage-level=10,--warn-on-spills,--warn-on-local-memory-usage,--warn-on-double-precision-use", + "-lineinfo", + "--source-in-ptx", ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(), }, include_dirs=[ Path(this_dir) / "csrc", + Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", diff --git a/tests/kernelkit/.gitignore b/tests/kernelkit/.gitignore new file mode 100644 index 0000000..42e7a8a --- /dev/null +++ b/tests/kernelkit/.gitignore @@ -0,0 +1,9 @@ +build +*.so +*.egg-info/ +__pycache__/ +dist/ +/.vscode +.cache +/temp +/profiles diff --git a/tests/kernelkit/__init__.py b/tests/kernelkit/__init__.py new file mode 100644 index 0000000..378c42f --- /dev/null +++ b/tests/kernelkit/__init__.py @@ -0,0 +1,11 @@ +from . import bench +from . import compare +from . import generate +from . import precision +from . import utils + +from .bench import bench_kineto, bench_by_cuda_events +from .compare import get_cos_diff, check_is_bitwise_equal, check_is_allclose, check_is_bitwise_equal_comparator, check_is_allclose_comparator +from .generate import gen_non_contiguous_randn_tensor, gen_non_contiguous_tensor, non_contiguousify +from .precision import LowPrecisionMode, is_low_precision_mode, optional_cast_to_bf16_and_cast_back +from .utils import colors, cdiv, is_using_profiling_tools, set_random_seed, Counter diff --git a/tests/kernelkit/bench.py b/tests/kernelkit/bench.py new file mode 100644 index 0000000..7b0b659 --- /dev/null +++ b/tests/kernelkit/bench.py @@ -0,0 +1,205 @@ +from typing import Tuple, List, Callable, Union, Dict, overload +import dataclasses + +import torch +import triton + +from .utils import is_using_profiling_tools + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + +@triton.jit +def profiler_range_start_marker_kernel(): + pass + +def _run_profiler_range_start_marker_kernel(): + profiler_range_start_marker_kernel[(1,)]() + +@dataclasses.dataclass +class BenchKinetoRawResult: + """ + A struct holding the result of `bench_kineto` + """ + + is_using_nsys: bool + num_tests: int + time_ranges: Dict[str, List[Tuple[float, float]]] + + def _get_matched_kernel_name(self, name_substr: str, allow_no_match: bool = False, allow_multiple_match: bool = False) -> List[str]: + matched_names = [name for name in self.time_ranges.keys() if name_substr in name] + if not allow_no_match and len(matched_names) == 0: + all_kernel_names_str = '\n - ' + '\n - '.join(self.time_ranges.keys()) + raise ValueError(f"Error: No kernel name matched for substring {name_substr}.\nAvailable kernels are: {all_kernel_names_str}") + if not allow_multiple_match and len(matched_names) > 1: + raise ValueError(f"Error: Multiple kernel matched for substring {name_substr}: {', '.join(matched_names)}") + return matched_names + + def get_kernel_names(self) -> List[str]: + return list(self.time_ranges.keys()) + + def get_kernel_times(self, kernel_names_substr: List[str], allow_indivisible_run_count: bool = False, allow_missing: bool = False, allow_multiple_match: bool = False, return_avg_individual_run: bool = False) -> List[float]: + """ + Get the average each-run time usage of each kernel provided in `kernel_names` + + If return_avg_individual_run is False, return sum(time) / num_tests, else return sum(time) / len(time) + If is_using_profiling_tools (which is conflict with bench_kineto), return a series of 1 seconds + """ + if is_using_profiling_tools(): + return [1 for _ in range(len(kernel_names_substr))] + + result = [] + for substr in kernel_names_substr: + matched_names = self._get_matched_kernel_name(substr, allow_no_match=allow_missing, allow_multiple_match=allow_multiple_match) + if len(matched_names) == 0: + assert allow_missing + result.append(0) + else: + time_usage_sum = 0 + run_cnt_sum = 0 + for matched_name in matched_names: + run_cnt = len(self.time_ranges[matched_name]) + if not allow_indivisible_run_count and run_cnt % self.num_tests != 0: + raise RuntimeError(f"Error: the number of runs for kernel {matched_name} ({run_cnt}) is indivisible by `num_tests` ({self.num_tests})") + time_usage_sum += sum([end-start for (start, end) in self.time_ranges[matched_name]]) + run_cnt_sum += run_cnt + denominator = run_cnt_sum if return_avg_individual_run else self.num_tests + result.append(time_usage_sum / denominator) + return result + + def get_kernel_time(self, kernel_name_substr: str) -> float: + return self.get_kernel_times([kernel_name_substr])[0] + + def get_e2e_time(self, start_kernel_name_substr: str, end_kenrel_name_substr: str) -> float: + """ + Get the end-to-end time usage for a sequence of kernels + defined as "last kernel end time" - "first kernel start time" + If is_using_profiling_tools (which is conflict with bench_kineto), return 1 second + """ + if is_using_profiling_tools(): + return 1 + + start_kernel_name = self._get_matched_kernel_name(start_kernel_name_substr)[0] + end_kernel_name = self._get_matched_kernel_name(end_kenrel_name_substr)[0] + num_start_kernels = len(self.time_ranges[start_kernel_name]) + num_end_kernels = len(self.time_ranges[end_kernel_name]) + if num_start_kernels%self.num_tests != 0: + raise RuntimeError(f"Error: the number of runs for kernel {start_kernel_name} ({num_start_kernels}) is indivisible by `num_tests` ({self.num_tests})") + if num_end_kernels%self.num_tests != 0: + raise RuntimeError(f"Error: the number of runs for kernel {end_kernel_name} ({num_end_kernels}) is indivisible by `num_tests` ({self.num_tests})") + time_spans = [] + for i in range(self.num_tests): + end_time = self.time_ranges[end_kernel_name][(i+1)*(num_end_kernels//self.num_tests)-1][1] + start_time = self.time_ranges[start_kernel_name][i*(num_start_kernels//self.num_tests)][0] + time_spans.append((start_time, end_time)) + result = sum([end-start for (start, end) in time_spans]) / self.num_tests + return result + + +def bench_kineto(fn: Callable, num_tests: int = 30, + flush_l2: bool = True) -> BenchKinetoRawResult: + """ + Run `fn` for `num_tests` times under `bench_kineto` (CUPTI), and returns a BenchKinetoRawResult + """ + using_nsys = is_using_profiling_tools() + + # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() + with profiler: + for i in range(2): + if i == 1 and not using_nsys: + _run_profiler_range_start_marker_kernel() # This marks the start of the profiling range + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + enable_nvtx_range = i == 1 and _ == num_tests-1 + if enable_nvtx_range: + torch.cuda.nvtx.range_push("profile_target") + fn() + if enable_nvtx_range: + torch.cuda.nvtx.range_pop() + if not using_nsys: + if i == 0: + torch.cuda.synchronize() + profiler.step() + + if using_nsys: + return BenchKinetoRawResult(True, num_tests, {}) + + from torch.autograd.profiler_util import EventList, FunctionEvent # pylint: disable=import-outside-toplevel + events: EventList = profiler.events() # type: ignore + + # Filter out all events that are not function events + events: List[FunctionEvent] = [event for event in events if isinstance(event, FunctionEvent)] + + # Filter out all events before the range marker + for idx, event in enumerate(events): + if event.name == "profiler_range_start_marker_kernel": + events = events[idx+1:] + break + else: + raise RuntimeError("Could not find profiler range start marker kernel event") + + # Get time ranges of each kernel + kernel_times = {} + for event in events: + kernel_name = event.name + if kernel_name not in kernel_times: + kernel_times[kernel_name] = [] + kernel_times[kernel_name].append((event.time_range.start/1e6, event.time_range.end/1e6)) + + return BenchKinetoRawResult(False, num_tests, kernel_times) + +@overload +def bench_by_cuda_events(kernels: List[Callable], num_warmups_each: int, num_runs_each: int) -> List[float]: ... + +@overload +def bench_by_cuda_events(kernels: Callable, num_warmups_each: int, num_runs_each: int) -> float: ... + +def bench_by_cuda_events(kernels: Union[List[Callable], Callable], num_warmups_each: int, num_runs_each: int) -> Union[List[float], float]: + buf_for_l2_clear = torch.empty(int(256e6//4), dtype=torch.int32, device='cuda') + + is_kernel_single_callable = isinstance(kernels, Callable) + if is_kernel_single_callable: + kernels = [kernels] + + torch.cuda.synchronize() + for i in range(num_warmups_each): + for kernel in kernels: + kernel() + if i == 0: + # Ensure the first run is successful + try: + torch.cuda.synchronize() + except Exception as e: + print(f"Kernel {kernel.__name__} failed on warmup run {i}: {e}") + return [] + + start_events = [[torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] for _ in kernels] + end_events = [[torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] for _ in kernels] + for i in range(num_runs_each): + for j, kernel in enumerate(kernels): + buf_for_l2_clear.random_() + if i == num_runs_each-1: + torch.cuda.nvtx.range_push("profile_target") + start_events[j][i].record() + kernel() + end_events[j][i].record() + if i == num_runs_each-1: + torch.cuda.nvtx.range_pop() + + torch.cuda.synchronize() + time_usages = [ + sum([start_events[j][i].elapsed_time(end_events[j][i])*1e-3 for i in range(num_runs_each)]) / num_runs_each + for j in range(len(kernels)) + ] + if is_kernel_single_callable: + time_usages = time_usages[0] + return time_usages diff --git a/tests/kernelkit/compare.py b/tests/kernelkit/compare.py new file mode 100644 index 0000000..0ffde24 --- /dev/null +++ b/tests/kernelkit/compare.py @@ -0,0 +1,95 @@ +from typing import List + +import torch + +def check_is_bitwise_equal_comparator(ans: torch.Tensor, ref: torch.Tensor, result: torch.Tensor): + """ + Return if two tensors are bitwise equal + Return a bool if avoid_sync is False, else return a tensor + """ + assert ans.shape == ref.shape, "Shape mismatch" + torch.all(torch.eq(ans, ref), out=result) + +def check_is_bitwise_equal(name: str, ans: torch.Tensor, ref: torch.Tensor, quiet: bool = False) -> bool: + is_bitwise_equal = torch.equal(ans, ref) + if not quiet and not is_bitwise_equal: + print(f"`{name}` mismatch: not bitwise equal. Mismatch count: {(ans != ref).sum().item()} out of {ans.numel()}") + return is_bitwise_equal + +def get_cos_diff(ans: torch.Tensor, ref: torch.Tensor) -> float: + """ + Calculate the cosine diff between two tensors + Return a float if avoid_sync is False, else return a tensor + """ + ans, ref = ans.double(), ref.double() + if (ref*ref).sum().item() < 1e-12: + return 0 + denominator = (ans*ans + ref*ref).sum().item() + sim = 2 * (ans*ref).sum().item() / denominator + return 1 - sim + +def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7, quiet: bool = False) -> bool: + """ + Check if two tensors are close enough + Return a bool if avoid_sync is False, else return a tensor + """ + assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}" + assert ans.dtype == ref.dtype, f"`{name}` Dtype mismatch: {ans.dtype} vs {ref.dtype}" + + ans = ans.clone().to(torch.float) + ref = ref.clone().to(torch.float) + + def report_err(*args, **kwargs): + if not quiet: + print(*args, **kwargs) + + # Deal with anomalies + def deal_with_anomalies(val: float): + ref_mask = (ref == val) if (val == val) else (ref != ref) + ans_mask = (ans == val) if (val == val) else (ans != ans) + ref[ref_mask] = 0.0 + ans[ans_mask] = 0.0 + if not torch.equal(ref_mask, ans_mask): + report_err(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref") + return False + return True + + anomalies_check_passed = True + anomalies_check_passed &= deal_with_anomalies(float("inf")) + anomalies_check_passed &= deal_with_anomalies(float("-inf")) + anomalies_check_passed &= deal_with_anomalies(float("nan")) + + cos_diff = get_cos_diff(ans, ref) + raw_abs_err = torch.abs(ans-ref) + raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6)) + rel_err = raw_rel_err.masked_fill(raw_abs_err List[int]: + result = [] + for size in t.shape[::-1]: + result.append(pos % size) + pos = pos // size + assert pos == 0 + return result[::-1] + report_err(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}") + report_err(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}") + report_err(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)") + report_err(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})") + return False + else: + if abs(cos_diff) > cos_diff_tol: + report_err(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})") + return False + return True + +def check_is_allclose_comparator(name: str, ans: torch.Tensor, ref: torch.Tensor, out: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7): + out.fill_(check_is_allclose(name, ans, ref, abs_tol, rel_tol, cos_diff_tol)) diff --git a/tests/kernelkit/generate.py b/tests/kernelkit/generate.py new file mode 100644 index 0000000..8ee4dcf --- /dev/null +++ b/tests/kernelkit/generate.py @@ -0,0 +1,25 @@ +import torch + +def _get_new_non_contiguous_tensor_shape(shape): + """ + Get the expanded shape for a non-contiguous tensor. + The last dimension is increased by 128 (for alignment), and all other dimensions are increased by 1 + """ + return [dim+128 if dim_idx == len(shape)-1 else dim+1 for dim_idx, dim in enumerate(shape)] + +def gen_non_contiguous_randn_tensor(shape, *args, **kwargs): + new_shape = _get_new_non_contiguous_tensor_shape(shape) + base_tensor = torch.randn(new_shape, *args, **kwargs) + slices = [slice(0, dim) for dim in shape] + return base_tensor[slices] + +def gen_non_contiguous_tensor(shape, *args, **kwargs): + new_shape = _get_new_non_contiguous_tensor_shape(shape) + base_tensor = torch.empty(new_shape, *args, **kwargs) + slices = [slice(0, dim) for dim in shape] + return base_tensor[slices] + +def non_contiguousify(tensor: torch.Tensor) -> torch.Tensor: + new_tensor = gen_non_contiguous_tensor(tensor.shape, dtype=tensor.dtype, device=tensor.device) + new_tensor[:] = tensor + return new_tensor diff --git a/tests/kernelkit/precision.py b/tests/kernelkit/precision.py new file mode 100644 index 0000000..5da06e4 --- /dev/null +++ b/tests/kernelkit/precision.py @@ -0,0 +1,30 @@ +import torch + +_is_low_precision_mode_stack = [] + +class LowPrecisionMode: + def __init__(self, enabled: bool = True): + self.enabled = enabled + + def __enter__(self): + global _is_low_precision_mode_stack + _is_low_precision_mode_stack.append(self.enabled) + + def __exit__(self, exc_type, exc_value, traceback): + global _is_low_precision_mode_stack + _is_low_precision_mode_stack.pop() + +def is_low_precision_mode() -> bool: + global _is_low_precision_mode_stack + if len(_is_low_precision_mode_stack) == 0: + return False + return _is_low_precision_mode_stack[-1] + +def optional_cast_to_bf16_and_cast_back(tensor: torch.Tensor) -> torch.Tensor: + assert tensor.dtype == torch.float32, "Input tensor must be of dtype torch.float32 for optional casting." + if is_low_precision_mode(): + tensor_bf16 = tensor.to(torch.bfloat16) + tensor_fp32 = tensor_bf16.to(torch.float32) + return tensor_fp32 + else: + return tensor diff --git a/tests/kernelkit/utils.py b/tests/kernelkit/utils.py new file mode 100644 index 0000000..84a5927 --- /dev/null +++ b/tests/kernelkit/utils.py @@ -0,0 +1,50 @@ +import os +import functools + +colors = { + 'RED_FG': '\033[31m', + 'GREEN_FG': '\033[32m', + 'CYAN_FG': '\033[36m', + 'GRAY_FG': '\033[90m', + 'YELLOW_FG': '\033[33m', + 'RED_BG': '\033[41m', + 'GREEN_BG': '\033[42m', + 'CYAN_BG': '\033[46m', + 'YELLOW_BG': '\033[43m', + 'GRAY_BG': '\033[100m', + 'CLEAR': '\033[0m' +} + +def cdiv(a: int, b: int) -> int: + return (a + b - 1) // b + +@functools.lru_cache() +def is_using_profiling_tools() -> bool: + """ + Return whether we are running under profiling tools like nsys or ncu + + NOTE cuda-gdb will also cause conflict with CUPTI (bench_kineto) but currently we lack ways to detect it + """ + is_using_nsys = os.environ.get('NSYS_PROFILING_SESSION_ID') is not None + is_using_ncu = os.environ.get('NV_COMPUTE_PROFILER_PERFWORKS_DIR') is not None + is_using_compute_sanitizer = os.environ.get('NV_SANITIZER_INJECTION_PORT_RANGE_BEGIN') is not None + return is_using_nsys or is_using_ncu or is_using_compute_sanitizer + +def set_random_seed(seed: int): + import random + import numpy as np + import torch + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + +class Counter: + def __init__(self): + self.count = 0 + + def next(self) -> int: + self.count += 1 + return self.count - 1 diff --git a/tests/lib.py b/tests/lib.py index f884721..139e130 100644 --- a/tests/lib.py +++ b/tests/lib.py @@ -1,73 +1,405 @@ -from typing import List +import dataclasses +import os +import enum +from typing import List, Optional +import random import torch +import kernelkit as kk +import flash_mla -def cdiv(x: int, y: int): - return (x+y-1) // y +import quant -def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7): +class TestTarget(enum.Enum): + FWD = 0 + DECODE = 1 + +@dataclasses.dataclass +class ExtraTestParamForDecode: + b: int + is_varlen: bool + have_zero_seqlen_k: bool + extra_s_k: Optional[int] = None + extra_topk: Optional[int] = None + block_size: int = 64 + extra_block_size: Optional[int] = None + have_extra_topk_length: bool = False + +@dataclasses.dataclass +class TestParam: + s_q: int + s_kv: int + topk: int + h_q: int = 128 + h_kv: int = 1 + d_qk: int = 512 + d_v: int = 512 + seed: int = -1 # -1: to be filled automatically + check_correctness: bool = True + is_all_indices_invalid: bool = False # All indices are invalid, i.e., all indices are set to a large number (e.g., 2147483647) + num_runs: int = 10 + have_attn_sink: bool = False + have_topk_length: bool = False + decode: Optional[ExtraTestParamForDecode] = None + +@dataclasses.dataclass +class RawTestParamForDecode: + """ + "Flattened" test parameters for decoding test + + In our test script, to maintain compatibility with TestParam, we embed decode-only parameters into TestParam.decode, which is not very convinient when construct testcases. So here we have a "flattened" version of test parameters for decoding test. """ - Check if two tensors are close enough + b: int + h_q: int + s_q: int + h_kv: int + s_kv: int + is_varlen: bool + topk: int + is_all_indices_invalid: bool = False + have_zero_seqlen_k: bool = False + have_topk_length: bool = False + enable_attn_sink: bool = True + extra_s_k: Optional[int] = None + extra_topk: Optional[int] = None + block_size: int = 64 + extra_block_size: Optional[int] = None + have_extra_topk_length: bool = False + d_qk: int = 576 # Q/K head dim (= dv + RoPE dim) + d_v: int = 512 # V head dim + check_correctness: bool = True + num_runs: int = 10 + seed: int = -1 + + def to_test_param(self) -> TestParam: + return TestParam( + self.s_q, self.s_kv, self.topk, self.h_q, self.h_kv, self.d_qk, self.d_v, + self.seed, self.check_correctness, + self.is_all_indices_invalid, + self.num_runs, + self.enable_attn_sink, + self.have_topk_length, + decode = ExtraTestParamForDecode( + self.b, self.is_varlen, self.have_zero_seqlen_k, + self.extra_s_k, self.extra_topk, + self.block_size, self.extra_block_size, self.have_extra_topk_length + ) + ) + +@dataclasses.dataclass +class Testcase: + p: TestParam + dOut: torch.Tensor # [s_q, h_q, d_v] + q: torch.Tensor # [s_q, h_q, d_qk] + kv: torch.Tensor # [s_kv, h_kv, d_qk] + indices: torch.Tensor # [s_q, h_kv, topk] + sm_scale: float + attn_sink: Optional[torch.Tensor] # [h_q] + topk_length: Optional[torch.Tensor] # [s_q] + +def _randperm_batch(batch_size: int, perm_range: torch.Tensor, perm_size: int, paddings: List[int]) -> torch.Tensor: """ - def get_cos_diff(x: torch.Tensor, y: torch.Tensor) -> float: + Generate random permutations in batch + The return tensor, denoted as `res`, has a shape of [batch_size, perm_size]. `0 <= res[i, :] < perm_range[i]` holds. + Values within each row are unique. + If, for some `i`, `perm_range[i] < perm_size` holds, then `res[i, :]` contains values in `[0, perm_range[i])` as many as possible, and the rest are filled with `padding`. + """ + assert not torch.are_deterministic_algorithms_enabled() + torch.use_deterministic_algorithms(True) + perm_range_max = max(int(torch.max(perm_range).item()), perm_size) + rand = torch.rand(batch_size, perm_range_max, dtype=torch.float32) + rand[torch.arange(0, perm_range_max).broadcast_to(batch_size, perm_range_max) >= perm_range.view(batch_size, 1)] = float("-inf") # Fill invalid positions, so that the following `topk` operators will select positions within `perm_range` first + res = rand.topk(perm_size, dim=-1, sorted=True).indices.to(torch.int32) + if len(paddings) == 1: + res[res >= perm_range.view(batch_size, 1)] = paddings[0] + else: + fillers = torch.tensor(paddings, dtype=torch.int32).index_select(0, torch.randint(0, len(paddings), (res.numel(), ), dtype=torch.int32)) + res.masked_scatter_(res >= perm_range.view(batch_size, 1), fillers) + torch.use_deterministic_algorithms(False) + return res + +def generate_testcase(t: TestParam) -> Testcase: + kk.set_random_seed(t.seed) + q = torch.randn((t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10 + kv = torch.randn((t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10 + do = torch.randn((t.s_q, t.h_q, t.d_v), dtype=torch.bfloat16)/10 + (random.random()-0.5)/10 + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + do.clamp_(-10, 10) + + invalid_indices_candidate = [-2147483648, -123456, -1, t.s_kv, 114514, 1919810, 2147480000, 2147483647] + indices = _randperm_batch(t.s_q, torch.full((t.s_q, ), t.s_kv, dtype=torch.int32), t.topk, invalid_indices_candidate).view(t.s_q, t.h_kv, t.topk) + + if t.is_all_indices_invalid: + all_indices_invalid_mask = torch.randn(t.s_q, device='cpu') < -2 + indices[all_indices_invalid_mask[:, None, None].broadcast_to(indices.shape)] = random.choice(invalid_indices_candidate) + indices = indices.to(q.device) + + attn_sink = None + if t.have_attn_sink: + attn_sink = torch.randn((t.h_q, ), dtype=torch.float32) + mask = torch.randn((t.h_q, ), dtype=torch.float32) + attn_sink[mask < -0.5] = float("-inf") + attn_sink[mask > +0.5] = float("+inf") + + topk_length = None + if t.have_topk_length: + topk_length = torch.randint(0, max(t.topk + 1, 64), (t.s_q, ), dtype=torch.int32, device=q.device).clamp_max(t.topk) + + q = kk.non_contiguousify(q) + kv = kk.non_contiguousify(kv) + do = kk.non_contiguousify(do) + indices = kk.non_contiguousify(indices) + + return Testcase( + p=t, + dOut=do, + q=q, + kv=kv, + indices=indices, + sm_scale=0.5, # Otherwise dK is too small compared to dV + attn_sink=attn_sink, + topk_length=topk_length + ) + + +@dataclasses.dataclass +class KVScope: + t: TestParam + cache_seqlens: torch.Tensor + block_table: torch.Tensor + blocked_k: torch.Tensor + abs_indices: torch.Tensor + indices_in_kvcache: torch.Tensor + topk_length: Optional[torch.Tensor] + blocked_k_quantized: Optional[torch.Tensor] = None + + def quant_and_dequant_(self): """ - Calculate the cosine diff between two tensors + For FP8 cases, we need to quantize the KV cache for Flash MLA. + Besides, the quantization error may be too large to be distinguished from wrong kernels, so we de-quantize kvcache here to mitigate quantization error + """ + fp8_kvcache_layout = None + if self.t.d_qk == 576: + fp8_kvcache_layout = quant.FP8KVCacheLayout.V32_FP8Sparse + elif self.t.d_qk == 512: + assert self.abs_indices is not None + fp8_kvcache_layout = quant.FP8KVCacheLayout.MODEL1_FP8Sparse + else: + assert False + self.blocked_k_quantized = quant.quantize_k_cache(self.blocked_k, fp8_kvcache_layout) + blocked_k_dequantized = quant.dequantize_k_cache(self.blocked_k_quantized, fp8_kvcache_layout) + self.blocked_k = blocked_k_dequantized + + def get_kvcache_for_flash_mla(self) -> torch.Tensor: """ - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum().item() - if denominator == 0: - return 0 - sim = 2 * (x * y).sum().item() / denominator - return 1 - sim - assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}" + Return the quantized blocked_k for Flash MLA + """ + assert self.blocked_k_quantized is not None, "Please call `quant_and_dequant_` first before calling `get_kvcache_for_flash_mla`" + return self.blocked_k_quantized - ans = ans.clone().to(torch.float) - ref = ref.clone().to(torch.float) - - # Deal with anomalies - def deal_with_anomalies(val: float): - ref_mask = (ref == val) if (val == val) else (ref != ref) - ans_mask = (ans == val) if (val == val) else (ans != ans) - ref[ref_mask] = 0.0 - ans[ans_mask] = 0.0 - if not torch.equal(ref_mask, ans_mask): - print(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref") - return False - return True + def apply_perm(self, perm: torch.Tensor) -> "KVScope": + """ + Apply a batch permutation to this KVScope. Used for batch-invariance test + """ + new_kvscope = KVScope( + self.t, + self.cache_seqlens[perm], + self.block_table[perm], + self.blocked_k, + self.abs_indices[perm], + self.indices_in_kvcache[perm], + self.topk_length[perm] if self.topk_length is not None else None, + self.blocked_k_quantized + ) + return new_kvscope - anomalies_check_passed = True - anomalies_check_passed &= deal_with_anomalies(float("inf")) - anomalies_check_passed &= deal_with_anomalies(float("-inf")) - anomalies_check_passed &= deal_with_anomalies(float("nan")) - - if not anomalies_check_passed: - return False - - cos_diff = get_cos_diff(ans, ref) - raw_abs_err = torch.abs(ans-ref) - raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6)) - rel_err = raw_rel_err.masked_fill(raw_abs_err List[int]: - result = [] - for size in t.shape[::-1]: - result.append(pos % size) - pos = pos // size - assert pos == 0 - return result[::-1] - print(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}") - print(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}") - print(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)") - print(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})") - return False +@dataclasses.dataclass +class TestcaseForDecode: + p: TestParam + q: torch.Tensor # [b, s_q, h_q, d_qk] + attn_sink: Optional[torch.Tensor] # [h_q] + sm_scale: float + kv_scope: KVScope + extra_kv_scope: Optional[KVScope] + +def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode: + kk.set_random_seed(t.seed) + assert t.h_q % t.h_kv == 0 + assert t.decode is not None + + q = torch.randn((t.decode.b, t.s_q, t.h_q, t.d_qk)) + q.clamp_(min=-1.0, max=1.0) + + attn_sink = None + if t.have_attn_sink: + attn_sink = torch.randn((t.h_q, ), dtype=torch.float32) + inf_mask = torch.randn((t.h_q, ), dtype=torch.float32) + attn_sink[inf_mask > 0.5] = float("inf") + attn_sink[inf_mask < -0.5] = float("-inf") + + def generate_one_k_scope(s_k: int, block_size: int, topk: int, is_varlen: bool, have_zero_seqlen: bool, is_all_indices_invalid: bool, have_topk_length: bool) -> KVScope: + b = t.decode.b # type: ignore + cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device='cpu') + if is_varlen: + for i in range(b): + cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), t.s_q) + + if have_zero_seqlen: + zeros_mask = torch.randn(b, dtype=torch.float32, device='cpu') > 0 + cache_seqlens_cpu[zeros_mask] = 0 + + max_seqlen_alignment = 4 * block_size + max_seqlen_pad = max(kk.cdiv(int(cache_seqlens_cpu.max().item()), max_seqlen_alignment), 1) * max_seqlen_alignment + cache_seqlens = cache_seqlens_cpu.cuda() + + assert max_seqlen_pad % block_size == 0 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1) + + blocked_k = kk.gen_non_contiguous_randn_tensor((block_table.numel(), block_size, t.h_kv, t.d_qk)) / 10 + blocked_k.clamp_(min=-1.0, max=1.0) + + abs_indices = torch.empty((b, t.s_q, topk), dtype=torch.int32) + if is_all_indices_invalid: + abs_indices.fill_(-1) + else: + abs_indices[:] = _randperm_batch(b*t.s_q, cache_seqlens.repeat_interleave(t.s_q), topk, [-1]).view(b, t.s_q, topk) + indices_in_kvcache = quant.abs_indices2indices_in_kvcache(abs_indices, block_table, block_size) + + topk_length = torch.randint(0, topk+1, (b, ), dtype=torch.int32, device=q.device) if have_topk_length else None + + # Mask nonused KV as NaN + if have_topk_length: + indices_in_kvcache_masked = indices_in_kvcache.clone() + indices_in_kvcache_masked[torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, t.s_q, topk) >= (topk_length.view(b, 1, 1) if have_topk_length else topk)] = -1 + else: + indices_in_kvcache_masked = indices_in_kvcache + + blocked_k = blocked_k.view(-1, t.h_kv, t.d_qk) + nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu') + nonused_indices_mask[indices_in_kvcache_masked] = False + blocked_k[nonused_indices_mask, :, :] = float("nan") + blocked_k = blocked_k.view(-1, block_size, t.h_kv, t.d_qk) + + block_table = kk.non_contiguousify(block_table) + abs_indices = kk.non_contiguousify(abs_indices) + indices_in_kvcache = kk.non_contiguousify(indices_in_kvcache) + return KVScope(t, cache_seqlens, block_table, blocked_k, abs_indices, indices_in_kvcache, topk_length) + + kv_scope0 = generate_one_k_scope(t.s_kv, t.decode.block_size, t.topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.have_topk_length) + kv_scope0.quant_and_dequant_() + if t.decode.extra_topk is not None: + if t.decode.extra_s_k is None: + t.decode.extra_s_k = t.decode.extra_topk*2 + if t.decode.extra_block_size is None: + t.decode.extra_block_size = t.decode.block_size + kv_scope1 = generate_one_k_scope(t.decode.extra_s_k, t.decode.extra_block_size, t.decode.extra_topk, t.decode.is_varlen, t.decode.have_zero_seqlen_k, t.is_all_indices_invalid, t.decode.have_extra_topk_length) + kv_scope1.quant_and_dequant_() else: - if abs(cos_diff) > cos_diff_tol: - print(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})") - return False - return True \ No newline at end of file + assert t.decode.extra_block_size is None and t.decode.extra_s_k is None and not t.decode.have_extra_topk_length + kv_scope1 = None + + sm_scale = t.d_qk ** -0.55 + + q = kk.non_contiguousify(q) + return TestcaseForDecode(t, q, attn_sink, sm_scale, kv_scope0, kv_scope1) + + +def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool): + assert not return_p_sum + return flash_mla.flash_mla_sparse_fwd( + t.q, t.kv, t.indices, + sm_scale=t.sm_scale, + attn_sink=t.attn_sink, + topk_length=t.topk_length + ) + +def run_flash_mla_decode(p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits): + assert p.decode is not None + return flash_mla.flash_mla_with_kvcache( + t.q, + t.kv_scope.get_kvcache_for_flash_mla(), + None, None, p.d_v, + tile_scheduler_metadata, num_splits, + + t.sm_scale, False, True, + t.kv_scope.indices_in_kvcache, + t.attn_sink, + t.extra_kv_scope.get_kvcache_for_flash_mla() if t.extra_kv_scope is not None else None, + t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None, + t.kv_scope.topk_length, + t.extra_kv_scope.topk_length if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None else None + ) + + +@dataclasses.dataclass +class FlopsAndMemVolStatistics: + """ + FLOPs and memory volume statistics for prefilling + """ + fwd_flop: float + fwd_mem_vol: float + +def count_flop_and_mem_vol(p: TestParam, t: Testcase) -> FlopsAndMemVolStatistics: + total_topk = (p.s_q*p.topk) if t.topk_length is None else t.topk_length.sum().item() + indices_valid_mask = (t.indices >= 0) & (t.indices < p.s_kv) + if t.topk_length is not None: + indices_valid_mask &= (torch.arange(p.topk)[None, None, :].broadcast_to(p.s_q, p.h_kv, p.topk)) < t.topk_length[:, None, None] + num_valid_indices = indices_valid_mask.sum().item() + + fwd_flop = 2 * total_topk * p.h_q * (p.d_qk + p.d_v) + fwd_mem_vol = num_valid_indices*p.d_qk*2 + p.s_q*p.h_q*(p.d_qk+p.d_v)*2 + return FlopsAndMemVolStatistics( + fwd_flop, + fwd_mem_vol, + ) + +@dataclasses.dataclass +class FlopsAndMemVolStatisticsForDecode: + """ + FLOPs and memory volume statistics for decoding + """ + flop: float + mem_vol: float + +def count_flop_and_mem_vol_for_decode(p: TestParam, t: TestcaseForDecode) -> FlopsAndMemVolStatisticsForDecode: + assert p.decode + b = p.decode.b + + def get_num_attended_tokens(kv_scope: KVScope) -> int: + topk = kv_scope.indices_in_kvcache.shape[-1] + if kv_scope.topk_length is None: + return b * p.s_q * topk + else: + return int(kv_scope.topk_length.sum().item()) * p.s_q + + def get_num_retrieved_tokens(kv_scope: KVScope) -> int: + if kv_scope.topk_length is None: + indices = kv_scope.indices_in_kvcache + else: + indices = kv_scope.indices_in_kvcache.clone() + batch, s_q, topk = indices.shape + mask = torch.arange(0, topk, device=indices.device).view(1, 1, topk).broadcast_to(batch, s_q, topk) >= kv_scope.topk_length.view(batch, 1, 1) + indices[mask] = -1 + num_unique_tokens = indices.unique().numel() # type: ignore + return num_unique_tokens + + num_attended_tokens = get_num_attended_tokens(t.kv_scope) + (get_num_attended_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0) + num_retrieved_tokens = get_num_retrieved_tokens(t.kv_scope) + (get_num_retrieved_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0) + + compute_flop = 2 * p.h_q * num_attended_tokens * (p.d_qk + p.d_v) + kv_token_size = 656 if p.d_qk == 576 else 576 # Assume FP8 KV Cache + mem_vol = sum([ + 2 * b * p.s_q * p.h_q * p.d_qk, # Q + num_retrieved_tokens * kv_token_size, # K + 2 * b * p.s_q * p.h_q * p.d_v, # O + ]) + return FlopsAndMemVolStatisticsForDecode( + compute_flop, + mem_vol + ) + +def is_no_cooldown() -> bool: + return os.environ.get('NO_COOLDOWN', '').lower() in ['1', 'yes', 'y'] diff --git a/tests/quant.py b/tests/quant.py index 0624759..b92b539 100644 --- a/tests/quant.py +++ b/tests/quant.py @@ -1,66 +1,158 @@ +import enum +from typing import Tuple + import torch +class FP8KVCacheLayout(enum.Enum): + V32_FP8Sparse = 1 + MODEL1_FP8Sparse = 2 + + def get_meta(self) -> Tuple[int, int, int, int, int]: + # Return: (d, d_nope, d_rope, tile_size, num_tiles) + return { + FP8KVCacheLayout.V32_FP8Sparse: (576, 512, 64, 128, 4), + FP8KVCacheLayout.MODEL1_FP8Sparse: (512, 448, 64, 64, 7) + }[self] + +def _cast_scale_inv_to_ue8m0(scales_inv: torch.Tensor, out_dtype = torch.float32) -> torch.Tensor: + return torch.pow(2, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to(out_dtype) + def quantize_k_cache( input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d) - dv: int, - tile_size: int = 128, + kvcache_layout: FP8KVCacheLayout, ) -> torch.Tensor: """ Quantize the k-cache - Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size() - For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md + For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py """ - assert dv % tile_size == 0 - num_tiles = dv // tile_size - num_blocks, block_size, h_k, d = input_k_cache.shape + d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta() + assert input_k_cache.shape[-1] == d + num_blocks, block_size, h_k, _ = input_k_cache.shape assert h_k == 1 input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] input_elem_size = input_k_cache.element_size() - result = torch.empty((num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device) - result_k_nope_part = result[..., :dv] - result_k_scale_factor = result[..., dv:dv + num_tiles * 4].view(torch.float32) - result_k_rope_part = result[..., dv + num_tiles * 4:].view(input_k_cache.dtype) - result_k_rope_part[:] = input_k_cache[..., dv:] + if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse: + bytes_per_token = d_nope + num_tiles*4 + input_elem_size*d_rope + result = torch.empty((num_blocks, block_size+1, bytes_per_token), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size, :] + result_k_nope_part = result[..., :d_nope] + result_k_scale_factor = result[..., d_nope: d_nope + num_tiles*4].view(torch.float32) + result_k_rope_part = result[..., d_nope + num_tiles*4:].view(input_k_cache.dtype) + result_k_rope_part[:] = input_k_cache[..., d_nope:] - for tile_idx in range(0, num_tiles): - cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size] - result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv + for tile_idx in range(0, num_tiles): + cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size] + cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv) + result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv - cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] - cur_quantized_nope = (input_k_cache[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) - result_k_nope_part[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_quantized_nope + cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] + cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) + result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope + + result = result.view(num_blocks, block_size, 1, -1) + return result + + elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse: + bytes_per_token = d_nope + 2*d_rope + num_tiles + 1 + size_per_block_padded = (block_size*bytes_per_token + 576-1) // 576 * 576 + result = torch.empty((num_blocks, size_per_block_padded), dtype=torch.float8_e4m3fn, device=input_k_cache.device)[:, :block_size*bytes_per_token] + result_k_nope_rope_part = result[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope) + result_k_nope = result_k_nope_rope_part[:, :, :d_nope] # [num_blocks, block_size, d_nope] + result_k_rope = result_k_nope_rope_part[:, :, d_nope:].view(input_k_cache.dtype) # [num_blocks, block_size, d_rope] + result_k_scale_factor = result[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu) # [num_blocks, block_size, num_tiles] - result = result.view(num_blocks, block_size, 1, -1) - return result + result_k_rope[:] = input_k_cache[..., d_nope:] + for tile_idx in range(0, num_tiles): + cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values.float() / 448.0 # [num_blocks, block_size] + cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv) + result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv.to(torch.float8_e8m0fnu) + + cur_scale_factors_inv = cur_scale_factors_inv.view(num_blocks, block_size, 1) + cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn) + result_k_nope[:, :, tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope + + result = result.view(num_blocks, block_size, 1, -1) + return result + else: + raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}") + def dequantize_k_cache( quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token) - dv: int = 512, - tile_size: int = 128, - d: int = 576 + kvcache_layout: FP8KVCacheLayout, ) -> torch.Tensor: """ De-quantize the k-cache """ - assert dv % tile_size == 0 - num_tiles = dv // tile_size + d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta() num_blocks, block_size, h_k, _ = quant_k_cache.shape assert h_k == 1 result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device) - quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) + if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse: + quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) - input_nope = quant_k_cache[..., :dv] - input_scale = quant_k_cache[..., dv:dv + num_tiles * 4].view(torch.float32) - input_rope = quant_k_cache[..., dv + num_tiles * 4:].view(torch.bfloat16) - result[..., dv:] = input_rope + input_nope = quant_k_cache[..., :d_nope] + input_scale = quant_k_cache[..., d_nope:d_nope + num_tiles*4].view(torch.float32) + input_rope = quant_k_cache[..., d_nope + num_tiles*4:].view(torch.bfloat16) + result[..., d_nope:] = input_rope - for tile_idx in range(0, num_tiles): - cur_nope = input_nope[..., tile_idx * tile_size:(tile_idx + 1) * tile_size].to(torch.float32) - cur_scales = input_scale[..., tile_idx].unsqueeze(-1) - result[..., tile_idx * tile_size:(tile_idx + 1) * tile_size] = cur_nope * cur_scales + for tile_idx in range(0, num_tiles): + cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32) + cur_scales = input_scale[..., tile_idx].unsqueeze(-1) + result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales + elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse: + quant_k_cache = quant_k_cache.view(num_blocks, -1) # [num_blocks, ...] + input_nope_rope = quant_k_cache[:, :block_size*(d_nope+2*d_rope)].view(num_blocks, block_size, d_nope + 2*d_rope) + input_nope = input_nope_rope[:, :, :d_nope] + input_rope = input_nope_rope[:, :, d_nope:].view(torch.bfloat16) + input_scale = quant_k_cache[:, block_size*(d_nope+2*d_rope):].view(num_blocks, block_size, 8)[:, :, :7].view(torch.float8_e8m0fnu) # [num_blocks, block_size, num_tiles] + + result[..., d_nope:] = input_rope + for tile_idx in range(0, num_tiles): + cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.bfloat16) + cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1) + result[..., tile_idx*tile_size: (tile_idx+1)*tile_size] = cur_nope * cur_scales + + else: + raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}") + result = result.view(num_blocks, block_size, 1, d) return result + + +def abs_indices2indices_in_kvcache( + abs_indices: torch.Tensor, # [b, s_q, topk] + block_table: torch.Tensor, # [b, /] + block_size: int, +) -> torch.Tensor: + """ + Convert abs_indices (logical index, ranging from 0 to s_k-1) to index expected by the sparse attn kernel + Equivalent to: + + b, s_q, topk = abs_indices.shape + indices_in_kvcache = torch.empty_like(abs_indices) + for i in range(b): + cur_abs_indices = abs_indices[i, :, :].clone() # [s_q, topk] + invalid_mask = cur_abs_indices == -1 + cur_abs_indices[invalid_mask] = 0 + cur_indices_in_kvcache = block_table[i].index_select(0, cur_abs_indices.flatten()//block_size).view(s_q, topk)*block_size + cur_abs_indices%block_size + cur_indices_in_kvcache[invalid_mask] = -1 + indices_in_kvcache[i] = cur_indices_in_kvcache + return indices_in_kvcache + + """ + b, s_q, topk = abs_indices.shape + _, max_blocks_per_seq = block_table.shape + + abs_indices = abs_indices.clone() + invalid_mask = abs_indices == -1 + abs_indices[invalid_mask] = 0 + + real_block_idxs = block_table.view(-1).index_select(0, (abs_indices//block_size + torch.arange(0, b).view(b, 1, 1)*max_blocks_per_seq).view(-1)) + indices_in_kvcache = real_block_idxs.view(b, s_q, topk)*block_size + abs_indices%block_size + indices_in_kvcache[invalid_mask] = -1 + + return indices_in_kvcache \ No newline at end of file diff --git a/tests/ref.py b/tests/ref.py new file mode 100644 index 0000000..e5a14b3 --- /dev/null +++ b/tests/ref.py @@ -0,0 +1,103 @@ +from typing import Optional, Tuple + +import torch + +from lib import TestParam, Testcase, TestcaseForDecode, KVScope + +def _merge_two_lse(lse0: torch.Tensor, lse1: Optional[torch.Tensor], s_q: int, h_q: int) -> torch.Tensor: + if lse1 is None: + return lse0 + else: + return torch.logsumexp( + torch.stack([ + lse0.view(s_q, h_q), + lse1.broadcast_to(s_q, h_q) + ], dim=0), + dim=0 + ) + +def ref_sparse_attn_fwd(p: TestParam, t: Testcase) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Returns: + - o: [s_q, h_q, dv] + - o_fp32: [s_q, h_q, dv] + - max_logits: [s_q, h_q] + - lse: [s_q, h_q] + """ + indices = t.indices.clone().squeeze(1) + if t.topk_length is not None: + mask = torch.arange(p.topk, device=t.topk_length.device).unsqueeze(0).broadcast_to(p.s_q, p.topk) >= t.topk_length.unsqueeze(1) # [s_q, topk] + indices[mask] = -1 + invalid_mask = (indices < 0) | (indices >= p.s_kv) # [s_q, topk] + indices[invalid_mask] = 0 + + q = t.q.float() + gathered_kv = t.kv.index_select(dim=0, index=indices.flatten()).reshape(p.s_q, p.topk, p.d_qk).float() # [s_q, topk, d_qk] + P = (q @ gathered_kv.transpose(1, 2)) # [s_q, h_q, topk] + P *= t.sm_scale + P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float("-inf") + + orig_lse = torch.logsumexp(P, dim=-1) # [s_q, h_q] + max_logits = P.max(dim=-1).values # [s_q, h_q] + + lse_for_o = _merge_two_lse(orig_lse, t.attn_sink, p.s_q, p.h_q) + if not torch.is_inference_mode_enabled(): + lse_for_o = lse_for_o.clone() + lse_for_o[lse_for_o == float("-inf")] = float("+inf") # So that corresponding O will be 0 + s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1)) + out = s_for_o @ gathered_kv[..., :p.d_v] # [s_q, h_q, dv] + + lonely_q_mask = orig_lse == float("-inf") # [s_q, h_q] + orig_lse[lonely_q_mask] = float("+inf") + return (out.to(torch.bfloat16), out, max_logits, orig_lse) + + +def ref_sparse_attn_decode( + p: TestParam, + t: TestcaseForDecode +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation of sparse decoding attention in PyTorch + """ + assert p.h_kv == 1 + assert p.decode is not None + b = p.decode.b + + def process_kv_scope(kv_scope: KVScope) -> Tuple[torch.Tensor, torch.Tensor]: + assert kv_scope.indices_in_kvcache is not None + topk = kv_scope.indices_in_kvcache.size(-1) + indices_in_kv_cache_fixed = torch.clamp_min(kv_scope.indices_in_kvcache, 0) # Otherwise torch.index_select will complain + gathered_kv = kv_scope.blocked_k.view(-1, p.d_qk).index_select(0, indices_in_kv_cache_fixed.view(-1)).view(b, p.s_q, topk, p.d_qk) # [b, s_q, topk, d] + invalid_mask = kv_scope.indices_in_kvcache == -1 + if kv_scope.topk_length is not None: + invalid_mask |= torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, p.s_q, topk) >= kv_scope.topk_length.view(b, 1, 1) + return gathered_kv, invalid_mask + + gathered_kv, invalid_mask = process_kv_scope(t.kv_scope) + if t.extra_kv_scope is not None: + gathered_kv1, invalid_mask1 = process_kv_scope(t.extra_kv_scope) + gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2) # [b, s_q, topk+extra_topk, d] + invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2) # [b, s_q, topk+extra_topk] + + gathered_kv = gathered_kv.view(b*p.s_q, -1, p.d_qk).float() + gathered_kv[gathered_kv != gathered_kv] = 0.0 + q = t.q.float().view(b*p.s_q, p.h_q, p.d_qk) + attn_weight = q @ gathered_kv.transpose(-1, -2) # [t.b*t.s_q, t.h_q, topk+extra_topk] + attn_weight *= t.sm_scale + attn_weight[invalid_mask.view(b*p.s_q, 1, -1).broadcast_to(b*p.s_q, p.h_q, invalid_mask.size(-1))] = float("-inf") + lse = attn_weight.logsumexp(dim=-1) # [t.b*t.s_q, t.h_q] + attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1)) + output = attn_weight @ gathered_kv[..., :p.d_v] # [t.b*t.s_q, t.h_q, t.dv] + output = output.view(b, p.s_q, p.h_q, p.d_v) + lse = lse.view(b, p.s_q, p.h_q) + + # Attention sink + if t.attn_sink is not None: + output *= (1.0 / (1.0 + torch.exp(t.attn_sink.view(1, 1, p.h_q) - lse))).unsqueeze(-1) + + # Correct for q tokens which has no attendable k + lonely_q_mask = (lse == float("-inf")) + output[lonely_q_mask.unsqueeze(-1).broadcast_to(b, p.s_q, p.h_q, p.d_v)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output.to(torch.bfloat16), lse.transpose(1, 2) \ No newline at end of file diff --git a/tests/test_flash_mla_decoding.py b/tests/test_flash_mla_dense_decoding.py similarity index 50% rename from tests/test_flash_mla_decoding.py rename to tests/test_flash_mla_dense_decoding.py index dc140d7..6cf92dd 100644 --- a/tests/test_flash_mla_decoding.py +++ b/tests/test_flash_mla_dense_decoding.py @@ -2,14 +2,12 @@ import math import random import dataclasses -from typing import Optional, Tuple +from typing import Tuple import torch -import triton +import kernelkit as kk import flash_mla -import quant -from lib import cdiv, check_is_allclose @dataclasses.dataclass class TestParam: @@ -18,10 +16,7 @@ class TestParam: s_k: int # Seq len, or mean seq len if varlen == True is_varlen: bool is_causal: bool - is_fp8: bool - topk: Optional[int] = None test_performance: bool = True - is_all_indices_invalid: bool = False have_zero_seqlen_k: bool = False block_size: int = 64 h_q: int = 128 # Number of q heads @@ -31,7 +26,7 @@ class TestParam: seed: int = 0 -def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: +def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Generate test data from a given configuration Return: [cache_seqlens, q, block_table, blocked_k] @@ -53,11 +48,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0 cache_seqlens_cpu[zeros_mask] = 0 - max_seqlen = cache_seqlens_cpu.max().item() - max_seqlen_pad = cdiv(max_seqlen, 256) * 256 + max_seqlen = int(cache_seqlens_cpu.max().item()) + max_seqlen_pad = kk.cdiv(max_seqlen, 256) * 256 cache_seqlens = cache_seqlens_cpu.cuda() - q = torch.randn(t.b, t.s_q, t.h_q, t.d) + q = torch.randn(t.b, t.s_q, t.h_q, t.d) / 10 q.clamp_(min=-1.0, max=1.0) block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size) @@ -65,59 +60,14 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch. blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10 blocked_k.clamp_(min=-1.0, max=1.0) - if t.topk is None: - for i in range(t.b): - cur_len = cache_seqlens_cpu[i].item() - cur_num_blocks = cdiv(cur_len, t.block_size) - blocked_k[block_table[i][cur_num_blocks:]] = float("nan") - if cur_len % t.block_size != 0: - blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan") - block_table[i][cur_num_blocks:] = 2147480000 - return cache_seqlens, q, block_table, blocked_k, None, None - else: - block_table_cpu = block_table.cpu() - abs_indices = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu") - indices_in_kvcache = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu") - for i in range(t.b): - # Generate indices - for j in range(t.s_q): - cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk] - cur_blocked_indices = block_table_cpu[i, cur_abs_indices // t.block_size] * t.block_size + (cur_abs_indices % t.block_size) - if len(cur_abs_indices) < t.topk: - pad_len = t.topk - len(cur_abs_indices) - cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')]) - cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')]) - - # Mask KV - perm = torch.randperm(t.topk, device='cpu') - cur_abs_indices = cur_abs_indices[perm] - cur_blocked_indices = cur_blocked_indices[perm] - - # Fill it with invalid indices if needed - if t.is_all_indices_invalid: - cur_abs_indices.fill_(-1) - cur_blocked_indices.fill_(-1) - - abs_indices[i, j, :] = cur_abs_indices - indices_in_kvcache[i, j, :] = cur_blocked_indices - - # Mask nonused KV as NaN - all_indices = indices_in_kvcache.flatten().tolist() - all_indices = list(set(all_indices)) - if -1 in all_indices: - all_indices.remove(-1) - all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu') - - blocked_k = blocked_k.view(-1, t.h_kv, t.d) - nonused_indices_mask = torch.ones(blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device='cpu') - nonused_indices_mask[all_indices] = False - blocked_k[nonused_indices_mask, :, :] = float("nan") - blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d) - - abs_indices = abs_indices.to(q.device) - indices_in_kvcache = indices_in_kvcache.to(q.device) - - return cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache + for i in range(t.b): + cur_len = int(cache_seqlens_cpu[i].item()) + cur_num_blocks = kk.cdiv(cur_len, t.block_size) + blocked_k[block_table[i][cur_num_blocks:]] = float("nan") + if cur_len % t.block_size != 0: + blocked_k[block_table[i][cur_num_blocks - 1]][cur_len % t.block_size:] = float("nan") + block_table[i][cur_num_blocks:] = 2147480000 + return cache_seqlens, q, block_table, blocked_k def reference_torch( @@ -127,18 +77,10 @@ def reference_torch( blocked_k: torch.Tensor, # [?, block_size, h_kv, d] dv: int, is_causal: bool, - indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk] ) -> Tuple[torch.Tensor, torch.Tensor]: """ A reference implementation in PyTorch """ - def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor): - mask = torch.zeros(s_q, s_k, dtype=torch.bool) - for i in range(s_q): - cur_indices = indices[i] - valid_indices = cur_indices[cur_indices != -1] - mask[i, valid_indices] = True - return mask def scaled_dot_product_attention( batch_idx: int, @@ -146,7 +88,6 @@ def scaled_dot_product_attention( kv: torch.Tensor, # [h_kv, s_k, d] dv: int, is_causal, - indices: Optional[torch.Tensor], # [s_q, topk] ) -> Tuple[torch.Tensor, torch.Tensor]: h_q = query.size(0) h_kv = kv.size(0) @@ -158,13 +99,10 @@ def scaled_dot_product_attention( kv = kv.repeat_interleave(h_q // h_kv, dim=0) kv[kv != kv] = 0.0 attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k] - if (is_causal and query.size(1) > 1) or indices is not None: + if is_causal and query.size(1) > 1: mask = torch.ones(s_q, s_k, dtype=torch.bool) if is_causal: - assert indices is None mask = mask.tril(diagonal=s_k - s_q) - if indices is not None: - mask &= get_topk_attn_mask(s_q, s_k, indices) attn_bias = torch.zeros(s_q, s_k, dtype=torch.float) attn_bias.masked_fill_(mask.logical_not(), float("-inf")) attn_weight += attn_bias.to(q.dtype) @@ -186,8 +124,8 @@ def scaled_dot_product_attention( out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): - cur_len = cache_seqlens_cpu[i].item() - cur_num_blocks = cdiv(cur_len, block_size) + cur_len = int(cache_seqlens_cpu[i].item()) + cur_num_blocks = kk.cdiv(cur_len, block_size) cur_block_indices = block_table[i][0: cur_num_blocks] cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...] cur_out, cur_lse = scaled_dot_product_attention( @@ -195,12 +133,11 @@ def scaled_dot_product_attention( q[i].transpose(0, 1), cur_kv.transpose(0, 1), dv, - is_causal, - indices[i] if indices is not None else None + is_causal ) out_ref[i] = cur_out.transpose(0, 1) lse_ref[i] = cur_lse - out_ref = out_ref.to(torch.bfloat16) + out_ref = out_ref.to(q.dtype) return out_ref, lse_ref @@ -211,58 +148,42 @@ def test_flash_mla(t: TestParam): # Generating test data torch.cuda.synchronize() - cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache = generate_test_data(t) + cache_seqlens, q, block_table, blocked_k, = generate_test_data(t) - if t.is_fp8: - # The quantization error may be too large to be distinguished from wrong kernels - # So we quantize and de-quantize kv-cache here to mitigate quantization error - blocked_k_quantized = quant.quantize_k_cache(blocked_k, t.dv, 128) - blocked_k_dequantized = quant.dequantize_k_cache(blocked_k_quantized) - blocked_k = blocked_k_dequantized - - # Get schedule metadata - torch.cuda.synchronize() - tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( - cache_seqlens, - t.s_q * t.h_q // t.h_kv, - t.h_kv, - t.h_q, - t.is_fp8, - t.topk - ) - torch.cuda.synchronize() + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata() def run_flash_mla(): return flash_mla.flash_mla_with_kvcache( q, - blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore + blocked_k, block_table, cache_seqlens, t.dv, tile_scheduler_metadata, num_splits, - causal=t.is_causal, - is_fp8_kvcache=t.is_fp8, - indices=indices_in_kvcache + causal=t.is_causal ) out_ans, lse_ans = run_flash_mla() - out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices) - assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) - assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536) + out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal) + is_correct = True + is_correct &= kk.check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=5e-6) + is_correct &= kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536) + assert is_correct if t.test_performance: - time_usage: float = triton.testing.do_bench(run_flash_mla) / 1000 # type: ignore - mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk + time_usage = kk.bench_kineto(run_flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla_kernel") + + mean_attended_seqlens = cache_seqlens.float().mean().item() compute_volume_flop = t.b * t.h_q * t.s_q * sum([ 2 * t.d * mean_attended_seqlens, # Q * K^T 2 * mean_attended_seqlens * t.dv, # attention * V ]) q_elem_size = torch.bfloat16.itemsize - kv_token_size = 656 if t.is_fp8 else t.d * torch.bfloat16.itemsize + kv_token_size = t.d * torch.bfloat16.itemsize memory_volume_B = t.b * sum([ t.s_q * t.h_q * (t.d * q_elem_size), # Q - (t.s_q if t.topk is not None else 1) * mean_attended_seqlens * t.h_kv * kv_token_size, # K/V + mean_attended_seqlens * t.h_kv * kv_token_size, # K/V t.s_q * t.h_q * (t.dv * q_elem_size), # Output ]) achieved_tflops = compute_volume_flop / time_usage / 1e12 @@ -277,54 +198,39 @@ def main(torch_dtype): torch.set_default_device(device) torch.cuda.set_device(device) + cc_major, cc_minor = torch.cuda.get_device_capability() + assert cc_major == 9, "Dense MLA decoding is only supported on sm90 (Hopper) currently." + correctness_cases = [ - TestParam(b, s_q, s_k, is_varlen, is_causal, is_fp8, topk, test_performance=False) + TestParam(b, s_q, s_k, is_varlen, is_causal, test_performance=False, have_zero_seqlen_k=False, block_size=64, h_q=h_q, h_kv=h_kv) for b in [1, 2, 6, 64] for s_q in [1, 2, 4] for s_k in [20, 140, 4096] + for h_q in [1, 3, 9, 63, 64, 126, 128] + for h_kv in [1, 2, 3, 8] for is_varlen in [False, True] for is_causal in [False, True] - for (is_fp8, topk) in [ - (False, None), - (True, 128), - (True, 2048) - ] - if not (is_causal and topk is not None) + if h_q % h_kv == 0 ] corner_cases = [ - # Cases where all topk indices are invalid - TestParam(128, 2, 4096, is_varlen=True, is_causal=False, is_fp8=True, topk=topk, test_performance=False, is_all_indices_invalid=True) - for topk in [128, 2048, 4096] - ] + [ # Cases where some kv cache have zero length - TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=False, have_zero_seqlen_k=True) - for (is_causal, is_fp8, topk) in [ - (False, False, None), - (True, False, None), - (False, True, 128), - (False, True, 2048), - ] + TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, test_performance=False, have_zero_seqlen_k=True, h_q=h_q, h_kv=h_kv) + for h_q in [1, 3, 9, 63, 64, 126, 128] + for h_kv in [1, 2, 3, 8] + for is_causal in [False, True] + if h_q % h_kv == 0 ] performance_cases = [ - TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=True) - for (is_causal, is_fp8, topk) in [ - (False, False, None), - (True, False, None), - (False, True, 2048), - ] + TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, test_performance=True) + for is_causal in [False, True] for s_q in [1, 2] for s_k in [4096, 8192, 16384, 32768] ] testcases = correctness_cases + corner_cases + performance_cases - # Prune out unsupported cases - cc_major, cc_minor = torch.cuda.get_device_capability() - if cc_major == 10: - testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)] - for testcase in testcases: test_flash_mla(testcase) @@ -345,4 +251,4 @@ def main(torch_dtype): if args.dtype == "fp16": torch_dtype = torch.float16 - main(torch_dtype) + main(torch_dtype) \ No newline at end of file diff --git a/tests/test_flash_mla_prefill.py b/tests/test_flash_mla_prefill.py deleted file mode 100644 index d2f5b7e..0000000 --- a/tests/test_flash_mla_prefill.py +++ /dev/null @@ -1,196 +0,0 @@ -import math -import time -from typing import Tuple -import random -import dataclasses - -import torch -import triton - -from flash_mla import flash_mla_sparse_fwd -from lib import check_is_allclose - -@dataclasses.dataclass -class TestParam: - b: int - s_q: int - s_kv: int - topk: int - h_q: int = 128 - h_kv: int = 1 - d_qk: int = 576 - d_v: int = 512 - seed: int = 0 - check_correctness: bool = True - benchmark: bool = True - -@dataclasses.dataclass -class Testcase: - t: TestParam - q: torch.Tensor - kv: torch.Tensor - indices: torch.Tensor - -def generate_testcase(t: TestParam) -> Testcase: - torch.manual_seed(t.seed) - torch.cuda.manual_seed(t.seed) - random.seed(t.seed) - q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10 - kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10 - - q.clamp_(-10, 10) - kv.clamp_(-10, 10) - - indices = torch.full((t.b, t.s_q, t.h_kv, t.topk), t.s_kv, dtype=torch.int32) - for b in range(t.b): - for s in range(t.s_q): - for h in range(t.h_kv): - # NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention - near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31 - cur_indices = torch.randperm(t.s_kv)[:t.topk] - cur_indices[near_mask] = torch.randint(max(0, t.s_kv - 20000), t.s_kv - 1, (near_mask.sum().item(),)) - if len(cur_indices) < t.topk: - cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)]) - cur_indices = cur_indices[torch.randperm(t.topk)] - indices[b, s, h] = cur_indices - indices = indices.to(q.device) - - return Testcase( - t=t, - q=q, - kv=kv, - indices=indices - ) - -def get_flop(p: TestParam) -> float: - flop = 2 * sum([ - p.h_q * p.d_qk * p.topk, - p.h_q * p.d_v * p.topk - ]) * p.b * p.s_q - return flop - -def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: - return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) - - assert p.b == 1 - indices = t.indices[0, :, 0, :] # [s_q, topk] - invalid_indices_mask = (indices < 0) | (indices >= p.s_kv) - qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk] - kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk] - - kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(p.s_q, p.topk, p.d_qk) # [s_q, topk, d_qk] - attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk] - attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf')) - attn_score *= sm_scale * math.log2(math.e) - max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q] - lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] - attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk] - result = attn_score @ kvs[:, :, :p.d_v] - return (max_logits, lse, result) - -@torch.inference_mode() -def run_test(p: TestParam) -> bool: - print("================") - print(f"Running on {p}") - torch.cuda.empty_cache() - assert p.b == 1 - - t = generate_testcase(p) - sm_scale = 1 / math.sqrt(p.d_qk) - torch.cuda.synchronize() - - def run_ans(): - return flash_mla_sparse_fwd( - t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale - ) - - ans_out, ans_max_logits, ans_lse = run_ans() - torch.cuda.synchronize() - - if p.benchmark: - flop = get_flop(p) - prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore - prefill_flops = flop / prefill_ans_time / 1e12 - print(f"Prefill: {prefill_ans_time * 1e6:4.0f} us, {prefill_flops:.3f} TFlops") - - if p.check_correctness: - torch.cuda.synchronize() - ref_max_logits, ref_lse, ref_out = reference_torch(p, t, sm_scale) - torch.cuda.synchronize() - - is_correct = True - is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01 / 128, cos_diff_tol=7e-6) - is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01 / 65536) - is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01 / 65536) - - return is_correct - else: - return True - - -if __name__ == '__main__': - device = torch.device("cuda:0") - torch.set_default_dtype(torch.bfloat16) - torch.set_default_device(device) - torch.cuda.set_device(device) - torch.set_float32_matmul_precision('high') - - correctness_cases = [ - # Regular shapes - TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) - for s_kv, topk in [ - # Regular shapes - (128, 128), - (256, 256), - (512, 512), - - # Irregular shapes - (592, 128), - (1840, 256), - (1592, 384), - (1521, 512), - - # Irregular shapes with OOB TopK - (95, 128), - (153, 256), - (114, 384), - ] - for s_q in [ - 1, 62 - ] - ] - - corner_cases = [ - # In these cases, some blocks may not have any valid topk indices - TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False) - for s_kv, topk in [ - (32, 2048), - (64, 8192) - ] - for s_q in [1, 1024] - ] - - performance_cases = [ - TestParam(1, s_q, s_kv, topk, h_q=128) - for s_q in [4096] - for s_kv in [4096, 8192, 16384, 32768, 49152, 65536, 81920, 98304, 114688, 131072] - for topk in [2048] - ] - - testcases = correctness_cases + corner_cases + performance_cases - - failed_cases = [] - for test in testcases: - if test.benchmark: - time.sleep(0.2) - is_correct = run_test(test) - if not is_correct: - failed_cases.append(test) - - if len(failed_cases) > 0: - print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") - for case in failed_cases: - print(f" {case}") - else: - print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") diff --git a/tests/test_flash_mla_sparse_decoding.py b/tests/test_flash_mla_sparse_decoding.py new file mode 100644 index 0000000..31aac48 --- /dev/null +++ b/tests/test_flash_mla_sparse_decoding.py @@ -0,0 +1,319 @@ +import time +import dataclasses +from typing import Tuple, List, Dict, Optional +import copy + +import rich.console +import rich.table + +import torch +import kernelkit as kk + +import flash_mla + +import lib +from lib import TestParam +from lib import RawTestParamForDecode as RawTestParam +import ref + +""" +Generate testcase for unit test +""" + +def gen_testcase() -> List[RawTestParam]: + correctness_cases = [] + corner_cases = [] + for d_qk in [576, 512]: + for have_extra_k in ([False, True] if d_qk == 512 else [False]): + for have_extra_topk_len in ([False, True] if have_extra_k else [False]): + for have_topk_len in ([False, True] if d_qk == 512 else [False]): + for h_q in [64, 128]: + cur_correctness_cases = [ + RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk, + have_topk_length=have_topk_len, + enable_attn_sink=True, + extra_s_k=extra_s_k, + extra_topk=extra_topk, + block_size=block_size, + extra_block_size=extra_block_size, + have_extra_topk_length=have_extra_topk_len, + d_qk=d_qk, + check_correctness=True, + num_runs=0) + for (s_k, topk, block_size) in [ + (512, 64, 2), + (512, 64, 64), + (512, 64, 69), + (1024, 576, 2), + (1024, 576, 61), + (2046, 2048, 2), + (2046, 2048, 64), + (2046, 2048, 576) + ] + for (extra_s_k, extra_topk, extra_block_size) in ([ + (512, 64, 2), + (512, 64, 64), + (512, 64, 69), + (1024, 576, 2), + (1024, 576, 61), + (2046, 2048, 2), + (2046, 2048, 64), + (2046, 2048, 576) + ] if have_extra_k else [(None, None, None)]) + for b in [4, 74, 321] + for s_q in [1, 3] + for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True]) + ] + correctness_cases.extend(cur_correctness_cases) + + cur_corner_cases = [ + RawTestParam(b, h_q, s_q, 1, s_k, is_varlen, topk, + is_all_indices_invalid=is_all_indices_invalid, + have_zero_seqlen_k=have_zero_seqlen_k, + have_topk_length=have_topk_len, + enable_attn_sink=enable_attn_sink, + extra_s_k=extra_s_k, + extra_topk=extra_topk, + block_size=block_size, + extra_block_size=extra_block_size, + have_extra_topk_length=have_extra_topk_len, + d_qk=d_qk, + check_correctness=True, + num_runs=0, + ) + for (s_k, topk, block_size) in [ + (512, 64, 61), + (650, 576, 53), + ] + for (extra_s_k, extra_topk, extra_block_size) in ([ + (512, 64, 61), + (650, 576, 53), + ] if have_extra_k else [(None, None, None)]) + for b in [4, 74, 321] + for s_q in [3] + for is_varlen in ([True, False] if (b == 74 and not have_topk_len and not have_extra_topk_len) else [True]) + for is_all_indices_invalid in [True, False] + for have_zero_seqlen_k in [True, False] + for enable_attn_sink in [True, False] + if (is_all_indices_invalid or have_zero_seqlen_k or enable_attn_sink) + ] + corner_cases.extend(cur_corner_cases) + + base_and_bszs = [ + # V3.2 + (RawTestParam(0, 128, 2, 1, 32768, True, topk=2048, d_qk=576), [2, 64, 74, 128]), + # MODEL1 CONFIG1 + (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=512, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]), + # MODEL1 CONFIG2 + (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=64), [2, 64, 74, 128, 74*2, 256]), + # MODEL1 CONFIG3 + (RawTestParam(0, 64, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), + # MODEL1 CONFIG4 + (RawTestParam(0, 128, 2, 1, 16384, True, topk=128, d_qk=512, extra_s_k=16384, extra_topk=1024, block_size=256, extra_block_size=2, have_extra_topk_length=True), [2, 64, 74, 128, 74*2, 256]), + ] + performance_cases = [ + # Production cases + dataclasses.replace(base, b=b) + for base, bszs in base_and_bszs + for b in bszs + ] + [ + # Peak perf cases + RawTestParam(74*2, h_q, 2, 1, 32768, True, topk=16384, d_qk=d_qk) + for h_q in [64, 128] + for d_qk in [512, 576] + ] + + return correctness_cases + corner_cases + performance_cases + + +@dataclasses.dataclass +class Result: + is_correct: bool + compute_memory_ratio: float + time_usage_per_us: float + splitkv_time_usage_us: float + combine_time_usage_us: float + achieved_tflops: float + achieved_gBps: float + +_counter = kk.Counter() + +@torch.inference_mode() +def test_flash_mla(p: TestParam) -> Result: + if p.seed == -1: + global _counter + p.seed = _counter.next() + assert p.decode + + print("================") + print(f"Running on {p}") + torch.cuda.empty_cache() + + t = lib.generate_testcase_for_decode(p) + + tile_scheduler_metadata, _ = flash_mla.get_mla_metadata() + def run_decode(): + return lib.run_flash_mla_decode(p, t, tile_scheduler_metadata, None) + + # We first run the kernel once to generate output data for the correctness test + # We must do this first, otherwise when allocating tensors for storing answers, + # it may re-use memory that contains the correct answer, leading to false positives + if p.check_correctness: + torch.cuda.synchronize() + out_ans, lse_ans = run_decode() + torch.cuda.synchronize() + # torch.set_printoptions(profile='full') + # print(tile_scheduler_metadata.tile_scheduler_metadata[:, :7]) + + # We run the performance test before generating the answer for the correctness test to avoid interference + performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + if p.num_runs == 0: + performance_result = Result(True, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + else: + result = kk.bench_kineto(run_decode, p.num_runs) + + splitkv_kernel_name = "flash_fwd_splitkv_mla_fp8_sparse_kernel" + combine_kernel_name = "flash_fwd_mla_combine_kernel" + + # Get individual kernel time usages + kernel_time_usages_us: Dict[str, Optional[float]] = {} + def pick_kernel_time_usage(kernel_name: str): + t = [kernel_name in s for s in result.get_kernel_names()] + if any(t): + assert sum(t) == 1 + kernel_time_usages_us[kernel_name] = result.get_kernel_time(kernel_name) * 1e6 + else: + kernel_time_usages_us[kernel_name] = None + pick_kernel_time_usage(splitkv_kernel_name) + pick_kernel_time_usage(combine_kernel_name) + + # Get E2E time usages + def have_kernel(name: str): + return kernel_time_usages_us[name] is not None + + if kk.is_using_profiling_tools(): + e2e_time_usage_us = 1e6 + else: + assert have_kernel(splitkv_kernel_name) + if have_kernel(combine_kernel_name): + e2e_time_usage_us = result.get_e2e_time(splitkv_kernel_name, combine_kernel_name) * 1e6 + else: + e2e_time_usage_us = kernel_time_usages_us[splitkv_kernel_name] + assert e2e_time_usage_us is not None + + flops_and_mem_vol = lib.count_flop_and_mem_vol_for_decode(p, t) + + e2e_time_usage_s = e2e_time_usage_us / 1e6 + theoritical_compute_memory_ratio = flops_and_mem_vol.flop / flops_and_mem_vol.mem_vol + achieved_tflops = flops_and_mem_vol.flop / e2e_time_usage_s / 1e12 + achieved_gBps = flops_and_mem_vol.mem_vol / e2e_time_usage_s / 1e9 + def print_kernel_time_usage(name: str, short_name: str): + if kernel_time_usages_us[name] is not None: + print(f'{short_name} time: {kernel_time_usages_us[name]:.1f} us') + print(f'Compute/Memory: {theoritical_compute_memory_ratio:.2f}') + print(f'Time (per): {e2e_time_usage_us:.1f} us') + print_kernel_time_usage(splitkv_kernel_name, "Splitkv") + print_kernel_time_usage(combine_kernel_name, "Combine") + print(f'TFlops: {achieved_tflops:.1f}') + print(f'GB/s: {achieved_gBps:.0f}') + + performance_result = Result(True, theoritical_compute_memory_ratio, e2e_time_usage_us, kernel_time_usages_us[splitkv_kernel_name] or 0.0, kernel_time_usages_us[combine_kernel_name] or 0.0, achieved_tflops, achieved_gBps) + + is_correct = True + if p.check_correctness: + torch.cuda.synchronize() + with torch.profiler.record_function("reference_flash_mla"): + out_ref, lse_ref = ref.ref_sparse_attn_decode(p, t) + + is_out_correct = kk.check_is_allclose("out", out_ans, out_ref, abs_tol=1e-3, rel_tol=2.01/128, cos_diff_tol=5e-6) + is_lse_correct = kk.check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536) + is_correct &= is_out_correct and is_lse_correct + + performance_result.is_correct = is_correct + return performance_result + + +def main(): + dtype = torch.bfloat16 + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.set_float32_matmul_precision('high') + torch.set_num_threads(32) + + raw_testcases = gen_testcase() + testcases = [t.to_test_param() for t in raw_testcases] + + print(f"{kk.colors['CYAN_BG']}{len(testcases)} testcases to run{kk.colors['CLEAR']}") + + is_no_cooldown = lib.is_no_cooldown() + num_testcases_len = len(str(len(testcases))) + failed_cases = [] + results: List[Tuple[TestParam, Result]] = [] + for testcase_idx, testcase in enumerate(testcases): + if testcase != testcases[0] and testcase.num_runs > 0 and not is_no_cooldown: + time.sleep(0.3) # Cooldown + print(f"[{testcase_idx+1:{num_testcases_len}d}/{len(testcases)}, {testcase_idx/len(testcases)*100:3.0f}%] ", end='') + result = test_flash_mla(testcase) + results.append((testcase, result)) + if not result.is_correct: + failed_cases.append(testcase) + import sys + sys.exit(1) + + console = rich.console.Console(width=120) + table = rich.table.Table(show_header=True, header_style="bold cyan") + table.add_column("topk") + table.add_column("Bsz") + table.add_column("h_q&k") + table.add_column("sq") + table.add_column("sk") + table.add_column("d_qk") + table.add_column("Feats") + table.add_column("C/M") + table.add_column("TFlops") + table.add_column("GBps") + table.add_column("us") + table.add_column(" ") + + for testcase, result in results: + assert testcase.decode + topk_str = f"{testcase.topk}" if testcase.decode.extra_topk is None else f"{testcase.topk}+{testcase.decode.extra_topk}" + table.add_row( + topk_str, + str(testcase.decode.b), + f"{testcase.h_q:3d} {testcase.h_kv}", + str(testcase.s_q), + str(testcase.s_kv), + str(testcase.d_qk), + " V"[testcase.decode.is_varlen] + " L"[testcase.have_topk_length] + " E"[testcase.decode.have_extra_topk_length], + f"{result.compute_memory_ratio:3.0f}", + f"{result.achieved_tflops:3.0f}", + f"{result.achieved_gBps:4.0f}", + f"{result.time_usage_per_us:4.1f}", + "" if result.is_correct else "X" + ) + console.print(table) + + def geomean(l) -> float: + import numpy + return numpy.exp(numpy.mean(numpy.log(l))) + + num_correct_testcases = [result.is_correct for t, result in results if t.check_correctness].count(True) + num_correctness_cases = sum([1 for t in testcases if t.check_correctness]) + if num_correct_testcases == num_correctness_cases: + print(f"{kk.colors['GREEN_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}") + else: + print(f"{kk.colors['RED_BG']}{num_correct_testcases}/{num_correctness_cases} correctness cases passed{kk.colors['CLEAR']}") + for t in failed_cases: + print(f"\t{t},") + + valid_achieved_tflops = [result.achieved_tflops for _, result in results if result.achieved_tflops > 0.1] + if len(valid_achieved_tflops) > 0: + achieved_tflops_geomean = geomean(valid_achieved_tflops) # > 0.1 to prune out correctness cases + print(f"TFlops geomean: {achieved_tflops_geomean:.1f}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_flash_mla_sparse_prefill.py b/tests/test_flash_mla_sparse_prefill.py new file mode 100644 index 0000000..f4d39dd --- /dev/null +++ b/tests/test_flash_mla_sparse_prefill.py @@ -0,0 +1,180 @@ +import time +import sys + +import torch +import kernelkit as kk + +from lib import TestParam +import lib +import ref + +_counter = kk.Counter() + +@torch.inference_mode() +def run_test(p: TestParam) -> bool: + if p.seed == -1: + global _counter + p.seed = _counter.next() + + print("================") + print(f"Running on {p}") + torch.cuda.empty_cache() + + t = lib.generate_testcase(p) + torch.cuda.synchronize() + + def run_prefill(): + return lib.run_flash_mla_sparse_fwd(p, t, False) + + prefill_ans_out, prefill_ans_max_logits, prefill_ans_lse = run_prefill() + torch.cuda.synchronize() + + if p.num_runs > 0: + flops_and_mem_vol = lib.count_flop_and_mem_vol(p, t) + prefill_ans_time = kk.bench_kineto(run_prefill, num_tests=p.num_runs).get_kernel_time("sparse_attn_fwd") + prefill_flops = flops_and_mem_vol.fwd_flop/prefill_ans_time/1e12 + prefill_mem_bw = flops_and_mem_vol.fwd_mem_vol/prefill_ans_time/1e12 + print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:6.1f} TFlops, {prefill_mem_bw:4.2f} TBps") + + if p.check_correctness: + torch.cuda.synchronize() + ref_out, ref_out_fp32, ref_max_logits, ref_lse = ref.ref_sparse_attn_fwd(p, t) + ref_lse[ref_lse == float("-inf")] = float("+inf") + torch.cuda.synchronize() + + is_correct = True + is_correct &= kk.check_is_allclose("out", prefill_ans_out.float(), ref_out_fp32, abs_tol=8e-4, rel_tol=3.01/128, cos_diff_tol=7e-6) + is_correct &= kk.check_is_allclose("max_logits", prefill_ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536) + is_correct &= kk.check_is_allclose("lse", prefill_ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536) + + return is_correct + else: + return True + + +if __name__ == '__main__': + device = torch.device("cuda:0") + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.set_float32_matmul_precision('high') + + correctness_cases = [ + # Regular shapes + TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, d_qk=d_qk) + for d_qk in [512, 576] + for h_q in [ + 128, 64 + ] + for s_kv, topk in [ + # Regular shapes + (128, 128), + (256, 256), + (512, 512), + + # Irregular shapes + (592, 128), + (1840, 256), + (1592, 384), + (1521, 512), + + # Irregular shapes with OOB TopK + (95, 128), + (153, 256), + (114, 384), + ] + for s_q in [ + 1, 62, 213 + ] + ] + + correctness_cases_with_features = [ + TestParam(s_q, s_kv, topk, h_q=h_q, num_runs=0, have_attn_sink=have_attn_sink, have_topk_length=have_topk_length, d_qk=d_qk) + for d_qk in [512, 576] + for h_q in [ + 128, 64 + ] + for s_kv, topk in [ + (592, 128), + (1840, 256), + (1592, 384), + (1521, 512), + + (95, 128), + (153, 256), + (114, 384), + ] + for s_q in [62, 213] + for have_sink_lse in [False, True] + for have_attn_sink in [False, True] + for have_topk_length in [False, True] + ] + + corner_cases = [ + TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=True, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) + for d_qk in [512, 576] + for h_q in [ + 128, 64 + ] + for s_q, s_kv, topk in [ + (1, 128, 128), + (1, 256, 256), + (1234, 4321, 4096), + (4096, 2048, 2048) + ] + ] + [ + # In these cases, some blocks may not have any valid topk indices + TestParam(s_q, s_kv, topk, h_q=h_q, is_all_indices_invalid=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) + for d_qk in [512, 576] + for h_q in [ + 128, 64 + ] + for s_kv, topk in [ + (32, 2048), + (64, 8192) + ] + for s_q in [1, 1024] + ] + [ + # In this testcase, s_q is really large, so we cannot put it on the second dimension of grid shape + TestParam(70000, 256, 256, h_q=h_q, check_correctness=False, num_runs=0, have_attn_sink=True, have_topk_length=True, d_qk=d_qk) + for d_qk in [512, 576] + for h_q in [ + 128, 64 + ] + ] + + performance_case_templates = [ + # V3.2 + (576, 128, 2048, [8192, 32768, 65536, 98304, 131072]), + # MODEL1 CONFIG1 + (512, 64, 512, [8192, 32768, 49152, 65536]), + # MODEL1 CONFIG2 + (512, 128, 1024, [8192, 32768, 49152, 65536]), + ] + + performance_cases = [ + TestParam(s_q, s_kv, topk, h_q=h_q, d_qk=d_qk, have_attn_sink=True) + for (d_qk, h_q, topk, s_kv_list) in performance_case_templates + for s_q in [4096] + for s_kv in s_kv_list + ] + + testcases = correctness_cases + correctness_cases_with_features + corner_cases + performance_cases + + is_no_cooldown = lib.is_no_cooldown() + failed_cases = [] + for test in testcases: + if test != testcases[0] and test.num_runs > 0 and not is_no_cooldown: + time.sleep(0.3) + is_correct = run_test(test) + if not is_correct: + failed_cases.append(test) + + if len(failed_cases) > 0: + print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m") + for case in failed_cases: + print(f" {case}") + sys.exit(1) + else: + print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m") + diff --git a/tests/test_fmha_sm100.py b/tests/test_fmha_sm100.py index 62e3344..79ba556 100644 --- a/tests/test_fmha_sm100.py +++ b/tests/test_fmha_sm100.py @@ -5,7 +5,7 @@ import triton from flash_mla import flash_attn_varlen_func -from lib import check_is_allclose +from kernelkit import check_is_allclose def get_window_size(causal, window): if window > 0: @@ -116,14 +116,14 @@ def torch_attn(): out_flash, lse_flash = flash_attn() if has_bwd: out_flash.backward(grad_out, retain_graph=True) - dq1 = q1.grad.clone() + _dq1 = q1.grad.clone() dk1 = k1.grad.clone() dv1 = v1.grad.clone() if check_correctness: out_torch, lse_torch = torch_attn() - assert check_is_allclose("out", out_flash, out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) - assert check_is_allclose("lse", lse_flash, lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536) + assert check_is_allclose("out", out_flash.float(), out_torch, abs_tol=1e-3, rel_tol=8.01 / 128, cos_diff_tol=7e-6) + assert check_is_allclose("lse", lse_flash.float(), lse_torch, abs_tol=1e-6, rel_tol=2.01 / 65536) if has_bwd: out_torch.backward(grad_out, retain_graph=True) From c741387bcb1f17fe86d7367201505a6c4dcfaf42 Mon Sep 17 00:00:00 2001 From: Jiashi Li Date: Mon, 19 Jan 2026 14:06:14 +0800 Subject: [PATCH 22/24] Add missing include Co-authored-by: baowending.bwd --- csrc/api/common.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/api/common.h b/csrc/api/common.h index 6beeab4..2c930ed 100644 --- a/csrc/api/common.h +++ b/csrc/api/common.h @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include From 48c6dc426f045cb7743b18f5c7329f35f1b7ed79 Mon Sep 17 00:00:00 2001 From: Shengyu Liu Date: Tue, 20 Jan 2026 23:57:41 +0800 Subject: [PATCH 23/24] nits --- flash_mla/flash_mla_interface.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 4fac685..a3740b0 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -72,10 +72,8 @@ def flash_mla_with_kvcache( Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - Different modes (including fp8/bf16, sparsity, and model version (i.e. V3.2 or MODEL1)) has different KV cache layouts. See comments below for details. + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. - Besides, some kernels also have their own requirements on the layout of k cache, including: - - For sparse fp8 decoding kernel on F3, k_cache.stride(0) must be a multiple of 656B (for V32) or 576B (for MODEL1). Padding is needed sometimes. block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. head_dim_v: Head_dim of v. Must be 512 @@ -88,7 +86,7 @@ def flash_mla_with_kvcache( Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), where t is the k-th token of the j-th q-sequence in the i-th batch. attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. - extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. This is used to support MODEL1. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: @@ -99,11 +97,6 @@ def flash_mla_with_kvcache( - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. - For DeepSeek MODEL1: - head_dim should be 512 while head_dim_v is also 512. - - In FP8+sparse mode, every block can be divided into two parts. The first parts stores NoPE0, RoPE0, NoPE1, RoPE1, ... while the second part stores scale factors: 7xue8m0, 1Bpad, 7xue8m0, 1Bpad, ... - Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. From 47c35a712362f11bc235854ead51819ad76f5a81 Mon Sep 17 00:00:00 2001 From: Zeyu WANG Date: Fri, 6 Feb 2026 15:31:00 +0800 Subject: [PATCH 24/24] Add CUDAGuard and device id assignment in sm100 dense fmha (#160) --- csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh | 8 +++++++- csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh | 6 +++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh b/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh index f4a1ce8..086a468 100644 --- a/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh +++ b/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cuh @@ -49,6 +49,9 @@ #include "collective/fmha_fusion.hpp" #include "device/fmha_device_bwd.hpp" +#include +#include + using namespace cute; using namespace cutlass::fmha::kernel; using namespace cutlass::fmha::collective; @@ -95,8 +98,11 @@ struct BwdRunner { at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, float softmax_scale, int max_seqlen_q, int max_seqlen_kv) { + const at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + const int device_id = q.get_device(); + cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; + hw_info.device_id =device_id; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); ProblemShape problem_shape; cute::tuple> tensor_shape; diff --git a/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh index 987a5f7..58adf3c 100644 --- a/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh +++ b/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cuh @@ -13,6 +13,7 @@ #include "kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp" #include +#include #include using namespace cute; @@ -288,8 +289,11 @@ void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, float scale_softmax, int max_seqlen_q, int max_seqlen_kv) { + const at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + const int device_id = q.get_device(); + cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; + hw_info.device_id = device_id; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);