diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..05015d5e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "deep-gemm/torch-ext/deep_gemm/include/third-party/cutlass"] + path = deep-gemm/torch-ext/deep_gemm/include/third-party/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/deep-gemm/CMakeLists.txt b/deep-gemm/CMakeLists.txt new file mode 100644 index 00000000..79f1964d --- /dev/null +++ b/deep-gemm/CMakeLists.txt @@ -0,0 +1,33 @@ +# NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT +cmake_minimum_required(VERSION 3.10) +project(deep_gemm LANGUAGES CXX CUDA) +set(CMAKE_VERBOSE_MAKEFILE ON) + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi") +set(CUDA_SEPARABLE_COMPILATION ON) +list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") +list(APPEND CUDA_NVCC_FLAGS "-O3") +list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") + +set(USE_SYSTEM_NVTX on) +set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile") +set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") + +find_package(CUDAToolkit REQUIRED) +find_package(pybind11 REQUIRED) +find_package(Torch REQUIRED) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) + +include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include) +include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) +link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) + +# The main Python API entrance +pybind11_add_module(_C csrc/python_api.cpp) +target_link_libraries(_C PRIVATE ${TORCH_LIBRARIES} torch_python) + +# Enable kernel code indexing with CMake-based IDEs +cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu) diff --git a/deep-gemm/LICENSE b/deep-gemm/LICENSE new file mode 100644 index 00000000..5c48bdc9 --- /dev/null +++ b/deep-gemm/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 DeepSeek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/deep-gemm/README.md b/deep-gemm/README.md new file mode 100644 index 00000000..c81bf46b --- /dev/null +++ b/deep-gemm/README.md @@ -0,0 +1,179 @@ +# DeepGEMM + +DeepGEMM is a library designed for clean and efficient General Matrix Multiplications (GEMMs). It supports FP8 and BF16 (working in progress) for both normal and Mix-of-Experts (MoE) grouped scenarios. Written in CUDA, the library has no kernel compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module. + +DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only a limited number of core kernel functions. This makes it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques. + +Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. + +## News + +- 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2. + - Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details. +- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module. + - NVRTC and post-compilation SASS optimization are all disabled. + - NVRTC will be supported later. + - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported. + - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details. +- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. +- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). +- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. + +## Roadmap + +- [x] More correctness tests for grouped-contiguous layout +- [x] Shared memory swizzling for output +- [x] MoE scheduler with TMA multicast compatibility +- [x] Fix TMA multicast compatibility for indivisible shapes +- [x] Skip useless computation on M +- [x] NVRTC as a faster compiler +- [x] Sanitizer for testing +- [x] Weight gradient kernels for dense models +- [x] Weight gradient kernels for MoE models +- [ ] Better `get_best_configs` modeling +- [ ] CUDA PDL support +- [ ] Larger TMA multicast size for some shapes +- [x] MMA template refactor with CUTLASS +- [x] Remove shape limitations on N and K +- [x] BF16 kernels +- [ ] Split/stream-k optimizations +- [ ] Ampere kernels +- [ ] Polish docs + +## Quick start + +### Requirements + +- NVIDIA SM90 or SM100 architecture GPU +- Python 3.8 or higher +- Compilers with C++20 support +- CUDA Toolkit: + - CUDA 12.3 or higher for SM90 + - **We highly recommend 12.9 or higher for the best performance** + - CUDA 12.9 or higher for SM100 +- PyTorch 2.1 or higher +- CUTLASS 4.0 or higher (could be cloned by Git submodule) +- `{fmt}` library (could be cloned by Git submodule) + +### Development + +```bash +# Submodule must be cloned +git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git +cd DeepGEMM + +# Link some essential includes and build the CPP JIT module +cat develop.sh +./develop.sh + +# Test all GEMM implements +python tests/test_layout.py +python tests/test_attention.py +python tests/test_core.py +``` + +### Installation + +```bash +cat install.sh +./install.sh +``` + +Then, import `deep_gemm` in your Python project, and enjoy! + +## Interfaces + +#### Notices + +This library provides optimized GEMM kernels for NVIDIA GPUs with a naming convention: `D = C + A @ B`. The input shape layout is NT (non-transposed A, transposed B). While the SM90 implementation supports only the NT memory layout (row-major, col-major), the SM100 implementation supports all memory layouts (NT, TN, NN, TT). For example, `fp8_gemm_nt` will do a `D = C + A @ B.T` + +For both architectures, the LHS scaling factor is required to have a TMA-aligned and transposed layout. And the data format for the scaling factor of SM90 and SM100 is different: + +- SM90 requires scaling factors in FP32 format. +- SM100 requires scaling factors in packed [UE8M0](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats) format, which packs 4 UE8M0 into a single `torch.int`. + +Please note that operations like input transposition or FP8 casting must be handled separately by the user, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves. + +#### Normal dense GEMMs (non-grouped) + +To perform a basic non-grouped FP8 GEMM, call the `fp8_gemm_{nt, nn, tn, tt}` function. For more details, please refer to the function documentation. + +#### Grouped GEMMs (contiguous layout) + +Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_mk_alignment_for_contiguous_layout()`). For more information, please refer to the `m_grouped_fp8_gemm_{nt, nn}_contiguous` function documentation. + +We also provide a K-axis-grouped API for MoE weight backward (with M and N must remain fixed), please refer to `k_grouped_fp8_gemm_tn_contiguous` for more information. + +#### Grouped GEMMs (masked layout) + +During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions. + +Use `m_grouped_fp8_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. + +#### V3.2 MQA kernels for the indexer + +The kernel family has two versions, non-paged (for prefilling) and paged (for decoding). +Take the non-paged version `fp8_mqa_logits` as an example. It has 6 inputs: + +- `q`, E4M3 tensor with shape `[seq_len, num_heads, head_dim]` +- `kv`, E4M3 tensor (shaped as `[seq_len_kv, head_dim]`) with float SF (shaped as `[seq_len_kv]`) +- `weights`, float tensor with shape `[seq_len, num_heads]` +- `cu_seq_len_k_start` and `cu_seq_len_k_end`, int tensor with shape `[seq_len]` +- `clean_logits`, whether to clean the unfilled logits into `-inf` + +The output tensor is shaped as `[seq_len, seq_len_kv]`, indicating token-to-token logits. +For each token `i` in `q`, it will iterate all tokens `j` from `[cu_seq_len_k_start[i], cu_seq_len_k_end[i])`, +and calculate the logit `out[i, j]` as: + +```python +kv_j = kv[0][j, :] * kv[1][j].unsqueeze(1) # [head_dim] +out_ij = q[i, :, :] @ kv_j # [num_heads] +out_ij = out_ij.relu() * weights[i, :] # [num_heads] +out_ij = out_ij.sum() # Scalar +``` + +For more details and the paged version `fp8_paged_mqa_logits`, please refer to `tests/test_attention.py`. + +#### Utilities + +The library provides some utility functions besides the above kernels: + +- `deep_gemm.set_num_sms`: set the maximum SM count to use +- `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set) +- `deep_gemm.set_tc_util`: set an approximated tensor core utilization ratio +- `deep_gemm.get_tc_util`: get the current tensor core utilization ratio +- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout +- `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size +- `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout +- `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor +- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0) +- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel + +The library also provides some environment variables, which may be useful: + +- General + - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default +- JIT cache related + - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default +- NVCC/NVRTC selections + - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default + - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default +- Compiler options + - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default + - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default +- Heuristic selection + - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default + +For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. + +## Acknowledgement + +DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers! + +## License + +This code repository is released under [the MIT License](LICENSE). + +-- + +vendored at commit 477618cd51baffca09c4b0b87e97c03fe827ef03 \ No newline at end of file diff --git a/deep-gemm/build.sh b/deep-gemm/build.sh new file mode 100755 index 00000000..abdfc406 --- /dev/null +++ b/deep-gemm/build.sh @@ -0,0 +1,12 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Remove old dist file, build files, and install +rm -rf build dist +rm -rf *.egg-info +python setup.py bdist_wheel + +# Open users' original directory +cd "$original_dir" diff --git a/deep-gemm/build.toml b/deep-gemm/build.toml new file mode 100644 index 00000000..06bd443c --- /dev/null +++ b/deep-gemm/build.toml @@ -0,0 +1,105 @@ +[general] +name = "deep-gemm" +license = "MIT" +backends = ["cuda"] +version = 1 + +[general.cuda] +minver = "12.3" + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] +pyext = ["py", "cuh", "hpp", "h"] + +[kernel.deep_gemm] +backend = "cuda" +cuda-minver = "12.3" +cuda-capabilities = ["9.0a"] +depends = ["torch", "cutlass_4_0"] +include = [ + "csrc", + "deep_gemm/include", +] +cuda-flags = [ + "-std=c++17", + "-O3", + "-DNDEBUG", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "-Wno-deprecated-declarations", +] +src = [ + # Compiled source + "csrc/impl.cu", + # API headers + "csrc/apis/attention.hpp", + "csrc/apis/einsum.hpp", + "csrc/apis/gemm.hpp", + "csrc/apis/hyperconnection.hpp", + "csrc/apis/layout.hpp", + "csrc/apis/runtime.hpp", + # JIT infrastructure headers + "csrc/jit/cache.hpp", + "csrc/jit/compiler.hpp", + "csrc/jit/device_runtime.hpp", + "csrc/jit/handle.hpp", + "csrc/jit/kernel_runtime.hpp", + # JIT kernel heuristics + "csrc/jit_kernels/heuristics/common.hpp", + "csrc/jit_kernels/heuristics/sm100.hpp", + "csrc/jit_kernels/heuristics/sm90.hpp", + # JIT kernel implementations + "csrc/jit_kernels/impls/epilogue.hpp", + "csrc/jit_kernels/impls/runtime_utils.hpp", + "csrc/jit_kernels/impls/sm100_bf16_gemm.hpp", + "csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp", + "csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp", + "csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp", + "csrc/jit_kernels/impls/sm90_bf16_gemm.hpp", + "csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp", + "csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp", + "csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp", + "csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp", + "csrc/jit_kernels/impls/smxx_clean_logits.hpp", + "csrc/jit_kernels/impls/smxx_cublaslt.hpp", + "csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp", + "csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp", + "csrc/jit_kernels/impls/smxx_layout.hpp", + # Utility headers + "csrc/utils/compatibility.hpp", + "csrc/utils/exception.hpp", + "csrc/utils/format.hpp", + "csrc/utils/hash.hpp", + "csrc/utils/layout.hpp", + "csrc/utils/lazy_init.hpp", + "csrc/utils/math.hpp", + "csrc/utils/system.hpp", + # Runtime JIT headers (deep_gemm/include) + "deep_gemm/include/deep_gemm/common/cute_tie.cuh", + "deep_gemm/include/deep_gemm/common/epilogue_utils.cuh", + "deep_gemm/include/deep_gemm/common/types.hpp", + "deep_gemm/include/deep_gemm/common/sm90_utils.cuh", + "deep_gemm/include/deep_gemm/common/reduction.cuh", + "deep_gemm/include/deep_gemm/common/utils.cuh", + "deep_gemm/include/deep_gemm/common/tma_utils.cuh", + "deep_gemm/include/deep_gemm/common/sm100_utils.cuh", + "deep_gemm/include/deep_gemm/common/scheduler.cuh", + "deep_gemm/include/deep_gemm/impls/smxx_layout.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh", + "deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh", + "deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh", + "deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh", +] diff --git a/deep-gemm/csrc/apis/attention.hpp b/deep-gemm/csrc/apis/attention.hpp new file mode 100644 index 00000000..c83233d0 --- /dev/null +++ b/deep-gemm/csrc/apis/attention.hpp @@ -0,0 +1,281 @@ +#pragma once + +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/smxx_fp8_mqa_logits.hpp" +#include "../jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp" +#include "../jit_kernels/impls/smxx_clean_logits.hpp" +#endif + +#include "layout.hpp" + +namespace deep_gemm::attention { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +static void fp8_gemm_nt_skip_head_mid(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::tuple& head_splits, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + if (fp8_requires_k_major()) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + } + + // D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a.first); + const auto& [n , k_] = get_shape<2>(b.first); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Check head splits and N + const auto& [left, mid, right] = head_splits; + DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid); + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + const auto& [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, disable_ue8m0_cast); + DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128); + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) { + const auto& major_sfb = get_major_type_ab(sfb); + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + // NOTES: Only granularity 128 and FP8 are exposed in the API + sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, + 128, 128, major_a, major_b, compiled_dims, epilogue_type); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static torch::Tensor fp8_mqa_logits(const torch::Tensor& q, + const std::pair& kv, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const bool& clean_logits, + const int& max_seqlen_k) { + const auto& [seq_len, num_heads, head_dim] = get_shape<3>(q); + const auto& [seq_len_kv, head_dim_] = get_shape<2>(kv.first); + const auto& [seq_len_, num_heads_] = get_shape<2>(weights); + const auto& [seq_len_kv_] = get_shape<1>(kv.second); + + DG_HOST_ASSERT(seq_len == seq_len_); + DG_HOST_ASSERT(num_heads == num_heads_ and head_dim == head_dim_); + DG_HOST_ASSERT(seq_len_kv == seq_len_kv_); + DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len); + DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len); + + DG_HOST_ASSERT(q.is_contiguous() and kv.first.is_contiguous()); + DG_HOST_ASSERT(kv.second.is_contiguous()); + DG_HOST_ASSERT(weights.is_contiguous()); + DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous()); + DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous()); + + DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(kv.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(kv.second.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(cu_seq_len_k_start.scalar_type() == torch::kInt); + DG_HOST_ASSERT(cu_seq_len_k_end.scalar_type() == torch::kInt); + + constexpr int seq_len_alignment = 4; + constexpr int block_kv = 256; + const auto aligned_seq_len = align(seq_len, seq_len_alignment); + + torch::Tensor logits; + int stride_logits; + if (max_seqlen_k == 0) { + stride_logits = align(seq_len_kv + block_kv, 4); + logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat)); + logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)}); + } else { + stride_logits = align(max_seqlen_k, block_kv); + logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat)); + logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, max_seqlen_k)}); + DG_HOST_ASSERT(not clean_logits); + } + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 or arch_major == 10) { + smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, + seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, seq_len_alignment); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + // Clean unfilled logits + if (clean_logits) + smxx_clean_logits(logits, cu_seq_len_k_start, cu_seq_len_k_end, 1, seq_len, seq_len_kv, stride_logits); + return logits; +} + +static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms) { + const bool is_context_lens_2d = context_lens.dim() == 2; + int batch_size = 0, next_n = 0; + if (is_context_lens_2d) { + batch_size = context_lens.size(0); + next_n = context_lens.size(1); + } else { + DG_HOST_ASSERT(context_lens.dim() == 1); + batch_size = context_lens.size(0); + } + DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); + DG_HOST_ASSERT(context_lens.is_contiguous()); + + auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options()); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 or arch_major == 10) { + smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + return schedule_metadata; +} + +static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& fused_kv_cache, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& block_table, + const torch::Tensor& schedule_meta, + const int& max_context_len, + const bool& clean_logits) { + const auto& [batch_size, next_n, num_heads, head_dim] = get_shape<4>(q); + const auto& [num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf] = get_shape<4>(fused_kv_cache); + const auto& [batch_size_next_n, num_heads_] = get_shape<2>(weights); + const auto& [batch_size_, max_block_len] = get_shape<2>(block_table); + const auto& [schedule_meta_size, meta_info_size] = get_shape<2>(schedule_meta); + const auto& num_sms = device_runtime->get_num_sms(); + const auto& kv_cache_stride_bytes = fused_kv_cache.stride(0); + const auto& block_table_stride = block_table.stride(0); + + const bool is_context_lens_2d = context_lens.dim() == 2; + if (is_context_lens_2d) { + const auto& [batch_size__, next_n_] = get_shape<2>(context_lens); + DG_HOST_ASSERT(batch_size == batch_size__ and next_n == next_n_); + } else { + DG_HOST_ASSERT(context_lens.dim() == 1); + const auto& [batch_size__] = get_shape<1>(context_lens); + DG_HOST_ASSERT(batch_size == batch_size__); + } + + DG_HOST_ASSERT(batch_size == batch_size_); + DG_HOST_ASSERT(batch_size_next_n == batch_size * next_n); + DG_HOST_ASSERT(num_heads == num_heads_ and num_heads_kv == 1); + DG_HOST_ASSERT(head_dim_with_sf == head_dim + static_cast(sizeof(float))); + DG_HOST_ASSERT(schedule_meta_size == num_sms + 1 and meta_info_size == 2); + + DG_HOST_ASSERT(next_n == 1 or next_n == 2); + DG_HOST_ASSERT(block_kv == 64); + + DG_HOST_ASSERT(q.is_contiguous()); + DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0); + DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf); + DG_HOST_ASSERT(fused_kv_cache.stride(2) == head_dim_with_sf); + DG_HOST_ASSERT(fused_kv_cache.stride(3) == 1); + DG_HOST_ASSERT(weights.is_contiguous()); + DG_HOST_ASSERT(context_lens.is_contiguous()); + DG_HOST_ASSERT(block_table.stride(1) == 1); + DG_HOST_ASSERT(schedule_meta.is_contiguous()); + + DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte); + DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); + DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt); + DG_HOST_ASSERT(schedule_meta.scalar_type() == torch::kInt); + + // Derive FP8 values and SF tensor from KV cache + const auto& kv_cache = torch::from_blob( + fused_kv_cache.data_ptr(), + {num_kv_blocks, block_kv, head_dim}, + {kv_cache_stride_bytes, head_dim, 1}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn) + ); + const auto& kv_cache_scales = torch::from_blob( + fused_kv_cache.data_ptr() + block_kv * head_dim, + {num_kv_blocks, block_kv}, + {kv_cache_stride_bytes / static_cast(sizeof(float)), 1}, + torch::TensorOptions().dtype(torch::kFloat32) + ); + + // Allocate output + constexpr int split_kv = 256; + const auto& aligned_max_context_len = align(max_context_len, split_kv); + auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q.options().dtype(torch::kFloat)); + logits = logits.slice(-1, 0, max_context_len); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 or arch_major == 10) { + smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta, + batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d, + kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, split_kv); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + // Clean unfilled logits + if (clean_logits) { + DG_HOST_ASSERT(not is_context_lens_2d); + smxx_clean_logits(logits, std::nullopt, context_lens, next_n, batch_size * next_n, max_context_len, aligned_max_context_len); + } + return logits; +} + +#endif + +#ifdef DG_USE_PYBIND11 +static void register_apis(pybind11::module_& m) { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"), + py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_mqa_logits", &fp8_mqa_logits, + py::arg("q"), py::arg("kv"), py::arg("weights"), + py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"), + py::arg("clean_logits") = true, + py::arg("max_seqlen_k") = 0); + m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata, + py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms")); + m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits, + py::arg("q"), py::arg("kv_cache"), py::arg("weights"), + py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"), + py::arg("max_context_len"), py::arg("clean_logits") = false); +#endif +} +#endif + +} // namespace deep_gemm::attention diff --git a/deep-gemm/csrc/apis/einsum.hpp b/deep-gemm/csrc/apis/einsum.hpp new file mode 100644 index 00000000..ad489923 --- /dev/null +++ b/deep-gemm/csrc/apis/einsum.hpp @@ -0,0 +1,234 @@ +#pragma once + +#ifdef DG_USE_PYBIND11 +#include +#include +#endif + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/layout.hpp" +#include "../utils/compatibility.hpp" +#include "gemm.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp" +#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp" +#include "../jit_kernels/impls/sm90_bf16_gemm.hpp" +#include "../jit_kernels/impls/sm100_bf16_gemm.hpp" +#include "../jit_kernels/impls/smxx_cublaslt.hpp" +#endif + +namespace deep_gemm::einsum { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, + const std::optional& c) { + // Currently FP32 only support the accumulated expression + if (d.scalar_type() == torch::kFloat) { + DG_HOST_ASSERT(c->data_ptr() == d.data_ptr() and c->sizes() == d.sizes() and c->strides() == d.strides()); + } else { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(not c.has_value()); + + const auto& workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32)); + DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(), + c10::cuda::getCurrentCUDAStream())); + bmk_bnk_mn(a, b, workspace, workspace); + + // This line has an implicit FP32-to-BF16 casting + d.copy_(workspace); + return; + } + + DG_HOST_ASSERT(a.is_contiguous()); + DG_HOST_ASSERT(b.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + + const auto& [s , m, k ] = get_shape<3>(a); + const auto& [s_, n, k_] = get_shape<3>(b); + DG_HOST_ASSERT(s == s_ and k == k_); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k); + } else if (arch_major == 10) { + sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) { + const auto& [b , h , r ] = get_shape<3>(A); + const auto& [h_, d , r_] = get_shape<3>(B); + const auto& [b_, h__, d_] = get_shape<3>(D); + DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__); + + DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1); + DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1); + DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (use_cublaslt) { + cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d); + } else if (arch_major == 9) { + sm90_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d); + } else if (arch_major == 10) { + sm100_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) { + const auto& [b , h , d ] = get_shape<3>(A); + const auto& [h_, d_ , r ] = get_shape<3>(B); + const auto& [b_, h__, r_] = get_shape<3>(D); + DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__); + + DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1); + DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1); + DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (use_cublaslt) { + cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d); + } else if (arch_major == 9) { + sm90_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d); + } else if (arch_major == 10) { + sm100_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void einsum(const std::string& expr, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const bool& use_cublaslt) { + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + if (c.has_value()) { + DG_HOST_ASSERT(c->scalar_type() == torch::kFloat); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + } + + // Some hardcoded Einstein sum kernels + // TODO: support any expression + // TODO: canonicalize expression + if (expr == "bmk,bnk->mn") { + DG_HOST_ASSERT(not use_cublaslt); + bmk_bnk_mn(a, b, d, c); + } else if (expr == "bhr,hdr->bhd") { + DG_HOST_ASSERT(not c.has_value()); + bhr_hdr_bhd(a, b, d, use_cublaslt); + } else if (expr == "bhd,hdr->bhr") { + DG_HOST_ASSERT(not c.has_value()); + bhd_hdr_bhr(a, b, d, use_cublaslt); + } else { + DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr)); + } +} + +static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + const std::string& compiled_dims) { + // Shape must be `[B, M, K] @ [B, N, K].T` + const auto& major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; + const auto& major_b = b.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; + DG_HOST_ASSERT(a.stride(-1) == 1 or a.stride(-2) == 1); + DG_HOST_ASSERT(b.stride(-1) == 1 or b.stride(-2) == 1); + DG_HOST_ASSERT(d.stride(-1) == 1); + + // Type and shape checks + const auto& [batch_size , m , k ] = get_shape<3>(a); + const auto& [batch_size_ , n , k_] = get_shape<3>(b); + const auto& [batch_size__, m_, n_] = get_shape<3>(d); + DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size_); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(a.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Early return for trivial cases + if (batch_size == 0 or gemm::early_return(m, n, k, d, c)) + return; + + // Transform scaling factors + const auto& [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 10) { + sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, compiled_dims); + } else { + const auto& major_sfb = get_major_type_ab(sfb); + sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims); + } +} + +static void fp8_einsum(const std::string& expr, + const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::tuple& recipe) { + // Some hardcoded Einstein sum kernels + const auto arch_major = device_runtime->get_arch_major(); + if (expr == "bhr,hdr->bhd") { + // Permute dims to satisfy the order of (batch_size, m, n, k) + // (batch_size, m, n, k): (h, b, d, r) + const auto& perm_a = a.first.permute({1, 0, 2}); + const auto& perm_sfa = a.second.permute({1, 0, 2}); + const auto& perm_d = d.permute({1, 0, 2}); + const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; + fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk"); + } else if (expr == "bhd,hdr->bhr" and arch_major == 10) { + // (batch_size, m, n, k): (h, b, r, d) + const auto& perm_a = a.first.permute({1, 0, 2}); + const auto& perm_sfa = a.second.permute({1, 0, 2}); + const auto& perm_b = b.first.permute({0, 2, 1}); + const auto& perm_sfb = b.second.permute({0, 2, 1}); + const auto& perm_d = d.permute({1, 0, 2}); + const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; + fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk"); + } else if (expr == "bhd,bhr->hdr" and arch_major == 10) { + // (batch_size, m, n, k): (h, d, r, b) + const auto& perm_a = a.first.permute({1, 2, 0}); + const auto& perm_sfa = a.second.permute({1, 2, 0}); + const auto& perm_b = b.first.permute({1, 2, 0}); + const auto& perm_sfb = b.second.permute({1, 2, 0}); + fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, d, c, recipe, "mn"); + } else { + DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr)); + } +} +#endif + +#ifdef DG_USE_PYBIND11 +static void register_apis(pybind11::module_& m) { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + m.def("einsum", &einsum, + py::arg("expr"), py::arg("a"), py::arg("b"), + py::arg("d"), py::arg("c") = std::nullopt, + py::arg("use_cublaslt") = false); + m.def("fp8_einsum", &fp8_einsum, + py::arg("expr"), py::arg("a"), py::arg("b"), + py::arg("d"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 128, 128)); +#endif +} +#endif + +} // namespace deep_gemm::einsum diff --git a/deep-gemm/csrc/apis/gemm.hpp b/deep-gemm/csrc/apis/gemm.hpp new file mode 100644 index 00000000..f12517cf --- /dev/null +++ b/deep-gemm/csrc/apis/gemm.hpp @@ -0,0 +1,714 @@ +#pragma once + +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm90_bf16_gemm.hpp" +#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm100_bf16_gemm.hpp" +#endif + +#include "../jit_kernels/impls/smxx_cublaslt.hpp" + +#include "layout.hpp" + +namespace deep_gemm::gemm { + +static bool early_return(const int& m, const int &n, const int& k, + const torch::Tensor& d, const std::optional& c) { + // Do nothing if the problem is empty + if (m == 0 or n == 0) + return true; + + // Checks + const bool& is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr(); + if (is_cd_same) + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + if (c.has_value()) { + check_major_type_cd(c.value()); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + } + + // No accumulation + if (k == 0) { + if (not is_cd_same) + c.has_value() ? d.copy_(c.value()) : d.zero_(); + return true; + } + + // With accumulation, do copy before GEMM (assuming the GEMM kernel does not support different C/D) + if (c.has_value() and not is_cd_same) + d.copy_(c.value()); + return false; +} + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + +static void fp8_fp4_gemm_nt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + if (fp8_requires_k_major()) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + } + + // C/D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Early return for trivial cases + if (early_return(m, n, k, d, c)) + return; + + // Transform SFA and SFB into compute-required layout + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast); + + // Dispatch into different implements + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + const int gran_n = recipe.has_value() ? std::get<1>(recipe.value()) : std::get<0>(recipe_b.value()); + if (gran_n == 1) { + sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + const auto& major_sfb = get_major_type_ab(sfb); + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims); + } + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b, + major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void fp8_fp4_gemm_nn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); +} + +static void fp8_fp4_gemm_tn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, + {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); +} + +static void fp8_fp4_gemm_tt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); +} + +static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + if (fp8_requires_k_major()) + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); + + // Type and shape checks + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); + + // Layout checks + if (use_psum_layout) { + const auto& [num_groups_] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(num_groups == num_groups_); + } else { + const auto& [m__] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(m == m__); + DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); + } + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, num_groups, disable_ue8m0_cast); + + // Dispatch implementation + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + const auto& major_sfb = get_major_type_ab(sfb); + DG_HOST_ASSERT(not use_psum_layout); + sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, grouped_layout, + num_groups, m, n, k, major_a, major_b, major_sfb, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout, + num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b, + compiled_dims, use_psum_layout, expected_m_for_psum_layout); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast, + const bool& use_psum_layout) { + m_grouped_fp8_fp4_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, + d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt); +} + +static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& expected_m, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto arch_major = device_runtime->get_arch_major(); + const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [num_groups__, m_, n_] = get_shape<3>(d); + const auto num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Transform scaling factors + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, num_groups, num_groups, disable_ue8m0_cast); + + // Dispatch implementation + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + const auto& major_sfb = get_major_type_ab(sfb); + sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, gran_k_a, gran_k_b, + major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::tuple& recipe, + const std::string& compiled_dims) { + // Must be 1D1D kernel + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + + // Shape checks + const auto& [num_groups, m, n] = get_shape<3>(d); + const auto& [sum_k_ , m_] = get_shape<2>(a.first); + const auto& [sum_k__, n_] = get_shape<2>(b.first); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); + + // Contiguity checks + DG_HOST_ASSERT(a.first.is_contiguous()); + DG_HOST_ASSERT(b.first.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous()); + + // Early return for trivial cases + if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c)) + return; + + // Transform SF with padding + const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); + const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 10) { + sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void k_grouped_fp8_gemm_nt_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::tuple& recipe, + const std::string& compiled_dims) { + // Must be 1D1D kernel + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + + // Shape checks + const auto& [num_groups, m, n] = get_shape<3>(d); + const auto& sum_mk = a.first.numel(); + const auto& sum_nk = b.first.numel(); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(sum_mk == static_cast(sum_k) * m); + DG_HOST_ASSERT(sum_nk == static_cast(sum_k) * n); + + // Contiguity checks + DG_HOST_ASSERT(a.first.is_contiguous()); + DG_HOST_ASSERT(b.first.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous()); + + // Early return for trivial cases + if (early_return(m, n, accumulate(ks.begin(), ks.end(), 0), d, c)) + return; + + // Transform SF with padding + const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); + const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); + + // Allocate tensormap buffer + // `4` means the double buffering for both A and B operands (2 * 2) + const auto& num_sms = device_runtime->get_num_sms(); + const auto& tensor_map_buffer = torch::empty({num_sms * 4 * static_cast(sizeof(CUtensorMap))}, + a.first.options().dtype(torch::kByte)); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer, + cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} +#endif + +#if DG_TENSORMAP_COMPATIBLE +static void bf16_gemm_nt(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + + // C/D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a); + const auto& [n , k_] = get_shape<2>(b); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Early return for trivial cases + if (early_return(m, n, k, d, c)) + return; + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10) { + sm100_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bf16_gemm_nn(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a, b.transpose(0, 1), d, c, compiled_dims); +} + +static void bf16_gemm_tn(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c, compiled_dims); +} + +static void bf16_gemm_tt(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a.transpose(0, 1), b, d, c, compiled_dims); +} + +static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); + + // Type and shape checks + const auto& [m, k] = get_shape<2>(a); + const auto& [num_groups, n, k_] = get_shape<3>(b); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); + + // Layout checks + if (use_psum_layout) { + const auto& [num_groups_] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(num_groups == num_groups_); + } else { + const auto& [m__] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(m == m__); + DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); + } + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + DG_HOST_ASSERT(not use_psum_layout); + sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10) { + sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, + num_groups, m, n, k, major_a, major_b, compiled_dims, + use_psum_layout, expected_m_for_psum_layout); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void m_grouped_bf16_gemm_nn_contiguous(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const bool& use_psum_layout) { + m_grouped_bf16_gemm_nt_contiguous(a, b.transpose(1, 2), + d, grouped_layout, compiled_dims, use_psum_layout, std::nullopt); +} + +static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& masked_m, + const int& expected_m, const std::string& compiled_dims) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto& [num_groups, m, k] = get_shape<3>(a); + const auto& [num_groups_, n, k_] = get_shape<3>(b); + const auto& [num_groups__, m_, n_] = get_shape<3>(d); + const auto& num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else if (arch_major == 10) { + sm100_m_grouped_bf16_gemm_masked(a, b, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::string& compiled_dims) { + // Shape checks + const auto& [num_groups, m, n] = get_shape<3>(d); + const auto& [sum_k_ , m_] = get_shape<2>(a); + const auto& [sum_k__, n_] = get_shape<2>(b); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); + + // Contiguity checks + DG_HOST_ASSERT(a.is_contiguous()); + DG_HOST_ASSERT(b.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous()); + + // Early return for trivial cases + if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c)) + return; + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else if (arch_major == 10) { + sm100_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} +#endif + +static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a); + const auto& [n , k_] = get_shape<2>(b); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + + // Early return for trivial cases + if (early_return(m, n, k, d, c)) + return; + + cublaslt_gemm(a, b, c, d, m, n, k, major_a, major_b); +} + +static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + cublaslt_gemm_nt(a, b.transpose(0, 1), d, c); +} + +static void cublaslt_gemm_tn(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + cublaslt_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c); +} + +static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + cublaslt_gemm_nt(a.transpose(0, 1), b, d, c); +} + +#ifdef DG_USE_PYBIND11 +static void register_apis(pybind11::module_& m) { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + // FP8 FP4 GEMMs + m.def("fp8_fp4_gemm_nt", &fp8_fp4_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_fp4_gemm_nn", &fp8_fp4_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_fp4_gemm_tn", &fp8_fp4_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_fp4_gemm_tt", &fp8_fp4_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_fp4_gemm_nt_contiguous", &m_grouped_fp8_fp4_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false, + py::arg("use_psum_layout") = false, + py::arg("expected_m_for_psum_layout") = std::nullopt); + m.def("m_grouped_fp8_fp4_gemm_nn_contiguous", &m_grouped_fp8_fp4_gemm_nn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false, + py::arg("use_psum_layout") = false); + m.def("m_grouped_fp8_fp4_gemm_nt_masked", &m_grouped_fp8_fp4_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 1, 128), + py::arg("compiled_dims") = "mn"); + m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 1, 128), + py::arg("compiled_dims") = "mn"); + + // FP8 GEMM alias names + m.attr("fp8_gemm_nt") = m.attr("fp8_fp4_gemm_nt"); + m.attr("fp8_gemm_nn") = m.attr("fp8_fp4_gemm_nn"); + m.attr("fp8_gemm_tn") = m.attr("fp8_fp4_gemm_tn"); + m.attr("fp8_gemm_tt") = m.attr("fp8_fp4_gemm_tt"); + m.attr("m_grouped_fp8_gemm_nt_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nt_contiguous"); + m.attr("m_grouped_fp8_gemm_nn_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nn_contiguous"); + m.attr("m_grouped_fp8_gemm_nt_masked") = m.attr("m_grouped_fp8_fp4_gemm_nt_masked"); +#endif + +#if DG_TENSORMAP_COMPATIBLE + // BF16 GEMMs + m.def("bf16_gemm_nt", &bf16_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "nk"); + m.def("bf16_gemm_nn", &bf16_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "nk"); + m.def("bf16_gemm_tn", &bf16_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); + m.def("bf16_gemm_tt", &bf16_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); + m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("compiled_dims") = "nk", + py::arg("use_psum_layout") = false, + py::arg("expected_m_for_psum_layout") = std::nullopt); + m.def("m_grouped_bf16_gemm_nn_contiguous", &m_grouped_bf16_gemm_nn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("compiled_dims") = "nk", + py::arg("use_psum_layout") = false); + m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("compiled_dims") = "nk"); + m.def("k_grouped_bf16_gemm_tn_contiguous", &k_grouped_bf16_gemm_tn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); +#endif + + // cuBLASLt GEMMs + m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); + m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); + m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); + m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); +} +#endif + +} // namespace deep_gemm::gemm diff --git a/deep-gemm/csrc/apis/hyperconnection.hpp b/deep-gemm/csrc/apis/hyperconnection.hpp new file mode 100644 index 00000000..713de4a3 --- /dev/null +++ b/deep-gemm/csrc/apis/hyperconnection.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp" +#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp" +#endif + +namespace deep_gemm::hyperconnection { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +static void tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const std::optional& num_splits) { + // A and B must be K-major, D must be N-major + DG_HOST_ASSERT(get_major_type_ab(a) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(b) == cute::UMMA::Major::K); + check_major_type_cd(d); + + // S must be contiguous + DG_HOST_ASSERT(sqr_sum.is_contiguous()); + + // Type and shape checks + const auto& [m, k ] = get_shape<2>(a); + const auto& [n, k_] = get_shape<2>(b); + if (num_splits.has_value()) { + const auto& [num_splits_, m_, n_] = get_shape<3>(d); + const auto& [num_splits__, m__] = get_shape<2>(sqr_sum); + DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + } else { + const auto& [m_, n_] = get_shape<2>(d); + const auto& [m__] = get_shape<1>(sqr_sum); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + } + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(sqr_sum.scalar_type() == torch::kFloat); + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1); + } else if (arch_major == 10) { + sm100_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +#endif + +#ifdef DG_USE_PYBIND11 +static void register_apis(pybind11::module_& m) { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("sqr_sum"), + py::arg("num_splits") = std::nullopt); +#endif +} +#endif + +} // namespace deep_gemm::hyperconnection diff --git a/deep-gemm/csrc/apis/layout.hpp b/deep-gemm/csrc/apis/layout.hpp new file mode 100644 index 00000000..3ec1c6a6 --- /dev/null +++ b/deep-gemm/csrc/apis/layout.hpp @@ -0,0 +1,122 @@ +#pragma once + +#include "../utils/layout.hpp" +#include "../utils/compatibility.hpp" + +#if DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/smxx_layout.hpp" +#endif + +namespace deep_gemm::layout { + +#if DG_TENSORMAP_COMPATIBLE +static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, + const int& mn, const int& k, + const std::optional>& recipe, + const std::optional>& recipe_ab, + const std::optional& num_groups, + const bool& is_sfa, + const bool& disable_ue8m0_cast) { + const auto& arch_major = device_runtime->get_arch_major(); + + int gran_mn, gran_k; + if (recipe.has_value()) { + DG_HOST_ASSERT(not recipe_ab.has_value()); + gran_mn = is_sfa ? std::get<0>(recipe.value()) : std::get<1>(recipe.value()); + gran_k = std::get<2>(recipe.value()); + } else { + DG_HOST_ASSERT(recipe_ab.has_value()); + std::tie(gran_mn, gran_k) = recipe_ab.value(); + } + + // Pre-transform checks + check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups); + + // (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) + return get_mn_major_tma_aligned_tensor(sf); + + // (FP32, 128, 128) on SM90: no need to transform, check SFB requirements + if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat); + + // (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) { + DG_HOST_ASSERT(not disable_ue8m0_cast); + const auto& broadcasted = gran_mn == 1 ? sf : + sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn)); + return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); + } + + // (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); + + DG_HOST_UNREACHABLE("Unknown SF transformation"); +} + +static std::tuple transform_sf_pair_into_required_layout( + const torch::Tensor& sfa, const torch::Tensor& sfb, + const int& m, const int& n, const int& k, + std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::optional& num_groups_a, + const std::optional& num_groups_b, + const bool& disable_ue8m0_cast = false) { + DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value()); + if (not recipe_a.has_value() and not recipe.has_value()) + recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type()); + const auto transformed_sfa = transform_sf_into_required_layout(sfa, m, k, recipe, recipe_a, num_groups_a, true, disable_ue8m0_cast); + const auto transformed_sfb = transform_sf_into_required_layout(sfb, n, k, recipe, recipe_b, num_groups_b, false, disable_ue8m0_cast); + const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value()); + const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value()); + return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b); +} + +static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::tuple& recipe) { + DG_HOST_ASSERT(sf.dim() == 2); + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + const auto& arch_major = device_runtime->get_arch_major(); + + // FP32 on SM90 + if (sf.scalar_type() == torch::kFloat and arch_major == 9) + return get_mn_major_tma_aligned_tensor(sf); + + // FP32 on SM100 + if (sf.scalar_type() == torch::kFloat and arch_major == 10) + return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks); + + // INT on SM100 + if (sf.scalar_type() == torch::kInt and arch_major == 10) + DG_HOST_UNREACHABLE("Unimplemented"); + + DG_HOST_UNREACHABLE("Unknown cases"); +} + +#endif + +#ifdef DG_USE_PYBIND11 +static void register_apis(pybind11::module_& m) { + +#if DG_TENSORMAP_COMPATIBLE + m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout, + py::arg("sf"), py::arg("mn"), py::arg("k"), + py::arg("recipe") = std::nullopt, py::arg("recipe_ab") = std::nullopt, + py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false, + py::arg("disable_ue8m0_cast") = false); + + m.def("get_tma_aligned_size", &get_tma_aligned_size); + m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor); + m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor); + m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); +#endif + + m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout); +} +#endif + +} // namespace deep_gemm::layout diff --git a/deep-gemm/csrc/apis/runtime.hpp b/deep-gemm/csrc/apis/runtime.hpp new file mode 100644 index 00000000..725cc09c --- /dev/null +++ b/deep-gemm/csrc/apis/runtime.hpp @@ -0,0 +1,56 @@ +#pragma once + +#if DG_TENSORMAP_COMPATIBLE +#include "../jit/compiler.hpp" +#endif +#include "../jit/device_runtime.hpp" + +namespace deep_gemm::runtime { + +static void deep_gemm_set_num_sms(int64_t new_num_sms) { + device_runtime->set_num_sms(static_cast(new_num_sms)); +} + +static int64_t deep_gemm_get_num_sms() { + return device_runtime->get_num_sms(); +} + +static void deep_gemm_set_tc_util(int64_t new_tc_util) { + device_runtime->set_tc_util(static_cast(new_tc_util)); +} + +static int64_t deep_gemm_get_tc_util() { + return device_runtime->get_tc_util(); +} + +static void deep_gemm_init(const std::string& library_root_path, const std::string& cuda_home_path_by_python) { +#if DG_TENSORMAP_COMPATIBLE + Compiler::prepare_init(library_root_path, cuda_home_path_by_python); + KernelRuntime::prepare_init(cuda_home_path_by_python); +#endif +} + +#ifdef DG_USE_PYBIND11 +static void register_apis(pybind11::module_& m) { + m.def("set_num_sms", [&](const int& new_num_sms) { + device_runtime->set_num_sms(new_num_sms); + }); + m.def("get_num_sms", [&]() { + return device_runtime->get_num_sms(); + }); + m.def("set_tc_util", [&](const int& new_tc_util) { + device_runtime->set_tc_util(new_tc_util); + }); + m.def("get_tc_util", [&]() { + return device_runtime->get_tc_util(); + }); + m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) { +#if DG_TENSORMAP_COMPATIBLE + Compiler::prepare_init(library_root_path, cuda_home_path_by_python); + KernelRuntime::prepare_init(cuda_home_path_by_python); +#endif + }); +} +#endif + +} // namespace deep_gemm::runtime diff --git a/deep-gemm/csrc/impl.cu b/deep-gemm/csrc/impl.cu new file mode 100644 index 00000000..ac2458db --- /dev/null +++ b/deep-gemm/csrc/impl.cu @@ -0,0 +1,433 @@ +#include +#include +#include + +#include "../torch-ext/torch_binding.h" +#include "apis/attention.hpp" +#include "apis/einsum.hpp" +#include "apis/hyperconnection.hpp" +#include "apis/gemm.hpp" +#include "apis/layout.hpp" +#include "apis/runtime.hpp" + +using Tensor = at::Tensor; + +// Helper: convert 1D int tensor to std::vector +static std::vector tensor_to_vec_int(const Tensor& t) { + auto cpu_t = t.cpu().contiguous(); + auto ptr = cpu_t.data_ptr(); + return std::vector(ptr, ptr + cpu_t.numel()); +} + +// Helper: reconstruct optional recipe tuple +static std::optional> make_recipe3( + int64_t r0, int64_t r1, int64_t r2, bool has) { + if (!has) return std::nullopt; + return std::make_tuple(static_cast(r0), static_cast(r1), static_cast(r2)); +} + +static std::optional> make_recipe2( + int64_t r0, int64_t r1, bool has) { + if (!has) return std::nullopt; + return std::make_tuple(static_cast(r0), static_cast(r1)); +} + +// Runtime ops + +void deep_gemm_init(const std::string& path, const std::string& cuda_home) { + deep_gemm::runtime::deep_gemm_init(path, cuda_home); +} + +void deep_gemm_set_num_sms(int64_t num_sms) { + deep_gemm::runtime::deep_gemm_set_num_sms(num_sms); +} + +int64_t deep_gemm_get_num_sms() { + return deep_gemm::runtime::deep_gemm_get_num_sms(); +} + +void deep_gemm_set_tc_util(int64_t tc_util) { + deep_gemm::runtime::deep_gemm_set_tc_util(tc_util); +} + +int64_t deep_gemm_get_tc_util() { + return deep_gemm::runtime::deep_gemm_get_tc_util(); +} + +// Layout ops + +int64_t deep_gemm_get_mk_alignment_for_contiguous_layout() { + return deep_gemm::get_mk_alignment_for_contiguous_layout(); +} + +Tensor deep_gemm_get_tma_aligned_size(int64_t mn, int64_t element_size) { + // Returns scalar tensor to satisfy TORCH_LIBRARY (can't return plain int) + auto result = deep_gemm::get_tma_aligned_size(static_cast(mn), static_cast(element_size)); + return torch::tensor(result, torch::kInt64); +} + +Tensor deep_gemm_get_mn_major_tma_aligned_tensor(const Tensor& sf) { + return deep_gemm::get_mn_major_tma_aligned_tensor(sf); +} + +Tensor deep_gemm_get_mn_major_tma_aligned_packed_ue8m0_tensor(const Tensor& sf) { + return deep_gemm::get_mn_major_tma_aligned_packed_ue8m0_tensor(sf); +} + +Tensor deep_gemm_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + const Tensor& sf, const Tensor& ks_tensor, const Tensor& ks_int_tensor) { + auto ks = tensor_to_vec_int(ks_int_tensor); + return deep_gemm::get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks); +} + +Tensor deep_gemm_transform_sf_into_required_layout( + const Tensor& sf, int64_t mn, int64_t k, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_ab_0, int64_t recipe_ab_1, bool has_recipe_ab, + int64_t num_groups, bool has_num_groups, + bool is_sfa, bool disable_ue8m0_cast) { + auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); + auto recipe_ab = make_recipe2(recipe_ab_0, recipe_ab_1, has_recipe_ab); + auto ng = has_num_groups ? std::make_optional(static_cast(num_groups)) : std::nullopt; + return deep_gemm::layout::transform_sf_into_required_layout( + sf, static_cast(mn), static_cast(k), + recipe, recipe_ab, ng, is_sfa, disable_ue8m0_cast); +} + +// GEMM ops - FP8/FP4 + +void deep_gemm_fp8_fp4_gemm_nt( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); + auto recipe_a = make_recipe2(recipe_a_0, recipe_a_1, has_recipe_a); + auto recipe_b = make_recipe2(recipe_b_0, recipe_b_1, has_recipe_b); + deep_gemm::gemm::fp8_fp4_gemm_nt(a, b, d, c, recipe, recipe_a, recipe_b, + compiled_dims, disable_ue8m0_cast); +} + +void deep_gemm_fp8_fp4_gemm_nn( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); + auto recipe_a = make_recipe2(recipe_a_0, recipe_a_1, has_recipe_a); + auto recipe_b = make_recipe2(recipe_b_0, recipe_b_1, has_recipe_b); + deep_gemm::gemm::fp8_fp4_gemm_nn(a, b, d, c, recipe, recipe_a, recipe_b, + compiled_dims, disable_ue8m0_cast); +} + +void deep_gemm_fp8_fp4_gemm_tn( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); + auto recipe_a = make_recipe2(recipe_a_0, recipe_a_1, has_recipe_a); + auto recipe_b = make_recipe2(recipe_b_0, recipe_b_1, has_recipe_b); + deep_gemm::gemm::fp8_fp4_gemm_tn(a, b, d, c, recipe, recipe_a, recipe_b, + compiled_dims, disable_ue8m0_cast); +} + +void deep_gemm_fp8_fp4_gemm_tt( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); + auto recipe_a = make_recipe2(recipe_a_0, recipe_a_1, has_recipe_a); + auto recipe_b = make_recipe2(recipe_b_0, recipe_b_1, has_recipe_b); + deep_gemm::gemm::fp8_fp4_gemm_tt(a, b, d, c, recipe, recipe_a, recipe_b, + compiled_dims, disable_ue8m0_cast); +} + +// GEMM ops - M-grouped FP8/FP4 + +void deep_gemm_m_grouped_fp8_fp4_gemm_nt_contiguous( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& grouped_layout, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast, + bool use_psum_layout, int64_t expected_m_for_psum_layout, + bool has_expected_m_for_psum_layout) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); + auto recipe_a = make_recipe2(recipe_a_0, recipe_a_1, has_recipe_a); + auto recipe_b = make_recipe2(recipe_b_0, recipe_b_1, has_recipe_b); + auto em = has_expected_m_for_psum_layout ? + std::make_optional(static_cast(expected_m_for_psum_layout)) : std::nullopt; + deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_contiguous( + a, b, d, grouped_layout, recipe, recipe_a, recipe_b, + compiled_dims, disable_ue8m0_cast, use_psum_layout, em); +} + +void deep_gemm_m_grouped_fp8_fp4_gemm_nn_contiguous( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& grouped_layout, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast, + bool use_psum_layout) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); + auto recipe_a = make_recipe2(recipe_a_0, recipe_a_1, has_recipe_a); + auto recipe_b = make_recipe2(recipe_b_0, recipe_b_1, has_recipe_b); + deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nn_contiguous( + a, b, d, grouped_layout, recipe, recipe_a, recipe_b, + compiled_dims, disable_ue8m0_cast, use_psum_layout); +} + +void deep_gemm_m_grouped_fp8_fp4_gemm_nt_masked( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& masked_m, int64_t expected_m, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); + auto recipe_a = make_recipe2(recipe_a_0, recipe_a_1, has_recipe_a); + auto recipe_b = make_recipe2(recipe_b_0, recipe_b_1, has_recipe_b); + deep_gemm::gemm::m_grouped_fp8_fp4_gemm_nt_masked( + a, b, d, masked_m, static_cast(expected_m), + recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); +} + +// GEMM ops - K-grouped FP8 + +void deep_gemm_k_grouped_fp8_gemm_tn_contiguous( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& ks_tensor, + const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, + const std::string& compiled_dims) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto ks = tensor_to_vec_int(ks_tensor); + auto recipe = std::make_tuple(static_cast(recipe_0), + static_cast(recipe_1), + static_cast(recipe_2)); + deep_gemm::gemm::k_grouped_fp8_gemm_tn_contiguous( + a, b, d, ks, ks_tensor, c, recipe, compiled_dims); +} + +void deep_gemm_k_grouped_fp8_gemm_nt_contiguous( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& ks_tensor, + const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, + const std::string& compiled_dims) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto ks = tensor_to_vec_int(ks_tensor); + auto recipe = std::make_tuple(static_cast(recipe_0), + static_cast(recipe_1), + static_cast(recipe_2)); + deep_gemm::gemm::k_grouped_fp8_gemm_nt_contiguous( + a, b, d, ks, ks_tensor, c, recipe, compiled_dims); +} + +// GEMM ops - BF16 + +void deep_gemm_bf16_gemm_nt( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, const std::string& compiled_dims) { + deep_gemm::gemm::bf16_gemm_nt(a, b, d, c, compiled_dims); +} + +void deep_gemm_bf16_gemm_nn( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, const std::string& compiled_dims) { + deep_gemm::gemm::bf16_gemm_nn(a, b, d, c, compiled_dims); +} + +void deep_gemm_bf16_gemm_tn( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, const std::string& compiled_dims) { + deep_gemm::gemm::bf16_gemm_tn(a, b, d, c, compiled_dims); +} + +void deep_gemm_bf16_gemm_tt( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, const std::string& compiled_dims) { + deep_gemm::gemm::bf16_gemm_tt(a, b, d, c, compiled_dims); +} + +// GEMM ops - M-grouped BF16 + +void deep_gemm_m_grouped_bf16_gemm_nt_contiguous( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& grouped_layout, const std::string& compiled_dims, + bool use_psum_layout, int64_t expected_m_for_psum_layout, + bool has_expected_m_for_psum_layout) { + auto em = has_expected_m_for_psum_layout ? + std::make_optional(static_cast(expected_m_for_psum_layout)) : std::nullopt; + deep_gemm::gemm::m_grouped_bf16_gemm_nt_contiguous( + a, b, d, grouped_layout, compiled_dims, use_psum_layout, em); +} + +void deep_gemm_m_grouped_bf16_gemm_nn_contiguous( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& grouped_layout, const std::string& compiled_dims, + bool use_psum_layout) { + deep_gemm::gemm::m_grouped_bf16_gemm_nn_contiguous( + a, b, d, grouped_layout, compiled_dims, use_psum_layout); +} + +void deep_gemm_m_grouped_bf16_gemm_nt_masked( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& masked_m, int64_t expected_m, + const std::string& compiled_dims) { + deep_gemm::gemm::m_grouped_bf16_gemm_nt_masked( + a, b, d, masked_m, static_cast(expected_m), compiled_dims); +} + +// GEMM ops - K-grouped BF16 + +void deep_gemm_k_grouped_bf16_gemm_tn_contiguous( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& ks_tensor, const std::optional& c, + const std::string& compiled_dims) { + auto ks = tensor_to_vec_int(ks_tensor); + deep_gemm::gemm::k_grouped_bf16_gemm_tn_contiguous( + a, b, d, ks, ks_tensor, c, compiled_dims); +} + +// GEMM ops - cuBLASLt + +void deep_gemm_cublaslt_gemm_nt( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c) { + deep_gemm::gemm::cublaslt_gemm_nt(a, b, d, c); +} + +void deep_gemm_cublaslt_gemm_nn( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c) { + deep_gemm::gemm::cublaslt_gemm_nn(a, b, d, c); +} + +void deep_gemm_cublaslt_gemm_tn( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c) { + deep_gemm::gemm::cublaslt_gemm_tn(a, b, d, c); +} + +void deep_gemm_cublaslt_gemm_tt( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c) { + deep_gemm::gemm::cublaslt_gemm_tt(a, b, d, c); +} + +// Attention ops + +void deep_gemm_fp8_gemm_nt_skip_head_mid( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, + int64_t head_split_left, int64_t head_split_mid, int64_t head_split_right, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + const std::string& compiled_dims, bool disable_ue8m0_cast) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto head_splits = std::make_tuple(static_cast(head_split_left), + static_cast(head_split_mid), + static_cast(head_split_right)); + auto recipe = make_recipe3(recipe_0, recipe_1, recipe_2, has_recipe); + deep_gemm::attention::fp8_gemm_nt_skip_head_mid( + a, b, d, head_splits, recipe, compiled_dims, disable_ue8m0_cast); +} + +Tensor deep_gemm_fp8_mqa_logits( + const Tensor& q, + const Tensor& kv_data, const Tensor& kv_sf, + const Tensor& weights, + const Tensor& cu_seq_len_k_start, const Tensor& cu_seq_len_k_end, + bool clean_logits, int64_t max_seqlen_k) { + auto kv = std::make_pair(kv_data, kv_sf); + return deep_gemm::attention::fp8_mqa_logits( + q, kv, weights, cu_seq_len_k_start, cu_seq_len_k_end, + clean_logits, static_cast(max_seqlen_k)); +} + +Tensor deep_gemm_get_paged_mqa_logits_metadata( + const Tensor& context_lens, int64_t block_kv, int64_t num_sms) { + return deep_gemm::attention::get_paged_mqa_logits_metadata( + context_lens, static_cast(block_kv), static_cast(num_sms)); +} + +Tensor deep_gemm_fp8_paged_mqa_logits( + const Tensor& q, const Tensor& fused_kv_cache, + const Tensor& weights, const Tensor& context_lens, + const Tensor& block_table, const Tensor& schedule_meta, + int64_t max_context_len, bool clean_logits) { + return deep_gemm::attention::fp8_paged_mqa_logits( + q, fused_kv_cache, weights, context_lens, block_table, schedule_meta, + static_cast(max_context_len), clean_logits); +} + +// Einsum ops + +void deep_gemm_einsum( + const std::string& expr, + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, bool use_cublaslt) { + deep_gemm::einsum::einsum(expr, a, b, d, c, use_cublaslt); +} + +void deep_gemm_fp8_einsum( + const std::string& expr, + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2) { + auto a = std::make_pair(a_data, a_sf); + auto b = std::make_pair(b_data, b_sf); + auto recipe = std::make_tuple(static_cast(recipe_0), + static_cast(recipe_1), + static_cast(recipe_2)); + deep_gemm::einsum::fp8_einsum(expr, a, b, d, c, recipe); +} + +// Hyperconnection ops + +void deep_gemm_tf32_hc_prenorm_gemm( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& sqr_sum, int64_t num_splits, bool has_num_splits) { + auto ns = has_num_splits ? std::make_optional(static_cast(num_splits)) : std::nullopt; + deep_gemm::hyperconnection::tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns); +} diff --git a/deep-gemm/csrc/indexing/main.cu b/deep-gemm/csrc/indexing/main.cu new file mode 100644 index 00000000..1b96da2f --- /dev/null +++ b/deep-gemm/csrc/indexing/main.cu @@ -0,0 +1,30 @@ +// GEMM kernels +#include +#include +#include +#include +#include + +// Attention kernels +#include +#include +#include +#include + +// Einsum kernels +#include +#include + +// Hyperconnection kernels +#include +#include + +// Layout kernels +#include +#include + +using namespace deep_gemm; + +int main() { + return 0; +} diff --git a/deep-gemm/csrc/jit/cache.hpp b/deep-gemm/csrc/jit/cache.hpp new file mode 100644 index 00000000..1e8659fd --- /dev/null +++ b/deep-gemm/csrc/jit/cache.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +#include "kernel_runtime.hpp" + +namespace deep_gemm { + +class KernelRuntimeCache { + std::unordered_map> cache; + +public: + // TODO: consider cache capacity + KernelRuntimeCache() = default; + + std::shared_ptr get(const std::filesystem::path& dir_path) { + // Hit the runtime cache + if (const auto& iterator = cache.find(dir_path); iterator != cache.end()) + return iterator->second; + + if (KernelRuntime::check_validity(dir_path)) + return cache[dir_path] = std::make_shared(dir_path); + return nullptr; + } +}; + +static auto kernel_runtime_cache = std::make_shared(); + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit/compiler.hpp b/deep-gemm/csrc/jit/compiler.hpp new file mode 100644 index 00000000..38d090e7 --- /dev/null +++ b/deep-gemm/csrc/jit/compiler.hpp @@ -0,0 +1,423 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/hash.hpp" +#include "../utils/lazy_init.hpp" +#include "../utils/system.hpp" +#include "cache.hpp" +#include "device_runtime.hpp" + +namespace deep_gemm { + +// Lazy-load NVRTC to avoid link-time dependency on libnvrtc.so. +// kernel-builder doesn't support linking extra CUDA libs yet, so we dlopen +// at runtime — same pattern as the CUDA driver API in jit/handle.hpp. +static void* get_nvrtc_handle() { + static void* handle = nullptr; + if (handle == nullptr) { + handle = dlopen("libnvrtc.so", RTLD_LAZY | RTLD_LOCAL); + if (handle == nullptr) + handle = dlopen("libnvrtc.so.12", RTLD_LAZY | RTLD_LOCAL); + DG_HOST_ASSERT(handle != nullptr and "Failed to load NVRTC library"); + } + return handle; +} + +#define DECL_LAZY_NVRTC_FUNCTION(name) \ +template \ +static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ + using FuncType = decltype(&name); \ + static FuncType func = nullptr; \ + if (func == nullptr) { \ + func = reinterpret_cast(dlsym(get_nvrtc_handle(), #name)); \ + DG_HOST_ASSERT(func != nullptr and "Failed to load NVRTC function"); \ + } \ + return func(std::forward(args)...); \ +} + +DECL_LAZY_NVRTC_FUNCTION(nvrtcVersion); +DECL_LAZY_NVRTC_FUNCTION(nvrtcCreateProgram); +DECL_LAZY_NVRTC_FUNCTION(nvrtcCompileProgram); +DECL_LAZY_NVRTC_FUNCTION(nvrtcGetProgramLogSize); +DECL_LAZY_NVRTC_FUNCTION(nvrtcGetProgramLog); +DECL_LAZY_NVRTC_FUNCTION(nvrtcGetPTXSize); +DECL_LAZY_NVRTC_FUNCTION(nvrtcGetPTX); +DECL_LAZY_NVRTC_FUNCTION(nvrtcGetCUBINSize); +DECL_LAZY_NVRTC_FUNCTION(nvrtcGetCUBIN); +DECL_LAZY_NVRTC_FUNCTION(nvrtcDestroyProgram); +DECL_LAZY_NVRTC_FUNCTION(nvrtcGetErrorString); + +// Redirect nvrtc calls to lazy-loaded versions so NVRTCCompiler is unchanged +#define nvrtcVersion lazy_nvrtcVersion +#define nvrtcCreateProgram lazy_nvrtcCreateProgram +#define nvrtcCompileProgram lazy_nvrtcCompileProgram +#define nvrtcGetProgramLogSize lazy_nvrtcGetProgramLogSize +#define nvrtcGetProgramLog lazy_nvrtcGetProgramLog +#define nvrtcGetPTXSize lazy_nvrtcGetPTXSize +#define nvrtcGetPTX lazy_nvrtcGetPTX +#define nvrtcGetCUBINSize lazy_nvrtcGetCUBINSize +#define nvrtcGetCUBIN lazy_nvrtcGetCUBIN +#define nvrtcDestroyProgram lazy_nvrtcDestroyProgram +#define nvrtcGetErrorString lazy_nvrtcGetErrorString + +class Compiler { +public: + static std::filesystem::path library_root_path; + static std::filesystem::path library_include_path; + static std::filesystem::path cuda_home; + static std::string library_version; + static std::filesystem::path cuobjdump_path; + + static std::string get_library_version() { + const auto dg_include = library_include_path / "deep_gemm"; + if (not std::filesystem::exists(dg_include)) { + // Fallback: hash the root path itself + std::string fallback(library_root_path.string()); + return get_hex_digest(std::vector(fallback.begin(), fallback.end())); + } + std::vector buffer; + for (const auto& f: collect_files(dg_include)) { + std::ifstream in(f, std::ios::binary); + DG_HOST_ASSERT(in.is_open()); + buffer.insert(buffer.end(), + std::istreambuf_iterator(in), + std::istreambuf_iterator()); + } + return get_hex_digest(buffer); + } + + static void prepare_init(const std::string& library_root_path, + const std::string& cuda_home_path_by_python) { + Compiler::library_root_path = library_root_path; + Compiler::library_include_path = Compiler::library_root_path / "include"; + Compiler::cuda_home = cuda_home_path_by_python; + Compiler::library_version = get_library_version(); + Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump"; + } + + std::string signature, flags; + std::filesystem::path cache_dir_path; + + Compiler() { + // Check `prepare_init` + DG_HOST_ASSERT(not library_root_path.empty()); + DG_HOST_ASSERT(not library_include_path.empty()); + DG_HOST_ASSERT(not cuda_home.empty()); + DG_HOST_ASSERT(not library_version.empty()); + DG_HOST_ASSERT(not cuobjdump_path.empty()); + + // Cache settings + cache_dir_path = std::filesystem::path(get_env("HOME")) / ".deep_gemm"; + if (const auto& env_cache_dir_path = get_env("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty()) + cache_dir_path = env_cache_dir_path; + + // The compiler flags applied to all derived compilers + signature = "unknown-compiler"; + flags = fmt::format("-std=c++{} --diag-suppress=39,161,174,177,186,940 " + "--ptxas-options=--register-usage-level=10", + get_env("DG_JIT_CPP_STANDARD", 20)); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0) or get_env("DG_JIT_PTXAS_CHECK", 0)) + flags += " --ptxas-options=--verbose,--warn-on-local-memory-usage"; + if (get_env("DG_JIT_WITH_LINEINFO", 0)) + flags += " -Xcompiler -rdynamic -lineinfo"; + } + + virtual ~Compiler() = default; + + std::filesystem::path make_tmp_dir() const { + return make_dirs(cache_dir_path / "tmp"); + } + + std::filesystem::path get_tmp_file_path() const { + return make_tmp_dir() / get_uuid(); + } + + void put(const std::filesystem::path& path, const std::string& data) const { + const auto tmp_file_path = get_tmp_file_path(); + + // Write into the temporary file + std::ofstream out(tmp_file_path, std::ios::binary); + DG_HOST_ASSERT(out.write(data.data(), data.size())); + out.close(); + + // Atomically replace + std::filesystem::rename(tmp_file_path, path); + } + + std::shared_ptr build(const std::string& name, const std::string& code) const { + const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code); + const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature)); + + // Hit the runtime cache + if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr) + return runtime; + + // Create the kernel directory + make_dirs(dir_path); + + // Compile into a temporary CUBIN + const auto tmp_cubin_path = get_tmp_file_path(); + if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_PTX")) { + // Dump PTX if needed + const auto tmp_ptx_path = get_tmp_file_path(); + compile(code, dir_path, tmp_cubin_path, tmp_ptx_path); + + // Replace into the cache directory + std::filesystem::rename(tmp_ptx_path, dir_path / "kernel.ptx"); + } else { + compile(code, dir_path, tmp_cubin_path); + } + + // Replace into the cache directory + const auto cubin_path = dir_path / "kernel.cubin"; + std::filesystem::rename(tmp_cubin_path, cubin_path); + + // Disassemble if needed + if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_SASS")) { + // Dump into a temporary SASS + const auto tmp_sass_path = get_tmp_file_path(); + disassemble(cubin_path, tmp_sass_path); + + // Replace into the current directory + std::filesystem::rename(tmp_sass_path, dir_path / "kernel.sass"); + } + + // Put into the runtime cache + const auto runtime = kernel_runtime_cache->get(dir_path); + DG_HOST_ASSERT(runtime != nullptr); + return runtime; + } + + static void disassemble(const std::filesystem::path &cubin_path, const std::filesystem::path &sass_path) { + // Disassemble the CUBIN file to SASS + const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str()); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + fprintf(stderr, "Running cuobjdump command: %s\n", command.c_str()); + const auto [return_code, output] = call_external_command(command); + if (return_code != 0) { + fprintf(stderr, "cuobjdump failed: %s\n", output.c_str()); + DG_HOST_ASSERT(false and "cuobjdump failed"); + } + } + + virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path, const std::optional &ptx_path = std::nullopt) const = 0; +}; + +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path); + +class NVCCCompiler final: public Compiler { + std::filesystem::path nvcc_path; + + std::pair get_nvcc_version() const { + DG_HOST_ASSERT(std::filesystem::exists(nvcc_path)); + + // Call the version command + const auto& command = std::string(nvcc_path) + " --version"; + const auto& [return_code, output] = call_external_command(command); + DG_HOST_ASSERT(return_code == 0); + + // Parse "release X.Y" without std::regex + int major = 0, minor = 0; + const char* release_pos = std::strstr(output.c_str(), "release "); + DG_HOST_ASSERT(release_pos != nullptr and "Could not find 'release' in nvcc --version output"); + std::sscanf(release_pos + 8, "%d.%d", &major, &minor); + DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3"); + if (major == 12 and minor < 9) + fprintf(stderr, "Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n"); + return {major, minor}; + } + +public: + NVCCCompiler() { + // Override the compiler signature + nvcc_path = cuda_home / "bin" / "nvcc"; + if (const auto& env_nvcc_path = get_env("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty()) + nvcc_path = env_nvcc_path; + const auto& [nvcc_major, nvcc_minor] = get_nvcc_version(); + signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor); + + // The override the compiler flags + // Only NVCC >= 12.9 supports arch-specific family suffix + const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); + // DG_CUTLASS_INCLUDE is set by Python _find_cutlass_include() before ops.init() + const auto& cutlass_include = get_env("DG_CUTLASS_INCLUDE"); + std::string cutlass_flag = cutlass_include.empty() ? "" : fmt::format(" -I{}", cutlass_include); + flags = fmt::format("{} -I{}{} --gpu-architecture=sm_{} " + "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " + "-O3 --expt-relaxed-constexpr --expt-extended-lambda", + flags, library_include_path.c_str(), cutlass_flag, arch); + + // print flags if ENV is set + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_FLAGS", 0)) + fprintf(stderr, "NVCC compiler flags: %s\n", flags.c_str()); + } + + void compile(const std::string &code, const std::filesystem::path& dir_path, + const std::filesystem::path &cubin_path, + const std::optional &ptx_path) const override { + // Write the code into the cache directory + const auto& code_path = dir_path / "kernel.cu"; + put(code_path, code); + + // Compile + const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + fprintf(stderr, "Running NVCC command: %s\n", command.c_str()); + const auto& [return_code, output] = call_external_command(command); + if (return_code != 0) { + fprintf(stderr, "NVCC compilation failed: %s\n", output.c_str()); + DG_HOST_ASSERT(false and "NVCC compilation failed"); + } + + // Compile to PTX if needed + if (ptx_path.has_value()) { + const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + fprintf(stderr, "Running NVCC PTX command: %s\n", ptx_command.c_str()); + const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command); + if (ptx_return_code != 0) { + fprintf(stderr, "NVCC PTX compilation failed: %s\n", ptx_output.c_str()); + DG_HOST_ASSERT(false and "NVCC PTX compilation failed"); + } + } + + // Check local memory usage (without std::regex — avoids ABI issues) + if (get_env("DG_JIT_PTXAS_CHECK", 0)) + DG_HOST_ASSERT(output.find("Local memory used") == std::string::npos); + + // Print PTXAS log + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) + fprintf(stderr, "%s", output.c_str()); + } +}; + +class NVRTCCompiler final: public Compiler { +public: + NVRTCCompiler() { + // Override the compiler signature + int major, minor; + DG_NVRTC_CHECK(nvrtcVersion(&major, &minor)); + signature = fmt::format("NVRTC{}.{}", major, minor); + DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVRTC version should be >= 12.3"); + + // Build include directories list + std::string include_dirs; + include_dirs += fmt::format("-I{} ", library_include_path.string()); + include_dirs += fmt::format("-I{} ", (cuda_home / "include").string()); + // DG_CUTLASS_INCLUDE is set by Python _find_cutlass_include() before ops.init() + if (const auto& cutlass_include = get_env("DG_CUTLASS_INCLUDE"); not cutlass_include.empty()) + include_dirs += fmt::format("-I{} ", cutlass_include); + + // Add PCH support for version 12.8 and above + // NOTES: PCH is vital for compilation speed + std::string pch_flags; + if (major > 12 or minor >= 8) { + pch_flags = "--pch "; + if (get_env("DG_JIT_DEBUG", 0)) + pch_flags += "--pch-verbose=true "; + } + + // Override the compiler flags + // Only NVRTC >= 12.9 supports arch-specific family suffix + const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9); + flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128", + flags, include_dirs, arch, pch_flags); + } + + void compile(const std::string &code, const std::filesystem::path& dir_path, + const std::filesystem::path &cubin_path, + const std::optional &ptx_path) const override { + // Write the code into the cache directory + const auto& code_path = dir_path / "kernel.cu"; + put(code_path, code); + + // Split flags by whitespace (without std::istringstream — avoids ABI issues) + std::vector options; + { + size_t i = 0; + while (i < flags.size()) { + while (i < flags.size() && (flags[i] == ' ' || flags[i] == '\t')) ++i; + if (i >= flags.size()) break; + size_t start = i; + while (i < flags.size() && flags[i] != ' ' && flags[i] != '\t') ++i; + options.push_back(flags.substr(start, i - start)); + } + } + + // Convert to C-style string array for NVRTC + std::vector option_cstrs; + for (const auto& opt: options) + option_cstrs.push_back(opt.c_str()); + + // Print compiler command if requested + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) { + fprintf(stderr, "Compiling JIT runtime with NVRTC options: "); + for (const auto& opt: options) + fprintf(stderr, "%s ", opt.c_str()); + fprintf(stderr, "\n"); + } + + // Create NVRTC program and compile + nvrtcProgram program; + DG_NVRTC_CHECK(nvrtcCreateProgram(&program, code.c_str(), "kernel.cu", 0, nullptr, nullptr)); + const auto& compile_result = nvrtcCompileProgram(program, static_cast(option_cstrs.size()), option_cstrs.data()); + + // Get and print compiler log + size_t log_size; + DG_NVRTC_CHECK(nvrtcGetProgramLogSize(program, &log_size)); + if (get_env("DG_JIT_DEBUG", 0) or compile_result != NVRTC_SUCCESS) { + if (compile_result != NVRTC_SUCCESS) + DG_HOST_ASSERT(log_size > 1); + if (log_size > 1) { + std::string compilation_log(log_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetProgramLog(program, compilation_log.data())); + fprintf(stderr, "NVRTC log: %s\n", compilation_log.c_str()); + } + } + + if (ptx_path.has_value()) { + // Get PTX size and data if needed + size_t ptx_size; + DG_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size)); + std::string ptx_data(ptx_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetPTX(program, ptx_data.data())); + + // Write into the file system + put(ptx_path.value(), ptx_data); + } + + // Get CUBIN size and data + size_t cubin_size; + DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size)); + std::string cubin_data(cubin_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetCUBIN(program, cubin_data.data())); + + // Write into the file system + put(cubin_path, cubin_data); + + // Cleanup + DG_NVRTC_CHECK(nvrtcDestroyProgram(&program)); + } +}; + +static auto compiler = LazyInit([]() -> std::shared_ptr { + if (get_env("DG_JIT_USE_NVRTC", 0)) { + return std::make_shared(); + } + return std::make_shared(); +}); + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit/device_runtime.hpp b/deep-gemm/csrc/jit/device_runtime.hpp new file mode 100644 index 00000000..d33743ef --- /dev/null +++ b/deep-gemm/csrc/jit/device_runtime.hpp @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/lazy_init.hpp" + +namespace deep_gemm { + +class DeviceRuntime { + int num_sms = 0, tc_util = 0; + std::shared_ptr cached_prop; + + // cuBLASLt utils + static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024; + +public: + // Create the cuBLASLt handle ourselves + cublasLtHandle_t cublaslt_handle{}; + std::shared_ptr cublaslt_workspace; + + explicit DeviceRuntime() { + cublaslt_workspace = std::make_shared(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA))); + DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle)); + } + + ~DeviceRuntime() noexcept(false) { + DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle)); + } + + cublasLtHandle_t get_cublaslt_handle() const { + return cublaslt_handle; + } + + torch::Tensor get_cublaslt_workspace() const { + return *cublaslt_workspace; + } + + std::shared_ptr get_prop() { + if (cached_prop == nullptr) { + int device_idx; + cudaDeviceProp prop; + DG_CUDA_RUNTIME_CHECK(cudaGetDevice(&device_idx)); + DG_CUDA_RUNTIME_CHECK(cudaGetDeviceProperties(&prop, device_idx)); + cached_prop = std::make_shared(prop); + } + return cached_prop; + } + + std::pair get_arch_pair() { + const auto prop = get_prop(); + return {prop->major, prop->minor}; + } + + std::string get_arch(const bool& number_only = false, + const bool& support_arch_family = false) { + const auto& [major, minor] = get_arch_pair(); + if (major == 10 and minor != 1) { + if (number_only) + return "100"; + return support_arch_family ? "100f" : "100a"; + } + return std::to_string(major * 10 + minor) + (number_only ? "" : "a"); + } + + int get_arch_major() { + return get_arch_pair().first; + } + + void set_num_sms(const int& new_num_sms) { + DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount); + num_sms = new_num_sms; + } + + int get_num_sms() { + if (num_sms == 0) + num_sms = get_prop()->multiProcessorCount; + return num_sms; + } + + int get_l2_cache_size() { + return get_prop()->l2CacheSize; + } + + void set_tc_util(const int& new_tc_util) { + DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100); + tc_util = new_tc_util; + } + + int get_tc_util() const { + return tc_util == 0 ? 100 : tc_util; + } +}; + +static auto device_runtime = LazyInit([](){ return std::make_shared(); }); + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit/handle.hpp b/deep-gemm/csrc/jit/handle.hpp new file mode 100644 index 00000000..34447f91 --- /dev/null +++ b/deep-gemm/csrc/jit/handle.hpp @@ -0,0 +1,168 @@ +#pragma once + +#include +#include +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/compatibility.hpp" + +namespace deep_gemm { + +// Lazy loading all driver symbols +static void* get_driver_handle() { + static void* handle = nullptr; + if (handle == nullptr) { + handle = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_LOCAL); + DG_HOST_ASSERT(handle != nullptr and "Failed to load CUDA driver `libcuda.so.1`"); + } + return handle; +} + +// Macro to define wrapper functions named `lazy_cu{API name}` +#define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \ +template \ +static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ + using FuncType = decltype(&name); \ + static FuncType func = nullptr; \ + if (func == nullptr) { \ + func = reinterpret_cast(dlsym(get_driver_handle(), #name)); \ + DG_HOST_ASSERT(func != nullptr and "Failed to load CUDA driver API"); \ + } \ + return func(std::forward(args)...); \ +} + +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorName); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorString); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled); + +#if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API) + +// Use CUDA runtime API +using LibraryHandle = cudaLibrary_t; +using KernelHandle = cudaKernel_t; +using LaunchConfigHandle = cudaLaunchConfig_t; +using LaunchAttrHandle = cudaLaunchAttribute; + +#define DG_CUDA_UNIFIED_CHECK DG_CUDA_RUNTIME_CHECK + +static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name, + LibraryHandle *library_opt = nullptr) { + LibraryHandle library; + KernelHandle kernel{}; + DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, func_name.c_str())); + + if (library_opt != nullptr) + *library_opt = library; + return kernel; +} + +static void unload_library(const LibraryHandle& library) { + const auto& error = cudaLibraryUnload(library); + DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading); +} + +static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, + const cudaStream_t& stream, const int& smem_size, + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + if (smem_size > 0) + DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + LaunchConfigHandle config; + config.gridDim = grid_dim; + config.blockDim = block_dim; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + config.numAttrs = 0; + config.attrs = nullptr; + + // NOTES: must use `static` or the `attr` will be deconstructed + static LaunchAttrHandle attr; + if (cluster_dim > 1) { + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {static_cast(cluster_dim), 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + } + return config; +} + +template +static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) { + void *ptr_args[] = { &args... }; + return cudaLaunchKernelExC(&config, kernel, ptr_args); +} + +#else + +// Use CUDA driver API +using LibraryHandle = CUmodule; +using KernelHandle = CUfunction; +using LaunchConfigHandle = CUlaunchConfig; +using LaunchAttrHandle = CUlaunchAttribute; + +#define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK + +static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name, + LibraryHandle *library_opt = nullptr) { + LibraryHandle library; + KernelHandle kernel; + DG_CUDA_DRIVER_CHECK(lazy_cuModuleLoad(&library, cubin_path.c_str())); + DG_CUDA_DRIVER_CHECK(lazy_cuModuleGetFunction(&kernel, library, func_name.c_str())); + + if (library_opt != nullptr) + *library_opt = library; + return kernel; +} + +static void unload_library(const LibraryHandle& library) { + const auto& error = lazy_cuModuleUnload(library); + DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED); +} + +static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, + const cudaStream_t& stream, const int& smem_size, + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + if (smem_size > 0) + DG_CUDA_DRIVER_CHECK(lazy_cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); + + LaunchConfigHandle config; + config.gridDimX = grid_dim.x; + config.gridDimY = grid_dim.y; + config.gridDimZ = grid_dim.z; + config.blockDimX = block_dim.x; + config.blockDimY = block_dim.y; + config.blockDimZ = block_dim.z; + config.sharedMemBytes = smem_size; + config.hStream = stream; + config.numAttrs = 0; + config.attrs = nullptr; + + // NOTES: must use `static` or the `attr` will be deconstructed + static LaunchAttrHandle attr; + if (cluster_dim > 1) { + attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + attr.value.clusterDim.x = cluster_dim; + attr.value.clusterDim.y = 1; + attr.value.clusterDim.z = 1; + config.attrs = &attr; + config.numAttrs = 1; + } + return config; +} + +template +static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) { + void *ptr_args[] = { &args... }; + return lazy_cuLaunchKernelEx(&config, kernel, ptr_args, nullptr); +} +#endif + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit/kernel_runtime.hpp b/deep-gemm/csrc/jit/kernel_runtime.hpp new file mode 100644 index 00000000..60563a1d --- /dev/null +++ b/deep-gemm/csrc/jit/kernel_runtime.hpp @@ -0,0 +1,123 @@ +#pragma once + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/system.hpp" +#include "device_runtime.hpp" +#include "handle.hpp" + +namespace deep_gemm { + +struct LaunchArgs { + std::pair grid_dim; + int num_threads; + int smem_size; + int cluster_dim; + + LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): + grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} + + LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1): + grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} +}; + +class KernelRuntime final { +public: + static std::filesystem::path cuda_home; + + LibraryHandle library; + KernelHandle kernel; + + explicit KernelRuntime(const std::filesystem::path& dir_path) { + // Check `prepare_init` + DG_HOST_ASSERT(not cuda_home.empty()); + + // NOLINT(*-pro-type-member-init) + const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump"; + const auto& cubin_path = dir_path / "kernel.cubin"; + if (get_env("DG_JIT_DEBUG")) + fprintf(stderr, "Loading CUBIN: %s\n", cubin_path.c_str()); + + // Find the only symbol + // TODO: use kernel enumeration for newer drivers + const std::vector illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"}; + const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str())); + DG_HOST_ASSERT(exit_code == 0); + // Parse line-by-line without std::istringstream + std::vector symbol_names; + size_t pos = 0; + while (pos < symbols.size()) { + size_t eol = symbols.find('\n', pos); + if (eol == std::string::npos) eol = symbols.size(); + std::string line = symbols.substr(pos, eol - pos); + pos = eol + 1; + if (line.find("STT_FUNC") == 0 and line.find("STO_ENTRY") != std::string::npos and + std::none_of(illegal_names.begin(), illegal_names.end(), + [&](const auto& name) { return line.find(name) != std::string::npos; })) { + const auto& last_space = line.rfind(' '); + symbol_names.push_back(line.substr(last_space + 1)); + } + } + if (get_env("DG_JIT_DEBUG")) { + fprintf(stderr, "Symbol names: "); + for (const auto& symbol: symbol_names) + fprintf(stderr, "%s, ", symbol.c_str()); + fprintf(stderr, "\n"); + } + + // Load from the library + DG_HOST_ASSERT(symbol_names.size() == 1); + kernel = load_kernel(cubin_path, symbol_names[0], &library); + } + + static void prepare_init(const std::string& cuda_home_path_by_python) { + cuda_home = cuda_home_path_by_python; + } + + static bool check_validity(const std::filesystem::path& dir_path) { + return std::filesystem::exists(dir_path / "kernel.cu") and + std::filesystem::exists(dir_path / "kernel.cubin"); + } + + ~KernelRuntime() noexcept(false) { + unload_library(library); + } +}; + +DG_DECLARE_STATIC_VAR_IN_CLASS(KernelRuntime, cuda_home); + +template +class LaunchRuntime { +public: + template + static std::string generate(const Args& args) { + const auto& code = Derived::generate_impl(args); + if (get_env("DG_JIT_DEBUG", 0)) + fprintf(stderr, "Generated kernel code: %s\n", code.c_str()); + return code; + } + + template + static void launch(const std::shared_ptr& kernel_runtime, const Args& args) { + const auto& kernel = kernel_runtime->kernel; + const auto& stream = at::cuda::getCurrentCUDAStream(); + const LaunchArgs& launch_args = args.launch_args; + + const dim3& grid_dim = {static_cast(launch_args.grid_dim.first), + static_cast(launch_args.grid_dim.second), + 1}; + const dim3& block_dim = {static_cast(launch_args.num_threads), 1, 1}; + auto config = construct_launch_config(kernel, stream, launch_args.smem_size, + grid_dim, block_dim, launch_args.cluster_dim); + + // Launch in the derived class + if (get_env("DG_JIT_DEBUG")) { + fprintf(stderr, "Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n", + launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads, + launch_args.smem_size, launch_args.cluster_dim, stream.id()); + } + Derived::launch_impl(kernel, config, args); + } +}; + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/heuristics/common.hpp b/deep-gemm/csrc/jit_kernels/heuristics/common.hpp new file mode 100644 index 00000000..a49584f4 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/heuristics/common.hpp @@ -0,0 +1,339 @@ +#pragma once + +#include + +#include "../../utils/math.hpp" +#include "../../utils/layout.hpp" +#include "../../utils/system.hpp" + +namespace deep_gemm { + +struct MulticastConfig { + int num_multicast; + bool is_multicast_on_a; + + MulticastConfig(const int& num_multicast, const bool& is_multicast_on_a): + num_multicast(num_multicast), is_multicast_on_a(is_multicast_on_a) { + DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2); + } +}; + +struct SharedMemoryConfig { + int smem_size; + int swizzle_a_mode; + int swizzle_b_mode; + int swizzle_cd_mode; +}; + +struct ThreadConfig { + int num_threads; + + // SM90 + int num_tma_threads; + int num_math_threads; + + // SM100 + int num_non_epilogue_threads; + int num_epilogue_threads; + + static ThreadConfig sm90(const int& num_tma_threads, + const int& num_math_threads) { + auto config = ThreadConfig(); + config.num_threads = num_tma_threads + num_math_threads; + config.num_tma_threads = num_tma_threads; + config.num_math_threads = num_math_threads; + return config; + } + + static ThreadConfig sm100(const int& num_non_epilogue_threads, + const int& num_epilogue_threads) { + auto config = ThreadConfig(); + config.num_threads = num_non_epilogue_threads + num_epilogue_threads; + config.num_non_epilogue_threads = num_non_epilogue_threads; + config.num_epilogue_threads = num_epilogue_threads; + return config; + } +}; + +struct GemmConfig { + // Templated configs + GemmType gemm_type; + KernelType kernel_type; + MmaKind mma_kind; + at::ScalarType a_dtype, b_dtype, cd_dtype; + cute::UMMA::Major major_a; + cute::UMMA::Major major_b; + bool with_accumulation; + int block_m, block_n, block_k; + int num_stages, num_last_stages; + + // Templated device configs + int num_sms; + int tc_util; + + // Structured configs + MulticastConfig multicast_config; + SharedMemoryConfig smem_config; + ThreadConfig thread_config; +}; + +static bool is_multicast_legal(const int& shape_dim, const int& block_dim, + const int& num_multicast, const int& num_sms, + const bool& require_divisible) { + const bool& divisible = ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible; + return divisible and num_sms % num_multicast == 0; +} + +template +static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) { + // `> 0` means interleaving + // 16B actually means non-swizzling (but interleaving) + for (const int& mode: {128, 64, 32, 16}) { + if ((block_size * static_cast(elem_size)) % mode == 0) + return mode; + } + DG_HOST_UNREACHABLE("Unreachable"); +} + +template +static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const KernelType& kernel_type, + const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype, + const int& num_stages, const MulticastConfig& multicast_config) { + const int& ab_elem_size = static_cast(get_element_size(mma_kind)); + const int& cd_elem_size = static_cast(c10::elementSize(cd_dtype)); + + const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m); + const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n); + const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size); + const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size); + const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0; + + // Different archs have different epilogue pipelines + const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype); + + // A/B shared memory + const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size; + const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size; + + // SF shared memory + const auto& [smem_sfa_per_stage, smem_sfb_per_stage] = + ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, mma_kind, cd_dtype); + const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k); + + // M-barriers and tensor memory pointers + const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages); + const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size(); + const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type); + + // Sum them up + int smem_size = 0; + smem_size += smem_tensor_map; + smem_size += smem_cd; + smem_size += num_stages * smem_a_per_stage; + smem_size += num_stages * smem_b_per_stage; + smem_size += num_stages * smem_sfa_per_stage; + smem_size += num_stages * smem_sfb_per_stage; + smem_size += smem_extra_sfb; + smem_size += smem_barrier; + smem_size += smem_tmem_ptr; + + return SharedMemoryConfig { + .smem_size = smem_size, + .swizzle_a_mode = swizzle_a_mode, + .swizzle_b_mode = swizzle_b_mode, + .swizzle_cd_mode = swizzle_cd_mode, + }; +} + +template +static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type, + const int& m, const int& n, const int& k, const int& num_groups, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const at::ScalarType& a_dtype, const at::ScalarType& b_dtype, + const at::ScalarType& cd_dtype, + const bool& with_accumulation, const int& num_sms) { + const auto mma_kind = (a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4); + if (mma_kind == MmaKind::BF16) { + DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16); + } else { + DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4); + DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4); + } + DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); + + // Select M/N block sizes + auto block_ms = ArchSpec::get_block_m_candidates(kernel_type, major_a, m); + if (gemm_type == GemmType::MGroupedContiguous) + block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; + if (gemm_type == GemmType::MGroupedMasked or gemm_type == GemmType::MGroupedContiguousWithPsumLayout) + block_ms = std::vector{64, 128}; // Exclude 256 for performance + auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype); + + // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B + // TODO: Optimize it + if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN) + block_ms = std::vector{128}; + if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN) + block_ns = std::vector{128}; + + // K block size is selected in a fixed manner + const auto& block_k = (mma_kind == MmaKind::BF16 ? 64 : 128); + + // Some util functions + const auto& get_num_blocks = [=](const int& block_m, const int& block_n) { + return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups; + }; + const auto& get_num_waves = [=](const int& block_m, const int& block_n) { + return ceil_div(get_num_blocks(block_m, block_n), num_sms); + }; + const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) { + const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms; + return num_last_blocks == 0 ? num_sms : num_last_blocks; + }; + + // Decide block sizes by waves + int best_block_m = 0, best_block_n = 0; + int best_num_waves = 0, best_last_util = 0; + for (const auto& block_m: block_ms) { + for (const auto& block_n: block_ns) { + const int& num_waves = get_num_waves(block_m, block_n); + const auto& last_util = get_last_wave_util(block_m, block_n); + if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, mma_kind, cd_dtype, m, n, k, block_m, block_n, block_k)) + continue; + + bool success = false; + if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) { + success = true; + } else if (num_waves == best_num_waves) { + // Check last wave utilization + success = last_util > best_last_util; + if (last_util == best_last_util) { + // Case 1: same `block_m`, smaller `block_n` (wasted) + success |= block_m == best_block_m and block_n < best_block_n; + // Case 2: same `block_n`, smaller `block_m` (wasted) + success |= block_n == best_block_n and block_m < best_block_m; + // Case 3: different for both `block_m` and `block_n`, larger `block_n` is better + // NOTES: don't pick `block_m/block_n` larger than shape `m/n` in this case + success |= block_m != best_block_m and block_n > best_block_n + and block_n <= n and block_m <= m; + } + } + + // Replace with the new config if successful + if (success) { + best_block_m = block_m, best_block_n = block_n; + best_num_waves = num_waves, best_last_util = last_util; + } + } + } + DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0); + + // Decide the number of TMA multicasts and whether broadcast on A + MulticastConfig best_multicast_config = {1, false}; + auto [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( + gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms); + + // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B + // TODO: Optimize it + if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN) + is_legal_on_a = false; + if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN) + is_legal_on_b = false; + + const bool is_legal[2] = {is_legal_on_b, is_legal_on_a}; + bool order[2] = {false, true}; + if (best_block_m > best_block_n) + std::swap(order[0], order[1]); + for (const bool& is_multicast_on_a: order) { + if (m >= 512 and is_legal[static_cast(is_multicast_on_a)]) { + best_multicast_config = {2, is_multicast_on_a}; + break; + } + } + + // Always pick the largest number of stage + constexpr int smem_capacity = ArchSpec::smem_capacity; + int best_num_stages = 0; + SharedMemoryConfig best_smem_config; + for (int num_stages = 32; num_stages > 0; -- num_stages) { + if (not ArchSpec::is_num_stages_legal(mma_kind, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) + continue; + + best_smem_config = get_smem_config(gemm_type, kernel_type, + m, n, k, + best_block_m, best_block_n, block_k, + major_a, major_b, + mma_kind, cd_dtype, + num_stages, best_multicast_config); + if (best_smem_config.smem_size <= smem_capacity) { + best_num_stages = num_stages; + break; + } + } + DG_HOST_ASSERT(best_num_stages != 0); + + // Recompute the minimal number of SMs required + // NOTES: less L2 cache usage and less GPU frequency drop + int num_min_sms = num_sms; + if (get_env("DG_JIT_MINIMIZE_NUM_SMS", 0)) { + num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves); + num_min_sms = align(num_min_sms, best_multicast_config.num_multicast); + DG_HOST_ASSERT(num_min_sms <= num_sms); + } + + const auto& config = GemmConfig { + .gemm_type = gemm_type, + .kernel_type = kernel_type, + .mma_kind = mma_kind, + .a_dtype = a_dtype, + .b_dtype = b_dtype, + .cd_dtype = cd_dtype, + .major_a = major_a, + .major_b = major_b, + .with_accumulation = with_accumulation, + .block_m = best_block_m, + .block_n = best_block_n, + .block_k = block_k, + .num_stages = best_num_stages, + .num_last_stages = ceil_div(k, block_k) % best_num_stages, + .num_sms = num_min_sms, + .tc_util = device_runtime->get_tc_util(), + .multicast_config = best_multicast_config, + // ReSharper disable once CppLocalVariableMightNotBeInitialized + .smem_config = best_smem_config, + .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n) + }; + + // Only SM100 BF16 kernels support tensor core control + if (config.tc_util < 100) + DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and mma_kind == MmaKind::BF16); + + // Print configs for the first time + if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { + auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b, + mma_kind, a_dtype, b_dtype, cd_dtype, with_accumulation, num_sms); + static std::set printed; + if (printed.count(key) == 0) { + printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, " + "A major: %d, B major: %d, MMA kind: %d, A dtype: %s, B dtype: %s, CD dtype: %s, accumulation: %d, " + "SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, " + "SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, " + "swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n", + static_cast(gemm_type), static_cast(kernel_type), m, n, k, num_groups, + static_cast(major_a), static_cast(major_b), static_cast(mma_kind), + c10::toString(a_dtype), c10::toString(b_dtype), c10::toString(cd_dtype), + static_cast(with_accumulation), num_sms, best_block_m, best_block_n, block_k, + best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast, + static_cast(best_multicast_config.is_multicast_on_a), + best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode, + best_smem_config.swizzle_cd_mode, config.num_sms, config.thread_config.num_threads, config.tc_util); + printed.insert(key); + } + } + return config; +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/heuristics/sm100.hpp b/deep-gemm/csrc/jit_kernels/heuristics/sm100.hpp new file mode 100644 index 00000000..dd1e6024 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/heuristics/sm100.hpp @@ -0,0 +1,167 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +struct SM100ArchSpec { + static constexpr int smem_capacity = 232448; + + static std::vector get_block_m_candidates(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const int& m) { + std::vector candidates{128, 256}; + if ((kernel_type == KernelType::Kernel1D1D or kernel_type == KernelType::KernelNoSF) and major_a == cute::UMMA::Major::K) { + // NOTES: `block_m = 32/64` is smaller than `LAYOUT_AD_M`, should be careful in handling this + if (m <= 32) candidates.push_back(32); + if (m <= 64) candidates.push_back(64); + } + return candidates; + } + + static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) { + // 16 is for better SM usage + // Stride 32 is due to low-performance swizzle-16/32B + std::vector candidates = {16}; + for (int i = 32; i <= 256; i += 32) + candidates.push_back(i); + return candidates; + } + + static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) { + return block_m / (config.is_multicast_on_a ? config.num_multicast : 1); + } + + static int get_ab_load_block_n(const MulticastConfig& config, const int& block_n) { + return block_n / (config.is_multicast_on_a ? 1 : config.num_multicast); + } + + static int get_cd_store_block_m(const int& block_m) { + constexpr int layout_ad_m = 128; + return std::min(block_m, layout_ad_m); + } + + static int get_cd_store_block_n(const int& block_n) { + return block_n; + } + + static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) { + return true; + } + + static std::pair get_sf_uttcp_aligned_block_sizes( + const int& block_m, const int& block_n, const MmaKind& mma_kind) { + constexpr int num_utccp_aligned_elems = 128; + switch (mma_kind) { + case MmaKind::BF16: return {0, 0}; + case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; + default: DG_HOST_UNREACHABLE("Unknown dtype"); + } + } + + static bool is_block_size_legal(const KernelType& kernel_type, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype, + const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k) { + // Layout A/D does not support `block_n % 16 != 0` + if (block_n % 16 != 0) + return false; + + // Performance is lower with 1D1D and `block_m == 256` + if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m > 128) + return false; + + // For small K, fewer store blocks improve store/compute overlap and reduce epilogue bottleneck + if (k <= 256 and (block_n > 128 or block_m > 128)) + return false; + + // Check tensor memory validity + int sf_block_m = 0, sf_block_n = 0; + if (kernel_type == KernelType::Kernel1D1D) { + const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); + sf_block_m = sf_block_m_, sf_block_n = sf_block_n_; + } + if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512) + return false; + + // NOTES: when B is MN-major, we restrict `block_n` to multiples of 64, + // since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA + return major_b == cute::UMMA::Major::K or (block_n * get_element_size(mma_kind)) % 64 == 0; + } + + static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype, + const int& num_stages, + const int& block_m, const int& block_n, const int& block_k) { + return true; + } + + static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, + const int& m, const int& n, const int& block_m, const int& block_n, + const int& num_sms) { + // TODO: support other layouts + return { + false, + is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous + or (gemm_type == GemmType::Batched and num_groups <= 32)), + }; + } + + static ThreadConfig get_thread_config(const KernelType& kernel_type, + const int& block_m, const int& block_n) { + return ThreadConfig::sm100(128, 128); + } + + static int get_smem_cd_size(const KernelType& kernel_type, + const int& block_m, const int& block_n, + const int& swizzle_cd_mode, + const at::ScalarType& cd_dtype) { + constexpr static int layout_ad_m = 128; + return std::min(block_m, layout_ad_m) * swizzle_cd_mode * 2; + } + + static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, + const int& block_m, const int& block_n, const int& block_k, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype) { + if (mma_kind == MmaKind::BF16) + return {0, 0}; + + int smem_sfa_per_stage = 0; + int smem_sfb_per_stage = 0; + if (kernel_type == KernelType::Kernel1D1D) { + const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); + smem_sfa_per_stage = sf_block_m * 4; + smem_sfb_per_stage = sf_block_n * 4; + } else { + smem_sfa_per_stage = block_m * 4; + smem_sfb_per_stage = 0; + } + return {smem_sfa_per_stage, smem_sfb_per_stage}; + } + + static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k) { + return 0; + } + + static int get_barrier_smem_size(const int& num_stages) { + // TODO: remove SF barriers for BF16 GEMMs + // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers + // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages + // NOTES: the last barrier is for tensor core utilization control + return num_stages * 8 * 3 + 2 * 8 * 2 + 8; + } + + static int get_tmem_ptr_smem_size() { + return 4; + } + + static int get_tensormap_smem_size(const GemmType& gemm_type) { + return 0; + } +}; + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/heuristics/sm90.hpp b/deep-gemm/csrc/jit_kernels/heuristics/sm90.hpp new file mode 100644 index 00000000..2fd2e9ec --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/heuristics/sm90.hpp @@ -0,0 +1,164 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" + +namespace deep_gemm { + +struct SM90ArchSpec { + static constexpr int smem_capacity = 232448; + + static std::vector get_block_m_candidates(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const int& m) { + std::vector candidates{64, 128, 256}; + if ((kernel_type == KernelType::Kernel1D2D or kernel_type == KernelType::KernelNoSF) and major_a == cute::UMMA::Major::K) { + // NOTES: `block_m = 16/32` is smaller than MMA M size, should be careful in handling this + if (m <= 16) candidates.push_back(16); + if (m <= 32) candidates.push_back(32); + } + return candidates; + } + + static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) { + int start = 16; + + // Avoid bank conflicts for 1D1D kernel FP32 output + std::vector candidates; + if (kernel_type == KernelType::Kernel1D1D and cd_dtype == torch::kFloat) { + candidates.push_back(16); + start = 24; + } + + // Push the strided options + for (int i = start; i <= 256; i += 16) + candidates.push_back(i); + return candidates; + } + + static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) { + return block_m; + } + + static int get_ab_load_block_n(const MulticastConfig& multicast_config, const int& block_n) { + return block_n; + } + + static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) { + constexpr int wgmma_m = 64; + return single_warpgroup_sync ? wgmma_m : block_m; + } + + static int get_cd_store_block_n(const int& block_n) { + return block_n; + } + + static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) { + return cd_dtype != torch::kFloat; + } + + static bool is_block_size_legal(const KernelType& kernel_type, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype, + const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k) { + // SM90 FP32 output does not support `block_m == 256` + if (cd_dtype == at::kFloat and block_m == 256) + return false; + + // Avoid large C/D shared memory for FP32 output + // Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel) + if (block_n > 128 and cd_dtype == torch::kFloat) { + if (kernel_type == KernelType::Kernel1D1D and block_n > 152) + return false; + if (kernel_type == KernelType::KernelNoSF and block_n > 200) + return false; + } + + // When B is N Major, use swizzle 128B for better performance; only affects SM90 BF16 GEMM + if (major_b == cute::UMMA::Major::MN and block_n >= 128 and block_n % 64 != 0) + return false; + + // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k` + // Or too many register spills + if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192)) + return false; + + // The block sizes cannot be too large (for enough registers), so at least one dim less than 128 + return block_m <= 128 or block_n <= 128; + } + + static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype, + const int& num_stages, + const int& block_m, const int& block_n, const int& block_k) { + // Unrolling both stages and `num_former_iters` will cause large code size + if (mma_kind == MmaKind::MXFP8FP4 and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4) + return num_stages <= 4; + return true; + } + + static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, + const int& m, const int& n, const int& block_m, const int& block_n, + const int& num_sms) { + // Disable multicast when the number of k-groups is large (a heuristic) + if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4) + return {false, false}; + + if (gemm_type == GemmType::Batched) + return {false, false}; + + return { + is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked), + // For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even + is_multicast_legal(m, block_m, 2, num_sms, false) + and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true)) + }; + } + + static ThreadConfig get_thread_config(const KernelType& kernel_type, + const int& block_m, const int& block_n) { + return ThreadConfig::sm90(128, (block_m <= 64 ? 1 : 2) * 128); + } + + static int get_smem_cd_size(const KernelType& kernel_type, + const int& block_m, const int& block_n, + const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) { + // NOTES: 1024 is for TMA swizzling alignment requirement + return align(block_m * block_n * static_cast(c10::elementSize(cd_dtype)), 1024); + } + + static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, + const int& block_m, const int& block_n, const int& block_k, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype) { + if (mma_kind == MmaKind::BF16) + return {0, 0}; + + // NOTES: 128 is for 2D TMA alignment requirement + int smem_sfa_per_stage = align(block_m * static_cast(sizeof(float)), 128); + int smem_sfb_per_stage = 0; + if (kernel_type == KernelType::Kernel1D1D) + smem_sfb_per_stage = align(block_n * 4, 128); + return {smem_sfa_per_stage, smem_sfb_per_stage}; + } + + static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k, + const int& block_m, const int& block_n, const int& block_k) { + const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2; + return align(ceil_div(k, block_k) * static_cast(sizeof(float)) * use_uniform_sfb, 8); + } + + static int get_barrier_smem_size(const int& num_stages) { + return num_stages * 8 * 2; + } + + static int get_tmem_ptr_smem_size() { + return 0; + } + + static int get_tensormap_smem_size(const GemmType& gemm_type) { + return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast(sizeof(CUtensorMap)) : 0; + } +}; + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/epilogue.hpp b/deep-gemm/csrc/jit_kernels/impls/epilogue.hpp new file mode 100644 index 00000000..bd21de10 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/epilogue.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +static std::string get_default_epilogue_type(const std::optional& epilogue_type) { + return epilogue_type.value_or("EpilogueIdentity"); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp b/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp new file mode 100644 index 00000000..677a89ba --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/runtime_utils.hpp @@ -0,0 +1,239 @@ +#pragma once + +#include +#include + +#include "../heuristics/sm90.hpp" +#include "../../jit/handle.hpp" +#include "../../utils/math.hpp" +#include "../../utils/system.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +static std::pair get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) { + return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k); +} + +static int get_non_contiguous_dim(const cute::UMMA::Major& major) { + return major == cute::UMMA::Major::K ? -2 : -1; +} + +static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) { + for (const char& c: compiled_dims) { + if (name == c) + return dim; + } + return 0; +} + +static std::string to_string(const cute::UMMA::Major& major) { + switch (major) { + case cute::UMMA::Major::K: return "cute::UMMA::Major::K"; + case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN"; + } + DG_HOST_UNREACHABLE("Unknown major"); +} + +static std::string to_string(const GemmType& type) { + switch (type) { + case GemmType::Normal: return "GemmType::Normal"; + case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous"; + case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked"; + case GemmType::MGroupedContiguousWithPsumLayout: return "GemmType::MGroupedContiguousWithPsumLayout"; + case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous"; + case GemmType::Batched: return "GemmType::Batched"; + } + DG_HOST_UNREACHABLE("Unknown GEMM type"); +} + +static std::string to_string(const at::ScalarType& dtype) { + switch (dtype) { + case torch::kInt: return "int"; + case torch::kFloat: return "float"; + case torch::kBFloat16: return "cutlass::bfloat16_t"; + case torch::kFloat8_e4m3fn: return "cutlass::float_e4m3_t"; + case kPackedFP4: return "cutlass::detail::float_e2m1_unpacksmem_t"; + default: DG_HOST_UNREACHABLE("Unsupported dtype"); + } +} + +static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype, + const bool& allow_tf32) { + if (allow_tf32 and dtype == torch::kFloat) + return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; + + switch (dtype) { + case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32; + case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8; +#if CUDART_VERSION >= 12080 + case kPackedFP4: return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; +#endif + default: DG_HOST_UNREACHABLE("Unsupported dtype"); + } +} + +static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) { +#if CUDART_VERSION >= 12080 + if (base != 0) { + DG_HOST_ASSERT(base == 32 and mode == 128); + return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; + } +#endif + + DG_HOST_ASSERT(base == 0); + switch (mode) { + case 0: + case 16: return CU_TENSOR_MAP_SWIZZLE_NONE; + case 32: return CU_TENSOR_MAP_SWIZZLE_32B; + case 64: return CU_TENSOR_MAP_SWIZZLE_64B; + case 128: return CU_TENSOR_MAP_SWIZZLE_128B; + default: DG_HOST_UNREACHABLE("Unsupported swizzling mode"); + } +} + +static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, + int gmem_inner_dim, int gmem_outer_dim, + int smem_inner_dim, int smem_outer_dim, + const int& gmem_outer_stride, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + const auto& elem_size = static_cast(t.element_size()); + if (swizzle_mode != 0) + smem_inner_dim = swizzle_mode / elem_size; + + // Inner dim must be a multiple of 64B for .b4x16_p64 + if (t.scalar_type() == kPackedFP4) + DG_HOST_ASSERT(gmem_inner_dim % 128 == 0); + + CUtensorMap tensor_map; + const cuuint64_t gmem_dims[2] = {static_cast(gmem_inner_dim), static_cast(gmem_outer_dim)}; + const cuuint32_t smem_dims[2] = {static_cast(smem_inner_dim), static_cast(smem_outer_dim)}; + const cuuint64_t gmem_strides[1] = {static_cast(gmem_outer_stride * elem_size), }; + const cuuint32_t elem_strides[2] = {1, 1}; + if (get_env("DG_JIT_DEBUG")) { + printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n", + gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, + gmem_outer_stride, swizzle_mode, swizzle_base, elem_size); + } + DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled( + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32), + 2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, + CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), + CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensor_map; +} + +static CUtensorMap make_tma_3d_desc(const torch::Tensor& t, + int gmem_dim_0, int gmem_dim_1, int gmem_dim_2, + int smem_dim_0, int smem_dim_1, int smem_dim_2, + const int& gmem_stride_0, const int& gmem_stride_1, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + const auto& elem_size = static_cast(t.element_size()); + if (swizzle_mode != 0) + smem_dim_0 = swizzle_mode / elem_size; + + // Inner dim must be a multiple of 64B for .b4x16_p64 + if (t.scalar_type() == kPackedFP4) + DG_HOST_ASSERT(gmem_dim_0 % 128 == 0); + + CUtensorMap tensor_map; + const cuuint64_t gmem_dims[3] = {static_cast(gmem_dim_0), static_cast(gmem_dim_1), static_cast(gmem_dim_2),}; + const cuuint32_t smem_dims[3] = {static_cast(smem_dim_0), static_cast(smem_dim_1), static_cast(smem_dim_2)}; + const cuuint64_t gmem_strides[2] = {static_cast(gmem_stride_0 * elem_size), static_cast(gmem_stride_1 * elem_size)}; + const cuuint32_t elem_strides[3] = {1, 1, 1}; + if (get_env("DG_JIT_DEBUG")) { + printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n", + gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2, + gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size); + } + DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled( + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32), + 3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, + CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), + CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensor_map; +} + +static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + const int& shape_m, const int& shape_k, + const int& block_m, const int& block_k, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + if (num_groups > 1) + DG_HOST_ASSERT(major == cute::UMMA::Major::K); + const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups); + const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m); + return make_tma_2d_desc(t, + gmem_inner_dim, gmem_outer_dim, + smem_inner_dim, smem_outer_dim, + outer_stride, + swizzle_mode, swizzle_base, + allow_tf32); +} + +static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + const int& shape_n, const int& shape_k, + const int& block_n, const int& block_k, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n); + const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n); + + // `num_groups` is always applied into the outer dimensions + return make_tma_2d_desc(t, + gmem_inner_dim, gmem_outer_dim * num_groups, + smem_inner_dim, smem_outer_dim, + outer_stride, + swizzle_mode, swizzle_base, + allow_tf32); +} + +static CUtensorMap make_tma_cd_desc(const torch::Tensor& t, + const int& shape_m, const int& shape_n, + const int& block_m, const int& block_n, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + // Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode` + // bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required + return make_tma_2d_desc(t, + shape_n, shape_m * num_groups, + block_n, block_m, + outer_stride, + swizzle_mode, swizzle_base, + allow_tf32); +} + +static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + int shape_mn, int shape_k, + const int& block_mn, const int& gran_k, + const int& num_groups, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + DG_HOST_ASSERT(major == cute::UMMA::Major::MN); + + // TODO: maybe swizzle SF as well + DG_HOST_ASSERT(swizzle_mode == 0); + + shape_mn = get_tma_aligned_size(shape_mn, static_cast(t.element_size())); + return make_tma_2d_desc(t, + shape_mn, ceil_div(shape_k, gran_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups, + block_mn, 1, + shape_mn, + swizzle_mode, swizzle_base, + allow_tf32); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp new file mode 100644 index 00000000..bca47a3a --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -0,0 +1,391 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BF16GemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_bf16_gemm_impl< + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {} + >); +}}; +)", + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.num_groups, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, + args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype), + args.gemm_config.tc_util); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_cd)); + } +}; + +static void sm100_bf16_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& config = get_best_config( + GemmType::Normal, KernelType::KernelNoSF, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_gemm", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m; + const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1; + + const auto& config = get_best_config( + gemm_type, KernelType::KernelNoSF, + // NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole + m_for_config, n, k, num_groups_for_config, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = grouped_layout.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::KernelNoSF, + expected_m, n, k, num_groups, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0; + for (const auto& k: ks) { + sum_k += k; + DG_HOST_ASSERT(k % 128 == 0); + } + const auto& num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto& max_k = *std::max_element(ks.begin(), ks.end()); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::KernelNoSF, + m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(0)), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(0)), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch kernel + const SM100BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_k_grouped_gemm", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto& config = get_best_config( + GemmType::Batched, KernelType::KernelNoSF, + b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, + device_runtime->get_num_sms()); + + const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, + config.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.smem_config.swizzle_a_mode); + const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.block_k, load_block_n, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.smem_config.swizzle_b_mode); + const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.smem_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = b, .n = d, .k = r, + .num_groups = h, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto& config = get_best_config( + GemmType::Batched, KernelType::KernelNoSF, + b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, + device_runtime->get_num_sms()); + + const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, + config.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.smem_config.swizzle_a_mode); + const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + load_block_n, config.block_k, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.smem_config.swizzle_b_mode); + const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.smem_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = b, .n = r, .k = d, + .num_groups = h, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp new file mode 100644 index 00000000..dc8766cc --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BmkBnkMnRuntime final: public LaunchRuntime { +public: + struct Args { + int s, m, n, k; + int block_m, block_n, block_k; + int split_factor; + int swizzle_ab_mode, swizzle_cd_mode; + int num_stages; + int num_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_bmn_bnk_mn_gemm_impl< + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, + {}, {} + >); +}}; +)", + args.m, args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.split_factor, + args.swizzle_ab_mode, args.swizzle_cd_mode, + args.num_stages, args.num_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.s, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d)); + } +}; + + +static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a, + const torch::Tensor &b, + const torch::Tensor &d, + const int &s, const int &m, const int &n, const int &k) { + constexpr int block_m = 128; + constexpr int block_n = 128; + constexpr int block_k = 64; + constexpr int num_threads = 128; + DG_HOST_ASSERT(k % block_k == 0); + DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0); + DG_HOST_ASSERT(static_cast(s) * static_cast(std::max(m, n)) <= std::numeric_limits::max()); + + const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast(a.element_size())); + const int swizzle_cd_mode = get_swizzle_mode(block_n, static_cast(d.element_size())); + + // Get best config + const int num_sms = device_runtime->get_num_sms(); + const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n); + const int num_sk_blocks = s * (k / block_k); + const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1)); + + // Select best number of stages + // NOTES: we select 4 as start, as it is tested to be faster than values > 4 + int num_stages = 4, smem_size = 0; + while (true) { + const int& smem_cd = block_m * swizzle_cd_mode * 2; + const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); + const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); + const int& smem_barrier = SM100ArchSpec::get_barrier_smem_size(num_stages); + const int& smem_tmem_ptr = SM100ArchSpec::get_tmem_ptr_smem_size(); + + smem_size = 0; + smem_size += smem_cd; + smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages; + smem_size += smem_barrier; + smem_size += smem_tmem_ptr; + if (smem_size <= SM100ArchSpec::smem_capacity) + break; + + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("S: %d, M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split-K factor: %d" + "stages: %d, shared memory: %d, swizzle AB: %d, swizzle CD: %d\n", + s, m, n, k, block_m, block_n, block_k, split_factor, + num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode); + } + + const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); + const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); + const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode); + + const SM100BmkBnkMnRuntime::Args& args = { + .s = s, .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .split_factor = split_factor, + .swizzle_ab_mode = swizzle_ab_mode, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_threads = num_threads, + .launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100BmkBnkMnRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code); + SM100BmkBnkMnRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp new file mode 100644 index 00000000..404369a4 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -0,0 +1,416 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" + +#include "epilogue.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + int gran_k_a, gran_k_b; + const std::string& compiled_dims; + const std::optional& epilogue_type; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_gemm_1d1d_impl< + {}, {}, + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {}, {}, {}, + {} + >); +}}; +)", + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + args.gran_k_a, args.gran_k_b, + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.num_groups, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, + args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, + to_string(args.gemm_config.a_dtype), to_string(args.gemm_config.b_dtype), to_string(args.gemm_config.cd_dtype), + get_default_epilogue_type(args.epilogue_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_cd)); + } +}; + +static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const auto& cd = c.value_or(d); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, static_cast(d.size(-1)), + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, gran_k_a, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, gran_k_b, 1, 0); + + // Launch + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .compiled_dims = compiled_dims, + .epilogue_type = epilogue_type, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m; + const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1; + + const auto& config = get_best_config( + gemm_type, KernelType::Kernel1D1D, + m_for_config, n, k, num_groups_for_config, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, gran_k_a, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, gran_k_b, num_groups, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = grouped_layout.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D1D, + expected_m, n, k, num_groups, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, gran_k_a, num_groups, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, gran_k_b, num_groups, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0, sum_sf_k = 0; + for (const auto& k: ks) { + sum_k += k, sum_sf_k += ceil_div(k, 512); + DG_HOST_ASSERT(k % 128 == 0); + } + const auto& num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto& max_k = *std::max_element(ks.begin(), ks.end()); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::Kernel1D1D, + m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(0)), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(0)), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512, + config.block_n, config.block_k, 1, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .gran_k_a = 128, + .gran_k_b = 128, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& batch_size, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& config = get_best_config( + GemmType::Batched, KernelType::Kernel1D1D, + m, n, k, batch_size, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& [inner_dim_a, outer_dim_a] = get_inner_outer_dims(major_a, k, m); + const auto& [inner_block_a, outer_block_a] = get_inner_outer_dims(major_a, config.block_k, load_block_m); + const auto& tensor_map_a = make_tma_3d_desc(a, inner_dim_a, outer_dim_a, batch_size, + inner_block_a, outer_block_a, 1, + a.stride(major_a == cute::UMMA::Major::K ? 1 : 2), + a.stride(0), + config.smem_config.swizzle_a_mode); + + const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& [inner_dim_b, outer_dim_b] = get_inner_outer_dims(major_b, k, n); + const auto& [inner_block_b, outer_block_b] = get_inner_outer_dims(major_b, config.block_k, load_block_n); + const auto& tensor_map_b = make_tma_3d_desc(b, inner_dim_b, outer_dim_b, batch_size, + inner_block_b, outer_block_b, 1, + b.stride(major_b == cute::UMMA::Major::K ? 1 : 2), + b.stride(0), + config.smem_config.swizzle_b_mode); + + const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.smem_config.swizzle_cd_mode); + + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, batch_size, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, batch_size, 0); + + // Launch + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = batch_size, + .gran_k_a = 128, + .gran_k_b = 128, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp b/deep-gemm/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp new file mode 100644 index 00000000..bdb5b11d --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp @@ -0,0 +1,149 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BF16HCPrenormGemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k; + int block_m, block_n, block_k; + int num_splits; + int swizzle_cd_mode; + int num_stages; + int num_mma_threads, num_cast_and_reduce_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + float* sqr_sum; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_tf32_hc_prenorm_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.num_splits, + args.swizzle_cd_mode, + args.num_stages, + args.num_mma_threads, args.num_cast_and_reduce_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum)); + } +}; + +static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const int& m, const int& n, const int& k, + const int& num_splits) { + constexpr int block_m = 64; + constexpr int block_k = 64; + constexpr int num_mma_threads = 128; + constexpr int num_cast_and_reduce_threads = 128; + + const int block_n = align(n, 16); + DG_HOST_ASSERT(n <= block_n); + DG_HOST_ASSERT(n <= 128 and n % 8 == 0); + DG_HOST_ASSERT(k % block_k == 0); + + const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) + : make_tma_3d_desc(d, n, m, num_splits, + block_n, block_m, 1, + static_cast(d.stride(-2)), + static_cast(d.stride(-3)), + swizzle_cd_mode); + + // Calculate stages + int num_stages = 12, smem_size = 0; + while (num_stages > 0) { + const int smem_a_per_stage = block_m * block_k * static_cast(sizeof(nv_bfloat16)); + const int smem_b_per_stage = block_n * block_k * static_cast(sizeof(float)); + const int smem_cd = block_m * swizzle_cd_mode; + const int smem_barriers = (num_stages * 4 + 1) * 8; + const int smem_tmem_ptr = 4; + smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages + + smem_cd + smem_barriers + smem_tmem_ptr; + + if (smem_size <= SM100ArchSpec::smem_capacity) + break; + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split K: %d" + "stages: %d, shared memory: %d, swizzle CD: %d\n", + m, n, k, block_m, block_n, block_k, num_splits, + num_stages, smem_size, swizzle_cd_mode); + } + + // Launch + const SM100BF16HCPrenormGemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .num_splits = num_splits, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_mma_threads = num_mma_threads, + .num_cast_and_reduce_threads = num_cast_and_reduce_threads, + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .sqr_sum = sqr_sum.data_ptr() + }; + const auto& code = SM100BF16HCPrenormGemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code); + SM100BF16HCPrenormGemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp new file mode 100644 index 00000000..6291d0d9 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -0,0 +1,390 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BF16GemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_bf16_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", + // TODO: add CD dtype + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, + to_string(args.gemm_config.cd_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_cd)); + } +}; + +static void sm90_bf16_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& config = get_best_config( + GemmType::Normal, KernelType::KernelNoSF, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_gemm", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::KernelNoSF, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::KernelNoSF, + expected_m, n, k, num_groups, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0; + for (const auto& k: ks) { + sum_k += k; + DG_HOST_ASSERT(k % 128 == 0); + } + const auto& num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto& max_k = *std::max_element(ks.begin(), ks.end()); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::KernelNoSF, + m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(0)), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(0)), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch kernel + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_k_grouped_gemm", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto& config = get_best_config( + GemmType::Batched, KernelType::KernelNoSF, + b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, + device_runtime->get_num_sms()); + + const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, + config.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.smem_config.swizzle_a_mode); + const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.block_k, load_block_n, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.smem_config.swizzle_b_mode); + const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.smem_config.swizzle_cd_mode); + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = b, .n = d, .k = r, + .num_groups = h, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto& config = get_best_config( + GemmType::Batched, KernelType::KernelNoSF, + b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, + device_runtime->get_num_sms()); + + const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, + config.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.smem_config.swizzle_a_mode); + const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + load_block_n, config.block_k, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.smem_config.swizzle_b_mode); + const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.smem_config.swizzle_cd_mode); + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = b, .n = r, .k = d, + .num_groups = h, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp new file mode 100644 index 00000000..8441e997 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp @@ -0,0 +1,131 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BmkBnkMnRuntime final: public LaunchRuntime { +public: + struct Args { + int s, m, n, k; + int block_m, block_n, block_k; + int split_factor; + int num_stages; + int num_tma_threads, num_math_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + float* d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_bmn_bnk_mn_gemm_impl< + {}, {}, {}, + {}, {}, {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.m, args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.split_factor, + args.num_stages, + args.num_tma_threads, args.num_math_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.s, args.tensor_map_a, args.tensor_map_b, args.d)); + } +}; + + +static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a, + const torch::Tensor &b, + const torch::Tensor &d, + const int &s, const int &m, const int &n, const int &k) { + constexpr int block_m = 128; + constexpr int block_n = 128; + constexpr int block_k = 64; + constexpr int num_tma_threads = 128; + constexpr int num_math_threads = 256; + DG_HOST_ASSERT(k % block_k == 0); + DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0); + DG_HOST_ASSERT(static_cast(s) * static_cast(std::max(m, n)) <= std::numeric_limits::max()); + + const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast(a.element_size())); + DG_HOST_ASSERT(swizzle_ab_mode == 128); + + // Get best config + const int num_sms = device_runtime->get_num_sms(); + const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n); + const int num_sk_blocks = s * (k / block_k); + const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1)); + + // Select best number of stages + int num_stages = 4, smem_size = 0; + while (true) { + const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); + const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); + const int& smem_barrier = SM90ArchSpec::get_barrier_smem_size(num_stages); + + smem_size = 0; + smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages; + smem_size += smem_barrier; + + if (smem_size <= SM90ArchSpec::smem_capacity) + break; + + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("S: %d, M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split-K factor: %d" + "stages: %d, shared memory: %d, swizzle AB: %d\n", + s, m, n, k, block_m, block_n, block_k, split_factor, + num_stages, smem_size, swizzle_ab_mode); + } + + const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); + const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); + + const SM90BmkBnkMnRuntime::Args& args = { + .s = s, .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .split_factor = split_factor, + .num_stages = num_stages, + .num_tma_threads = num_tma_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .d = d.data_ptr() + }; + const auto& code = SM90BmkBnkMnRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code); + SM90BmkBnkMnRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp new file mode 100644 index 00000000..002b3873 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -0,0 +1,218 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *gmem_a_ptr; + void *gmem_b_ptr; + void *grouped_layout; + void *tensor_map_buffer; + CUtensorMap tensor_map_a_base; + CUtensorMap tensor_map_b_base; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_gemm_1d1d_impl< + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {} + >); +}}; +)", + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, + args.gemm_config.num_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), + to_string(args.gemm_config.cd_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.gmem_a_ptr, args.gmem_b_ptr, + args.grouped_layout, + args.tensor_map_buffer, + args.m, args.n, args.k, + args.tensor_map_a_base, args.tensor_map_b_base, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_cd)); + } +}; + +static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, k, 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, k, 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, 1, 0); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m, true), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + 0); + + // Launch + const SM90FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .gmem_a_ptr = nullptr, + .gmem_b_ptr = nullptr, + .grouped_layout = nullptr, + .tensor_map_buffer = nullptr, + .tensor_map_a_base = tensor_map_a, + .tensor_map_b_base = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); + + SM90FP8Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const torch::Tensor& tensor_map_buffer, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + // Get config using max K for better performance + const auto& num_groups = static_cast(ks.size()); + const auto& max_k = *std::max_element(ks.begin(), ks.end()); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::Kernel1D1D, + m, n, max_k, num_groups, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + + int first_k = 0, sum_k = 0, sum_sf_k = 0; + for (int i = 0; i < num_groups; ++ i) { + if (first_k == 0 and ks[i] != 0) + first_k = ks[i]; + sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128); + DG_HOST_ASSERT(ks[i] % 128 == 0); + } + const auto& tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, first_k, 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, first_k, 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128, + config.block_n, config.block_k, 1, 0); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m, true), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .gmem_a_ptr = a.data_ptr(), + .gmem_b_ptr = b.data_ptr(), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_buffer = tensor_map_buffer.data_ptr(), + .tensor_map_a_base = tensor_map_a_base, + .tensor_map_b_base = tensor_map_b_base, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); + + SM90FP8Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp new file mode 100644 index 00000000..b29017f8 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -0,0 +1,331 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" + +#include "epilogue.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime { +public: + struct Args { + cute::UMMA::Major major_sfb; + int m, n, k, num_groups; + const std::string& compiled_dims; + const std::optional& epilogue_type; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *sfb, *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + CUtensorMap tensor_map_sfa; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_gemm_1d2d_impl< + {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, {}, + {} + >); +}}; +)", + // TODO: add CD dtype + to_string(args.major_sfb), + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), + get_default_epilogue_type(args.epilogue_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sfb, args.grouped_layout, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d, args.tensor_map_sfa)); + } +}; + +static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { + DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D2D, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast(d.size(-1)), + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .major_sfb = major_sfb, + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .epilogue_type = epilogue_type, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::Kernel1D2D, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .major_sfb = major_sfb, + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D2D, + expected_m, n, k, num_groups, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, num_groups, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .major_sfb = major_sfb, + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& batch_size, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::Batched, KernelType::Kernel1D2D, + m, n, k, batch_size, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(a, k, m, batch_size, + config.block_k, load_block_m, 1, + a.stride(1), + a.stride(0), + config.smem_config.swizzle_a_mode); + + const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(b, k, n, batch_size, + config.block_k, load_block_n, 1, + b.stride(1), + b.stride(0), + config.smem_config.swizzle_b_mode); + + const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_d = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.smem_config.swizzle_cd_mode); + + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, batch_size, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .major_sfb = major_sfb, + .m = m, .n = n, .k = k, + .num_groups = batch_size, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp b/deep-gemm/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp new file mode 100644 index 00000000..63a47c32 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BF16HCPrenormGemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k; + int block_m, block_n, block_k; + int num_splits; + int swizzle_cd_mode; + int num_stages; + int num_math_threads, num_tma_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + float* sqr_sum; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_tf32_hc_prenorm_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.num_splits, + args.swizzle_cd_mode, + args.num_stages, + args.num_math_threads, args.num_tma_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum)); + } +}; + +static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const int& m, const int& n, const int& k, + const int& num_splits) { + constexpr int block_m = 64; + constexpr int block_k = 64; + constexpr int num_math_threads = 128; + constexpr int num_tma_threads = 128; + constexpr int num_threads = num_math_threads + num_tma_threads; + + const int block_n = align(n, 16); + DG_HOST_ASSERT(n <= block_n); + // Only support small N for now + DG_HOST_ASSERT(n <= 32 and n % 8 == 0); + DG_HOST_ASSERT(k % block_k == 0); + + const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) + : make_tma_3d_desc(d, n, m, num_splits, + block_n, block_m, 1, + static_cast(d.stride(-2)), + static_cast(d.stride(-3)), + swizzle_cd_mode); + + // Calculate stages + int num_stages = 12, smem_size = 0; + while (num_stages > 0) { + const int smem_a_per_stage = block_m * block_k * static_cast(sizeof(nv_bfloat16)); + const int smem_b_per_stage = block_n * block_k * static_cast(sizeof(float)); + const int smem_cd = block_m * swizzle_cd_mode; + const int smem_barriers = num_stages * 2 * 8; + smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages + + smem_cd + smem_barriers; + + if (smem_size <= SM90ArchSpec::smem_capacity) + break; + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split K: %d" + "stages: %d, shared memory: %d, swizzle CD: %d\n", + m, n, k, block_m, block_n, block_k, num_splits, + num_stages, smem_size, swizzle_cd_mode); + } + + smem_size = SM90ArchSpec::smem_capacity; + + // Launch + const SM90BF16HCPrenormGemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .num_splits = num_splits, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_math_threads = num_math_threads, + .num_tma_threads = num_tma_threads, + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .sqr_sum = sqr_sum.data_ptr() + }; + const auto& code = SM90BF16HCPrenormGemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code); + SM90BF16HCPrenormGemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_clean_logits.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_clean_logits.hpp new file mode 100644 index 00000000..fdb91a03 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_clean_logits.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +class SMXXCleanLogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int next_n; + int seq_len; + int seq_len_kv; + uint64_t stride_logits; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + float* logits; + + int block_kv; + int num_warps; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&smxx_clean_logits< + {}, {}, {} + >); +}}; +)", args.next_n, args.block_kv, args.num_warps); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, static_cast(args.stride_logits), + args.cu_seq_len_k_start, args.cu_seq_len_k_end, args.logits + )); + } +}; + +static void smxx_clean_logits(const torch::Tensor& logits, + const std::optional& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const int& next_n, + const int& seq_len, const int& seq_len_kv, + const uint64_t &stride_logits) { + const int block_kv = 8192; + const int num_warps = 8; + const int smem_size = block_kv * sizeof(float); + + // Launch + const SMXXCleanLogitsRuntime::Args& args = { + .next_n = next_n, + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .stride_logits = stride_logits, + .cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr() : nullptr, + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .block_kv = block_kv, + .num_warps = num_warps, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_warps * 32, smem_size) + }; + const auto& code = SMXXCleanLogitsRuntime::generate(args); + const auto& runtime = compiler->build("smxx_clean_logits", code); + SMXXCleanLogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_cublaslt.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_cublaslt.hpp new file mode 100644 index 00000000..dc20e334 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_cublaslt.hpp @@ -0,0 +1,162 @@ +#pragma once + +#include +#include +#include +#include + +#include "../../jit/device_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/compatibility.hpp" + +namespace deep_gemm { + +static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld, + const std::optional& batch_count = std::nullopt, + const std::optional& batch_offset = std::nullopt) { + cublasLtMatrixLayout_t layout; + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld)); + if (batch_count.has_value()) { + DG_HOST_ASSERT(batch_offset.has_value()); + + const int64_t batch_offset_int64 = batch_offset.value(); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value()))); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64))); + } + return layout; +} + +static void call_cublaslt_api(const cublasOperation_t& trans_a, + const cublasOperation_t& trans_b, + const cublasLtMatrixLayout_t& layout_a, + const cublasLtMatrixLayout_t& layout_b, + const cublasLtMatrixLayout_t& layout_d, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const bool& accumulate) { + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + cudaDataType_t scale_type = CUDA_R_32F; + + // Operation description + cublasLtMatmulDesc_t desc; + DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type)); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a))); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b))); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + +#if DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE + const int& math_sms = device_runtime->get_num_sms(); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms))); +#endif + +#if DG_FP8_COMPATIBLE and DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE + bool fp8_fast_accumulate = false; + if (a.scalar_type() == torch::kFloat8_e4m3fn) + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate))); +#endif + + // Get cuBLASLt handle, workspace, and stream + const auto& handle = device_runtime->get_cublaslt_handle(); + const auto& workspace = device_runtime->get_cublaslt_workspace(); + const auto& workspace_bytes = workspace.nbytes(); + const auto& stream = at::cuda::getCurrentCUDAStream(); + + // Algorithm selection + cublasLtMatmulPreference_t pref; + cublasLtMatmulHeuristicResult_t heuristic; + int num_heuristic_results = 0; + uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE; + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref)); + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_bytes, sizeof(workspace_bytes))); + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, + &reduction_scheme_mask, sizeof(reduction_scheme_mask))); + DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d, + pref, 1, &heuristic, &num_heuristic_results)); + DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM"); + + // Call: D = alpha * (A @ B) + beta * C + const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0; + DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle + desc, // Operation description + &alpha, // Alpha + b.data_ptr(), layout_a, // A + a.data_ptr(), layout_b, // B + &beta, // Beta + d.data_ptr(), layout_d, // C + d.data_ptr(), layout_d, // D + &heuristic.algo, // Algorithm + workspace.data_ptr(), workspace_bytes, // Workspace + stream)); // Stream + + // Free memory + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref)); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a)); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b)); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d)); + DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc)); +} + +static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs, + const std::optional& acc, + const torch::Tensor& out, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) { + const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N; + const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T; + + // Duplicate the accumulator if necessary + // TODO: remove this + if (acc.has_value()) { + if (acc->data_ptr() == out.data_ptr()) { + DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides()); + } else { + out.copy_(acc.value()); + } + } + + // Matrix layouts + const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type()); + const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type()); + const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type()); + const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0)) + : get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1)); + const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0)) + : get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1)); + const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value()); +} + + +static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out, + const int& b, const int& h, const int& r, const int& d) { + const auto& m = d, n = b, k = r; + const auto& trans_a = CUBLAS_OP_T; + const auto& trans_b = CUBLAS_OP_N; + + // Matrix layouts + const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0)); + const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); + const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false); +} + + +static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out, + const int& b, const int& h, const int& r, const int& d) { + const auto& m = r, n = b, k = d; + const auto& trans_a = CUBLAS_OP_N; + const auto& trans_b = CUBLAS_OP_N; + + // Matrix layouts + const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0)); + const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); + const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp new file mode 100644 index 00000000..f3b82e3d --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp @@ -0,0 +1,164 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SMXXFP8MQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int seq_len; + int seq_len_kv; + int max_seqlen_k; + int stride_logits; + int num_heads, head_dim; + bool is_compressed_logits; + + int num_q_stages; + int num_kv_stages; + int block_q; + int block_kv; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + float* logits; + float softmax_scale; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto& arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_mqa_logits< + {}, {}, + {}, + {}, {}, + {}, {}, + {}, {} + >); +}}; +)", arch, arch, + args.num_heads, args.head_dim, + args.is_compressed_logits, + args.block_q, args.block_kv, + args.num_q_stages, args.num_kv_stages, + args.num_specialized_threads, args.num_math_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, + args.max_seqlen_k, static_cast(args.stride_logits), + args.cu_seq_len_k_start, args.cu_seq_len_k_end, + args.logits, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv, const torch::Tensor& kv_scales, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const torch::Tensor& logits, + const int& seq_len, const int& seq_len_kv, + const int& max_seqlen_k, const int& stride_logits, + const int& num_heads, const int& head_dim, + const int& seq_len_alignment) { + constexpr int block_qh = 128; + constexpr int block_kv = 256; + constexpr int num_specialized_threads = 128; + constexpr int num_q_stages = 3, num_kv_stages = 3; + const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512); + const int block_q = block_qh / num_heads; + DG_HOST_ASSERT(block_qh % num_heads == 0); + DG_HOST_ASSERT(seq_len_alignment % block_q == 0); + + // Use compressed logits format when max_seqlen_k is specified + const bool is_compressed_logits = (max_seqlen_k > 0); + + // Construct TMAs + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads, + head_dim, block_qh, head_dim, head_dim); + const auto& tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv, + head_dim, block_kv, head_dim, head_dim); + // According to the driver API, the minimal alignment is 256 bytes + // So it is safe for us to do a 16-byte OOB + const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_scales, + get_tma_aligned_size(seq_len_kv, static_cast(kv_scales.element_size())), + 1, block_kv, 1, 0, 0); + const auto& tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len, + num_heads, block_q, num_heads, 0); + + // Calculate shared memory size + int smem_size = 0; + const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast(q.element_size()); + const int smem_weight_size_per_stage = block_q * num_heads * static_cast(weights.element_size()); + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv.element_size()); + const int kv_scale_size_per_stage = block_kv * static_cast(kv_scales.element_size()); + smem_size += num_q_stages * smem_q_size_per_stage; + smem_size += num_kv_stages * smem_kv_size_per_stage; + smem_size += num_q_stages * smem_weight_size_per_stage; + smem_size += num_kv_stages * kv_scale_size_per_stage; + smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8; + smem_size += 4; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SMXXFP8MQALogitsRuntime::Args& args = { + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .max_seqlen_k = max_seqlen_k, + .stride_logits = stride_logits, + .num_heads = num_heads, .head_dim = head_dim, + .is_compressed_logits = is_compressed_logits, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .block_q = block_q, + .block_kv = block_kv, + .cu_seq_len_k_start = cu_seq_len_k_start.data_ptr(), + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto& code = SMXXFP8MQALogitsRuntime::generate(args); + const auto& runtime = compiler->build("smxx_fp8_mqa_logits", code); + SMXXFP8MQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp new file mode 100644 index 00000000..1240aad8 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp @@ -0,0 +1,265 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime { +public: + struct Args { + int aligned_batch_size; + int split_kv; + int num_sms; + + int batch_size; + int next_n; + bool is_context_lens_2d; + int* context_lens; + int* schedule_metadata; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + const auto& arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&smxx_paged_mqa_logits_metadata< + {}, {}, {} + >); +}}; +)", arch, args.aligned_batch_size, args.split_kv, args.num_sms); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + args.next_n, + args.is_context_lens_2d, + args.context_lens, + args.schedule_metadata + )); + } +}; + +static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, + const torch::Tensor& schedule_metadata, + const int& batch_size, const int& next_n, + const int& block_kv, const int& num_sms, + const bool& is_context_lens_2d) { + constexpr int num_math_warpgroups = 4; + constexpr int num_threads = 32; + const int aligned_batch_size = align(batch_size, 32); + const int split_kv = block_kv * num_math_warpgroups; + + // Calculate shared memory size + const int smem_size = aligned_batch_size * static_cast(sizeof(int)); + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SMXXPagedMQALogitsMetadataRuntime::Args& args = { + .aligned_batch_size = aligned_batch_size, + .split_kv = split_kv, + .num_sms = num_sms, + .batch_size = batch_size, + .next_n = next_n, + .is_context_lens_2d = is_context_lens_2d, + .context_lens = context_lens.data_ptr(), + .schedule_metadata = schedule_metadata.data_ptr(), + .launch_args = LaunchArgs(1, num_threads, smem_size) + }; + const auto& code = SMXXPagedMQALogitsMetadataRuntime::generate(args); + const auto& runtime = compiler->build("smxx_paged_mqa_logits_metadata", code); + SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args); +} + +class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int batch_size; + int next_n; + int num_heads; + int head_dim; + int block_kv; + bool is_context_lens_2d; + int block_table_stride; + int logits_stride; + + int num_q_stages; + int num_kv_stages; + int split_kv; + + int* context_lens; + float* logits; + int* block_table; + int* schedule_meta; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto& arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_paged_mqa_logits< + {}, {}, + {}, {}, + {}, + {}, {}, + {}, + {}, {} + >); +}}; +)", arch, arch, + args.next_n, args.num_heads, + args.head_dim, args.block_kv, + args.is_context_lens_2d, + args.num_q_stages, args.num_kv_stages, + args.split_kv, + args.num_specialized_threads, args.num_math_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + static_cast(args.logits_stride), + static_cast(args.block_table_stride), + args.context_lens, args.logits, + args.block_table, args.schedule_meta, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv_cache, + const torch::Tensor& kv_cache_scales, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& logits, + const torch::Tensor& block_table, + const torch::Tensor& schedule_meta, + const int& batch_size, const int& next_n, + const int& num_heads, const int& head_dim, + const int& num_kv_blocks, const int& block_kv, + const bool& is_context_lens_2d, + const int& kv_cache_stride_bytes, + const int& logits_stride, + const int& block_table_stride, + const int& num_sms, + const int& split_kv) { + const int num_specialized_threads = 128; + const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64); + const int num_math_warp_groups = split_kv / mma_m; + const int num_math_threads = num_math_warp_groups * 128; + const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3); + DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0); + + // Construct TMAs + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads, + head_dim, next_n * num_heads, head_dim, head_dim); + const auto& tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks, + head_dim, block_kv, 1, + head_dim, kv_cache_stride_bytes, head_dim); + // TODO: use 1D TMA + const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks, + block_kv, 1, kv_cache_stride_bytes / static_cast(sizeof(float)), 0); + const auto& tensor_map_weights = make_tma_2d_desc(weights, next_n * num_heads, batch_size, + next_n * num_heads, 1, next_n * num_heads, 0); + + // Calculate shared memory size + int smem_size = 0; + if (device_runtime->get_arch_major() == 9) { + const int swizzle_alignment = head_dim * 8; + + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); + const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); + + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); + const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); + const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); + + // Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90 + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + } else { + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int smem_kv_size_per_stage = split_kv * head_dim * static_cast(kv_cache.element_size()); + const int smem_kv_scale_size_per_stage = split_kv * static_cast(kv_cache_scales.element_size()); + const int smem_weight_size_per_stage = next_n * num_heads * static_cast(weights.element_size()); + + const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8; + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) + + smem_barriers + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + } + + // Launch + const SMXXFP8PagedMQALogitsRuntime::Args& args = { + .batch_size = batch_size, + .next_n = next_n, + .num_heads = num_heads, + .head_dim = head_dim, + .block_kv = block_kv, + .is_context_lens_2d = is_context_lens_2d, + .block_table_stride = block_table_stride, + .logits_stride = logits_stride, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .split_kv = split_kv, + .context_lens = context_lens.data_ptr(), + .logits = logits.data_ptr(), + .block_table = block_table.data_ptr(), + .schedule_meta = schedule_meta.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_sms, + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args); + const auto& runtime = compiler->build("smxx_fp8_paged_mqa_logits", code); + SMXXFP8PagedMQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/jit_kernels/impls/smxx_layout.hpp b/deep-gemm/csrc/jit_kernels/impls/smxx_layout.hpp new file mode 100644 index 00000000..0b9eebd7 --- /dev/null +++ b/deep-gemm/csrc/jit_kernels/impls/smxx_layout.hpp @@ -0,0 +1,264 @@ +#pragma once + +#include + +#include "../../jit/kernel_runtime.hpp" +#include "../../jit/compiler.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../../utils/layout.hpp" + +namespace deep_gemm { + +class TransposeFP32Runtime final: public LaunchRuntime { +public: + struct Args { + int mn, sf_k; + int block_mn; + void *sf, *out; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&transpose_fp32< + {}, {}, {} + >); +}}; +)", args.launch_args.num_threads, args.block_mn, args.sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast(args.mn))); + } +}; + +class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime { +public: + struct Args { + int mn, sf_k; + int block_mn; + void *sf, *out; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&transpose_and_pack_fp32_into_ue8m0< + {}, {}, {} + >); +}}; +)", args.launch_args.num_threads, args.block_mn, args.sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast(args.mn))); + } +}; + +class PackFP32IntoUE8M0Runtime final: public LaunchRuntime { +public: + struct Args { + int num_groups, mn, sf_k, packed_sf_k; + int block_mn, block_packed_sf_k; + void *sf, *out, *ks; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&pack_fp32_into_ue8m0< + {}, {}, {}, {} + >); +}}; +)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k)); + } +}; + +static std::tuple preprocess_sf(const torch::Tensor& sf) { + // NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + const auto& dim = sf.dim(); + DG_HOST_ASSERT(dim == 2 or dim == 3); + DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat); + const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf; + + const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf); + const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast(sf.element_size())); + return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf}; +} + +static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { + const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + + // The last kernel already gives a column-major TMA aligned layout + if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn) + return (dim == 2) ? batched_sf.squeeze(0) : batched_sf; + + const auto& out = torch::empty_strided({num_groups, mn, sf_k}, + {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, + batched_sf.options()); + + if (not batched_sf.is_contiguous()) { + // Fallback to PyTorch's slow copy if not contiguous + // ReSharper disable once CppExpressionWithoutSideEffects + out.copy_(batched_sf); + } else { + constexpr int block_mn = 64; + constexpr int num_threads = 512; + const auto& smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast(sizeof(float)); + const TransposeFP32Runtime::Args& args = { + .mn = mn, + .sf_k = sf_k, + .block_mn = block_mn, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size) + }; + + const auto& code = TransposeFP32Runtime::generate(args); + const auto& runtime = compiler->build("transpose_fp32", code); + TransposeFP32Runtime::launch(runtime, args); + } + return (dim == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) { + const auto& sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf; + + // First, convert into UE8M0 `uint8_t` + const auto& ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8); + + // Second, make padded packed tensors + const auto& [num_groups, mn, k] = get_shape<3>(sf_reshaped); + const auto& aligned_mn = get_tma_aligned_size(mn, 4); + const auto& aligned_k = align(k, 4); + + const auto& options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8); + auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options); + // ReSharper disable once CppExpressionWithoutSideEffects + padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor); + padded = padded.view(-1).view(torch::kInt32).view({num_groups, aligned_mn, aligned_k / 4}); + + // Finally, transpose + auto out = torch::empty_strided({num_groups, aligned_mn, aligned_k / 4}, + {aligned_mn * (aligned_k / 4), 1, aligned_mn}, + at::TensorOptions().device(sf.device()).dtype(torch::kInt32)); + out = out.copy_(padded).slice(1, 0, mn); + return (sf.dim() == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) { + const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + const auto& packed_sf_k = ceil_div(sf_k, 4); + const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k}, + {packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn}, + at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt)); + // Launch the kernel + if (batched_sf.is_contiguous()) { + if ((mn * sf_k) % 4 != 0 and num_groups > 1) + return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); + + constexpr int block_mn = 48; + constexpr int num_threads = 512; + const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = { + .mn = mn, + .sf_k = sf_k, + .block_mn = block_mn, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4) + }; + + const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args); + const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code); + TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args); + } else { + if (mn % 4 != 0 or num_groups > 1) + return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); + DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn); + + constexpr int block_mn = 128; + constexpr int block_packed_sf_k = 16; + constexpr int num_threads = 512; + const PackFP32IntoUE8M0Runtime::Args& args = { + .num_groups = 1, + .mn = mn, + .sf_k = sf_k, + .packed_sf_k = packed_sf_k, + .block_mn = block_mn, + .block_packed_sf_k = block_packed_sf_k, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .ks = nullptr, + .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + }; + + const auto& code = PackFP32IntoUE8M0Runtime::generate(args); + const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code); + PackFP32IntoUE8M0Runtime::launch(runtime, args); + } + return (dim == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf, + const torch::Tensor& ks_tensor, + const std::vector& ks) { + const auto& [sf_k, mn] = get_shape<2>(sf); + const auto& num_groups = static_cast(ks.size()); + + int ref_sf_k = 0, packed_sf_k = 0; + for (const auto& k: ks) + ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512); + DG_HOST_ASSERT(sf.is_contiguous()); + DG_HOST_ASSERT(ref_sf_k == sf_k); + DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0); + + const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt)); + + constexpr int block_mn = 128; + constexpr int block_packed_sf_k = 16; + constexpr int num_threads = 512; + const PackFP32IntoUE8M0Runtime::Args& args = { + .num_groups = num_groups, + .mn = mn, + .sf_k = sf_k, + .packed_sf_k = packed_sf_k, + .block_mn = block_mn, + .block_packed_sf_k = block_packed_sf_k, + .sf = sf.data_ptr(), + .out = out.data_ptr(), + .ks = ks_tensor.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + }; + + const auto& code = PackFP32IntoUE8M0Runtime::generate(args); + const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code); + PackFP32IntoUE8M0Runtime::launch(runtime, args); + return out; +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/python_api.cpp b/deep-gemm/csrc/python_api.cpp new file mode 100644 index 00000000..0354f1f8 --- /dev/null +++ b/deep-gemm/csrc/python_api.cpp @@ -0,0 +1,26 @@ +#include +#include + +#include "apis/attention.hpp" +#include "apis/einsum.hpp" +#include "apis/hyperconnection.hpp" +#include "apis/gemm.hpp" +#include "apis/layout.hpp" +#include "apis/runtime.hpp" + +#ifndef TORCH_EXTENSION_NAME +#define TORCH_EXTENSION_NAME _C +#endif + +// ReSharper disable once CppParameterMayBeConstPtrOrRef +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "DeepGEMM C++ library"; + + // TODO: make SM80 incompatible issues raise errors + deep_gemm::attention::register_apis(m); + deep_gemm::einsum::register_apis(m); + deep_gemm::hyperconnection::register_apis(m); + deep_gemm::gemm::register_apis(m); + deep_gemm::layout::register_apis(m); + deep_gemm::runtime::register_apis(m); +} diff --git a/deep-gemm/csrc/utils/compatibility.hpp b/deep-gemm/csrc/utils/compatibility.hpp new file mode 100644 index 00000000..9e2d6720 --- /dev/null +++ b/deep-gemm/csrc/utils/compatibility.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +// `torch::kFloat8_e4m3fn` is supported since PyTorch 2.1 +#define DG_FP8_COMPATIBLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 1)) + +// `cuTensorMapEncodeTiled` is supported since CUDA Driver API 12.1 +#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010) + +// `cublasGetErrorString` is supported since CUDA Runtime API 11.4.2 +#define DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE (CUDART_VERSION >= 11042) + +// `CUBLASLT_MATMUL_DESC_FAST_ACCUM` and `CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET` are supported since CUDA Runtime API 11.8 +#define DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE (CUDART_VERSION >= 11080) \ No newline at end of file diff --git a/deep-gemm/csrc/utils/exception.hpp b/deep-gemm/csrc/utils/exception.hpp new file mode 100644 index 00000000..2aa27066 --- /dev/null +++ b/deep-gemm/csrc/utils/exception.hpp @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include + +#include "compatibility.hpp" + +namespace deep_gemm { + +class DGException final : public std::exception { + std::string message = {}; + +public: + explicit DGException(const char *name, const char* file, const int line, const std::string& error) { + message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error; + } + + const char *what() const noexcept override { + return message.c_str(); + } +}; + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +#ifndef DG_HOST_ASSERT +#define DG_HOST_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + throw DGException("Assertion", __FILE__, __LINE__, #cond); \ + } \ +} while (0) +#endif + +#ifndef DG_HOST_UNREACHABLE +#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason)) +#endif + +#ifndef DG_NVRTC_CHECK +#define DG_NVRTC_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != NVRTC_SUCCESS) { \ + throw DGException("NVRTC", __FILE__, __LINE__, nvrtcGetErrorString(e)); \ + } \ +} while (0) +#endif + +#ifndef DG_CUDA_DRIVER_CHECK +#define DG_CUDA_DRIVER_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + std::stringstream ss; \ + const char *name, *info; \ + lazy_cuGetErrorName(e, &name), lazy_cuGetErrorString(e, &info); \ + ss << static_cast(e) << " (" << name << ", " << info << ")"; \ + throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \ + } \ +} while (0) +#endif + +#ifndef DG_CUDA_RUNTIME_CHECK +#define DG_CUDA_RUNTIME_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != cudaSuccess) { \ + std::stringstream ss; \ + ss << static_cast(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \ + throw DGException("CUDA runtime", __FILE__, __LINE__, ss.str()); \ + } \ +} while (0) +#endif + +#ifndef DG_CUBLASLT_CHECK + +#if !DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE +inline const char* cublasGetStatusString(cublasStatus_t status) { + switch(status) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + default: return "Unknown cuBLAS error"; + } +} +#endif + +#define DG_CUBLASLT_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != CUBLAS_STATUS_SUCCESS) { \ + std::ostringstream ss; \ + ss << static_cast(e) << " (" << cublasGetStatusString(e) << ")"; \ + throw DGException("cuBLASLt", __FILE__, __LINE__, ss.str()); \ + } \ +} while (0) +#endif + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/utils/format.hpp b/deep-gemm/csrc/utils/format.hpp new file mode 100644 index 00000000..b89f9c83 --- /dev/null +++ b/deep-gemm/csrc/utils/format.hpp @@ -0,0 +1,103 @@ +#pragma once + +// Minimal fmt::format shim — supports only "{}" placeholders and "{{" / "}}" +// escapes. This covers all usage in DeepGEMM and avoids depending on libfmt. +// +// Uses std::string concatenation instead of std::ostringstream to avoid +// potential locale/ABI issues with ostringstream across different platforms. + +#include +#include +#include +#include +#include + +namespace fmt { + +namespace detail { + +// Convert value to string — specializations for common types +template +inline std::string to_str(const T& v) { + if constexpr (std::is_same_v) { + return v; + } else if constexpr (std::is_same_v || std::is_same_v) { + return std::string(v); + } else if constexpr (std::is_same_v) { + return std::string(v); + } else if constexpr (std::is_same_v) { + return std::string(1, v); + } else if constexpr (std::is_same_v, std::filesystem::path>) { + return v.string(); + } else if constexpr (std::is_integral_v) { + return std::to_string(v); + } else if constexpr (std::is_floating_point_v) { + return std::to_string(v); + } else { + // Fallback for other types with operator<< + std::ostringstream os; + os << v; + return os.str(); + } +} + +// Overload for C string literals (arrays) +template +inline std::string to_str(const char (&s)[N]) { + return std::string(s, N - 1); +} + +inline std::string format_impl(std::string_view fmt) { + std::string result; + result.reserve(fmt.size()); + size_t i = 0; + while (i < fmt.size()) { + if (fmt[i] == '{' && i + 1 < fmt.size() && fmt[i + 1] == '{') { + result += '{'; + i += 2; + } else if (fmt[i] == '}' && i + 1 < fmt.size() && fmt[i + 1] == '}') { + result += '}'; + i += 2; + } else { + result += fmt[i++]; + } + } + return result; +} + +template +std::string format_impl(std::string_view fmt, + const T& first, const Args&... rest) { + std::string result; + result.reserve(fmt.size()); + size_t i = 0; + while (i < fmt.size()) { + if (fmt[i] == '{') { + if (i + 1 < fmt.size() && fmt[i + 1] == '{') { + result += '{'; + i += 2; + } else if (i + 1 < fmt.size() && fmt[i + 1] == '}') { + result += to_str(first); + result += format_impl(fmt.substr(i + 2), rest...); + return result; + } else { + result += fmt[i++]; + } + } else if (fmt[i] == '}' && i + 1 < fmt.size() && fmt[i + 1] == '}') { + result += '}'; + i += 2; + } else { + result += fmt[i++]; + } + } + return result; +} + +} // namespace detail + +template +std::string format(std::string_view fmt, const Args&... args) { + return detail::format_impl(fmt, args...); +} + +} // namespace fmt diff --git a/deep-gemm/csrc/utils/hash.hpp b/deep-gemm/csrc/utils/hash.hpp new file mode 100644 index 00000000..ff36ef39 --- /dev/null +++ b/deep-gemm/csrc/utils/hash.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm { + +static uint64_t fnv1a(const std::vector& data, const uint64_t& seed) { + uint64_t h = seed; + const uint64_t& prime = 0x100000001b3ull; + for (const char& c: data) { + h ^= static_cast(c); + h *= prime; + } + return h; +} + +static std::string get_hex_digest(const std::vector& data) { + const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull); + const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull); + + // Split-mix 64 + const auto& split_mix = [](uint64_t z) { + z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull; + z = (z ^ (z >> 27)) * 0x94d049bb133111ebull; + return z ^ (z >> 31); + }; + + // Use snprintf instead of ostringstream + char buf[64]; + snprintf(buf, sizeof(buf), "%016lx%016lx", + (unsigned long)split_mix(state_0), + (unsigned long)split_mix(state_1)); + return std::string(buf); +} + +static std::string get_hex_digest(const std::string& data) { + return get_hex_digest(std::vector{data.begin(), data.end()}); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/utils/layout.hpp b/deep-gemm/csrc/utils/layout.hpp new file mode 100644 index 00000000..c9ac9514 --- /dev/null +++ b/deep-gemm/csrc/utils/layout.hpp @@ -0,0 +1,124 @@ +#pragma once + +#include +#include + +#include "math.hpp" +#include "exception.hpp" +#include "../jit/device_runtime.hpp" + +namespace deep_gemm { + +// Major-ness stuffs +static void major_check(const torch::Tensor& t) { + const auto dim = t.dim(); + DG_HOST_ASSERT(dim == 2 or dim == 3); + if (dim == 3) + DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1)); + DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1); +} + +static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) { + major_check(t); + return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; +} + +static void check_major_type_cd(const torch::Tensor& t) { + // NOTES: the library only supports row-major output layouts + major_check(t); + DG_HOST_ASSERT(t.stride(-1) == 1); +} + +static bool fp8_requires_k_major() { + return device_runtime->get_arch_major() == 9; +} + +// Tensor utils +template +static auto get_shape(const torch::Tensor& t) { + DG_HOST_ASSERT(t.dim() == N); + return [&t] (std::index_sequence) { + return std::make_tuple(static_cast(t.sizes()[Is])...); + }(std::make_index_sequence()); +} + +static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { + auto [mn, k] = get_shape<2>(ab); + if (ab.scalar_type() != torch::kFloat8_e4m3fn) { + DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); + major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); + } + return std::make_tuple(mn, k); +} + +static std::tuple check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { + auto [num_groups, mn, k] = get_shape<3>(ab); + if (ab.scalar_type() != torch::kFloat8_e4m3fn) { + DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); + major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); + } + return std::make_tuple(num_groups, mn, k); +} + +// Recipe +static std::tuple +get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) { + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat); + return {1, 128, 128}; + } else if (arch_major == 10) { + DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt); + return sfb_dtype == torch::kFloat ? + std::make_tuple(1, 128, 128): // Legacy format + std::make_tuple(1, 1, 128); // 1D1D kernels + } + DG_HOST_UNREACHABLE("Unknown recipe"); +} + +// SF layouts +static torch::Tensor check_sf_layout(const torch::Tensor& sf, + const int& mn, const int& k, + const int& gran_mn, const int& gran_k, + const std::optional& num_groups, + const bool& tma_stride_check = false, + const bool& sm90_sfb_check = false, + const std::optional& type_check = std::nullopt) { + // Type check + if (type_check.has_value()) + DG_HOST_ASSERT(sf.scalar_type() == type_check.value()); + + // Always do shape checks + const auto sf_dtype = sf.scalar_type(); + DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt); + DG_HOST_ASSERT(sf.dim() == static_cast(num_groups.has_value()) + 2); + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.size(-3) == num_groups.value()); + DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn)); + DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4))); + + // TMA stride checks: TMA aligned and MN-major + if (tma_stride_check) { + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1)); + // Check contiguity in the MN direction + DG_HOST_ASSERT(sf.stride(-2) == 1 or mn == 1); + DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())); + } + + // SM90 SFB must be contiguous, or contiguous after transposing the last two dimensions + if (sm90_sfb_check) { + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.stride(-3) == sf.size(-2) * sf.size(-1)); + DG_HOST_ASSERT((sf.stride(-1) == 1 and sf.stride(-2) == sf.size(-1)) or + (sf.stride(-1) == sf.size(-2) and sf.stride(-2) == 1)); + } + return sf; +} + +// Value matrix layout +static int get_mk_alignment_for_contiguous_layout() { + return 128; +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/utils/lazy_init.hpp b/deep-gemm/csrc/utils/lazy_init.hpp new file mode 100644 index 00000000..386b1b45 --- /dev/null +++ b/deep-gemm/csrc/utils/lazy_init.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +#define DG_DECLARE_STATIC_VAR_IN_CLASS(cls, name) decltype(cls::name) cls::name + +namespace deep_gemm { + +template +class LazyInit { +public: + explicit LazyInit(std::function()> factory) + : factory(std::move(factory)) {} + + T* operator -> () { + if (ptr == nullptr) + ptr = factory(); + return ptr.get(); + } + +private: + std::shared_ptr ptr; + std::function()> factory; +}; + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/utils/math.hpp b/deep-gemm/csrc/utils/math.hpp new file mode 100644 index 00000000..2af48e83 --- /dev/null +++ b/deep-gemm/csrc/utils/math.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include + +#include "exception.hpp" + +namespace deep_gemm { + +// TODO: Use `torch::kFloat4_e2m1fn_x2` +constexpr auto kPackedFP4 = torch::kUInt8; + +template +static T ceil_div(const T& a, const T& b) { + return (a + b - 1) / b; +} + +template +static constexpr T align(const T& a, const T& b) { + return ceil_div(a, b) * b; +} + +static int get_tma_aligned_size(const int& x, const int& element_size) { + constexpr int kNumTMAAlignmentBytes = 16; + DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0); + return align(x, kNumTMAAlignmentBytes / element_size); +} + +} // namespace deep_gemm diff --git a/deep-gemm/csrc/utils/system.hpp b/deep-gemm/csrc/utils/system.hpp new file mode 100644 index 00000000..2c97066f --- /dev/null +++ b/deep-gemm/csrc/utils/system.hpp @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "exception.hpp" +#include "format.hpp" + +namespace deep_gemm { + +// ReSharper disable once CppNotAllPathsReturnValue +template +static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) { + const auto& c_str = std::getenv(name.c_str()); + if (c_str == nullptr) + return default_value; + + // Read the env and convert to the desired type + if constexpr (std::is_same_v) { + return std::string(c_str); + } else if constexpr (std::is_same_v) { + int value; + std::sscanf(c_str, "%d", &value); + return value; + } else { + DG_HOST_ASSERT(false and "Unexpected type"); + } +} + +static std::tuple call_external_command(std::string command) { + command = command + " 2>&1"; + const auto& deleter = [](FILE* f) { if (f) pclose(f); }; + std::unique_ptr pipe(popen(command.c_str(), "r"), deleter); + DG_HOST_ASSERT(pipe != nullptr); + + std::array buffer; + std::string output; + while (fgets(buffer.data(), buffer.size(), pipe.get())) + output += buffer.data(); + const auto& exit_code = WEXITSTATUS(pclose(pipe.release())); + return {exit_code, output}; +} + +static std::vector collect_files(const std::filesystem::path& root) { + std::vector files; + std::function impl; + impl = [&](const std::filesystem::path& dir) { + for (const auto& entry: std::filesystem::directory_iterator(dir)) { + if (entry.is_directory()) { + impl(entry.path()); + } else if (entry.is_regular_file() and entry.path().extension() == ".cuh") { + files.emplace_back(entry.path()); + } + } + }; + impl(root); + + // Be consistent + std::sort(files.begin(), files.end()); + return files; +} + +static std::filesystem::path make_dirs(const std::filesystem::path& path) { + // OK if existed + std::error_code capture; + const bool& created = std::filesystem::create_directories(path, capture); + if (not (created or capture.value() == 0)) { + DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}", + path.c_str(), created, capture.value())); + } + if (created and get_env("DG_JIT_DEBUG")) + fprintf(stderr, "Create directory: %s\n", path.c_str()); + return path; +} + +static std::string get_uuid() { + static std::random_device rd; + static std::mt19937 gen([]() { + return rd() ^ std::chrono::steady_clock::now().time_since_epoch().count(); + }()); + static std::uniform_int_distribution dist; + + // Use snprintf instead of stringstream + char buf[64]; + std::snprintf(buf, sizeof(buf), "%d-%08x-%08x-%08x", + getpid(), dist(gen), dist(gen), dist(gen)); + return std::string(buf); +} + +} // deep_gemm diff --git a/deep-gemm/deep_gemm/__init__.py b/deep-gemm/deep_gemm/__init__.py new file mode 100644 index 00000000..1c07f5d9 --- /dev/null +++ b/deep-gemm/deep_gemm/__init__.py @@ -0,0 +1,112 @@ +import os +import subprocess +import torch + +# Set some default environment provided at setup +try: + # noinspection PyUnresolvedReferences + from .envs import persistent_envs + for key, value in persistent_envs.items(): + if key not in os.environ: + os.environ[key] = value +except ImportError: + pass + +# Configs +from . import _C +from ._C import ( + set_num_sms, + get_num_sms, + set_tc_util, + get_tc_util, +) + +# cuBLASLt Kernels +from ._C import ( + cublaslt_gemm_nt, cublaslt_gemm_nn, + cublaslt_gemm_tn, cublaslt_gemm_tt, +) + +try: + # DeepGEMM Kernels + from ._C import ( + # FP8 FP4 GEMMs + fp8_fp4_gemm_nt, fp8_fp4_gemm_nn, + fp8_fp4_gemm_tn, fp8_fp4_gemm_tt, + m_grouped_fp8_fp4_gemm_nt_contiguous, + m_grouped_fp8_fp4_gemm_nn_contiguous, + m_grouped_fp8_fp4_gemm_nt_masked, + # FP8 GEMMs + fp8_gemm_nt, fp8_gemm_nn, + fp8_gemm_tn, fp8_gemm_tt, + fp8_gemm_nt_skip_head_mid, + m_grouped_fp8_gemm_nt_contiguous, + m_grouped_fp8_gemm_nn_contiguous, + m_grouped_fp8_gemm_nt_masked, + k_grouped_fp8_gemm_nt_contiguous, + k_grouped_fp8_gemm_tn_contiguous, + # BF16 GEMMs + bf16_gemm_nt, bf16_gemm_nn, + bf16_gemm_tn, bf16_gemm_tt, + m_grouped_bf16_gemm_nt_contiguous, + m_grouped_bf16_gemm_nn_contiguous, + m_grouped_bf16_gemm_nt_masked, + k_grouped_bf16_gemm_tn_contiguous, + # Einsum kernels + einsum, + fp8_einsum, + # Attention kernels + fp8_mqa_logits, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, + # Hyperconnection kernels + tf32_hc_prenorm_gemm, + # Layout kernels + transform_sf_into_required_layout, + get_mk_alignment_for_contiguous_layout + ) + + # Some alias for legacy supports + # TODO: remove these later + fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked + bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked +except ImportError: + # Expected behavior for CUDA runtime version before 12.1 + pass + +# Some utils +from . import testing +from . import utils +from .utils import * + +# Legacy Triton kernels for A100 +try: + from . import legacy +except Exception as e: + print(f'Failed to load legacy DeepGEMM A100 Triton kernels: {e}') + +# Initialize CPP modules +def _find_cuda_home() -> str: + # TODO: reuse PyTorch API later + # For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + if cuda_home is None: + # noinspection PyBroadException + try: + with open(os.devnull, 'w') as devnull: + nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n') + cuda_home = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + cuda_home = '/usr/local/cuda' + if not os.path.exists(cuda_home): + cuda_home = None + assert cuda_home is not None + return cuda_home + + +_C.init( + os.path.dirname(os.path.abspath(__file__)), # Library root directory path + _find_cuda_home() # CUDA home +) + +__version__ = '2.3.0' diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/cute_tie.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/cute_tie.cuh new file mode 100644 index 00000000..cd2aace7 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/cute_tie.cuh @@ -0,0 +1,48 @@ +#pragma once + +namespace cute { + +struct ignore_t { + template + constexpr const ignore_t& operator=(T&&) const noexcept { + return *this; + } +}; + +inline constexpr ignore_t ignore{}; + +} // namespace cute + +#define CUTE_TIE_CONCAT_IMPL(A, B) A##B +#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B) + +#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define CUTE_TIE_COUNT_ARGS(...) \ + CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + +#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get(TUPLE) +#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get(TUPLE) + +#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1); +#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2); +#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); +#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); +#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5); + +#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_DECL, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ) + +#define CUTE_TIE(TUPLE_EXPR, ...) \ + do { \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_ASSIGN, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ); \ + } while (0) diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh new file mode 100644 index 00000000..5f6a7a19 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +struct EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 + and kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/reduction.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/reduction.cuh new file mode 100644 index 00000000..d9e35f73 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/reduction.cuh @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include + +#include + +// Operation functors +template struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +__forceinline__ __device__ T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +__forceinline__ __device__ T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/scheduler.cuh new file mode 100644 index 00000000..f93b96ee --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -0,0 +1,288 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto& candidate: {8u, 16u}) { + const auto& usage = kIsMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for countiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + __device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = __ldg(grouped_layout + group_idx); + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k, + int* grouped_layout = nullptr) { + num_m_blocks = ceil_div(shape_m, BLOCK_M); + num_n_blocks = ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = __ldg(grouped_layout); + num_m_blocks = ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto& group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + current_group_idx)), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = align(current_psum_m, 128u); + current_psum_m = __ldg(grouped_layout + current_group_idx); + current_m_block_cumsum += num_m_blocks; + num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with 128 + m_block_idx += last_psum_m / BLOCK_M; + DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M"); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto& block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx); + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/sm100_utils.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/sm100_utils.cuh new file mode 100644 index 00000000..537cbe08 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/sm100_utils.cuh @@ -0,0 +1,266 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace deep_gemm::sm100 { + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + uint32_t stride_byte_offset, uint32_t leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +__device__ __forceinline__ +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +__device__ __forceinline__ +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +__device__ __forceinline__ +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +template +__device__ __forceinline__ +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto& layout_type = to_umma_layout_type(); + const auto& num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +__device__ __forceinline__ +uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +template +__device__ constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if (kNumCols <= 32) return 32; + if (kNumCols <= 64) return 64; + if (kNumCols <= 128) return 128; + if (kNumCols <= 256) return 256; + return 512; +} + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +__device__ __forceinline__ +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +} // namespace `deep_gemm::sm100` diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/sm90_utils.cuh new file mode 100644 index 00000000..0874b675 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -0,0 +1,332 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm::sm90 { + +template +struct FP8MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + + template + __forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +template +struct SM90_U32x2_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +struct SM90_U32x2_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +__forceinline__ __device__ void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +__forceinline__ __device__ void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +template +__device__ cute::GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type, + const int& leading_byte_offset = 0, + const int& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto& uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +__device__ __forceinline__ +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto& layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +} // namespace `deep_gemm::sm90` diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/tma_utils.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/tma_utils.cuh new file mode 100644 index 00000000..bd54adc2 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/tma_utils.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +// Tensormap related +__device__ __forceinline__ void tensor_map_release_cta() { + asm volatile ("fence.proxy.tensormap::generic.release.cta;"); +} + +__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +} // namespace `deep_gemm` diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/types.hpp b/deep-gemm/deep_gemm/include/deep_gemm/common/types.hpp new file mode 100644 index 00000000..410c5469 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/types.hpp @@ -0,0 +1,41 @@ +#pragma once + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/common/utils.cuh b/deep-gemm/deep_gemm/include/deep_gemm/common/utils.cuh new file mode 100644 index 00000000..8fb6c2fc --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/common/utils.cuh @@ -0,0 +1,183 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "cute_tie.cuh" + +#ifdef __CLION_IDE__ + +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +namespace deep_gemm { + +template +struct PatternVisitor { + FuncT func; + + __device__ __host__ + explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} + + __device__ __host__ + auto operator [](const uint32_t& i) { + return func(i); + } +}; + +template +__device__ __host__ T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ T align(T a, T b) { + return ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +__forceinline__ __device__ void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +__forceinline__ __device__ uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +__forceinline__ __device__ uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); +} + +__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); +} + +__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +template +__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +__device__ __forceinline__ void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +template +struct Vectorized { + static auto zeros() { + // TODO: add `ulonglong4` for SM100 once `__ldg` support this + if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) { + return make_uint4(0, 0, 0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) { + return make_uint2(0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) { + return 0; + } else { + DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization"); + } + } + + using vec_t = decltype(zeros()); +}; + +} // namespace `deep_gemm` diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh new file mode 100644 index 00000000..0227b3e8 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -0,0 +1,482 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `umma_arrive()` overhead + constexpr bool kDoMergeStages = + kNumStages_ >= 8 and kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 8; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; + constexpr uint32_t kNumTMAStoreStages = 2; + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive only at the leader CTA + full_barriers[i]->init(kNumMulticast); + // Arrive at all CTAs + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + if constexpr (kTensorCoreUtilControl < 100) + tensor_core_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = (stage_idx + 1) % kNumStages; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // UMMA and empty barrier arrival alias + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc); + } + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + + // Let tensor cores relax for lower possibility of frequency drop + DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control"); + if constexpr (kTensorCoreUtilControl < 100) { + // For utilization control + umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + + // Wait for last UMMA to be done + tensor_core_full_barrier->wait(tensor_core_phase); + tensor_core_phase ^= 1; + + // Sleep for certain cycles + constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; + const auto& start_clock = clock64(); + if (cute::elect_one_sync()) + while (clock64() - start_clock < kNumDummyCycles) {} + __syncwarp(); + } + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } + + // Deallocate tensor memory by the last UMMA store warp + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) + Allocator().free(0, kNumTmemCols); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh new file mode 100644 index 00000000..86303347 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -0,0 +1,265 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumThreads, 1) +sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); + DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Shared memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Fill D/A/B + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx == 0) { + // TMA load warp + for (uint32_t s = 0; s < num_total_stages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + uint32_t m_idx = BLOCK_M * m_block_idx; + uint32_t n_idx = BLOCK_N * n_block_idx; + uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + uint32_t k_idx = sk_idx % SHAPE_K; + uint32_t s_idx = sk_idx / SHAPE_K; + + // Issue TMAs + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + } + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (cute::elect_one_sync()) + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } else if (warp_idx == 1) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + constexpr uint32_t UMMA_M = LAYOUT_AD_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Wait tensor memory empty barrier arrival + tcgen05_after_thread_sync(); + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + if (warp_idx == 2) + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + if (s >= kNumTMAStoreStages) { + if (warp_idx == 0 and cute::elect_one_sync()) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = s % kNumTMAStoreStages; + const auto m_idx = m_block_idx * BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is doing TMA stores + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh new file mode 100644 index 00000000..45a603ad --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -0,0 +1,563 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; + constexpr uint32_t kNumTMAStoreStages = 2; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4); + const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + tma_copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad))); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + tma_copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx)); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32; + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + tcgen05_after_thread_sync(); + + // Do SF copy at certain stages + // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + __syncwarp(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc, + kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32), + kTmemStartColOfSFB); + } + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } + + // Deallocate tensor memory by the last UMMA store warp + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) + Allocator().free(0, kNumTmemCols); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh new file mode 100644 index 00000000..180a308b --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -0,0 +1,404 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint64_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warp_in_group_idx = warp_idx % 4; + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + + // Align to 512 bytes for swizzle-64B + extern __shared__ __align__(512) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_WEIGHT_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + + // TMA configs + constexpr uint32_t kNumTmemCols = BLOCK_Q * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + + // Tensor memory allocation + auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); + + // Initialize barriers + DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); + const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32)); + const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1)); + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (is_umma_warp) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 24; + constexpr uint32_t kNumMathRegisters = 240; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + + if (is_tma_load_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue UMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } else if (warp_idx >= kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + const auto& warp_offset = warp_idx * 32; + const auto& v_offset = lane_idx; + + // Preload weights + constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads); + float weights[BLOCK_Q][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) { + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); + tcgen05_after_thread_sync(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + + constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems"); + uint32_t shifted_accum[kNumLDTMElems]; + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); + } + + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result; + } else { + logits[q_idx * stride_logits + kv_offset + v_offset] = result; + } + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + + // Free tensor memory + __syncthreads(); + if (is_tma_load_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh new file mode 100644 index 00000000..7058c40f --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -0,0 +1,398 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); + }); + constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); + + constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); + const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); + const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); + + // Initialize barriers + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + cutlass::arch::fence_barrier_init(); + } + if (is_umma_warp) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = kNextN * kNumHeads; + DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + + if (is_tma_load_warp) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + } + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + if (cute::elect_one_sync()) { + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 1; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + if (q_idx != next_q_idx) { + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + } + + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(umma_phase); + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + umma_phase ^= 1; + } + } else if (is_math_warp) { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + const uint32_t thread_idx = threadIdx.x; + + // Weights + constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); + float weights[kNextN][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 0; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait(umma_phase); + tcgen05_after_thread_sync(); + umma_phase ^= 1; + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; + uint32_t shifted_accum[kNumLDTMElems]; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); + } + + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + thread_idx] = result; + } + } + } else { + cutlass::arch::warpgroup_reg_dealloc(); + } + + // Free tensor memory + __syncthreads(); + if (is_umma_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..4e4ff21d --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,345 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__device__ __forceinline__ +uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { + // Calculate the index of the bank group to be written in the atom + const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` + // - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)` + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % (kSwizzleMode / kSwizzleBase); + + return row * 128 + col * kSwizzleBase; +} + +template +__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t kNumCastStages = 2; + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + constexpr auto kMajorA = cute::UMMA::Major::K; + constexpr auto kMajorB = cute::UMMA::Major::K; + DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages"); + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 4 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + full_cast_barriers[i]->init(kNumCastAndReduceThreads); + empty_barriers[i]->init(1); + empty_cast_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + + // Dispatch warps into different roles + if (warp_idx < kNumMMAThreads / 32) { + // TMA load warp + if (warp_idx == 0 and cute::elect_one_sync()) { + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } + + // MMA issue warp + if (warp_idx == 1) { + // Make instruction descriptor + constexpr uint32_t UMMA_M = BLOCK_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(float); + constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float); + using umma_t = cute::SM100_MMA_TF32_TS; + auto instr_desc = cute::UMMA::make_instr_desc(); + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Launch MMAs + // We can not unroll this part + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + const auto& cast_stage_idx = s % kNumCastStages; + full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; + const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; + const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); + } + + // Commit + cutlass::arch::umma_arrive(reinterpret_cast(empty_cast_barriers[cast_stage_idx])); + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + + // Commit to epilogue threads + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Only support layout F (M = 64) and D (M = 128) + DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Source and destination memory address + uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup; + auto smem_ptr = reinterpret_cast(smem_cd) + // Base pointer + warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset + get_swizzled_smem_offset(i, lane_idx); // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + if constexpr (BLOCK_M == 64) + __syncwarp(); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0); + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } else { + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32; + + // TODO: make even larger block K + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + + // Launch reductions + float2 sum[2] = {float2{0, 0}, float2{0, 0}}; + #pragma unroll kNumStages + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b) + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + const auto& smem_base_ptr = reinterpret_cast(smem_a[stage_idx]) + // Base pointer + sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset + + // 4 lanes shared a bank group + uint32_t uint32_values[2][kNumLoads]; + DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads"); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; i += 2) { + auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); + sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); + } + + // Wait tensor memory empty + const auto& cast_stage_idx = s % kNumCastStages; + empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1); + + // Cast, reduce and store into tensor memory + float2 fp32x2_values[2][kNumLoads]; + const auto& upper_view = reinterpret_cast(&fp32x2_values[0]); + const auto& lower_view = reinterpret_cast(&fp32x2_values[1]); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast(&uint32_values[u][i])); + sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]); + } + + // Store upper and lower part at the same time + const auto idx_0 = i * 2, idx_1 = i * 2 + 1; + cute::SM100_TMEM_STORE_16dp256b1x::copy( + upper_view[idx_0], upper_view[idx_1], + lower_view[idx_0], lower_view[idx_1], + cast_stage_idx * BLOCK_K + i * 8); + } + cutlass::arch::fence_view_async_tmem_store(); + + // Arrive for issuing MMAs + tcgen05_before_thread_sync(); + full_cast_barriers[cast_stage_idx]->arrive(); + } + + // Intra-warp reduction and write back + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + if (lane_idx % 4 == 0 and m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum; + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh new file mode 100644 index 00000000..7a77e4e8 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -0,0 +1,381 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `warpgroup_wait<0>()` overhead + constexpr uint32_t kDoMergeStages = + kNumStages_ >= 10 and + kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K and + kNumMathThreads == 128; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 5; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B/D must be aligned to 1024 bytes"); + + // D/A/B shared memory + auto smem_d = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 48; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(kNumTMAThreads >= 128, "Need at least 128 threads for TMA warp-group"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = make_gmma_desc(smem_b[0], 0, 0); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= 64 or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // TODO: remove some useless computation for unaligned Ms + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = advance_gmma_desc_lo( + a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); + b_desc.reg32_[0] = advance_gmma_desc_lo( + b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); + WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + if constexpr (cute::is_same_v) { + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + } else { + // Use `st.shared` if STSM is not available + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2); + auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); + } + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + // Use TMA store to write back to global memory + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh new file mode 100644 index 00000000..191a4fe2 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + float *d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Shared memory + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); + DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); + DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + // Fill shared memory pointers + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + #pragma unroll + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); + + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t& k_idx = sk_idx % SHAPE_K; + const uint32_t& s_idx = sk_idx / SHAPE_K; + + constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + float accum[WGMMA::kNumAccum] = {0}; + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrivals + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, 1); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + empty_barriers[stage_idx]->arrive(); + } + + const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + if (col + i * 8 >= SHAPE_N) + break; + if (row < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 0) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 0], accum[i * 4 + 1])); + } + if (row + 8 < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 8) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 2], accum[i * 4 + 3])); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh new file mode 100644 index 00000000..cdd28fcb --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -0,0 +1,349 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, + int* grouped_layout, + cute::TmaDescriptor* tensor_map_buffer, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_b_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = threadIdx.x % 32; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a_base); + cute::prefetch_tma_descriptor(&tensor_map_b_base); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Tensor maps on shared and global memory + auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * i); + }); + auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * (2 + i)); + }); + auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; }); + auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; }); + + // Data on shared memory + auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); + }); + auto smem_sfb = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); + }); + + // Barriers on shared memory + constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); + auto full_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); + }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); + }); + + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // Load tensormap A/B to shared memory + if constexpr (kGemmType == GemmType::KGroupedContiguous) { + *smem_tensor_map_a[0] = tensor_map_a_base; + *smem_tensor_map_a[1] = tensor_map_a_base; + *smem_tensor_map_b[0] = tensor_map_b_base; + *smem_tensor_map_b[1] = tensor_map_b_base; + } + + // Initialize barriers + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Pipeline unroll control + constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages); + + // Register reconfigurations (more math registers are needed with unrolling) + constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); + constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // TMA and MMA pipeline + const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase + }; + uint32_t iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; + const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; + uint32_t last_group_idx = kNumGroups, sum_k = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t& m_idx = m_block_idx * BLOCK_M; + const uint32_t& n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) { + const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1; + const uint32_t& next_stage_idx = stage_idx ^ 1; + last_group_idx = scheduler.current_group_idx; + + // Prepare next tensor map + sum_k += scheduler.current_shape_k; + if (scheduler.next_group_idx < kNumGroups) { + tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); + tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); + *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); + tensor_map_release_cta(); + } + + // Get current tensor map + if (scheduler.current_num_valid_groups > 0) { + tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]); + tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]); + current_tensor_map_a = gmem_tensor_map_a[stage_idx]; + current_tensor_map_b = gmem_tensor_map_b[stage_idx]; + } + } + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& k_idx = k_block_idx * BLOCK_K; + const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + tma_copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma_copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma_copy(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma_copy(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; + const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Accumulation for WGMMA or CUDA promotion + DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); + const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + float2 scales_b[WGMMA::kNumAccum / 4]; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait TMA arrivals + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + full_barriers[stage_idx]->wait(phase); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1); + + // Read B scales + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + scales_b[i] = ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + + // Promote with scales + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const float &scale_b_0 = scales_b[i].x; + const float &scale_b_1 = scales_b[i].y; + final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + + // Flush previous stores + if (warp_idx % 4 == 0 and cute::elect_one_sync()) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Store to D shared memory + const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Use TMA store to write back to global memory + if (warp_idx % 4 == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy( + &tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N, + current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh new file mode 100644 index 00000000..9247304c --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -0,0 +1,440 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { + if (num_former_iters == kNumFormerIters) { + func(cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + dispatch_num_former_iters(num_former_iters, func); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K); + const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); + auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA A + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t k_idx = k_block_idx * BLOCK_K; + tma_copy(&tensor_map_a, &full_barrier, + smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a, batch_idx); + tma_copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), + num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, &full_barrier, + smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b, batch_idx); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = make_smem_desc(smem_b[0], 1); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; + const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; + auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; + + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) + st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb)); + } + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + + // Accumulation for WGMMA or CUDA promotion + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&]() { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void(); + } + }; + + // Skip useless computations + if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { + // The compiler must know the dynamic variable `num_former_iters`'s real value + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // Dispatch `num_former_iters` and launch MMAs + dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { + #pragma unroll 8 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Read B scales + float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; + b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; + WGMMA::wgmma(a_desc, b_desc, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(); + + // Skip promotion for the unfilled parts + if (not do_wgmma_store) + continue; + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + } + }); + } else { + #pragma unroll + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + empty_barrier_arrive(); + } + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset); + auto m_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr, + n_idx, m_idx, scheduler.current_group_idx); + } else { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh new file mode 100644 index 00000000..d58c7162 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -0,0 +1,329 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint64_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + + // Initialize barriers + const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 32; + constexpr uint32_t kNumMathRegisters = 112; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // Only the first warp remains + if (not is_tma_load_warp) + return; + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& thread_idx = threadIdx.x % kNumMathThreads; + const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; + + const auto& warp_offset = warp_idx * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; + if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; + } else { + logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; + logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1; + } + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh new file mode 100644 index 00000000..482a85a8 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -0,0 +1,413 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = get_lane_idx(); + + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t q_idx = k * 32 + lane_idx; + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0); + num_segs[k] = ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t q_idx = 0; + while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts) + ++ q_idx; + const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } +} + +template +struct PagedMQALogitsScheduler { + uint32_t batch_size; + const uint32_t* context_lens; + + uint32_t current_q_idx, current_kv_idx; + uint32_t end_q_idx, end_kv_idx; + uint32_t current_num_kv; + + __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) { + const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0; + } + + __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, + const uint32_t* context_lens, const uint32_t* schedule_meta) { + this->batch_size = batch_size; + this->context_lens = context_lens; + + const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); + const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); + current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; + + current_num_kv = get_num_kv(current_q_idx); + } + + __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_idx = current_q_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (q_idx == end_q_idx and kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumBlocksPerSplit; + if (current_kv_idx >= current_num_kv) { + ++ current_q_idx; + current_kv_idx = 0; + current_num_kv = get_num_kv(current_q_idx); + } + + return true; + } + + __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const { + return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx; + } +}; + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + + constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + + constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q data and barriers on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + + // Separate math warpgroups and tma load warps into KV groups + // Each math warpgroup corresponds to a tma load warp + const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + + // Per group KV data and barriers on shared memory + const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + + // Initialize barriers + if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { + if (kv_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + } + if (kv_group_idx < kNumMathWarpGroups) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(128); + } + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 64; + constexpr uint32_t kNumMathRegisters = 104; + + // Scheduler + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + if (kv_group_idx >= kNumMathWarpGroups) + return; + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (kv_group_idx == 0 and cute::elect_one_sync()) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? + __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + } + const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; + const auto& sub_warp_offset = (warp_idx % 4) * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + + // Wait WGMMA + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + v_0_offset] = v_0; + logits[kv_offset + i * logits_stride + v_1_offset] = v_1; + } + } + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..e3bf9847 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,287 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ __forceinline__ +uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { + constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; + + const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % kGroupsInSwizzleRange; + + return (row * kNumBankGroups + col) % kGroupsInSwizzleRange; +} + +template +__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // kSwizzleAMode and kSwizzleBMode must be 128 for now + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode"); + + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 256; + + // TMA load warp + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cutlass::arch::warpgroup_reg_dealloc(); + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + + for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + } + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + constexpr uint32_t WGMMA_M = 64; + constexpr uint32_t WGMMA_N = BLOCK_N; + constexpr uint32_t WGMMA_K = 8; + + using WGMMA = typename TF32MMASelector::type; + float accum[WGMMA::kNumAccum] = {0}; + + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + float sqr_sum_acc_0 = 0; + float sqr_sum_acc_1 = 0; + + #pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2 + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128; + constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K; + + float a[kNumRegPerWgmma * kNumWgmmaPerBlockK]; + // Assume swizzle A mode is 128 + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + + // Load BF16 A fragment from shared memory into registers, and transpose to FP32 + uint32_t row = warp_idx * 16 + lane_idx / 4; + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + // Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a + uint32_t bank_group_idx = (row ^ i) % 8; + nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + + uint32_t elem_offset = lane_idx % 4; + nv_bfloat16 a_bf16[kNumRegPerWgmma]; + a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset]; + a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4]; + a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset]; + a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4]; + + auto a_bf16x2_ptr = reinterpret_cast(a_bf16); + auto a_float2_ptr = reinterpret_cast(a); + float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]); + float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]); + a_float2_ptr[i * 2 + 0] = a_float2_0; + a_float2_ptr[i * 2 + 1] = a_float2_1; + sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x; + sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; + } + + warpgroup_wait<0>(); + if (s > 0) + empty_barriers[(s - 1) % kNumStages]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + + constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); + constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; + DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K"); + + #pragma unroll + for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { + #pragma unroll + for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { + auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + } + + const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + + const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); + if (lane_idx % 4 == 0) { + if (m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum_0; + if (m_idx + 8 < shape_m) + sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; + } + warpgroup_wait<0>(); + empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); + + // Write accum to shared memory + // Every 2 threads (one pair) will write to the same bank group (16 bytes). + // Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d + uint32_t is_odd_pair = lane_idx / 2 % 2; + + // Four threads per group; write the data to the same row. + uint32_t row_idx = lane_idx / 4; + + // Even/odd index pairs write to the same column, we need to reorder idx: + // group even pair indices consecutively, and likewise for odd ones. + uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx; + + auto shifted_smem_ptr = reinterpret_cast(smem_cd) + + (warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows + lane_idx % 2 * 8; // One thread of a pair writes 8 bytes + + #pragma unroll + for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) { + // Get the swizzled bank group index (16 bytes per group) + uint32_t bank_group_idx = get_swizzled_bank_group_idx(i + is_odd_pair, reordered_pair_idx); + auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group + + // 0/1 write to the same row, 2/3 write to another row + auto values = reinterpret_cast(accum + i * 2); + st_shared(smem_ptr, values[0], values[1]); + st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, 1); + + // Issue TMA stores + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh new file mode 100644 index 00000000..cc9e5e6b --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(kNumWarps * 32, 1) +void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { + const uint32_t& num_sms = gridDim.x; + const uint32_t& sm_idx = blockIdx.x; + const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + constexpr float neg_inf = -cute::numeric_limits::infinity(); + + // Allocate filled `-inf` shared memory + extern __shared__ __align__(1024) float smem_buffer[]; + #pragma unroll + for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) + smem_buffer[i] = neg_inf; + cute::tma_store_fence(); + __syncthreads(); + + // Assign sequence to each warp + const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto& per = total / num, rem = total % num; + return {start + idx * per + min(idx, rem), per + (idx < rem)}; + }; + CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); + CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + + if (cute::elect_one_sync()) { + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + + for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { + const auto& right = min(left + BLOCK_KV, static_cast(stride_logits)); + if (right <= ks or ke <= left) { + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float)); + } else { + if (left < aligned_ks) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float)); + if (aligned_ke < right) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float)); + } + } + } + } + + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + for (uint32_t j = aligned_ks; j < ks; ++ j) + logits[i * stride_logits + j] = neg_inf; + for (uint32_t j = ke; j < aligned_ke; ++ j) + logits[i * stride_logits + j] = neg_inf; + } +} + +} diff --git a/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh new file mode 100644 index 00000000..bea70002 --- /dev/null +++ b/deep-gemm/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -0,0 +1,176 @@ +#pragma once + +#include + +namespace deep_gemm { + +template +__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename Vectorized::vec_t in_vec_t; + constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); + constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; + + // Shapes and strides + extern __shared__ float smem_buffer[]; + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the block + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; + const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + + // Load + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { + auto in_vec = __ldg(local_sf + i); + const auto& in_values = reinterpret_cast(&in_vec); + + const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; + #pragma unroll + for (uint32_t j = 0; j < kNumElemsPerVec; ++ j) + smem_buffer[row * PADDED_SF_K + col + j] = in_values[j]; + } + __syncthreads(); + + // Store + #pragma unroll + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { + const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; + const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + } +} + +// NOTES: the two kernels below always pack the K dimension + +template +__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { + extern __shared__ uint32_t smem_buffer[]; + + // Shapes and strides + constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the group + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + + // Load FP32 SFs + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); + const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + const auto num_values = in_block_mn * SF_K; + const auto num_uint4 = num_values / 4; + #pragma unroll + for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { + const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); + st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + } + + // Fill unaligned values as well + if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) + st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + __syncthreads(); + + // Pack into UE8M0 and store + #pragma unroll + for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) { + const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN; + + // Load shared memory + uint32_t values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + const auto sf_k_idx = sf_k_pack_idx * 4 + j; + values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + } + + // Pack and store + uint32_t packed = 0; + packed |= (values[0] >> 23u); + packed |= (values[1] >> 15u); + packed |= (values[2] >> 7u); + packed |= (values[3] << 1u); + if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn) + out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed; + } +} + +template +__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { + // Always packing the K dimension + // NOTES: should also assert `mn % 4 == 0` at launch + DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes"); + + // Shapes and strides + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto in_block_mn_uint4 = in_block_mn / 4; + const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K); + + // Shift into the right block along MN + sf += blockIdx.x * BLOCK_MN; + out += blockIdx.x * BLOCK_MN; + + // Each warp is responsible for a packed row + const auto warp_idx = threadIdx.x / 32; + const auto lane_idx = get_lane_idx(); + const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; + if (warp_idx >= in_block_packed_sf_k) + return; + + // Make an offset on the input + uint32_t input_offset = 0; + if constexpr (kNumGroups > 1) { + // Load each group's size + DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups"); + uint32_t group_ks[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) { + const auto group_idx = lane_idx * 4 + i; + group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + } + __syncwarp(); + + // Make the offset + sf_k = 0; + auto sum_packed_sf_k = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumGroups; ++ i) { + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + sf_k += sf_k_in_group; + sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + if (packed_sf_k_idx < sum_packed_sf_k) + break; + if (const auto remainder = sf_k_in_group % 4; remainder > 0) + input_offset += 4 - remainder; + } + } + + for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + // Load + uint4 values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + values[j] = make_uint4(0, 0, 0, 0); + if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) + values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + } + + // Pack and store + uint4 packed; + packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u); + packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u); + packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u); + packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u); + reinterpret_cast(out + packed_sf_k_idx * mn)[mn_idx] = packed; + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/deep_gemm/legacy/__init__.py b/deep-gemm/deep_gemm/legacy/__init__.py new file mode 100644 index 00000000..cce39ec7 --- /dev/null +++ b/deep-gemm/deep_gemm/legacy/__init__.py @@ -0,0 +1,5 @@ +# All kernels may be deprecated in the future (or rewrite in TileLang) +from .m_grouped_gemm import * +from .a_fused_m_grouped_gemm import * +from .a_fused_k_grouped_gemm import * +from .b_fused_k_grouped_gemm import * diff --git a/deep-gemm/deep_gemm/legacy/a_fused_k_grouped_gemm.py b/deep-gemm/deep_gemm/legacy/a_fused_k_grouped_gemm.py new file mode 100644 index 00000000..7b42f152 --- /dev/null +++ b/deep-gemm/deep_gemm/legacy/a_fused_k_grouped_gemm.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def a_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k + tl.arange(0, BLOCK_SIZE_K) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + rows[None, :] * M + + b_ptrs = b_ptr + k_range[:, None].to(tl.int64) * N + n_range[None, :] + a = tl.load(a_ptrs, mask=(rows >= 0)[None, :] & m_mask, other=0) + b = tl.load(b_ptrs, mask=n_mask, other=0) + acc = tl.dot(a, b, acc) + + # Write back + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def a_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == b.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert b.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K_, M = a.shape + K, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + a_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/deep-gemm/deep_gemm/legacy/a_fused_m_grouped_gemm.py b/deep-gemm/deep_gemm/legacy/a_fused_m_grouped_gemm.py new file mode 100644 index 00000000..3f1f5294 --- /dev/null +++ b/deep-gemm/deep_gemm/legacy/a_fused_m_grouped_gemm.py @@ -0,0 +1,92 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def a_fused_m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, m_row_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # b block + rows = tl.load(m_row_indices_ptr + m_range).to(tl.int64) + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + k_range = (k + tl.arange(0, BLOCK_SIZE_K)).to(tl.int64) + k_mask = k_range < K + a_ptrs = a_ptr + rows[:, None] * K + k_range[None, :] + b_ptrs = b_ptr + batch_id * K * N + k_range[:, None] * (1 if IS_B_K_MAJOR else N) + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + a = tl.load(a_ptrs, mask=(rows >= 0)[:, None] & k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + d = acc.to(d_ptr.dtype.element_ty) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, d, mask=n_mask) + + +def a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + m_indices, m_row_indices = mappings + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous() or b.mT.is_contiguous()) and d.is_contiguous() + assert m_indices.is_contiguous() and m_row_indices.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 and d.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and m_row_indices.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and d.size(0) == m_indices.numel() and d.size(1) == r1 + assert m_indices.numel() == m_row_indices.numel() + assert m_indices.numel() % get_mk_alignment_for_contiguous_layout() == 0 + + if d.size(0) == 0: + return d + + M_, K = a.shape + B, K, N = r0, r2, r1 + M = m_indices.numel() + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), ) + a_fused_m_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, m_indices, m_row_indices, + M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def a_fused_m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, mappings) diff --git a/deep-gemm/deep_gemm/legacy/b_fused_k_grouped_gemm.py b/deep-gemm/deep_gemm/legacy/b_fused_k_grouped_gemm.py new file mode 100644 index 00000000..a642204b --- /dev/null +++ b/deep-gemm/deep_gemm/legacy/b_fused_k_grouped_gemm.py @@ -0,0 +1,86 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def b_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = (k + tl.arange(0, BLOCK_SIZE_K)).to(tl.int64) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + k_range[None, :] * M + b_ptrs = b_ptr + rows[:, None] * N + n_range[None, :] + a = tl.load(a_ptrs, mask=m_mask, other=0.0) + b = tl.load(b_ptrs, mask=(rows >= 0)[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == a.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K, M = a.shape + K_, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + b_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/deep-gemm/deep_gemm/legacy/m_grouped_gemm.py b/deep-gemm/deep_gemm/legacy/m_grouped_gemm.py new file mode 100644 index 00000000..e685a9ab --- /dev/null +++ b/deep-gemm/deep_gemm/legacy/m_grouped_gemm.py @@ -0,0 +1,84 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + # Empty tokens + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # Compute + a_ptrs = a_ptr + m_range[:, None].to(tl.int64) * K + tl.arange(0, BLOCK_SIZE_K)[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + b_ptrs = b_ptr + batch_id * K * N + \ + tl.arange(0, BLOCK_SIZE_K)[:, None].to(tl.int64) * (1 if IS_B_K_MAJOR else N) + \ + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + for k in range(0, K, BLOCK_SIZE_K): + k_mask = (k + tl.arange(0, BLOCK_SIZE_K)) < K + a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * (1 if IS_B_K_MAJOR else N) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, accumulator.to(d_ptr.dtype.element_ty), mask=n_mask) + + +def m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous or b.mT.is_contiguous()) + assert m_indices.is_contiguous() and d.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and d.dtype == torch.bfloat16 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and a.size(0) == d.size(0) and r1 == d.size(1) + assert m_indices.numel() == a.size(0) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + M, K = a.shape + B, N, K_ = r0, r1, r2 + + # For Triton 2.0, persistent kernel will lead to errors + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + m_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, m_indices, M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, m_indices) diff --git a/deep-gemm/deep_gemm/legacy/tune_options.py b/deep-gemm/deep_gemm/legacy/tune_options.py new file mode 100644 index 00000000..ed6a7f77 --- /dev/null +++ b/deep-gemm/deep_gemm/legacy/tune_options.py @@ -0,0 +1,28 @@ +from triton import Config +from .._C import get_mk_alignment_for_contiguous_layout + + +def get_config_smem_size(config: Config, elem_bytes: int = 2): + # NOTES: FP8 kernels will not use Triton, so by default we assume BF16 kernels + return (config.kwargs['BLOCK_SIZE_M'] + config.kwargs['BLOCK_SIZE_N']) * config.kwargs['BLOCK_SIZE_K'] * elem_bytes * config.num_stages + + +_gemm_configs = [ + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), +] + +# NOTES: we only consider A100 shared memory sizes here, as legacy kernels are only used for Ampere +_gemm_configs = list(filter(lambda x: get_config_smem_size(x) <= 166912, _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) + +get_m_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +get_k_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) diff --git a/deep-gemm/deep_gemm/testing/__init__.py b/deep-gemm/deep_gemm/testing/__init__.py new file mode 100644 index 00000000..13a9d78d --- /dev/null +++ b/deep-gemm/deep_gemm/testing/__init__.py @@ -0,0 +1,4 @@ +from . import bench, numeric, utils +from .bench import * +from .numeric import * +from .utils import * diff --git a/deep-gemm/deep_gemm/testing/bench.py b/deep-gemm/deep_gemm/testing/bench.py new file mode 100644 index 00000000..2c752da2 --- /dev/null +++ b/deep-gemm/deep_gemm/testing/bench.py @@ -0,0 +1,137 @@ +import os +import sys +import torch + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests / 1e3 + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, + with_multiple_kernels: bool = False): + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + + # Skip profiling + # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer + if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): + return (1, ) * len(kernel_names) if is_tuple else 1 + + # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + fn() + profiler.step() + + # Parse the profiling table + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num if total_num > 0 else 0) + + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/deep-gemm/deep_gemm/testing/numeric.py b/deep-gemm/deep_gemm/testing/numeric.py new file mode 100644 index 00000000..a42c4318 --- /dev/null +++ b/deep-gemm/deep_gemm/testing/numeric.py @@ -0,0 +1,21 @@ +import torch +from typing import Iterable + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + if denominator == 0: # Which means that all elements in x and y are 0 + return 0.0 + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(*tensors): + total = 0 + for t in tensors: + if isinstance(t, (tuple, list)): + total += count_bytes(*t) + elif t is not None: + total += t.numel() * t.element_size() + return total diff --git a/deep-gemm/deep_gemm/testing/utils.py b/deep-gemm/deep_gemm/testing/utils.py new file mode 100644 index 00000000..2d202d41 --- /dev/null +++ b/deep-gemm/deep_gemm/testing/utils.py @@ -0,0 +1,38 @@ +import functools +import os +import torch +from typing import Callable + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def test_filter(condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + func(*args, **kwargs) + else: + print(f'{func.__name__}:') + print(f' > Filtered by {condition}') + print() + return wrapper + return decorator + + +def ignore_env(name: str, condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + saved = os.environ.pop(name, None) + func(*args, **kwargs) + if saved is not None: + os.environ[name] = saved + else: + func(*args, **kwargs) + + return wrapper + return decorator diff --git a/deep-gemm/deep_gemm/utils/__init__.py b/deep-gemm/deep_gemm/utils/__init__.py new file mode 100644 index 00000000..e8f859a2 --- /dev/null +++ b/deep-gemm/deep_gemm/utils/__init__.py @@ -0,0 +1,3 @@ +from . import math, layout +from .layout import * +from .math import * diff --git a/deep-gemm/deep_gemm/utils/layout.py b/deep-gemm/deep_gemm/utils/layout.py new file mode 100644 index 00000000..790e0d66 --- /dev/null +++ b/deep-gemm/deep_gemm/utils/layout.py @@ -0,0 +1,17 @@ +try: + from .._C import ( + get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor + ) +except ImportError: + # Expected behavior for CUDA runtime version before 12.1 + pass + +# Valid for all CUDA versions +from .._C import get_mk_alignment_for_contiguous_layout + +# Some alias +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/deep-gemm/deep_gemm/utils/math.py b/deep-gemm/deep_gemm/utils/math.py new file mode 100644 index 00000000..c65026e5 --- /dev/null +++ b/deep-gemm/deep_gemm/utils/math.py @@ -0,0 +1,107 @@ +import torch +from typing import Tuple + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + padded_n = align(n, gran_k) + x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf + + +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % gran_k == 0 + m, n = x.shape + x_view = x.view(-1, gran_k, n) + x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + ax = x.abs().clamp_max(6.0) + # {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], + device=x.device, dtype=ax.dtype) + idx = torch.bucketize(ax, boundaries) + code = idx.to(torch.uint8) + sign = (x < 0) & (idx != 0) + code = code | (sign.to(torch.uint8) << 3) + return code # uint8, 0..15 + + +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + assert n % 2 == 0 + padded_n = align(n, gran_k) + x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = x_amax / 6.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = x_view * (1.0 / sf.unsqueeze(2)) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes2 = codes.view(m, padded_n // 2, 2) + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 + return packed[:, :n // 2].contiguous(), sf + + +def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.uint8 + assert a.dim() == 2 + m, n2 = a.shape + n = n2 * 2 + assert (m % 2) == 0 + lo = a & 0x0F + hi = (a >> 4) & 0x0F + codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes[:, 0::2], codes[:, 1::2] = lo, hi + codes_t = codes.transpose(0, 1).contiguous() + codes2 = codes_t.view(n, m // 2, 2) + out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) + return out.contiguous() \ No newline at end of file diff --git a/deep-gemm/develop.sh b/deep-gemm/develop.sh new file mode 100755 index 00000000..e784347a --- /dev/null +++ b/deep-gemm/develop.sh @@ -0,0 +1,25 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Link CUTLASS includes +ln -sf $script_dir/third-party/cutlass/include/cutlass deep_gemm/include +ln -sf $script_dir/third-party/cutlass/include/cute deep_gemm/include + +# Remove old dist file, build files, and build +rm -rf build dist +rm -rf *.egg-info +python setup.py build + +# Find the .so file in build directory and create symlink in current directory +so_file=$(find build -name "*.so" -type f | head -n 1) +if [ -n "$so_file" ]; then + ln -sf "../$so_file" deep_gemm/ +else + echo "Error: No SO file found in build directory" >&2 + exit 1 +fi + +# Open users' original directory +cd "$original_dir" diff --git a/deep-gemm/flake.lock b/deep-gemm/flake.lock new file mode 100644 index 00000000..07107fc1 --- /dev/null +++ b/deep-gemm/flake.lock @@ -0,0 +1,117 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1765121682, + "narHash": "sha256-4VBOP18BFeiPkyhy9o4ssBNQEvfvv1kXkasAYd0+rrA=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "65f23138d8d09a92e30f1e5c87611b23ef451bf3", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "rust-overlay": "rust-overlay" + }, + "locked": { + "lastModified": 1771541347, + "narHash": "sha256-HnOoyFI66n+1TwxDjpPBS0hYS6JOtLqFd/8rU7B8ZaQ=", + "owner": "huggingface", + "repo": "kernels", + "rev": "93e3c8ed8ffe774bc626fd096e636a1140dfa53f", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernels", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1766341660, + "narHash": "sha256-4yG6vx7Dddk9/zh45Y2KM82OaRD4jO3HA9r98ORzysA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "26861f5606e3e4d1400771b513cc63e5f70151a6", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "rust-overlay": { + "inputs": { + "nixpkgs": [ + "kernel-builder", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1769050281, + "narHash": "sha256-1H8DN4UZgEUqPUA5ecHOufLZMscJ4IlcGaEftaPtpBY=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "6deef0585c52d9e70f96b6121207e1496d4b0c49", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/deep-gemm/flake.nix b/deep-gemm/flake.nix new file mode 100644 index 00000000..036512b7 --- /dev/null +++ b/deep-gemm/flake.nix @@ -0,0 +1,18 @@ +{ + description = "Flake for DeepGEMM kernel"; + + inputs = { + self.submodules = true; + kernel-builder.url = "github:huggingface/kernels"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genKernelFlakeOutputs { + inherit self; + path = ./.; + }; +} diff --git a/deep-gemm/install.sh b/deep-gemm/install.sh new file mode 100755 index 00000000..5c7021c6 --- /dev/null +++ b/deep-gemm/install.sh @@ -0,0 +1,13 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Remove old dist file, build files, and install +rm -rf build dist +rm -rf *.egg-info +python setup.py bdist_wheel +pip install dist/*.whl --force-reinstall + +# Open users' original directory +cd "$original_dir" diff --git a/deep-gemm/scripts/generate_pyi.py b/deep-gemm/scripts/generate_pyi.py new file mode 100644 index 00000000..df7490d4 --- /dev/null +++ b/deep-gemm/scripts/generate_pyi.py @@ -0,0 +1,890 @@ +import re +from pathlib import Path + + +def build_cpp_function_index(root_path): + func_index = {} + extensions = {'.cpp', '.cc', '.cxx', '.c', '.hpp', '.h'} + + pattern = re.compile( + r'([\w:\s*<&>,\[\]\(\)]+?)' + r'\s+' + r'([a-zA-Z_][a-zA-Z0-9_:]*)' + r'\s*\(', + ) + + for file_path in Path(root_path).rglob('*'): + if file_path.suffix.lower() not in extensions: + continue + if not file_path.is_file(): + continue + + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + except Exception as e: + print(f'Failed to read file {file_path}: {e}') + continue + + # Remove the compile directives and comments + lines = content.split('\n') + clean_lines = [line for line in lines if not line.strip().startswith(('#', '//'))] + content = '\n'.join(clean_lines) + + for match in pattern.finditer(content): + return_type_part = match.group(1).strip() + full_func_name = match.group(2).strip() + + if not return_type_part or not re.match(r'^[a-zA-Z_]', return_type_part): + continue + + first_token = return_type_part.split()[0] + if first_token in {'return', 'if', 'else', 'for', 'while', 'switch', 'case', 'throw', 'catch', 'auto'}: + continue + + # Extract base name + if '::' in full_func_name: + base_name = full_func_name.split('::')[-1] + else: + base_name = full_func_name + + if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', base_name): + continue + + # Find matching ')' + paren_start = match.end() - 1 + paren_count = 0 + pos = paren_start + while pos < len(content): + ch = content[pos] + if ch == '(': + paren_count += 1 + elif ch == ')': + paren_count -= 1 + if paren_count == 0: + break + elif paren_count < 0: + pos = -1 + break + pos += 1 + else: + continue + + if pos == -1: + continue + + # Check context before match: should be at statement boundary + match_start = match.start() + context_before = content[max(0, match_start - 50):match_start] + if context_before and re.search(r'[a-zA-Z0-9_]$', context_before.rstrip()): + continue + + # Check for definition or header declaration + is_header = file_path.suffix.lower() in {'.h', '.hpp', '.cuh'} + after_paren = content[pos+1:pos+500] + has_brace = '{' in after_paren + has_semicolon = ';' in after_paren.split('{')[0] + + if has_brace or (is_header and has_semicolon): + sig_start = match.start(1) + full_signature = content[sig_start:pos+1].strip() + if base_name not in func_index: + func_index[base_name] = full_signature + + return func_index + + +class BracketTracker: + """ + Tracks nesting levels of various brackets in C++ code: + - () → paren + - [] → bracket + - {} → brace + - <> → angle (treated as template brackets only at top level) + Provides is_top_level() to check if currently outside all brackets. + """ + def __init__(self): + self.paren = 0 # () + self.bracket = 0 # [] + self.brace = 0 # {} + self.angle = 0 # <> + + def update(self, char: str): + """ + Update internal counters based on the given character. + """ + if char == '(': + self.paren += 1 + elif char == ')': + self.paren -= 1 + elif char == '[': + self.bracket += 1 + elif char == ']': + self.bracket -= 1 + elif char == '{': + self.brace += 1 + elif char == '}': + self.brace -= 1 + # Angle brackets < > are only treated as template delimiters + # when not inside (), [], or {} + elif char == '<' and self._in_top_level_of_other_brackets(): + self.angle += 1 + elif char == '>' and self.angle > 0 and self._in_top_level_of_other_brackets(): + self.angle -= 1 + + def _in_top_level_of_other_brackets(self): + """ + Check if not inside parentheses, square brackets, or braces (for correct template bracket recognition). + """ + return self.paren == 0 and self.bracket == 0 and self.brace == 0 + + def is_top_level(self): + """ + Check if completely at top level (all bracket counters are zero). + """ + return (self.paren == 0 and + self.bracket == 0 and + self.brace == 0 and + self.angle == 0) + + +def extract_m_def_statements(root_path): + """ + Scan all c files under root_path and extract all m.def(...) statements. + """ + results = [] + extensions = {'.hpp', '.cpp', '.h', '.cc'} + + # Regex: match m.def( ... ), supports multi-line + pattern = re.compile(r'm\.def\s*\(') + + for file_path in Path(root_path).rglob('*'): + if file_path.suffix.lower() not in extensions: + continue + if not file_path.is_file(): + continue + + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + except Exception as e: + print(f'Failed to read file {file_path}: {e}') + continue + + m_def_list = [] + lines = content.splitlines(keepends=True) + i = 0 + while i < len(lines): + line = lines[i] + if 'm.def(' in line: + # Found a potential starting line + start_i = i + # Check if it's a comment + stripped = line.lstrip() + if stripped.startswith('//') or stripped.startswith('/*'): + i += 1 + continue + + # Try to match the complete m.def(...) call + paren_count = 0 + j = i + found_start = False + while j < len(lines): + current_line = lines[j] + for k, char in enumerate(current_line): + if char == '(': + if not found_start and re.search(r'm\.def\s*\(', current_line[:k+1]): + found_start = True + if found_start: + paren_count += 1 + elif char == ')': + if found_start: + paren_count -= 1 + if paren_count == 0: + # Found complete statement + full_stmt = ''.join(lines[i:j+1]).rstrip() + m_def_list.append(full_stmt) + i = j + break + if paren_count <= 0 and found_start: + break + j += 1 + else: + pass + i += 1 + + if m_def_list: + results.append({ + 'file': str(file_path), + 'm_def_statements': m_def_list + }) + + return results + + +def parse_m_def_statement(m_def_str): + result = { + 'python_function_name': None, + 'num_args': 0, + 'default_args': {}, + 'is_lambda': False, + } + + # Extract top-level arguments + start = m_def_str.find('m.def(') + if start == -1: + raise ValueError(f'[{m_def_str}] Could not find m.def start position') + + paren_count = 0 + content_start = start + len('m.def(') + content_end = -1 + for i in range(content_start, len(m_def_str)): + ch = m_def_str[i] + if ch == '(': + paren_count += 1 + elif ch == ')': + if paren_count == 0: + content_end = i + break + else: + paren_count -= 1 + if content_end == -1: + raise ValueError(f'[{m_def_str}] m.def parentheses not closed') + + args_content = m_def_str[content_start:content_end] + + # Split arguments using BracketTracker + args_list = [] + current = [] + tracker = BracketTracker() + + for ch in args_content: + if ch in '()[]{}<>': + tracker.update(ch) + if ch == ',' and tracker.is_top_level(): + args_list.append(''.join(current).strip()) + current = [] + else: + current.append(ch) + + if current: + args_list.append(''.join(current).strip()) + + if len(args_list) < 2: + raise ValueError(f'[{m_def_str}] m.def has insufficient arguments') + + # Extract Python function name + first = args_list[0].strip() + str_match = re.match(r'^"([^"\\]*(?:\\.[^"\\]*)*)"', first) + if str_match: + result['python_function_name'] = str_match.group(1) + else: + raise ValueError(f'[{m_def_str}] m.def first argument should be a string literal') + + cpp_func_part = args_list[1].strip() + if cpp_func_part.startswith('&'): + cpp_func_part = cpp_func_part[1:].strip() + + if cpp_func_part.startswith('['): + result['is_lambda'] = True + result['cpp_function_name'] = None + else: + if '::' in cpp_func_part: + cpp_func_name = cpp_func_part.split('::')[-1] + else: + cpp_func_name = cpp_func_part + + match = re.match(r'^([a-zA-Z_][a-zA-Z0-9_]*)', cpp_func_name) + if match: + result['cpp_function_name'] = match.group(1) + else: + result['cpp_function_name'] = cpp_func_name + + # Parse py::arg arguments + py_args = args_list[2:] + result['num_args'] = len(py_args) + + for idx, arg_expr in enumerate(py_args): + expr = arg_expr.strip() + # Find top-level '=' + eq_pos = -1 + p_depth = b_depth = br_depth = angle_depth = 0 + i = 0 + while i < len(expr): + ch = expr[i] + if ch == '(': + p_depth += 1 + elif ch == ')': + p_depth -= 1 + elif ch == '[': + b_depth += 1 + elif ch == ']': + b_depth -= 1 + elif ch == '{': + br_depth += 1 + elif ch == '}': + br_depth -= 1 + elif ch == '<' and p_depth == 0 and b_depth == 0 and br_depth == 0: + angle_depth += 1 + elif ch == '>' and angle_depth > 0 and p_depth == 0 and b_depth == 0 and br_depth == 0: + angle_depth -= 1 + elif ch == '=' and all(d == 0 for d in [p_depth, b_depth, br_depth, angle_depth]): + eq_pos = i + break + i += 1 + + if eq_pos != -1: + default_val = expr[eq_pos + 1:].strip() + if not default_val: + raise ValueError(f'[{expr}] Default value is empty (arg {idx})') + result['default_args'][idx] = default_val + + return result + + +def extract_cpp_signature_from_content(cpp_func_name, content): + """ + Search for the C++ function signature of cpp_func_name in the given file content. + """ + if not cpp_func_name: + return None + + # Build regex: match function starting with cpp_func_name (after word boundary) + # Note: function name may be preceded by return type (with templates, namespaces, etc.), followed by '(' + pattern = re.compile( + r'^\s*' # leading whitespace + r'([\w:\s*<&>,\[\]\(\)]+?)' # return type (non-greedy, allows templates, pointers, etc.) + r'\s+' # at least one space + r'\b' + re.escape(cpp_func_name) + r'\b' # function name (word boundary) + r'\s*\(', # optional whitespace + start of param list + re.MULTILINE + ) + + for match in pattern.finditer(content): + # Find '(' position after function name + paren_start = match.end() - 1 + if content[paren_start] != '(': + paren_start = content.find('(', match.end(0) - 1) + if paren_start == -1: + continue + + # From '(', match to corresponding ')' + paren_count = 0 + pos = paren_start + while pos < len(content): + ch = content[pos] + if ch == '(': + paren_count += 1 + elif ch == ')': + paren_count -= 1 + if paren_count == 0: + start_sig = match.start(1) + full_signature = content[start_sig:pos+1].strip() + return full_signature + pos += 1 + + return None + + +def parse_mdef_and_attach_cpp_signatures(item, func_index): + """ + Enhance item by parsing m.def and extracting C++ function signature from global index + """ + statements_with_parsed_signatures = [] + + for stmt in item['m_def_statements']: + parsed = parse_m_def_statement(stmt,) + cpp_func_name = parsed.get('cpp_function_name') + + cpp_sig = None + if cpp_func_name and cpp_func_name in func_index: + cpp_sig = func_index[cpp_func_name] + else: + if not parsed['is_lambda']: + print(f'Warning: C++ function "{cpp_func_name}" not found in any .cpp file') + + parsed['cpp_signature'] = cpp_sig + statements_with_parsed_signatures.append({ + 'raw': stmt, + 'parsed': parsed + }) + + return { + 'm_def_statements': statements_with_parsed_signatures + } + + +def parse_cpp_signature(cpp_sig): + """ + Parse a C++ function signature and extract return type, parameter types, and names. + """ + if not cpp_sig or not cpp_sig.strip(): + return None + + # Find function name: last identifier before '(' + paren_pos = cpp_sig.find('(') + if paren_pos == -1: + return None + + before_paren = cpp_sig[:paren_pos].strip() + if not before_paren: + return None + + # Function name is the last word in before_paren (may include templates like func) + tokens = before_paren.split() + if len(tokens) < 2: + return None + + # Heuristic: function name is usually the last token (may include <>) + func_name_part = tokens[-1] + return_type = ' '.join(tokens[:-1]).strip() + + # Now extract parameter list content + param_list_str = cpp_sig[paren_pos+1:cpp_sig.rfind(')')].strip() + parameters = [] + + if param_list_str and param_list_str != 'void': # 'void' means no parameters + # Split parameters (handle commas not inside templates/brackets) + param_decls = split_cpp_parameters(param_list_str) + for decl in param_decls: + decl = decl.strip() + if not decl: + continue + # Try to split type and name from right to left + param_info = parse_parameter_declaration(decl) + if param_info: + parameters.append(param_info) + + return { + 'return_type': return_type, + 'parameters': parameters, + 'num_parameters': len(parameters) + } + + +def split_cpp_parameters(param_str: str): + """ + Split a C++ parameter list string by top-level commas, + e.g., 'int a, std::vector b' → ['int a', 'std::vector b'] + """ + if not param_str.strip() or param_str == 'void': + return [] + params = [] + current = [] + tracker = BracketTracker() + + for ch in param_str: + if ch in '()[]{}<>': + tracker.update(ch) + if ch == ',' and tracker.is_top_level(): + param = ''.join(current).strip() + if param: # Only add non-empty parameters + params.append(param) + current = [] + else: + current.append(ch) + + if current: + final_param = ''.join(current).strip() + if final_param: # Only add non-empty parameters + params.append(final_param) + return params + + +def parse_parameter_declaration(decl: str): + """ + Parse a single parameter declaration, e.g., 'const std::string& name' → {'type': 'const std::string&', 'name': 'name'} + Improved version that better handles template types. + """ + decl = decl.strip() + if not decl: + return None + + # Remove possible default value (starting from top-level '=') + tracker = BracketTracker() + eq_pos = -1 + for i, ch in enumerate(decl): + if ch in '()[]{}<>': + tracker.update(ch) + elif ch == '=' and tracker.is_top_level(): + eq_pos = i + break + + if eq_pos != -1: + decl = decl[:eq_pos].strip() + + # Now decl is 'type name' or just 'type' + # Instead of simple splitting, we'll use a more robust approach + # to find the parameter name + + # First, let's handle the case where there's no explicit parameter name + # (this sometimes happens in function declarations) + if not re.search(r'[a-zA-Z_][a-zA-Z0-9_]*$', decl): + # No parameter name found, just return the type + return { + 'type': decl, + 'name': None + } + + # Use bracket tracking to find where the type ends and name begins + tracker = BracketTracker() + name_start = -1 + + # Scan from the end to find the start of the parameter name + # We look for the first identifier that's outside all brackets + i = len(decl) - 1 + while i >= 0: + ch = decl[i] + + if ch in '()[]{}<>': + tracker.update(ch) + + # If we're at top level and find an identifier character + if tracker.is_top_level() and re.match(r'[a-zA-Z0-9_]', ch): + # Track back to find the start of this identifier + name_start = i + while name_start > 0 and re.match(r'[a-zA-Z0-9_]', decl[name_start - 1]): + name_start -= 1 + + # Check if this might be part of a type keyword (like 'int', 'bool', etc.) + potential_name = decl[name_start:i+1] + type_keywords = {'int', 'long', 'short', 'char', 'bool', 'float', 'double', + 'void', 'auto', 'const', 'static', 'volatile', 'mutable', + 'unsigned', 'signed'} + + # If it's not a type keyword and looks like a parameter name, use it + if (potential_name not in type_keywords and + re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', potential_name)): + break + + i -= 1 + + if name_start != -1 and i >= 0: + param_name = decl[name_start:i+1] + param_type = decl[:name_start].strip() + + # Clean up the type - remove trailing &, * and whitespace + param_type = param_type.rstrip('&* \t') + + return { + 'type': param_type, + 'name': param_name + } + + # Fallback: if we can't find a clear parameter name, just return the type + return { + 'type': decl, + 'name': None + } + + +def extract_cpp_signature_details(item): + """ + For each m.def entry in item, parse cpp_signature to extract return type and parameter details. + """ + statements_with_parsed_signatures = [] + for stmt_info in item['m_def_statements']: + parsed = stmt_info['parsed'] + cpp_sig = parsed.get('cpp_signature') + + cpp_params_info = None + if cpp_sig: + try: + cpp_params_info = parse_cpp_signature(cpp_sig) + except Exception as e: + print(f'Failed to parse C++ signature: {e}') + + parsed['cpp_parsed_signature'] = cpp_params_info + statements_with_parsed_signatures.append({ + 'raw': stmt_info['raw'], + 'parsed': parsed + }) + + return { + 'm_def_statements': statements_with_parsed_signatures + } + + +def cpp_type_to_python_type(cpp_type: str) -> str: + if not cpp_type: + return 'Any' + + original = cpp_type.strip() + if not original: + return 'Any' + + # Remove C++ specifiers that don't affect Python type + cleaned = re.sub(r'\b(static|inline|constexpr|thread_local|extern|mutable|const|volatile|endif)\b', '', original) + cleaned = cleaned.replace('&', '').replace('*', '').strip() + cleaned = re.sub(r'\s+', ' ', cleaned).strip() + + # Handle void + if cleaned == 'void': + return 'None' + + # Handle template types — ORDER MATTERS! Must come before internal type checks. + + # std::pair + if cleaned.startswith('std::pair<'): + inner = cleaned[10:-1].strip() # len('std::pair<') == 10 + args = split_template_args(inner) + if len(args) == 2: + t1 = cpp_type_to_python_type(args[0]) + t2 = cpp_type_to_python_type(args[1]) + return f'tuple[{t1}, {t2}]' + else: + print(f'Warning: std::pair with unexpected number of args: {cleaned}') + return 'Any' + + # std::tuple + if cleaned.startswith('std::tuple<'): + inner = cleaned[11:-1].strip() # len('std::tuple<') == 11 + args = split_template_args(inner) + py_types = [cpp_type_to_python_type(arg) for arg in args] + return f"tuple[{', '.join(py_types)}]" + + # std::vector + if cleaned.startswith('std::vector<'): + inner = cleaned[12:-1].strip() # len('std::vector<') == 12 + args = split_template_args(inner) + if len(args) == 1: + inner_py = cpp_type_to_python_type(args[0]) + return f'list[{inner_py}]' + else: + print(f'Warning: std::vector with unexpected args: {cleaned}') + return 'Any' + + # std::optional + if cleaned.startswith('std::optional<'): + inner = cleaned[14:-1].strip() # len('std::optional<') == 14 + args = split_template_args(inner) + if len(args) == 1: + inner_py = cpp_type_to_python_type(args[0]) + return f'Optional[{inner_py}]' + else: + print(f'Warning: std::optional with unexpected args: {cleaned}') + return 'Any' + + # std::string + if re.search(r'\bstd::string\b', original): + return 'str' + + # C-style strings: char*, const char*, char[], etc. + if re.search(r'\b(?:const\s+)?char\s*[\*\[]', original): + return 'str' + + # Boolean + if re.search(r'\bbool\b', cleaned): + return 'bool' + + # Integer types (including fixed-width and common aliases) + if re.search(r'\b(int|long|short|size_t|ssize_t|ptrdiff_t|' + r'int8_t|int16_t|int32_t|int64_t|' + r'uint8_t|uint16_t|uint32_t|uint64_t)\b', cleaned): + return 'int' + + # Floating-point + if re.search(r'\b(float|double|long\s+double)\b', cleaned): + return 'float' + + # torch::Tensor + if re.search(r'\btorch::Tensor\b', original): + return 'torch.Tensor' + + # Unrecognized type + print(f'Warning: Unrecognized C++ type: {original}') + return 'Any' + + +def split_template_args(template_args: str): + """ + Split template arguments, e.g., 'int, std::vector' → ['int', 'std::vector'] + """ + if not template_args.strip(): + return [] + args = [] + current = [] + tracker = BracketTracker() + + for ch in template_args: + if ch in '()[]{}<>': + tracker.update(ch) + if ch == ',' and tracker.is_top_level(): + args.append(''.join(current).strip()) + current = [] + else: + current.append(ch) + + if current: + args.append(''.join(current).strip()) + return args + + +def cpp_default_to_python_default(cpp_default: str): + """ + Convert C++ default value string to valid Python expression string. + """ + if not cpp_default: + return 'None' + + s = cpp_default.strip() + + # Handle string literals: 'bf16' → 'bf16' + # Match: starts and ends with unescaped double quotes + string_match = re.match(r'^"([^"\\]*(?:\\.[^"\\]*)*)"$', s) + if string_match: + return s + + # Handle boolean literals + if s == 'false': + return 'False' + if s == 'true': + return 'True' + + # Handle null-like values: nullptr, nullopt, NULL, etc. + if s in ('nullptr', 'NULL') or 'nullopt' in s: + return 'None' + + # Handle std::tuple({128, 128}) → (128, 128) + tuple_match = re.match(r'std::tuple\s*<[^>]*>\s*\(\s*({.*?})\s*\)', s) + if tuple_match: + inner = tuple_match.group(1) # {128, 128} + inner_py = inner.replace('{', '(').replace('}', ')') + return inner_py + + # Handle std::make_tuple(1, 2, 3) → (1, 2, 3) + make_tuple_match = re.match(r'std::make_tuple\s*\(\s*(.*?)\s*\)', s) + if make_tuple_match: + inner = make_tuple_match.group(1) + # Ensure it's a valid tuple even with one element: add comma if needed? + # But in C++ default args, it's usually multi-element, so we assume valid. + return f'({inner})' + + # Handle std::vector({1,2,3}) → [1, 2, 3] + vector_match = re.match(r'std::vector\s*<[^>]*>\s*\(\s*({.*?})\s*\)', s) + if vector_match: + inner = vector_match.group(1) + inner_py = inner.replace('{', '[').replace('}', ']') + return inner_py + + # Handle numeric literals: integers and floats + if re.match(r'^[+-]?\d+$', s): # integer + return s + if re.match(r'^[+-]?\d*\.\d+([eE][+-]?\d+)?$', s): # float + return s + + # Fallback: unrecognized → warn and return None + print(f'Warning: Unrecognized default value: {s}') + return 'None' + + +def generate_pyi_function(item_entry): + parsed = item_entry['parsed'] + py_name = parsed['python_function_name'] + + if parsed.get('is_lambda'): + return f'def {py_name}(*args, **kwargs) -> Any: ...' + + sig_info = parsed.get('cpp_parsed_signature') + default_args = parsed.get('default_args', {}) + + if not sig_info: + return f'def {py_name}(*args, **kwargs) -> Any: ...' + + return_type = cpp_type_to_python_type(sig_info['return_type']) + params = sig_info['parameters'] + num_params = len(params) + + # Build parameter list + param_lines = [] + for i in range(num_params): + param_info = params[i] if i < len(params) else {'type': 'Any', 'name': f'arg{i}'} + param_type = cpp_type_to_python_type(param_info['type']) + param_name = param_info['name'] or f'arg{i}' + + # Replace invalid Python identifiers (e.g., keywords) + if param_name in {'def', 'class', 'from', 'import', 'None', 'True', 'False'}: + param_name = f'{param_name}_' + + # Check for default value + if i in default_args: + cpp_default = default_args[i] + py_default = cpp_default_to_python_default(cpp_default) + param_str = f' {param_name}: {param_type} = {py_default}' + else: + param_str = f' {param_name}: {param_type}' + + param_lines.append(param_str) + + if param_lines: + params_block = ',\n'.join(param_lines) + func_def = f'def {py_name}(\n{params_block}\n) -> {return_type}: ...' + else: + func_def = f'def {py_name}() -> {return_type}: ...' + + return func_def + + +def generate_pyi_file_content(enhanced_results, module_name: str = 'my_module'): + function_decls = [] + has_optional = False + has_torch = False + has_numpy = False + + for item in enhanced_results: + for stmt in item['m_def_statements']: + try: + decl = generate_pyi_function(stmt) + function_decls.append(decl) + + if 'Optional[' in decl: + has_optional = True + if 'torch.Tensor' in decl: + has_torch = True + if 'numpy.ndarray' in decl or 'py::array' in str(stmt): + has_numpy = True + except Exception as e: + func_name = stmt['parsed'].get('python_function_name', 'unknown') + function_decls.append(f'# ERROR: failed to generate stub for {func_name}: {e}') + + imports = ['from typing import Any'] + if has_optional: + imports[0] += ', Optional' + + if has_torch: + imports.append('import torch') + if has_numpy: + imports.append('import numpy') + + lines = [f'# Stubs for module: {module_name}', ''] + lines.extend(imports) + lines.append('') + lines.append('') + + for decl in function_decls: + lines.append(decl) + lines.append('') + lines.append('') + + return '\n'.join(lines) + + +def generate_pyi_file(name, root, output_dir='.'): + func_index = build_cpp_function_index(root) + results = extract_m_def_statements(root) + + cpp_results = [] + for item in results: + enhanced_item = parse_mdef_and_attach_cpp_signatures(item, func_index) + cpp_item = extract_cpp_signature_details(enhanced_item) + cpp_results.append(cpp_item) + + pyi_content = generate_pyi_file_content(cpp_results, module_name=name) + + output_path = Path(output_dir) / f'{name}.pyi' + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(pyi_content) + + print(f'.pyi file generated: {output_path}') diff --git a/deep-gemm/scripts/readme_example.py b/deep-gemm/scripts/readme_example.py new file mode 100644 index 00000000..c3915b03 --- /dev/null +++ b/deep-gemm/scripts/readme_example.py @@ -0,0 +1,49 @@ +# /// script +# dependencies = [ +# "numpy", +# "torch", +# "kernels" +# ] +# /// + + +# CUDA_HOME=/usr/local/cuda-12.9 uv run scripts/readme_example.py +import torch +from kernels import get_local_kernel, get_kernel +from pathlib import Path + +# deep_gemm = get_local_kernel(Path("build"), "deep_gemm") +deep_gemm = get_kernel("drbh/deep-gemm", version=1) + +m, n, k = 256, 1024, 512 +device = "cuda" + +a = torch.randn((m, k), device=device, dtype=torch.bfloat16) +b = torch.randn((n, k), device=device, dtype=torch.bfloat16) +ref = a @ b.T + + +def compare(name, result, ref): + cos = torch.nn.functional.cosine_similarity( + result.float().flatten(), ref.float().flatten(), dim=0 + ) + diff = (result.float() - ref.float()).abs().max().item() + print(f"[{name}] shape: {m}x{n}x{k}, cosine_sim: {cos.item():.6f}, max_diff: {diff:.4f}") + + +# --- cuBLASLt GEMM (works on any GPU) --- +d = torch.empty((m, n), device=device, dtype=torch.bfloat16) +deep_gemm.cublaslt_gemm_nt(a, b, d) +compare("cuBLASLt BF16", d, ref) + +# --- FP8 GEMM (requires SM90+ / Hopper+) --- +arch = torch.cuda.get_device_capability()[0] +if arch >= 9: + # SFA: per-row (1, 128), SFB: per-block (128, 128) — SM90 recipe + a_fp8 = deep_gemm.utils.per_token_cast_to_fp8(a, use_ue8m0=False) + b_fp8 = deep_gemm.utils.per_block_cast_to_fp8(b, use_ue8m0=False) + d_fp8 = torch.empty((m, n), device=device, dtype=torch.bfloat16) + deep_gemm.fp8_gemm_nt(a_fp8, b_fp8, d_fp8) + compare("FP8 1D2D", d_fp8, ref) +else: + print(f"[FP8 GEMM] Skipped: requires SM90+ (Hopper), detected SM{arch}x") diff --git a/deep-gemm/setup.py b/deep-gemm/setup.py new file mode 100644 index 00000000..6199d7c3 --- /dev/null +++ b/deep-gemm/setup.py @@ -0,0 +1,213 @@ +import ast +import os +import re +import shutil +import setuptools +import subprocess +import sys +import torch +import platform +import urllib +import urllib.error +import urllib.request +from setuptools import find_packages +from setuptools.command.build_py import build_py +from packaging.version import parse +from pathlib import Path +from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from scripts.generate_pyi import generate_pyi_file + + +DG_SKIP_CUDA_BUILD = int(os.getenv('DG_SKIP_CUDA_BUILD', '0')) == 1 +DG_FORCE_BUILD = int(os.getenv('DG_FORCE_BUILD', '0')) == 1 +DG_USE_LOCAL_VERSION = int(os.getenv('DG_USE_LOCAL_VERSION', '1')) == 1 +DG_JIT_USE_RUNTIME_API = int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')) == 1 + +# Compiler flags +cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations', + f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}'] +if DG_JIT_USE_RUNTIME_API: + cxx_flags.append('-DDG_JIT_USE_RUNTIME_API') + +# Sources +current_dir = os.path.dirname(os.path.realpath(__file__)) +sources = ['csrc/python_api.cpp'] +build_include_dirs = [ + f'{CUDA_HOME}/include', + f'{CUDA_HOME}/include/cccl', + 'deep_gemm/include', + 'third-party/cutlass/include', + 'third-party/fmt/include', +] +build_libraries = ['cudart', 'nvrtc'] +build_library_dirs = [f'{CUDA_HOME}/lib64'] +third_party_include_dirs = [ + 'third-party/cutlass/include/cute', + 'third-party/cutlass/include/cutlass', +] + +# Release +base_wheel_url = 'https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}' + + +def get_package_version(): + with open(Path(current_dir) / 'deep_gemm' / '__init__.py', 'r') as f: + version_match = re.search(r'^__version__\s*=\s*(.*)$', f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + + revision = '' + if DG_USE_LOCAL_VERSION: + # noinspection PyBroadException + try: + status_cmd = ['git', 'status', '--porcelain'] + status_output = subprocess.check_output(status_cmd).decode('ascii').strip() + if status_output: + print(f'Warning: Git working directory is not clean. Uncommitted changes:\n{status_output}') + assert False, 'Git working directory is not clean' + + cmd = ['git', 'rev-parse', '--short', 'HEAD'] + revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + revision = '+local' + return f'{public_version}{revision}' + + +def get_platform(): + if sys.platform.startswith('linux'): + return f'linux_{platform.uname().machine}' + else: + raise ValueError('Unsupported platform: {}'.format(sys.platform)) + + +def get_wheel_url(): + torch_version = parse(torch.__version__) + torch_version = f'{torch_version.major}.{torch_version.minor}' + python_version = f'cp{sys.version_info.major}{sys.version_info.minor}' + platform_name = get_platform() + deep_gemm_version = get_package_version() + cxx11_abi = int(torch._C._GLIBCXX_USE_CXX11_ABI) + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + cuda_version = parse(torch.version.cuda) + cuda_version = f'{cuda_version.major}' + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f'deep_gemm-{deep_gemm_version}+cu{cuda_version}-torch{torch_version}-cxx11abi{cxx11_abi}-{python_version}-{platform_name}.whl' + wheel_url = base_wheel_url.format(tag_name=f'v{deep_gemm_version}', wheel_name=wheel_filename) + return wheel_url, wheel_filename + + +def get_ext_modules(): + if DG_SKIP_CUDA_BUILD: + return [] + + return [CUDAExtension(name='deep_gemm._C', + sources=sources, + include_dirs=build_include_dirs, + libraries=build_libraries, + library_dirs=build_library_dirs, + extra_compile_args=cxx_flags)] + + +class CustomBuildPy(build_py): + def run(self): + # First, prepare the include directories + self.prepare_includes() + + # Second, make clusters' cache setting default into `envs.py` + self.generate_default_envs() + + # Third, generate and copy .pyi file to build root directory + self.generate_pyi_file() + + # Finally, run the regular build + build_py.run(self) + + def generate_pyi_file(self): + generate_pyi_file(name='_C', root='./csrc', output_dir='./stubs') + pyi_source = os.path.join(current_dir, 'stubs', '_C.pyi') + pyi_target = os.path.join(self.build_lib, 'deep_gemm', '_C.pyi') + + if os.path.exists(pyi_source): + print(f"Copying .pyi file from {pyi_source} to {pyi_target}") + os.makedirs(os.path.dirname(pyi_target), exist_ok=True) + shutil.copy2(pyi_source, pyi_target) + else: + print(f"Warning: .pyi file not found at {pyi_source}") + + def generate_default_envs(self): + code = '# Pre-installed environment variables\n' + code += 'persistent_envs = dict()\n' + for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_CPP_STANDARD'): + code += f"persistent_envs['{name}'] = '{os.environ[name]}'\n" if name in os.environ else '' + + with open(os.path.join(self.build_lib, 'deep_gemm', 'envs.py'), 'w') as f: + f.write(code) + + def prepare_includes(self): + # Create temporary build directory instead of modifying package directory + build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include') + os.makedirs(build_include_dir, exist_ok=True) + + # Copy third-party includes to the build directory + for d in third_party_include_dirs: + dirname = d.split('/')[-1] + src_dir = os.path.join(current_dir, d) + dst_dir = os.path.join(build_include_dir, dirname) + + # Remove existing directory if it exists + if os.path.exists(dst_dir): + shutil.rmtree(dst_dir) + + # Copy the directory + shutil.copytree(src_dir, dst_dir) + + +class CachedWheelsCommand(_bdist_wheel): + def run(self): + if DG_FORCE_BUILD or DG_USE_LOCAL_VERSION: + return super().run() + + wheel_url, wheel_filename = get_wheel_url() + print(f'Try to download wheel from URL: {wheel_url}') + try: + with urllib.request.urlopen(wheel_url, timeout=1) as response: + with open(wheel_filename, 'wb') as out_file: + data = response.read() + out_file.write(data) + + # Make the archive + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f'{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}' + wheel_path = os.path.join(self.dist_dir, archive_basename + '.whl') + os.rename(wheel_filename, wheel_path) + except (urllib.error.HTTPError, urllib.error.URLError): + print('Precompiled wheel not found. Building from source...') + # If the wheel could not be downloaded, build from source + super().run() + + +if __name__ == '__main__': + # noinspection PyTypeChecker + setuptools.setup( + name='deep_gemm', + version=get_package_version(), + packages=find_packages('.'), + package_data={ + 'deep_gemm': [ + 'include/deep_gemm/**/*', + 'include/cute/**/*', + 'include/cutlass/**/*', + ] + }, + ext_modules=get_ext_modules(), + zip_safe=False, + cmdclass={ + 'build_py': CustomBuildPy, + 'bdist_wheel': CachedWheelsCommand, + }, + ) diff --git a/deep-gemm/tests/generators.py b/deep-gemm/tests/generators.py new file mode 100644 index 00000000..ee22e515 --- /dev/null +++ b/deep-gemm/tests/generators.py @@ -0,0 +1,397 @@ +import enum +import random +import torch +from typing import Generator, List, Optional, Tuple + +from deep_gemm.testing import get_arch_major +from deep_gemm.utils import ( + align, ceil_div, + per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8, + per_token_cast_to_fp4, transpose_packed_fp4, + get_mk_alignment_for_contiguous_layout +) + + +class KernelType(enum.Enum): + Kernel1D1D = 0 + Kernel1D2D = 1 + KernelNoSF = 2 + + def is_1d1d(self): + return self.value == 0 + + def is_1d2d(self): + return self.value == 1 + + def is_nosf(self): + return self.value == 2 + + +class MajorTypeAB(enum.Enum): + KMajor = 0 + MNMajor = 1 + + def is_k_major(self): + return self.value == 0 + + def is_mn_major(self): + return self.value == 1 + + +class QuantConfig: + _legacy_quant_config = (128, 128, False, False) + + def __init__(self, value: Tuple[int, int, bool, bool] = _legacy_quant_config): + self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b = value + + def print(self): + print(f' > Testing with gran_k_a={self.gran_k_a}, gran_k_b={self.gran_k_b}, ' + f'is_fp4_a={self.is_fp4_a}, is_fp4_b={self.is_fp4_b}') + + def is_legacy(self) -> bool: + return (self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b) == self._legacy_quant_config + + def get_recipes(self, is_wgrad: bool = False) -> Tuple[Tuple, Tuple, Tuple]: + recipe, recipe_a, recipe_b = None, None, None + if self.is_legacy(): + recipe = (1, 1, 128) if is_wgrad else None + else: + recipe_a = (1, self.gran_k_a) + recipe_b = (1, self.gran_k_b) if self.is_fp4_b or is_wgrad else (self.gran_k_b, self.gran_k_b) + return recipe, recipe_a, recipe_b + + def max_diff(self) -> float: + if self.is_fp4_a and self.is_fp4_b: + return 0.02 + if self.is_fp4_a or self.is_fp4_b: + return 0.01 + return 0.001 + + @staticmethod + def get_list_from_dtype(dtype: torch.dtype) -> List: + if dtype == torch.bfloat16: + return [None] + quant_config_list = [QuantConfig()] + if get_arch_major() == 10: + quant_config_list.append(QuantConfig((128, 32, False, True))) + return quant_config_list + + +def reset_seed(seed: int = 0): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def get_ue8m0_usage(kernel_type: KernelType) -> bool: + if get_arch_major() == 9: + return False + return kernel_type.is_1d1d() + + +def get_kernel_types(dtype: torch.dtype) -> tuple: + if dtype == torch.bfloat16: + return (KernelType.KernelNoSF, ) + + return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, ) + + +def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator: + for major_a in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor): + for major_b in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor): + if major_a.is_mn_major() and not allow_a_mn_major: + continue + if major_b.is_mn_major() and not allow_b_mn_major: + continue + yield major_a, major_b + + +def get_psum_layout_usage() -> tuple: + return (False, True) if get_arch_major() == 10 else (False, ) + + +def enumerate_normal(dtype: torch.dtype) -> Generator: + assert dtype in (torch.float8_e4m3fn, torch.bfloat16) + + quant_config_list = QuantConfig.get_list_from_dtype(dtype) + fp32_output_nk = [(256, 7168), (129280, 7168)] + bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)] + m_fwd_list, m_bwd_list = [1, 128, 4096], [4096, ] + nk_list = list(bf16_output_nk) + + # Only BF16 GEMM needs FP32 outputs + if dtype == torch.bfloat16: + nk_list += fp32_output_nk + + for kernel_type in get_kernel_types(dtype): + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + reset_seed() + + # Forward + for m in m_fwd_list: + for i in range(len(nk_list)): + n, k = nk_list[i] + out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float + yield kernel_type, quant_config, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype + + # Backward + for m in m_bwd_list: + for n, k in nk_list: + override_major = MajorTypeAB.MNMajor + override_kernel_type = kernel_type + if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: + override_major = MajorTypeAB.KMajor + override_kernel_type = KernelType.Kernel1D1D + yield kernel_type, quant_config, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad + yield override_kernel_type, quant_config, n, m, k, override_major, override_major, True, torch.float # Wgrad + yield override_kernel_type, quant_config, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad + + +def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: + quant_config_list = QuantConfig.get_list_from_dtype(dtype) + m_group_list = [(4, 8192), (8, 4096)] + n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)] + for kernel_type in get_kernel_types(dtype): + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + for use_psum_layout in get_psum_layout_usage(): + reset_seed() + for num_groups, expected_m_per_group in m_group_list: + for n, k in n_k_list: + for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn): + yield kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout + + +def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator: + quant_config_list = QuantConfig.get_list_from_dtype(dtype) + max_m = 4096 + m_group_list = [(6, 1024), (32, 192), (32, 50)] + n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)] + for kernel_type in get_kernel_types(dtype): + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + for use_psum_layout in get_psum_layout_usage(): + reset_seed() + for num_groups, m in m_group_list: + for n, k in n_k_list: + yield kernel_type, quant_config, num_groups, max_m, m, n, k, use_psum_layout + + +def enumerate_k_grouped_contiguous(dtype: torch.dtype): + # Only K-major is supported for SM90 FP8 + major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 and dtype == torch.float8_e4m3fn \ + else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + # Must with FP32 accumulation and 1D1D kernels + for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64 + ( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32 + (16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16 + ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] + yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group + + +def enumerate_sf_layout(): + for use_ue8m0 in (False, True): + for with_transpose in (True, False): + for mn in (4096, 4097, 8192): + for k in (128, 7168, 7296): + for num_groups in (1, 2, 4): + yield mn, k, with_transpose, use_ue8m0, num_groups + + +def enumerate_k_grouped_sf_layout(): + alignment = get_mk_alignment_for_contiguous_layout() + assert alignment % 128 == 0 + for mn in (4096, 7168): + for num_groups, avg_k in ((16, 2048), (8, 4096), (72, 384), (128, 256)): + ks = [align(int(random.uniform(0.7, 1.3) * avg_k), alignment) for _ in range(num_groups)] + yield mn, ks, num_groups + + +def enumerate_transpose(): + for mn in (64, 4096, 16384): + for delta in (0, 101, 202, 303): + for k in (128, 1024, 4096, 9984, 16384): + yield mn + delta, k + + +def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, + use_ue8m0: bool, use_block_cast_for_fp8: bool = False): + if is_fp4: + x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k) + x = x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1]) + else: + x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ + else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) + x = x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1]) + return x + + +def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, + use_ue8m0: bool, use_block_cast_for_fp8: bool = False): + num_groups, mn, k = x.size() + if is_fp4: + x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.uint8) if major.is_k_major() else \ + torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.uint8), + torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) + x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1]) + x = x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1]) + else: + x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8 \ + else torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ + else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) + x = x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1]) + return x + + +def generate_normal(m: int, n: int, k: int, + major_a: MajorTypeAB, major_b: MajorTypeAB, + accumulate: bool, out_dtype: torch.dtype, + kernel_type: KernelType, + use_ue8m0: bool = False, use_bf16: bool = False, + quant_config: Optional[QuantConfig] = None): + a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ + torch.empty((m, n), device='cuda', dtype=out_dtype) + c = d if accumulate else None + ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype) + + if use_bf16: + a = a if major_a.is_k_major() else a.T.contiguous().T + b = b if major_b.is_k_major() else b.T.contiguous().T + return a, b, c, d, ref_d + + quant_config = QuantConfig() if quant_config is None else quant_config + a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, + use_block_cast_for_fp8=not (kernel_type.is_1d1d() and accumulate)) + + return a, b, c, d, ref_d + + +def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, + major_a: MajorTypeAB, major_b: MajorTypeAB, + use_ue8m0: bool = False, use_bf16: bool = False, + use_psum_layout: bool = False, + quant_config: Optional[QuantConfig] = None): + actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] + m = sum(aligned_ms) + + a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + grouped_layout = torch.empty(num_groups, device='cuda', dtype=torch.int32) if use_psum_layout \ + else torch.empty(m, device='cuda', dtype=torch.int32) + d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) + + start = 0 + for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)): + actual_end = start + actual_m + aligned_end = start + aligned_m + if use_psum_layout: + grouped_layout[i] = actual_end + else: + grouped_layout[start: actual_end] = i + grouped_layout[actual_end: aligned_end] = -1 + a[actual_end: aligned_end] = 0 + ref_d[start: aligned_end] = a[start: aligned_end] @ b[i].t() + start = aligned_end + + if use_bf16: + b = b if major_b.is_k_major() else b.mT.contiguous().mT + return m, a, b, grouped_layout, d, ref_d + + assert major_a.is_k_major() + quant_config = QuantConfig() if quant_config is None else quant_config + a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = grouped_cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True) + + return m, a, b, grouped_layout, d, ref_d + + +def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor): + num_groups, max_m, _ = x.size() + x_psum = torch.empty_like(x).view(num_groups * max_m, -1) + last_psum_m = 0 + for i in range(num_groups): + x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m] + last_psum_m = align(psum_m[i], 128) + return x_psum + + +def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, + use_ue8m0: bool = False, use_bf16: bool = False, + use_psum_layout: bool = False, + quant_config: Optional[QuantConfig] = None): + a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.einsum('gmk,gnk->gmn', a, b) + + masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + for j in range(num_groups): + masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], 128)) + masked_m[j] + assert masked_m.amax().item() <= max_m + + if use_bf16: + return a, b, masked_m, psum_m, d, ref_d + + quant_config = QuantConfig() if quant_config is None else quant_config + a = grouped_cast_fp8_fp4_with_major(a, MajorTypeAB.KMajor, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = grouped_cast_fp8_fp4_with_major(b, MajorTypeAB.KMajor, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True) + + return a, b, masked_m, psum_m, d, ref_d + + +def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], + use_ue8m0: bool = False, use_bf16: bool = False): + assert get_mk_alignment_for_contiguous_layout() % 128 == 0 + k = sum(ks) + + a = torch.randn((k, m), device='cuda', dtype=torch.bfloat16) + b = torch.randn((k, n), device='cuda', dtype=torch.bfloat16) + c = torch.randn((num_groups, m, n), device='cuda', dtype=torch.float) * 32 + d = c + ref_d = torch.empty_like(c) + + start = 0 + for i, group_k in enumerate(ks): + end = start + group_k + ref_d[i] = c[i] + (a[start:end].T @ b[start:end]) + start = end + + if use_bf16: + assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + return k, a, b, c, d, ref_d + + a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0) + b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) + + # Transpose for K Major A/B + if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor): + a, sfa = a_fp8 + b, sfb = b_fp8 + new_a = torch.empty((sum(ks) * m, ), dtype=a.dtype, device=a.device) + new_b = torch.empty((sum(ks) * n, ), dtype=b.dtype, device=b.device) + prefix = 0 + for K in ks: + new_a[prefix * m : (prefix + K) * m] = a[prefix : prefix + K, ].T.flatten() + new_b[prefix * n : (prefix + K) * n] = b[prefix : prefix + K, ].T.flatten() + prefix += K + a_fp8, b_fp8 = (new_a, sfa.T), (new_b, sfb.T) + else: + assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + + return k, a_fp8, b_fp8, c, d, ref_d diff --git a/deep-gemm/tests/test_attention.py b/deep-gemm/tests/test_attention.py new file mode 100644 index 00000000..b26cf673 --- /dev/null +++ b/deep-gemm/tests/test_attention.py @@ -0,0 +1,285 @@ +import dataclasses +import random +import torch +from typing import Tuple, List + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes, + ignore_env, get_arch_major, + test_filter +) +from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8 + +from generators import generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB + + +def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]): + left, mid, right = head_splits + m, n = d.shape + assert n % (left + right) == 0 + num_heads = n // (left + right) + + # Split and insert padding tensor + d = d.view(m, num_heads, -1) + d_left = d[:, :, :left] + d_right = d[:, :, -right:] + + d_mid = torch.zeros((m, num_heads, mid), dtype=d.dtype, device=d.device) + return torch.cat([d_left, d_mid, d_right], dim=2).view(m, -1) + + +def test_gemm_skip_head_mid() -> None: + print('Testing GEMM skip head mid:') + head_splits = (128, 64, 128) + + major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor + out_dtype, accumulate = torch.bfloat16, False + + for kernel_type in get_kernel_types(dtype=torch.float8_e4m3fn): + for m in (128, 4096): + for n, k in [(32768, 512), (8192, 512)]: + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + a, b, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) + d = apply_skip_head_mid(d, head_splits) + ref_d = apply_skip_head_mid(ref_d, head_splits) + + deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}' + + t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast), + 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s') + print() + + +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8) + x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(dtype=torch.uint8) + x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(dtype=torch.uint8) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def generate_cp_test_data(seq_len, seq_len_kv): + assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0 + chunk_size = seq_len // 2 + cp_size = seq_len_kv // seq_len + # Select an arbitrary CP rank + cp_id = cp_size // 3 + ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') + ke = torch.zeros(seq_len, dtype=torch.int, device='cuda') + for i in range(chunk_size): + ke[i] = cp_id * chunk_size + i + ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i + return ks, ke + + +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): + seq_len_kv = kv.shape[0] + + if cost_only: + start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum('mhd,nd->hmn', q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float('-inf')) + + cost = mask.sum() + return logits, cost + + +@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 10) +def test_mqa_logits(): + print('Testing FP8 MQA Logits:') + num_heads, head_dim = 64, 128 + for seq_len in (2048, 4096): + for compressed_logits in (False, True): + for seq_len_kv in (4096, 8192): + for disable_cp in (False, True): + q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32) + + if disable_cp: + ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') + ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len) + else: + ks, ke = generate_cp_test_data(seq_len, seq_len_kv) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False) + + if compressed_logits: + max_seqlen_k = (ke - ks).max().item() + logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False) + assert logits.size() == (seq_len, max_seqlen_k) + tmp = torch.full((seq_len, seq_len_kv), float('-inf'), device='cuda') + for i in range(seq_len): + tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]] + logits = tmp + else: + logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) + + do_check = (seq_len_kv < 32768) + if do_check: + ref_logits, ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + ref_neginf_mask = (ref_logits == float('-inf')) + neginf_mask = (logits == float('-inf')) + assert torch.equal(neginf_mask, ref_neginf_mask) + + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + logits = logits.masked_fill(neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f'{diff=}' + else: + ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True) + + tflops = 2 * ref_cost * num_heads * head_dim / 1e12 + if compressed_logits: + t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False), 'fp8_mqa_logits') + else: + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke), ('fp8_mqa_logits', 'clean_logits')) + clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke) + print(f' > S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: ' + f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, ' + f'{(count_bytes(q_fp8, kv_fp8, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s', end='') + # noinspection PyUnboundLocalVariable + print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if not compressed_logits else '') + print() + + +def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, + weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, + max_model_len: int, is_context_lens_2d: bool): + batch_size, next_n, heads, dim = q.size() + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.full((next_n, ), context_len, device='cuda', dtype=torch.int32) if is_context_lens_2d \ + else torch.arange(context_len - next_n, context_len, device='cuda') + weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() + + num_blocks = (context_len + block_size - 1) // block_size + block_idxs = block_tables[i][:num_blocks] + kv_slice = kv_cache[block_idxs] # [num_blocks, block_size, kv_heads, dim] + kx = kv_slice.permute(2, 3, 0, 1).reshape(kv_slice.size(2), dim, -1) # [kv_heads, dim, total_tokens] + qx = q[i].transpose(0, 1) # q[i]: [next_n, heads, dim] -> [heads, next_n, dim] + s = torch.matmul(qx, kx).to(logits.dtype) # [heads, next_n, dim] @ [1, dim, total_tokens] -> [heads, next_n, total_tokens] + + total_len = num_blocks * block_size + k_offsets = torch.arange(0, total_len, device=q.device) + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) + s = torch.where(mask[None, :, :], s, float('-inf')) # mask shape: [1, next_n, total_tokens] + s = torch.relu(s) * weight_slice[..., None] # weight_slice: [heads, next_n] -> [heads, next_n, 1] + s = s.sum(dim=0) # [next_n, total_tokens] + logits[i * next_n:(i + 1) * next_n, :total_len] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) + + return logits + + +def test_paged_mqa_logits(): + print('Testing FP8 Paged MQA Logits:') + max_model_len = 111 * 1000 + for is_context_lens_2d in (False, True): + for batch_size, next_n in [(64, 1), (64, 2), (128, 1)]: + for heads, index_dim in [(64, 128)]: + for avg_kv in (8192, 32768): + num_blocks, blocksize = max_model_len * 3, 64 + + q = torch.randn((batch_size, next_n, heads, index_dim), device='cuda', dtype=torch.bfloat16) + kv_cache = torch.randn((num_blocks, blocksize, 1, index_dim), device='cuda', dtype=torch.bfloat16) + weights = torch.randn((batch_size * next_n, heads), device='cuda', dtype=torch.float32) + q_fp8 = q.to(torch.float8_e4m3fn) + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size, )).cuda().to(torch.int32) + context_lens_list = context_lens.tolist() + max_block_len = (max(context_lens_list) + blocksize - 1) // blocksize * blocksize + block_tables = torch.zeros((batch_size, max_block_len), device='cuda', dtype=torch.int32) + + counter, block_idx_pool = 0, torch.randperm(num_blocks, device='cuda', dtype=torch.int32) + for i in range(batch_size): + num_blocks = ceil_div(context_lens_list[i], blocksize) + block_tables[i][:num_blocks] = block_idx_pool[counter: counter+num_blocks] + counter += num_blocks + + ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len, is_context_lens_2d) + positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1) + + if is_context_lens_2d: + context_lens_2d = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int() + context_lens_2d[:, next_n-1] = context_lens + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens_2d, blocksize, deep_gemm.get_num_sms()) + logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens_2d, block_tables, schedule_metadata, max_model_len, clean_logits=False) + ref_neginf_mask = ~(positions < context_lens_2d.view(-1).unsqueeze(1)) + else: + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms()) + logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True) + row_indices = torch.arange(batch_size * next_n, device='cuda') // next_n + next_n_offset = torch.arange(batch_size * next_n, device='cuda') % next_n + ref_neginf_mask = ~(positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze(1)) + neginf_mask = (logits == float('-inf')) + assert torch.equal(neginf_mask, ref_neginf_mask) + + logits = logits.masked_fill(ref_neginf_mask, 0) + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" + + sum_lens = sum(context_lens.to(torch.int64)) + tflops = 2 * sum_lens * next_n * heads * index_dim / 1e12 + input_bytes = count_bytes(q_fp8, weights, context_lens) + sum_lens * (index_dim + 4) + (sum_lens / blocksize) * 4 + output_bytes = sum_lens * next_n * 4 + if is_context_lens_2d: + t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens_2d, block_tables, schedule_metadata, max_model_len, clean_logits=False), + 'fp8_paged_mqa_logits') + else: + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True), + ('fp8_paged_mqa_logits', 'clean_logits')) + clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens) + print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: ' + f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, ' + f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s', end='') + # noinspection PyUnboundLocalVariable + print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if not is_context_lens_2d else '') + print() + + + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + test_gemm_skip_head_mid() + + test_mqa_logits() + test_paged_mqa_logits() diff --git a/deep-gemm/tests/test_bf16.py b/deep-gemm/tests/test_bf16.py new file mode 100644 index 00000000..1a3b0467 --- /dev/null +++ b/deep-gemm/tests/test_bf16.py @@ -0,0 +1,204 @@ +import copy +import numpy as np +import random +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes +) +from generators import ( + get_arch_major, layout_masked_to_psum, align, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous +) + + +def test_gemm() -> None: + print('Testing GEMM:') + scores = [] + for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else a.T + b = b if major_b.is_k_major() else b.T + assert a.is_contiguous() and b.is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c) + diff = calc_diff(d, ref_d) + assert diff < 1e-5, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + + t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True) + cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:7.1f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' + f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') + if cublas_t > 0: + scores.append((cublas_t + split_k_t) / t) + print(f"Average speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n") + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + + for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + for test_alias in (False, True): + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_bf16=True, use_psum_layout=use_psum_layout) + func_name = f"m_grouped_bf16_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else b.mT + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, grouped_layout, use_psum_layout=use_psum_layout) + diff = calc_diff(d, ref_d) + assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_bf16=True, use_psum_layout=use_psum_layout) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, use_psum_layout=use_psum_layout) + + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, psum={use_psum_layout}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for _, _, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.bfloat16): + num_tests = 8 + sum_t, max_t = 0, 0 + sum_ops, sum_bytes = 0, 0 + + for i in range(num_tests): + a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_bf16=True, use_psum_layout=use_psum_layout) + if use_psum_layout: + a_psum = layout_masked_to_psum(a, psum_m) + d_psum = layout_masked_to_psum(d, psum_m) + + # noinspection PyShadowingNames + def test_func(): + if use_psum_layout: + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, + use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group) + else: + deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + + test_func() + for j in range(num_groups): + if masked_m[j].item() == 0: + continue + if use_psum_layout: + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]] + else: + d_slice = d[j, :masked_m[j].item()] + diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) + assert diff < 1e-5, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + + sum_t += t + max_t = max(max_t, t) + sum_ops += 2 * valid_m * n * k + sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b) + + print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, ' + f'psum={1 if use_psum_layout else 0}): ' + f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | ' + f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | ' + f'{sum_bytes / sum_t / 1e9:4.0f} GB/s') + print() + + +def test_k_grouped_gemm_contiguous() -> None: + print('Testing k-grouped contiguous GEMM:') + + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.bfloat16): + for test_empty_groups in (False, True): + new_ks = copy.deepcopy(ks) + if test_empty_groups and len(ks) > 1: + new_ks[random.randint(0, num_groups - 1)] = 0 + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_bf16=True) + new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') + deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c) + + diff = calc_diff(d, ref_d) + assert diff < 1e-5, f'{m=}, {n=}, {k=}, {ks=}, {diff:.7f}' + + # Test performance + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_bf16=True) + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c) + + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_cublaslt_gemm() -> None: + print('Testing cuBLASLt GEMM:') + for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + deep_gemm.cublaslt_gemm_nt(a, b, d, c=c) + diff = calc_diff(d, ref_d) + assert diff < 6e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})' + + t_nvjet, t_gemv, t_gemm = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'gemv', 'gemm'), suppress_kineto_output=True) + t = t_nvjet + t_gemv + t_gemm + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:5.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + if get_arch_major() >= 9: + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() + test_k_grouped_gemm_contiguous() + + test_cublaslt_gemm() diff --git a/deep-gemm/tests/test_cublaslt.py b/deep-gemm/tests/test_cublaslt.py new file mode 100644 index 00000000..afe8a175 --- /dev/null +++ b/deep-gemm/tests/test_cublaslt.py @@ -0,0 +1,20 @@ +import pytest +import torch + +import deep_gemm + + +@pytest.mark.kernels_ci +def test_cublaslt_gemm_nt(): + m, n, k = 256, 1024, 512 + a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + b = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) + d = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) + + deep_gemm.cublaslt_gemm_nt(a, b, d) + + ref = a @ b.T + cos = torch.nn.functional.cosine_similarity( + d.float().flatten(), ref.float().flatten(), dim=0 + ) + assert cos.item() > 0.99, f"cosine similarity too low: {cos.item()}" diff --git a/deep-gemm/tests/test_einsum.py b/deep-gemm/tests/test_einsum.py new file mode 100644 index 00000000..b7979989 --- /dev/null +++ b/deep-gemm/tests/test_einsum.py @@ -0,0 +1,181 @@ +import random +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes, + get_arch_major, test_filter +) +from deep_gemm.utils.math import ( + ceil_div, + per_block_cast_to_fp8, per_channel_cast_to_fp8, per_token_cast_to_fp8 +) + + +def test_bmk_bnk_mn() -> None: + print('Testing "bmk, bnk -> mn":') + for s in (129, 4096, 8192): + for m, n, k in [(128, 384, 128), (256, 256, 256), (384, 128, 384)]: + for dtype in (torch.float, torch.bfloat16): + a = torch.randn((s, m, k), dtype=torch.bfloat16, device='cuda') + b = torch.randn((s, n, k), dtype=torch.bfloat16, device='cuda') + d = torch.randn((m, n), dtype=dtype, device='cuda') + c = d if dtype == torch.float else None + + # Test correctness + ref_d = (c if dtype == torch.float else 0) + torch.bmm(a.float(), b.float().mT).sum(0) + deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c) + assert calc_diff(d, ref_d) < 1e-5 + + t = bench_kineto(lambda: deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c), 'bmn_bnk_mn_gemm_impl', suppress_kineto_output=True) + print(f' > Perf (b={s:4.0f}, {m=}, {n=}, {k=}, {"FP32" if dtype == torch.float else "BF16"}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * s * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b) + (d.numel() * 4)) / 1e9 / t:4.0f} GB/s') + print() + + +def test_bhr_hdr_bhd(): + print('Testing "bhr, hdr -> bhd":') + for h, r, d in [(128, 512, 128), (8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) + fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16) + y = fy[:, :, :r] + ref_z = torch.einsum('bhr,hdr->bhd', x, y) + z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16) + deep_gemm.einsum('bhr,hdr->bhd', x, y, z) + assert calc_diff(z, ref_z) < 1e-10 + + t = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z), 'gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x, y, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +def test_bhd_hdr_bhr(): + print('Testing "bhd, hdr -> bhr":') + for h, r, d in [(128, 512, 128), (8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) + fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16) + y = fy[:, :, :r] + ref_z = torch.einsum('bhd,hdr->bhr', x, y) + z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16) + deep_gemm.einsum('bhd,hdr->bhr', x, y, z) + assert calc_diff(z, ref_z) < 1e-10 + + t = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z), 'gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x, y, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +def test_fp8_bhr_hdr_bhd(use_ue8m0: bool = True): + print('Testing FP8 "bhr, hdr -> bhd":') + for h, r, d in [(8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) + y = torch.randn((h, d, r), device='cuda', dtype=torch.bfloat16) + ref_z = torch.einsum('bhr,hdr->bhd', x, y) + + x_fp8 = per_token_cast_to_fp8(x.view(-1, r), use_ue8m0=use_ue8m0) + x_fp8 = x_fp8[0].view(b, h, r), x_fp8[1].view(b, h, ceil_div(r, 128)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty((h, ceil_div(d, 128), ceil_div(r, 128)), device='cuda', dtype=torch.float)) + for i in range(h): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], use_ue8m0=use_ue8m0) + z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16) + + deep_gemm.fp8_einsum('bhr,hdr->bhd', x_fp8, y_fp8, z) + assert calc_diff(z, ref_z) < 1e-3 + + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhr,hdr->bhd', x_fp8, y_fp8, z), 'fp8_gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x_fp8, y_fp8, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +@test_filter(lambda: get_arch_major() >= 10) +def test_fp8_bhd_hdr_bhr(use_ue8m0: bool = True): + print('Testing FP8 "bhd, hdr -> bhr":') + for h, r, d in [(8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) + y = torch.randn((h, d, r), device='cuda', dtype=torch.bfloat16) + ref_z = torch.einsum('bhd,hdr->bhr', x, y) + + x_fp8 = per_token_cast_to_fp8(x.view(-1, d), use_ue8m0=use_ue8m0) + x_fp8 = x_fp8[0].view(b, h, d), x_fp8[1].view(b, h, ceil_div(d, 128)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty((h, ceil_div(d, 128), ceil_div(r, 128)), device='cuda', dtype=torch.float)) + for i in range(h): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], use_ue8m0=use_ue8m0) + z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16) + + deep_gemm.fp8_einsum('bhd,hdr->bhr', x_fp8, y_fp8, z) + assert calc_diff(z, ref_z) < 1e-3 + + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,hdr->bhr', x_fp8, y_fp8, z), 'fp8_gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x_fp8, y_fp8, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +@test_filter(lambda: get_arch_major() >= 10) +def test_fp8_bhd_bhr_hdr(use_ue8m0: bool = True): + print('Testing FP8 "bhd, bhr -> hdr":') + for h, r, d in [(8, 4096, 1024)]: + for b in (4096, 8192): + x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) + y = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) + z_0 = torch.randn((h, d, r), device='cuda', dtype=torch.float32) * 10 + ref_z = z_0 + torch.einsum('bhd,bhr->hdr', x, y) + + x_fp8 = per_channel_cast_to_fp8(x.view(b, -1), use_ue8m0=use_ue8m0) + y_fp8 = per_channel_cast_to_fp8(y.view(b, -1), use_ue8m0=use_ue8m0) + x_fp8 = (x_fp8[0].view(b, h, d), x_fp8[1].view(ceil_div(b, 128), h, d)) + y_fp8 = (y_fp8[0].view(b, h, r), y_fp8[1].view(ceil_div(b, 128), h, r)) + z = z_0.clone() + deep_gemm.fp8_einsum('bhd,bhr->hdr', x_fp8, y_fp8, z, z, recipe=(1, 1, 128)) + assert calc_diff(z, ref_z) < 1e-3 + + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,bhr->hdr', x_fp8, y_fp8, z, z, recipe=(1, 1, 128)), 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x_fp8, y_fp8, z, z)) / t / 1e9:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_bmk_bnk_mn() + test_bhr_hdr_bhd() + test_bhd_hdr_bhr() + + test_fp8_bhr_hdr_bhd() + test_fp8_bhd_hdr_bhr() + test_fp8_bhd_bhr_hdr() diff --git a/deep-gemm/tests/test_fp8_fp4.py b/deep-gemm/tests/test_fp8_fp4.py new file mode 100644 index 00000000..f7e3e1c4 --- /dev/null +++ b/deep-gemm/tests/test_fp8_fp4.py @@ -0,0 +1,207 @@ +import copy +import numpy as np +import random +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes, + ignore_env, get_arch_major +) + +from generators import ( + KernelType, get_ue8m0_usage, layout_masked_to_psum, align, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous +) + + +@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9) +def test_gemm() -> None: + print('Testing GEMM:') + scores = [] + for kernel_type, quant_config, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes(is_wgrad=(kernel_type.is_1d1d() and accumulate)) + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config) + func_name = f'fp8_fp4_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else (a[0].T, a[1].T) + b = b if major_b.is_k_major() else (b[0].T, b[1].T) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + diff = calc_diff(d, ref_d) + assert diff < quant_config.max_diff(), (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config) + t = bench_kineto(lambda: deep_gemm.fp8_fp4_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b), + 'fp8_gemm', suppress_kineto_output=True) + cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) \ + if not quant_config.is_fp4_a and not quant_config.is_fp4_b else (0, 0) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' + f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') + if cublas_t > 0: + scores.append((cublas_t + split_k_t) / t) + print(f"Average FP8xFP8 GEMM speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n") + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + + for kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes() + + for test_alias in (False, True): + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + func_name = f"m_grouped_fp8_fp4_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else (b[0].mT, b[1].mT) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + diff = calc_diff(d, ref_d) + assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}, psum={use_psum_layout}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, quant_config, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.float8_e4m3fn): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes() + + num_tests = 8 + sum_t, max_t = 0, 0 + sum_ops, sum_bytes = 0, 0 + + for i in range(num_tests): + a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + if use_psum_layout: + a_psum = (layout_masked_to_psum(a[0], psum_m), layout_masked_to_psum(a[1], psum_m)) + d_psum = layout_masked_to_psum(d, psum_m) + + # noinspection PyShadowingNames + def test_func(): + if use_psum_layout: + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, disable_ue8m0_cast=disable_ue8m0_cast, + use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + else: + deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + + test_func() + for j in range(num_groups): + if masked_m[j].item() == 0: + continue + if use_psum_layout: + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]] + else: + d_slice = d[j, :masked_m[j].item()] + diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) + assert diff < quant_config.max_diff(), f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + + sum_t += t + max_t = max(max_t, t) + sum_ops += 2 * valid_m * n * k + sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b) + + print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, ' + f'{kernel_opt}, psum={1 if use_psum_layout else 0}): ' + f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | ' + f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | ' + f'{sum_bytes / sum_t / 1e9:4.0f} GB/s') + print() + + +@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9) +def test_k_grouped_gemm_contiguous() -> None: + print('Testing k-grouped contiguous GEMM:') + + k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \ + else deep_gemm.k_grouped_fp8_gemm_tn_contiguous + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.float8_e4m3fn): + use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) + + for test_empty_groups in (False, True): + new_ks = copy.deepcopy(ks) + if test_empty_groups and len(ks) > 1: + new_ks[random.randint(0, num_groups - 1)] = 0 + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0) + new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') + k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c) + + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}' + + # Test performance + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0) + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() + test_k_grouped_gemm_contiguous() diff --git a/deep-gemm/tests/test_hyperconnection.py b/deep-gemm/tests/test_hyperconnection.py new file mode 100644 index 00000000..24faf22c --- /dev/null +++ b/deep-gemm/tests/test_hyperconnection.py @@ -0,0 +1,57 @@ +import torch +import random + +import deep_gemm +from deep_gemm.testing import ( + test_filter, + bench_kineto, + calc_diff, count_bytes +) +from deep_gemm.utils import align +from generators import get_arch_major + + +@test_filter(lambda: get_arch_major() >= 9) +def test_hc_prenorm_gemm() -> None: + # Needs TF32 precision for PyTorch GEMMs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + print('Testing hyperconnection prenorm GEMM:') + for m in (13, 137, 4096, 8192): + for n, k in [(24, 28672), (24, 7680), (24, 7168)]: + for num_splits in [None, 16]: + a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda') + b = torch.randn((n, k), dtype=torch.float, device='cuda') + d = torch.empty((m, n), dtype=torch.float, device='cuda') if num_splits is None else \ + torch.empty((num_splits, m, n), dtype=torch.float, device='cuda') + s = torch.empty((m, ), dtype=torch.float, device='cuda') if num_splits is None else \ + torch.empty((num_splits, m), dtype=torch.float, device='cuda') + deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits) + final_d = d if num_splits is None else d.sum(0) + final_s = s if num_splits is None else s.sum(0) + + ref_d = a.float() @ b.T + ref_s = a.float().square().sum(-1) + + diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s)) + assert diff < 1e-8, f'{m=}, {n=}, {k=}, {diff:.10f}' + + t = bench_kineto(lambda: deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits), 'tf32_hc_prenorm_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, num_splits={(num_splits or 0):2}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d, s) / 1e9 / t:4.0f} GB/s') + print() + + + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_hc_prenorm_gemm() diff --git a/deep-gemm/tests/test_layout.py b/deep-gemm/tests/test_layout.py new file mode 100644 index 00000000..7875733a --- /dev/null +++ b/deep-gemm/tests/test_layout.py @@ -0,0 +1,112 @@ +import torch +import random +from deep_gemm.testing import bench_kineto, count_bytes +from deep_gemm.utils import ( + align, ceil_div, + per_token_cast_to_fp8, per_channel_cast_to_fp8, + get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor +) + +from generators import ( + enumerate_sf_layout, + enumerate_k_grouped_sf_layout +) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.float and x.dim() in (2, 3) + + # First, convert into UE8M0 `uint8_t` + ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) + + # Second, make padded packed tensors + mn, k = x.shape[-2], x.shape[-1] + remove_dim = False + if x.dim() == 2: + x, remove_dim = x.unsqueeze(0), True + b = x.shape[0] + aligned_mn = get_tma_aligned_size(mn, 4) + aligned_k = align(k, 4) + padded = torch.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=torch.uint8) + padded[:, :mn, :k] = ue8m0_tensor + padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, aligned_k // 4) + + # Finally, transpose + transposed = torch.zeros((b, aligned_k // 4, aligned_mn), device=x.device, dtype=torch.int).mT + transposed[:, :, :] = padded + aligned_x = transposed[:, :mn, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def test_sf_layout_kernels() -> None: + print('Testing SF layout kernels:') + for mn, k, with_transpose, use_ue8m0, num_groups in enumerate_sf_layout(): + x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda') + x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0) + fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1) + fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2) + + # Correctness + if use_ue8m0: + impl, name = get_mn_major_tma_aligned_packed_ue8m0_tensor, 'pack_fp32_into_ue8m0' + packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf) + ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf) + assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}' + assert packed_sf.shape == ref_packed_sf.shape + assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())]) + else: + impl, name = get_mn_major_tma_aligned_tensor, 'transpose' + transposed_sf = get_mn_major_tma_aligned_tensor(fp32_sf) + tma_aligned_mn, sf_k = get_tma_aligned_size(mn, fp32_sf.element_size()), ceil_div(k, 128) + if num_groups > 1: + assert transposed_sf.size(0) == num_groups + assert transposed_sf.stride(0) == tma_aligned_mn * sf_k + assert transposed_sf.shape[-2:] == (mn, sf_k) + assert transposed_sf.stride()[-2:] == (1, tma_aligned_mn) + assert torch.equal(fp32_sf, transposed_sf) + + # Performance + try: + t = bench_kineto(lambda: impl(fp32_sf), name) + except AssertionError as e: + # Some cases may fallback to PyTorch impl + t = 0 + print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}, use_ue8m0={int(use_ue8m0)}): ' + f'{t * 1e6:4.0f} us | {count_bytes(fp32_sf, impl(fp32_sf)) / 1e9 / t if t else 0:4.0f} GB/s') + print() + + +def test_k_grouped_sf_layout_kernels() -> None: + print('Testing k-grouped SF layout kernels:') + for mn, ks, num_groups in enumerate_k_grouped_sf_layout(): + sf_ks = [k // 128 for k in ks] + packed_sf_ks = [ceil_div(k, 512) for k in ks] + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + x = torch.randn((sum(ks), mn), dtype=torch.bfloat16, device='cuda') + x, fp32_sf = per_channel_cast_to_fp8(x, use_ue8m0=True) + + # Correctness + packed_sf = get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks) + split_packed_sf = packed_sf.split(packed_sf_ks) + split_fp32_sf = fp32_sf.split(sf_ks) + for i in range(num_groups): + ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(split_fp32_sf[i].T).T + assert torch.equal(split_packed_sf[i], ref_packed_sf), f'{i=}' + + # Performance + t = bench_kineto(lambda: get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks), 'pack_fp32_into_ue8m0') + print(f' > Perf ({num_groups=:3}, {mn=:5}, sum_k={sum(ks):5}):' + f'{t * 1e6:4.0f} us | ' + f'{count_bytes(fp32_sf, packed_sf, ks_tensor) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(1) + random.seed(1) + + test_sf_layout_kernels() + test_k_grouped_sf_layout_kernels() diff --git a/deep-gemm/tests/test_lazy_init.py b/deep-gemm/tests/test_lazy_init.py new file mode 100644 index 00000000..5363b6db --- /dev/null +++ b/deep-gemm/tests/test_lazy_init.py @@ -0,0 +1,15 @@ +import torch +import torch.multiprocessing as mp +import deep_gemm + + +def main(local_rank: int): + torch.cuda.set_device(local_rank) + + +if __name__ == '__main__': + procs = [mp.Process(target=main, args=(i, ), ) for i in range(8)] + for p in procs: + p.start() + for p in procs: + p.join() diff --git a/deep-gemm/tests/test_legacy.py b/deep-gemm/tests/test_legacy.py new file mode 100644 index 00000000..4456799f --- /dev/null +++ b/deep-gemm/tests/test_legacy.py @@ -0,0 +1,90 @@ +import torch +import random + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes +) +from generators import ( + enumerate_m_grouped_contiguous, enumerate_k_grouped_contiguous, + generate_m_grouped_contiguous, generate_k_grouped_contiguous, +) + +def test_m_grouped_gemm_contiguous_tl() -> None: + print('Testing m-grouped contiguous Triton GEMM:') + for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, _ in enumerate_m_grouped_contiguous(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + for expand in (False, True): + for test_alias in (False, True): + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + func_name = f"{'a_fused_' if expand else ''}m_grouped_bf16_gemm_{major_opt.lower() if test_alias else 'nt'}_contiguous_tl" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else b.mT + assert a[0].is_contiguous() and b[0].is_contiguous() + if expand: + m_row_indices = torch.arange(0, m, dtype=torch.int32, device='cuda') + getattr(deep_gemm.legacy, func_name)(a, b, d, (m_indices, m_row_indices)) + else: + getattr(deep_gemm.legacy, func_name)(a, b, d, m_indices) + d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.legacy.m_grouped_bf16_gemm_nt_contiguous_tl(a, b, d, m_indices) + + t = bench_kineto(test_func, 'm_grouped_bf16_gemm_contiguous_tl_impl', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_k_grouped_gemm_contiguous_tl() -> None: + print('Testing k-grouped contiguous Triton GEMM:') + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + for fused_operand in ('a', 'b'): + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=False, use_bf16=True) + func_name = f"{fused_operand}_fused_k_grouped_bf16_gemm_{major_opt.lower()}_contiguous_tl" + k_indices = torch.arange(0, k, dtype=torch.int32, device='cuda') + k_start = torch.empty(len(ks), dtype=torch.int32, device='cuda') + k_end = torch.empty(len(ks), dtype=torch.int32, device='cuda') + for i, group_k in enumerate(ks): + k_start[i] = k_end[i-1] if i > 0 else 0 + k_end[i] = k_start[i] + group_k + getattr(deep_gemm.legacy, func_name)(a, b, c, (k_indices, k_start, k_end), True) + diff = calc_diff(c, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}' + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=False, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.legacy.b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a, b, c, (k_indices, k_start, k_end), True) + + t = bench_kineto(test_func, 'b_fused_k_grouped_bf16_gemm_contiguous_tl_impl', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_m_grouped_gemm_contiguous_tl() + test_k_grouped_gemm_contiguous_tl() diff --git a/deep-gemm/tests/test_sanitizer.py b/deep-gemm/tests/test_sanitizer.py new file mode 100644 index 00000000..b063e6c4 --- /dev/null +++ b/deep-gemm/tests/test_sanitizer.py @@ -0,0 +1,78 @@ +import argparse +import importlib +import inspect +import os +import subprocess +import sys + +import deep_gemm + + +# Single test template +script_dir = os.path.dirname(os.path.abspath(__file__)) +test_template = """ +import random +import sys +import torch + +# Necessary for `generators.py` +sys.path.append('{script_dir}') + +torch.manual_seed(0) +random.seed(0) + +from tests.{module_name} import {func_name} +{func_name}() +""" + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--funcs', type=str, default='all') + parser.add_argument('--tools', type=str, default='memcheck,synccheck') + args = parser.parse_args() + + if args.funcs != 'all': + funcs = [] + for name in [x.strip() for x in args.funcs.split(',')]: + module_name, func_name = name.split('.') + funcs.append((module_name, func_name)) + else: + # Get all test functions except those related to cuBLAS + files = [f for f in os.listdir(script_dir) if f.endswith('.py')] + exclude_files = ['test_sanitizer.py', 'generators.py'] + funcs = [ + (module_name, name) + for module_name in [os.path.splitext(f)[0] for f in files if f not in exclude_files] + for name, obj in inspect.getmembers(importlib.import_module(module_name)) + if inspect.isfunction(obj) and name.startswith('test') and 'test_filter' not in name + ] + tools = [x.strip() for x in args.tools.split(',')] + + env = os.environ.copy() + env['CUDA_LAUNCH_BLOCKING'] = '1' + env['DG_JIT_PTXAS_CHECK'] = '1' + env['DG_USE_NVIDIA_TOOLS'] = '1' + env['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' + env['TORCH_SHOW_CPP_STACKTRACES'] = '1' + + print(f'Library path: {deep_gemm.__path__}') + for module_name, func_name in funcs: + for tool in tools: + cmd = [ + '/usr/local/cuda/bin/compute-sanitizer', + f'--tool={tool}', + '--target-processes=application-only', + '--destroy-on-device-error=context', + '--force-blocking-launches', + '--check-api-memory-access=no', + '--kernel-name-exclude', 'kns=nvjet', + 'python', + '-c', + test_template.format(module_name=module_name, func_name=func_name, script_dir=script_dir) + ] + print(f'\n{"=" * 60}') + print(f'Running {module_name}.{func_name} with compute-sanitizer {tool}') + result = subprocess.run(cmd, env=env) + if result.returncode != 0: + sys.exit(result.returncode) diff --git a/deep-gemm/torch-ext/deep_gemm/__init__.py b/deep-gemm/torch-ext/deep_gemm/__init__.py new file mode 100644 index 00000000..9cd2069a --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/__init__.py @@ -0,0 +1,684 @@ +import os +import subprocess +import torch + +# Import the compiled extension +from ._ops import ops +from . import utils + +__version__ = "2.3.0" + + +# Runtime + + +def set_num_sms(num_sms: int): + ops.set_num_sms(num_sms) + + +def get_num_sms() -> int: + return ops.get_num_sms() + + +def set_tc_util(tc_util: int): + ops.set_tc_util(tc_util) + + +def get_tc_util() -> int: + return ops.get_tc_util() + + +def get_mk_alignment_for_contiguous_layout() -> int: + return ops.get_mk_alignment_for_contiguous_layout() + + +# Layout utilities + + +def get_tma_aligned_size(mn: int, element_size: int) -> int: + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): + ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu") + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks_int + ) + + +def transform_sf_into_required_layout( + sf, + mn, + k, + recipe=None, + recipe_ab=None, + num_groups=None, + is_sfa=False, + disable_ue8m0_cast=False, +): + has_recipe = recipe is not None + r0, r1, r2 = recipe if has_recipe else (0, 0, 0) + has_recipe_ab = recipe_ab is not None + rab0, rab1 = recipe_ab if has_recipe_ab else (0, 0) + has_ng = num_groups is not None + ng = num_groups if has_ng else 0 + return ops.transform_sf_into_required_layout( + sf, + mn, + k, + r0, + r1, + r2, + has_recipe, + rab0, + rab1, + has_recipe_ab, + ng, + has_ng, + is_sfa, + disable_ue8m0_cast, + ) + + +# Aliases for contiguous layout alignment +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout + + +# Helper to flatten recipe args + + +def _flatten_recipe(recipe, recipe_a=None, recipe_b=None): + has_recipe = recipe is not None + r0, r1, r2 = recipe if has_recipe else (0, 0, 0) + has_ra = recipe_a is not None + ra0, ra1 = recipe_a if has_ra else (0, 0) + has_rb = recipe_b is not None + rb0, rb1 = recipe_b if has_rb else (0, 0) + return r0, r1, r2, has_recipe, ra0, ra1, has_ra, rb0, rb1, has_rb + + +# FP8/FP4 GEMM ops + + +def fp8_fp4_gemm_nt( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_nt( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_fp4_gemm_nn( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_nn( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_fp4_gemm_tn( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="mn", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_tn( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_fp4_gemm_tt( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="mn", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_tt( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +# FP8 aliases (same as FP8/FP4) +fp8_gemm_nt = fp8_fp4_gemm_nt +fp8_gemm_nn = fp8_fp4_gemm_nn +fp8_gemm_tn = fp8_fp4_gemm_tn +fp8_gemm_tt = fp8_fp4_gemm_tt + + +# M-grouped FP8/FP4 GEMM ops + + +def m_grouped_fp8_fp4_gemm_nt_contiguous( + a, + b, + d, + grouped_layout, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, + use_psum_layout=False, + expected_m_for_psum_layout=None, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + has_em = expected_m_for_psum_layout is not None + em = expected_m_for_psum_layout if has_em else 0 + ops.m_grouped_fp8_fp4_gemm_nt_contiguous( + a_data, + a_sf, + b_data, + b_sf, + d, + grouped_layout, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + use_psum_layout, + em, + has_em, + ) + + +def m_grouped_fp8_fp4_gemm_nn_contiguous( + a, + b, + d, + grouped_layout, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, + use_psum_layout=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.m_grouped_fp8_fp4_gemm_nn_contiguous( + a_data, + a_sf, + b_data, + b_sf, + d, + grouped_layout, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + use_psum_layout, + ) + + +def m_grouped_fp8_fp4_gemm_nt_masked( + a, + b, + d, + masked_m, + expected_m, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.m_grouped_fp8_fp4_gemm_nt_masked( + a_data, + a_sf, + b_data, + b_sf, + d, + masked_m, + expected_m, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +# M-grouped FP8 aliases +m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous +m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous +m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked + +# Legacy aliases +fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked + + +# K-grouped FP8 GEMM ops + + +def k_grouped_fp8_gemm_tn_contiguous( + a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn" +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2 = recipe + ops.k_grouped_fp8_gemm_tn_contiguous( + a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims + ) + + +def k_grouped_fp8_gemm_nt_contiguous( + a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn" +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2 = recipe + ops.k_grouped_fp8_gemm_nt_contiguous( + a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims + ) + + +# BF16 GEMM ops + + +def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"): + ops.bf16_gemm_nt(a, b, d, c, compiled_dims) + + +def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"): + ops.bf16_gemm_nn(a, b, d, c, compiled_dims) + + +def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"): + ops.bf16_gemm_tn(a, b, d, c, compiled_dims) + + +def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"): + ops.bf16_gemm_tt(a, b, d, c, compiled_dims) + + +# M-grouped BF16 GEMM ops + + +def m_grouped_bf16_gemm_nt_contiguous( + a, + b, + d, + grouped_layout, + compiled_dims="nk", + use_psum_layout=False, + expected_m_for_psum_layout=None, +): + has_em = expected_m_for_psum_layout is not None + em = expected_m_for_psum_layout if has_em else 0 + ops.m_grouped_bf16_gemm_nt_contiguous( + a, b, d, grouped_layout, compiled_dims, use_psum_layout, em, has_em + ) + + +def m_grouped_bf16_gemm_nn_contiguous( + a, b, d, grouped_layout, compiled_dims="nk", use_psum_layout=False +): + ops.m_grouped_bf16_gemm_nn_contiguous( + a, b, d, grouped_layout, compiled_dims, use_psum_layout + ) + + +def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims="nk"): + ops.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims) + + +# Legacy alias +bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked + + +# K-grouped BF16 GEMM ops + + +def k_grouped_bf16_gemm_tn_contiguous( + a, b, d, ks, ks_tensor, c=None, compiled_dims="mn" +): + ops.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks_tensor, c, compiled_dims) + + +# cuBLASLt GEMM ops + + +def cublaslt_gemm_nt(a, b, d, c=None): + ops.cublaslt_gemm_nt(a, b, d, c) + + +def cublaslt_gemm_nn(a, b, d, c=None): + ops.cublaslt_gemm_nn(a, b, d, c) + + +def cublaslt_gemm_tn(a, b, d, c=None): + ops.cublaslt_gemm_tn(a, b, d, c) + + +def cublaslt_gemm_tt(a, b, d, c=None): + ops.cublaslt_gemm_tt(a, b, d, c) + + +# Attention ops + + +def fp8_gemm_nt_skip_head_mid( + a, b, d, head_splits, recipe=None, compiled_dims="nk", disable_ue8m0_cast=False +): + a_data, a_sf = a + b_data, b_sf = b + left, mid, right = head_splits + has_recipe = recipe is not None + r0, r1, r2 = recipe if has_recipe else (0, 0, 0) + ops.fp8_gemm_nt_skip_head_mid( + a_data, + a_sf, + b_data, + b_sf, + d, + left, + mid, + right, + r0, + r1, + r2, + has_recipe, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_mqa_logits( + q, + kv, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits=True, + max_seqlen_k=0, +): + kv_data, kv_sf = kv + return ops.fp8_mqa_logits( + q, + kv_data, + kv_sf, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits, + max_seqlen_k, + ) + + +def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms): + return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms) + + +def fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits=False, +): + return ops.fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits, + ) + + +# Einsum ops + + +def einsum(expr, a, b, d, c=None, use_cublaslt=False): + ops.einsum(expr, a, b, d, c, use_cublaslt) + + +def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2 = recipe + ops.fp8_einsum(expr, a_data, a_sf, b_data, b_sf, d, c, r0, r1, r2) + + +# Hyperconnection ops + + +def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None): + has_ns = num_splits is not None + ns = num_splits if has_ns else 0 + ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns) + + +# Initialize the C++ runtime + + +def _find_cuda_home() -> str: + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home is None: + try: + with open(os.devnull, "w") as devnull: + nvcc = ( + subprocess.check_output(["which", "nvcc"], stderr=devnull) + .decode() + .rstrip("\r\n") + ) + cuda_home = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + cuda_home = "/usr/local/cuda" + if not os.path.exists(cuda_home): + cuda_home = None + assert cuda_home is not None, "Could not find CUDA installation" + return cuda_home + + +# Find the library root for JIT headers +# In development: use the repo's deep_gemm/ directory +# In installed wheel: use this package's directory +_lib_root = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "deep_gemm" +) +if not os.path.isdir(os.path.join(_lib_root, "include")): + # Fallback: try the parent package + _lib_root = os.path.dirname(os.path.abspath(__file__)) + +_initialized = False + +# Set DG_CUTLASS_INCLUDE for JIT kernel compilation (if not already set by user) +if "DG_CUTLASS_INCLUDE" not in os.environ: + _include = os.path.join(_lib_root, "include") + _cutlass_include_candidates = [ + _include, # legacy layout: include/cutlass + os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout + ] + for _cutlass_include in _cutlass_include_candidates: + if os.path.isdir(os.path.join(_cutlass_include, "cutlass")): + os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include + break + else: + # Fall back to nvidia-cutlass pip package + try: + import nvidia.cutlass as _nc + os.environ["DG_CUTLASS_INCLUDE"] = os.path.join( + os.path.dirname(_nc.__file__), "include" + ) + except ImportError: + pass + +def _ensure_initialized(): + global _initialized + if _initialized: + return + _initialized = True + ops.init(_lib_root, _find_cuda_home()) + + +# Try to initialize eagerly, but don't fail if CUDA is not found +# (e.g., during build-time import checks). init() will be called +# lazily on first actual kernel use. +try: + _ensure_initialized() +except (AssertionError, RuntimeError): + pass diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/cute_tie.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/cute_tie.cuh new file mode 100644 index 00000000..cd2aace7 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/cute_tie.cuh @@ -0,0 +1,48 @@ +#pragma once + +namespace cute { + +struct ignore_t { + template + constexpr const ignore_t& operator=(T&&) const noexcept { + return *this; + } +}; + +inline constexpr ignore_t ignore{}; + +} // namespace cute + +#define CUTE_TIE_CONCAT_IMPL(A, B) A##B +#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B) + +#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define CUTE_TIE_COUNT_ARGS(...) \ + CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + +#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get(TUPLE) +#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get(TUPLE) + +#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1); +#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2); +#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); +#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); +#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5); + +#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_DECL, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ) + +#define CUTE_TIE(TUPLE_EXPR, ...) \ + do { \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_ASSIGN, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ); \ + } while (0) diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh new file mode 100644 index 00000000..5f6a7a19 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +struct EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 + and kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/reduction.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/reduction.cuh new file mode 100644 index 00000000..d9e35f73 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/reduction.cuh @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include + +#include + +// Operation functors +template struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +__forceinline__ __device__ T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +__forceinline__ __device__ T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/scheduler.cuh new file mode 100644 index 00000000..f93b96ee --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -0,0 +1,288 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto& candidate: {8u, 16u}) { + const auto& usage = kIsMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for countiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + __device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = __ldg(grouped_layout + group_idx); + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k, + int* grouped_layout = nullptr) { + num_m_blocks = ceil_div(shape_m, BLOCK_M); + num_n_blocks = ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = __ldg(grouped_layout); + num_m_blocks = ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto& group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + current_group_idx)), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = align(current_psum_m, 128u); + current_psum_m = __ldg(grouped_layout + current_group_idx); + current_m_block_cumsum += num_m_blocks; + num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with 128 + m_block_idx += last_psum_m / BLOCK_M; + DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M"); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto& block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx); + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/sm100_utils.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/sm100_utils.cuh new file mode 100644 index 00000000..537cbe08 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/sm100_utils.cuh @@ -0,0 +1,266 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace deep_gemm::sm100 { + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + uint32_t stride_byte_offset, uint32_t leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +__device__ __forceinline__ +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +__device__ __forceinline__ +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +__device__ __forceinline__ +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +template +__device__ __forceinline__ +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto& layout_type = to_umma_layout_type(); + const auto& num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +__device__ __forceinline__ +uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +template +__device__ constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if (kNumCols <= 32) return 32; + if (kNumCols <= 64) return 64; + if (kNumCols <= 128) return 128; + if (kNumCols <= 256) return 256; + return 512; +} + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +__device__ __forceinline__ +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +} // namespace `deep_gemm::sm100` diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/sm90_utils.cuh new file mode 100644 index 00000000..0874b675 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -0,0 +1,332 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm::sm90 { + +template +struct FP8MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + + template + __forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +template +struct SM90_U32x2_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +struct SM90_U32x2_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +__forceinline__ __device__ void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +__forceinline__ __device__ void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +template +__device__ cute::GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type, + const int& leading_byte_offset = 0, + const int& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto& uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +__device__ __forceinline__ +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto& layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +} // namespace `deep_gemm::sm90` diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/tma_utils.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/tma_utils.cuh new file mode 100644 index 00000000..bd54adc2 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/tma_utils.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +// Tensormap related +__device__ __forceinline__ void tensor_map_release_cta() { + asm volatile ("fence.proxy.tensormap::generic.release.cta;"); +} + +__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +} // namespace `deep_gemm` diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/types.hpp b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/types.hpp new file mode 100644 index 00000000..410c5469 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/types.hpp @@ -0,0 +1,41 @@ +#pragma once + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/utils.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/utils.cuh new file mode 100644 index 00000000..8fb6c2fc --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/common/utils.cuh @@ -0,0 +1,183 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "cute_tie.cuh" + +#ifdef __CLION_IDE__ + +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +namespace deep_gemm { + +template +struct PatternVisitor { + FuncT func; + + __device__ __host__ + explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} + + __device__ __host__ + auto operator [](const uint32_t& i) { + return func(i); + } +}; + +template +__device__ __host__ T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ T align(T a, T b) { + return ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +__forceinline__ __device__ void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +__forceinline__ __device__ uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +__forceinline__ __device__ uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); +} + +__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); +} + +__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +template +__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +__device__ __forceinline__ void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +template +struct Vectorized { + static auto zeros() { + // TODO: add `ulonglong4` for SM100 once `__ldg` support this + if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) { + return make_uint4(0, 0, 0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) { + return make_uint2(0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) { + return 0; + } else { + DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization"); + } + } + + using vec_t = decltype(zeros()); +}; + +} // namespace `deep_gemm` diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh new file mode 100644 index 00000000..0227b3e8 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -0,0 +1,482 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `umma_arrive()` overhead + constexpr bool kDoMergeStages = + kNumStages_ >= 8 and kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 8; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; + constexpr uint32_t kNumTMAStoreStages = 2; + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive only at the leader CTA + full_barriers[i]->init(kNumMulticast); + // Arrive at all CTAs + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + if constexpr (kTensorCoreUtilControl < 100) + tensor_core_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = (stage_idx + 1) % kNumStages; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // UMMA and empty barrier arrival alias + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc); + } + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + + // Let tensor cores relax for lower possibility of frequency drop + DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control"); + if constexpr (kTensorCoreUtilControl < 100) { + // For utilization control + umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + + // Wait for last UMMA to be done + tensor_core_full_barrier->wait(tensor_core_phase); + tensor_core_phase ^= 1; + + // Sleep for certain cycles + constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; + const auto& start_clock = clock64(); + if (cute::elect_one_sync()) + while (clock64() - start_clock < kNumDummyCycles) {} + __syncwarp(); + } + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } + + // Deallocate tensor memory by the last UMMA store warp + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) + Allocator().free(0, kNumTmemCols); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh new file mode 100644 index 00000000..86303347 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -0,0 +1,265 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumThreads, 1) +sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); + DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Shared memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Fill D/A/B + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx == 0) { + // TMA load warp + for (uint32_t s = 0; s < num_total_stages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + uint32_t m_idx = BLOCK_M * m_block_idx; + uint32_t n_idx = BLOCK_N * n_block_idx; + uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + uint32_t k_idx = sk_idx % SHAPE_K; + uint32_t s_idx = sk_idx / SHAPE_K; + + // Issue TMAs + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + } + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (cute::elect_one_sync()) + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } else if (warp_idx == 1) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + constexpr uint32_t UMMA_M = LAYOUT_AD_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Wait tensor memory empty barrier arrival + tcgen05_after_thread_sync(); + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + if (warp_idx == 2) + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + if (s >= kNumTMAStoreStages) { + if (warp_idx == 0 and cute::elect_one_sync()) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = s % kNumTMAStoreStages; + const auto m_idx = m_block_idx * BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is doing TMA stores + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh new file mode 100644 index 00000000..45a603ad --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -0,0 +1,563 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; + constexpr uint32_t kNumTMAStoreStages = 2; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4); + const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + tma_copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad))); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + tma_copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx)); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32; + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + tcgen05_after_thread_sync(); + + // Do SF copy at certain stages + // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + __syncwarp(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc, + kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32), + kTmemStartColOfSFB); + } + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } + + // Deallocate tensor memory by the last UMMA store warp + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) + Allocator().free(0, kNumTmemCols); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh new file mode 100644 index 00000000..180a308b --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -0,0 +1,404 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint64_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warp_in_group_idx = warp_idx % 4; + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + + // Align to 512 bytes for swizzle-64B + extern __shared__ __align__(512) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_WEIGHT_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + + // TMA configs + constexpr uint32_t kNumTmemCols = BLOCK_Q * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + + // Tensor memory allocation + auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); + + // Initialize barriers + DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); + const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32)); + const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1)); + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (is_umma_warp) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 24; + constexpr uint32_t kNumMathRegisters = 240; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + + if (is_tma_load_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue UMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } else if (warp_idx >= kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + const auto& warp_offset = warp_idx * 32; + const auto& v_offset = lane_idx; + + // Preload weights + constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads); + float weights[BLOCK_Q][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) { + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); + tcgen05_after_thread_sync(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + + constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems"); + uint32_t shifted_accum[kNumLDTMElems]; + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); + } + + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result; + } else { + logits[q_idx * stride_logits + kv_offset + v_offset] = result; + } + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + + // Free tensor memory + __syncthreads(); + if (is_tma_load_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh new file mode 100644 index 00000000..7058c40f --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -0,0 +1,398 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); + }); + constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); + + constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); + const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); + const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); + + // Initialize barriers + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + cutlass::arch::fence_barrier_init(); + } + if (is_umma_warp) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = kNextN * kNumHeads; + DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + + if (is_tma_load_warp) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + } + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + if (cute::elect_one_sync()) { + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 1; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + if (q_idx != next_q_idx) { + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + } + + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(umma_phase); + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + umma_phase ^= 1; + } + } else if (is_math_warp) { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + const uint32_t thread_idx = threadIdx.x; + + // Weights + constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); + float weights[kNextN][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 0; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait(umma_phase); + tcgen05_after_thread_sync(); + umma_phase ^= 1; + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; + uint32_t shifted_accum[kNumLDTMElems]; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); + } + + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + thread_idx] = result; + } + } + } else { + cutlass::arch::warpgroup_reg_dealloc(); + } + + // Free tensor memory + __syncthreads(); + if (is_umma_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..4e4ff21d --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,345 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__device__ __forceinline__ +uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { + // Calculate the index of the bank group to be written in the atom + const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` + // - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)` + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % (kSwizzleMode / kSwizzleBase); + + return row * 128 + col * kSwizzleBase; +} + +template +__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t kNumCastStages = 2; + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + constexpr auto kMajorA = cute::UMMA::Major::K; + constexpr auto kMajorB = cute::UMMA::Major::K; + DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages"); + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 4 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + full_cast_barriers[i]->init(kNumCastAndReduceThreads); + empty_barriers[i]->init(1); + empty_cast_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + + // Dispatch warps into different roles + if (warp_idx < kNumMMAThreads / 32) { + // TMA load warp + if (warp_idx == 0 and cute::elect_one_sync()) { + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } + + // MMA issue warp + if (warp_idx == 1) { + // Make instruction descriptor + constexpr uint32_t UMMA_M = BLOCK_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(float); + constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float); + using umma_t = cute::SM100_MMA_TF32_TS; + auto instr_desc = cute::UMMA::make_instr_desc(); + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Launch MMAs + // We can not unroll this part + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + const auto& cast_stage_idx = s % kNumCastStages; + full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; + const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; + const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); + } + + // Commit + cutlass::arch::umma_arrive(reinterpret_cast(empty_cast_barriers[cast_stage_idx])); + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + + // Commit to epilogue threads + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Only support layout F (M = 64) and D (M = 128) + DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Source and destination memory address + uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup; + auto smem_ptr = reinterpret_cast(smem_cd) + // Base pointer + warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset + get_swizzled_smem_offset(i, lane_idx); // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + if constexpr (BLOCK_M == 64) + __syncwarp(); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0); + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } else { + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32; + + // TODO: make even larger block K + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + + // Launch reductions + float2 sum[2] = {float2{0, 0}, float2{0, 0}}; + #pragma unroll kNumStages + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b) + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + const auto& smem_base_ptr = reinterpret_cast(smem_a[stage_idx]) + // Base pointer + sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset + + // 4 lanes shared a bank group + uint32_t uint32_values[2][kNumLoads]; + DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads"); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; i += 2) { + auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); + sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); + } + + // Wait tensor memory empty + const auto& cast_stage_idx = s % kNumCastStages; + empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1); + + // Cast, reduce and store into tensor memory + float2 fp32x2_values[2][kNumLoads]; + const auto& upper_view = reinterpret_cast(&fp32x2_values[0]); + const auto& lower_view = reinterpret_cast(&fp32x2_values[1]); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast(&uint32_values[u][i])); + sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]); + } + + // Store upper and lower part at the same time + const auto idx_0 = i * 2, idx_1 = i * 2 + 1; + cute::SM100_TMEM_STORE_16dp256b1x::copy( + upper_view[idx_0], upper_view[idx_1], + lower_view[idx_0], lower_view[idx_1], + cast_stage_idx * BLOCK_K + i * 8); + } + cutlass::arch::fence_view_async_tmem_store(); + + // Arrive for issuing MMAs + tcgen05_before_thread_sync(); + full_cast_barriers[cast_stage_idx]->arrive(); + } + + // Intra-warp reduction and write back + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + if (lane_idx % 4 == 0 and m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum; + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh new file mode 100644 index 00000000..7a77e4e8 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -0,0 +1,381 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `warpgroup_wait<0>()` overhead + constexpr uint32_t kDoMergeStages = + kNumStages_ >= 10 and + kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K and + kNumMathThreads == 128; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 5; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B/D must be aligned to 1024 bytes"); + + // D/A/B shared memory + auto smem_d = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 48; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(kNumTMAThreads >= 128, "Need at least 128 threads for TMA warp-group"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = make_gmma_desc(smem_b[0], 0, 0); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= 64 or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // TODO: remove some useless computation for unaligned Ms + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = advance_gmma_desc_lo( + a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); + b_desc.reg32_[0] = advance_gmma_desc_lo( + b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); + WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + if constexpr (cute::is_same_v) { + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + } else { + // Use `st.shared` if STSM is not available + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2); + auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); + } + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + // Use TMA store to write back to global memory + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh new file mode 100644 index 00000000..191a4fe2 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + float *d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Shared memory + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); + DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); + DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + // Fill shared memory pointers + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + #pragma unroll + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); + + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t& k_idx = sk_idx % SHAPE_K; + const uint32_t& s_idx = sk_idx / SHAPE_K; + + constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + float accum[WGMMA::kNumAccum] = {0}; + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrivals + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, 1); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + empty_barriers[stage_idx]->arrive(); + } + + const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + if (col + i * 8 >= SHAPE_N) + break; + if (row < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 0) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 0], accum[i * 4 + 1])); + } + if (row + 8 < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 8) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 2], accum[i * 4 + 3])); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh new file mode 100644 index 00000000..cdd28fcb --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -0,0 +1,349 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, + int* grouped_layout, + cute::TmaDescriptor* tensor_map_buffer, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_b_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = threadIdx.x % 32; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a_base); + cute::prefetch_tma_descriptor(&tensor_map_b_base); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Tensor maps on shared and global memory + auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * i); + }); + auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * (2 + i)); + }); + auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; }); + auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; }); + + // Data on shared memory + auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); + }); + auto smem_sfb = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); + }); + + // Barriers on shared memory + constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); + auto full_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); + }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); + }); + + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // Load tensormap A/B to shared memory + if constexpr (kGemmType == GemmType::KGroupedContiguous) { + *smem_tensor_map_a[0] = tensor_map_a_base; + *smem_tensor_map_a[1] = tensor_map_a_base; + *smem_tensor_map_b[0] = tensor_map_b_base; + *smem_tensor_map_b[1] = tensor_map_b_base; + } + + // Initialize barriers + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Pipeline unroll control + constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages); + + // Register reconfigurations (more math registers are needed with unrolling) + constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); + constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // TMA and MMA pipeline + const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase + }; + uint32_t iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; + const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; + uint32_t last_group_idx = kNumGroups, sum_k = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t& m_idx = m_block_idx * BLOCK_M; + const uint32_t& n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) { + const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1; + const uint32_t& next_stage_idx = stage_idx ^ 1; + last_group_idx = scheduler.current_group_idx; + + // Prepare next tensor map + sum_k += scheduler.current_shape_k; + if (scheduler.next_group_idx < kNumGroups) { + tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); + tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); + *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); + tensor_map_release_cta(); + } + + // Get current tensor map + if (scheduler.current_num_valid_groups > 0) { + tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]); + tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]); + current_tensor_map_a = gmem_tensor_map_a[stage_idx]; + current_tensor_map_b = gmem_tensor_map_b[stage_idx]; + } + } + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& k_idx = k_block_idx * BLOCK_K; + const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + tma_copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma_copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma_copy(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma_copy(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; + const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Accumulation for WGMMA or CUDA promotion + DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); + const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + float2 scales_b[WGMMA::kNumAccum / 4]; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait TMA arrivals + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + full_barriers[stage_idx]->wait(phase); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1); + + // Read B scales + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + scales_b[i] = ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + + // Promote with scales + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const float &scale_b_0 = scales_b[i].x; + const float &scale_b_1 = scales_b[i].y; + final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + + // Flush previous stores + if (warp_idx % 4 == 0 and cute::elect_one_sync()) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Store to D shared memory + const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Use TMA store to write back to global memory + if (warp_idx % 4 == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy( + &tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N, + current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh new file mode 100644 index 00000000..9247304c --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -0,0 +1,440 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { + if (num_former_iters == kNumFormerIters) { + func(cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + dispatch_num_former_iters(num_former_iters, func); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K); + const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); + auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA A + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t k_idx = k_block_idx * BLOCK_K; + tma_copy(&tensor_map_a, &full_barrier, + smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a, batch_idx); + tma_copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), + num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, &full_barrier, + smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b, batch_idx); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = make_smem_desc(smem_b[0], 1); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; + const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; + auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; + + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) + st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb)); + } + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + + // Accumulation for WGMMA or CUDA promotion + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&]() { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void(); + } + }; + + // Skip useless computations + if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { + // The compiler must know the dynamic variable `num_former_iters`'s real value + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // Dispatch `num_former_iters` and launch MMAs + dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { + #pragma unroll 8 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Read B scales + float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; + b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; + WGMMA::wgmma(a_desc, b_desc, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(); + + // Skip promotion for the unfilled parts + if (not do_wgmma_store) + continue; + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + } + }); + } else { + #pragma unroll + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + empty_barrier_arrive(); + } + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset); + auto m_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr, + n_idx, m_idx, scheduler.current_group_idx); + } else { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh new file mode 100644 index 00000000..d58c7162 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -0,0 +1,329 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint64_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + + // Initialize barriers + const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 32; + constexpr uint32_t kNumMathRegisters = 112; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // Only the first warp remains + if (not is_tma_load_warp) + return; + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& thread_idx = threadIdx.x % kNumMathThreads; + const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; + + const auto& warp_offset = warp_idx * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; + if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; + } else { + logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; + logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1; + } + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh new file mode 100644 index 00000000..482a85a8 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -0,0 +1,413 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = get_lane_idx(); + + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t q_idx = k * 32 + lane_idx; + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0); + num_segs[k] = ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t q_idx = 0; + while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts) + ++ q_idx; + const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } +} + +template +struct PagedMQALogitsScheduler { + uint32_t batch_size; + const uint32_t* context_lens; + + uint32_t current_q_idx, current_kv_idx; + uint32_t end_q_idx, end_kv_idx; + uint32_t current_num_kv; + + __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) { + const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0; + } + + __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, + const uint32_t* context_lens, const uint32_t* schedule_meta) { + this->batch_size = batch_size; + this->context_lens = context_lens; + + const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); + const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); + current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; + + current_num_kv = get_num_kv(current_q_idx); + } + + __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_idx = current_q_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (q_idx == end_q_idx and kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumBlocksPerSplit; + if (current_kv_idx >= current_num_kv) { + ++ current_q_idx; + current_kv_idx = 0; + current_num_kv = get_num_kv(current_q_idx); + } + + return true; + } + + __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const { + return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx; + } +}; + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + + constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + + constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q data and barriers on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + + // Separate math warpgroups and tma load warps into KV groups + // Each math warpgroup corresponds to a tma load warp + const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + + // Per group KV data and barriers on shared memory + const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + + // Initialize barriers + if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { + if (kv_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + } + if (kv_group_idx < kNumMathWarpGroups) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(128); + } + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 64; + constexpr uint32_t kNumMathRegisters = 104; + + // Scheduler + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + if (kv_group_idx >= kNumMathWarpGroups) + return; + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (kv_group_idx == 0 and cute::elect_one_sync()) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? + __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + } + const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; + const auto& sub_warp_offset = (warp_idx % 4) * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + + // Wait WGMMA + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + v_0_offset] = v_0; + logits[kv_offset + i * logits_stride + v_1_offset] = v_1; + } + } + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..e3bf9847 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,287 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ __forceinline__ +uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { + constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; + + const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % kGroupsInSwizzleRange; + + return (row * kNumBankGroups + col) % kGroupsInSwizzleRange; +} + +template +__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // kSwizzleAMode and kSwizzleBMode must be 128 for now + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode"); + + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 256; + + // TMA load warp + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cutlass::arch::warpgroup_reg_dealloc(); + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + + for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + } + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + constexpr uint32_t WGMMA_M = 64; + constexpr uint32_t WGMMA_N = BLOCK_N; + constexpr uint32_t WGMMA_K = 8; + + using WGMMA = typename TF32MMASelector::type; + float accum[WGMMA::kNumAccum] = {0}; + + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + float sqr_sum_acc_0 = 0; + float sqr_sum_acc_1 = 0; + + #pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2 + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128; + constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K; + + float a[kNumRegPerWgmma * kNumWgmmaPerBlockK]; + // Assume swizzle A mode is 128 + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + + // Load BF16 A fragment from shared memory into registers, and transpose to FP32 + uint32_t row = warp_idx * 16 + lane_idx / 4; + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + // Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a + uint32_t bank_group_idx = (row ^ i) % 8; + nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + + uint32_t elem_offset = lane_idx % 4; + nv_bfloat16 a_bf16[kNumRegPerWgmma]; + a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset]; + a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4]; + a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset]; + a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4]; + + auto a_bf16x2_ptr = reinterpret_cast(a_bf16); + auto a_float2_ptr = reinterpret_cast(a); + float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]); + float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]); + a_float2_ptr[i * 2 + 0] = a_float2_0; + a_float2_ptr[i * 2 + 1] = a_float2_1; + sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x; + sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; + } + + warpgroup_wait<0>(); + if (s > 0) + empty_barriers[(s - 1) % kNumStages]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + + constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); + constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; + DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K"); + + #pragma unroll + for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { + #pragma unroll + for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { + auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + } + + const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + + const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); + if (lane_idx % 4 == 0) { + if (m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum_0; + if (m_idx + 8 < shape_m) + sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; + } + warpgroup_wait<0>(); + empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); + + // Write accum to shared memory + // Every 2 threads (one pair) will write to the same bank group (16 bytes). + // Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d + uint32_t is_odd_pair = lane_idx / 2 % 2; + + // Four threads per group; write the data to the same row. + uint32_t row_idx = lane_idx / 4; + + // Even/odd index pairs write to the same column, we need to reorder idx: + // group even pair indices consecutively, and likewise for odd ones. + uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx; + + auto shifted_smem_ptr = reinterpret_cast(smem_cd) + + (warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows + lane_idx % 2 * 8; // One thread of a pair writes 8 bytes + + #pragma unroll + for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) { + // Get the swizzled bank group index (16 bytes per group) + uint32_t bank_group_idx = get_swizzled_bank_group_idx(i + is_odd_pair, reordered_pair_idx); + auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group + + // 0/1 write to the same row, 2/3 write to another row + auto values = reinterpret_cast(accum + i * 2); + st_shared(smem_ptr, values[0], values[1]); + st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, 1); + + // Issue TMA stores + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh new file mode 100644 index 00000000..cc9e5e6b --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(kNumWarps * 32, 1) +void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { + const uint32_t& num_sms = gridDim.x; + const uint32_t& sm_idx = blockIdx.x; + const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + constexpr float neg_inf = -cute::numeric_limits::infinity(); + + // Allocate filled `-inf` shared memory + extern __shared__ __align__(1024) float smem_buffer[]; + #pragma unroll + for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) + smem_buffer[i] = neg_inf; + cute::tma_store_fence(); + __syncthreads(); + + // Assign sequence to each warp + const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto& per = total / num, rem = total % num; + return {start + idx * per + min(idx, rem), per + (idx < rem)}; + }; + CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); + CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + + if (cute::elect_one_sync()) { + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + + for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { + const auto& right = min(left + BLOCK_KV, static_cast(stride_logits)); + if (right <= ks or ke <= left) { + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float)); + } else { + if (left < aligned_ks) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float)); + if (aligned_ke < right) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float)); + } + } + } + } + + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + for (uint32_t j = aligned_ks; j < ks; ++ j) + logits[i * stride_logits + j] = neg_inf; + for (uint32_t j = ke; j < aligned_ke; ++ j) + logits[i * stride_logits + j] = neg_inf; + } +} + +} diff --git a/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh new file mode 100644 index 00000000..bea70002 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -0,0 +1,176 @@ +#pragma once + +#include + +namespace deep_gemm { + +template +__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename Vectorized::vec_t in_vec_t; + constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); + constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; + + // Shapes and strides + extern __shared__ float smem_buffer[]; + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the block + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; + const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + + // Load + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { + auto in_vec = __ldg(local_sf + i); + const auto& in_values = reinterpret_cast(&in_vec); + + const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; + #pragma unroll + for (uint32_t j = 0; j < kNumElemsPerVec; ++ j) + smem_buffer[row * PADDED_SF_K + col + j] = in_values[j]; + } + __syncthreads(); + + // Store + #pragma unroll + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { + const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; + const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + } +} + +// NOTES: the two kernels below always pack the K dimension + +template +__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { + extern __shared__ uint32_t smem_buffer[]; + + // Shapes and strides + constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the group + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + + // Load FP32 SFs + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); + const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + const auto num_values = in_block_mn * SF_K; + const auto num_uint4 = num_values / 4; + #pragma unroll + for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { + const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); + st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + } + + // Fill unaligned values as well + if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) + st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + __syncthreads(); + + // Pack into UE8M0 and store + #pragma unroll + for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) { + const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN; + + // Load shared memory + uint32_t values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + const auto sf_k_idx = sf_k_pack_idx * 4 + j; + values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + } + + // Pack and store + uint32_t packed = 0; + packed |= (values[0] >> 23u); + packed |= (values[1] >> 15u); + packed |= (values[2] >> 7u); + packed |= (values[3] << 1u); + if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn) + out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed; + } +} + +template +__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { + // Always packing the K dimension + // NOTES: should also assert `mn % 4 == 0` at launch + DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes"); + + // Shapes and strides + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto in_block_mn_uint4 = in_block_mn / 4; + const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K); + + // Shift into the right block along MN + sf += blockIdx.x * BLOCK_MN; + out += blockIdx.x * BLOCK_MN; + + // Each warp is responsible for a packed row + const auto warp_idx = threadIdx.x / 32; + const auto lane_idx = get_lane_idx(); + const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; + if (warp_idx >= in_block_packed_sf_k) + return; + + // Make an offset on the input + uint32_t input_offset = 0; + if constexpr (kNumGroups > 1) { + // Load each group's size + DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups"); + uint32_t group_ks[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) { + const auto group_idx = lane_idx * 4 + i; + group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + } + __syncwarp(); + + // Make the offset + sf_k = 0; + auto sum_packed_sf_k = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumGroups; ++ i) { + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + sf_k += sf_k_in_group; + sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + if (packed_sf_k_idx < sum_packed_sf_k) + break; + if (const auto remainder = sf_k_in_group % 4; remainder > 0) + input_offset += 4 - remainder; + } + } + + for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + // Load + uint4 values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + values[j] = make_uint4(0, 0, 0, 0); + if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) + values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + } + + // Pack and store + uint4 packed; + packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u); + packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u); + packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u); + packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u); + reinterpret_cast(out + packed_sf_k_idx * mn)[mn_idx] = packed; + } +} + +} // namespace deep_gemm diff --git a/deep-gemm/torch-ext/deep_gemm/include/third-party/cutlass b/deep-gemm/torch-ext/deep_gemm/include/third-party/cutlass new file mode 160000 index 00000000..f3fde583 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/include/third-party/cutlass @@ -0,0 +1 @@ +Subproject commit f3fde58372d33e9a5650ba7b80fc48b3b49d40c8 diff --git a/deep-gemm/torch-ext/deep_gemm/testing/__init__.py b/deep-gemm/torch-ext/deep_gemm/testing/__init__.py new file mode 100644 index 00000000..13a9d78d --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/testing/__init__.py @@ -0,0 +1,4 @@ +from . import bench, numeric, utils +from .bench import * +from .numeric import * +from .utils import * diff --git a/deep-gemm/torch-ext/deep_gemm/testing/bench.py b/deep-gemm/torch-ext/deep_gemm/testing/bench.py new file mode 100644 index 00000000..2c752da2 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/testing/bench.py @@ -0,0 +1,137 @@ +import os +import sys +import torch + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests / 1e3 + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, + with_multiple_kernels: bool = False): + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + + # Skip profiling + # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer + if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): + return (1, ) * len(kernel_names) if is_tuple else 1 + + # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + fn() + profiler.step() + + # Parse the profiling table + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num if total_num > 0 else 0) + + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/deep-gemm/torch-ext/deep_gemm/testing/numeric.py b/deep-gemm/torch-ext/deep_gemm/testing/numeric.py new file mode 100644 index 00000000..a42c4318 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/testing/numeric.py @@ -0,0 +1,21 @@ +import torch +from typing import Iterable + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + if denominator == 0: # Which means that all elements in x and y are 0 + return 0.0 + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(*tensors): + total = 0 + for t in tensors: + if isinstance(t, (tuple, list)): + total += count_bytes(*t) + elif t is not None: + total += t.numel() * t.element_size() + return total diff --git a/deep-gemm/torch-ext/deep_gemm/testing/utils.py b/deep-gemm/torch-ext/deep_gemm/testing/utils.py new file mode 100644 index 00000000..2d202d41 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/testing/utils.py @@ -0,0 +1,38 @@ +import functools +import os +import torch +from typing import Callable + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def test_filter(condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + func(*args, **kwargs) + else: + print(f'{func.__name__}:') + print(f' > Filtered by {condition}') + print() + return wrapper + return decorator + + +def ignore_env(name: str, condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + saved = os.environ.pop(name, None) + func(*args, **kwargs) + if saved is not None: + os.environ[name] = saved + else: + func(*args, **kwargs) + + return wrapper + return decorator diff --git a/deep-gemm/torch-ext/deep_gemm/utils/__init__.py b/deep-gemm/torch-ext/deep_gemm/utils/__init__.py new file mode 100644 index 00000000..e8f859a2 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/utils/__init__.py @@ -0,0 +1,3 @@ +from . import math, layout +from .layout import * +from .math import * diff --git a/deep-gemm/torch-ext/deep_gemm/utils/layout.py b/deep-gemm/torch-ext/deep_gemm/utils/layout.py new file mode 100644 index 00000000..a6bc29d9 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/utils/layout.py @@ -0,0 +1,25 @@ +from .._ops import ops + + +def get_mk_alignment_for_contiguous_layout(): + return ops.get_mk_alignment_for_contiguous_layout() + + +def get_tma_aligned_size(mn: int, element_size: int): + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks) + + +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/deep-gemm/torch-ext/deep_gemm/utils/math.py b/deep-gemm/torch-ext/deep_gemm/utils/math.py new file mode 100644 index 00000000..c65026e5 --- /dev/null +++ b/deep-gemm/torch-ext/deep_gemm/utils/math.py @@ -0,0 +1,107 @@ +import torch +from typing import Tuple + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + padded_n = align(n, gran_k) + x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf + + +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % gran_k == 0 + m, n = x.shape + x_view = x.view(-1, gran_k, n) + x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + ax = x.abs().clamp_max(6.0) + # {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], + device=x.device, dtype=ax.dtype) + idx = torch.bucketize(ax, boundaries) + code = idx.to(torch.uint8) + sign = (x < 0) & (idx != 0) + code = code | (sign.to(torch.uint8) << 3) + return code # uint8, 0..15 + + +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + assert n % 2 == 0 + padded_n = align(n, gran_k) + x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = x_amax / 6.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = x_view * (1.0 / sf.unsqueeze(2)) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes2 = codes.view(m, padded_n // 2, 2) + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 + return packed[:, :n // 2].contiguous(), sf + + +def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.uint8 + assert a.dim() == 2 + m, n2 = a.shape + n = n2 * 2 + assert (m % 2) == 0 + lo = a & 0x0F + hi = (a >> 4) & 0x0F + codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes[:, 0::2], codes[:, 1::2] = lo, hi + codes_t = codes.transpose(0, 1).contiguous() + codes2 = codes_t.view(n, m // 2, 2) + out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) + return out.contiguous() \ No newline at end of file diff --git a/deep-gemm/torch-ext/torch_binding.cpp b/deep-gemm/torch-ext/torch_binding.cpp new file mode 100644 index 00000000..e2ccb749 --- /dev/null +++ b/deep-gemm/torch-ext/torch_binding.cpp @@ -0,0 +1,313 @@ +#include +#include + +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // Runtime ops (no device dispatch - these are host-side) + ops.def("init(str path, str cuda_home) -> ()"); + ops.impl("init", &deep_gemm_init); + + ops.def("set_num_sms(int num_sms) -> ()"); + ops.impl("set_num_sms", &deep_gemm_set_num_sms); + + ops.def("get_num_sms() -> int"); + ops.impl("get_num_sms", &deep_gemm_get_num_sms); + + ops.def("set_tc_util(int tc_util) -> ()"); + ops.impl("set_tc_util", &deep_gemm_set_tc_util); + + ops.def("get_tc_util() -> int"); + ops.impl("get_tc_util", &deep_gemm_get_tc_util); + + ops.def("get_mk_alignment_for_contiguous_layout() -> int"); + ops.impl("get_mk_alignment_for_contiguous_layout", + &deep_gemm_get_mk_alignment_for_contiguous_layout); + + // Layout ops (CUDA dispatch) + ops.def( + "get_tma_aligned_size(int mn, int element_size) -> Tensor" + ); + ops.impl("get_tma_aligned_size", torch::kCUDA, &deep_gemm_get_tma_aligned_size); + + ops.def( + "get_mn_major_tma_aligned_tensor(Tensor sf) -> Tensor" + ); + ops.impl("get_mn_major_tma_aligned_tensor", torch::kCUDA, + &deep_gemm_get_mn_major_tma_aligned_tensor); + + ops.def( + "get_mn_major_tma_aligned_packed_ue8m0_tensor(Tensor sf) -> Tensor" + ); + ops.impl("get_mn_major_tma_aligned_packed_ue8m0_tensor", torch::kCUDA, + &deep_gemm_get_mn_major_tma_aligned_packed_ue8m0_tensor); + + ops.def( + "get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(" + "Tensor sf, Tensor ks_tensor, Tensor ks_int_tensor) -> Tensor" + ); + ops.impl("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", torch::kCUDA, + &deep_gemm_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); + + ops.def( + "transform_sf_into_required_layout(" + "Tensor sf, int mn, int k, " + "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " + "int recipe_ab_0, int recipe_ab_1, bool has_recipe_ab, " + "int num_groups, bool has_num_groups, " + "bool is_sfa, bool disable_ue8m0_cast) -> Tensor" + ); + ops.impl("transform_sf_into_required_layout", torch::kCUDA, + &deep_gemm_transform_sf_into_required_layout); + + // FP8/FP4 GEMM ops (CUDA dispatch) + ops.def( + "fp8_fp4_gemm_nt(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor? c, " + "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " + "int recipe_a_0, int recipe_a_1, bool has_recipe_a, " + "int recipe_b_0, int recipe_b_1, bool has_recipe_b, " + "str compiled_dims, bool disable_ue8m0_cast) -> ()" + ); + ops.impl("fp8_fp4_gemm_nt", torch::kCUDA, &deep_gemm_fp8_fp4_gemm_nt); + + ops.def( + "fp8_fp4_gemm_nn(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor? c, " + "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " + "int recipe_a_0, int recipe_a_1, bool has_recipe_a, " + "int recipe_b_0, int recipe_b_1, bool has_recipe_b, " + "str compiled_dims, bool disable_ue8m0_cast) -> ()" + ); + ops.impl("fp8_fp4_gemm_nn", torch::kCUDA, &deep_gemm_fp8_fp4_gemm_nn); + + ops.def( + "fp8_fp4_gemm_tn(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor? c, " + "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " + "int recipe_a_0, int recipe_a_1, bool has_recipe_a, " + "int recipe_b_0, int recipe_b_1, bool has_recipe_b, " + "str compiled_dims, bool disable_ue8m0_cast) -> ()" + ); + ops.impl("fp8_fp4_gemm_tn", torch::kCUDA, &deep_gemm_fp8_fp4_gemm_tn); + + ops.def( + "fp8_fp4_gemm_tt(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor? c, " + "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " + "int recipe_a_0, int recipe_a_1, bool has_recipe_a, " + "int recipe_b_0, int recipe_b_1, bool has_recipe_b, " + "str compiled_dims, bool disable_ue8m0_cast) -> ()" + ); + ops.impl("fp8_fp4_gemm_tt", torch::kCUDA, &deep_gemm_fp8_fp4_gemm_tt); + + // M-grouped FP8/FP4 GEMM ops (CUDA dispatch) + ops.def( + "m_grouped_fp8_fp4_gemm_nt_contiguous(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor grouped_layout, " + "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " + "int recipe_a_0, int recipe_a_1, bool has_recipe_a, " + "int recipe_b_0, int recipe_b_1, bool has_recipe_b, " + "str compiled_dims, bool disable_ue8m0_cast, " + "bool use_psum_layout, int expected_m_for_psum_layout, " + "bool has_expected_m_for_psum_layout) -> ()" + ); + ops.impl("m_grouped_fp8_fp4_gemm_nt_contiguous", torch::kCUDA, + &deep_gemm_m_grouped_fp8_fp4_gemm_nt_contiguous); + + ops.def( + "m_grouped_fp8_fp4_gemm_nn_contiguous(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor grouped_layout, " + "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " + "int recipe_a_0, int recipe_a_1, bool has_recipe_a, " + "int recipe_b_0, int recipe_b_1, bool has_recipe_b, " + "str compiled_dims, bool disable_ue8m0_cast, " + "bool use_psum_layout) -> ()" + ); + ops.impl("m_grouped_fp8_fp4_gemm_nn_contiguous", torch::kCUDA, + &deep_gemm_m_grouped_fp8_fp4_gemm_nn_contiguous); + + ops.def( + "m_grouped_fp8_fp4_gemm_nt_masked(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor masked_m, int expected_m, " + "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " + "int recipe_a_0, int recipe_a_1, bool has_recipe_a, " + "int recipe_b_0, int recipe_b_1, bool has_recipe_b, " + "str compiled_dims, bool disable_ue8m0_cast) -> ()" + ); + ops.impl("m_grouped_fp8_fp4_gemm_nt_masked", torch::kCUDA, + &deep_gemm_m_grouped_fp8_fp4_gemm_nt_masked); + + // K-grouped FP8 GEMM ops (CUDA dispatch) + ops.def( + "k_grouped_fp8_gemm_tn_contiguous(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor ks_tensor, Tensor? c, " + "int recipe_0, int recipe_1, int recipe_2, " + "str compiled_dims) -> ()" + ); + ops.impl("k_grouped_fp8_gemm_tn_contiguous", torch::kCUDA, + &deep_gemm_k_grouped_fp8_gemm_tn_contiguous); + + ops.def( + "k_grouped_fp8_gemm_nt_contiguous(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor ks_tensor, Tensor? c, " + "int recipe_0, int recipe_1, int recipe_2, " + "str compiled_dims) -> ()" + ); + ops.impl("k_grouped_fp8_gemm_nt_contiguous", torch::kCUDA, + &deep_gemm_k_grouped_fp8_gemm_nt_contiguous); + + // BF16 GEMM ops (CUDA dispatch) + ops.def( + "bf16_gemm_nt(Tensor a, Tensor b, Tensor d, Tensor? c, " + "str compiled_dims) -> ()" + ); + ops.impl("bf16_gemm_nt", torch::kCUDA, &deep_gemm_bf16_gemm_nt); + + ops.def( + "bf16_gemm_nn(Tensor a, Tensor b, Tensor d, Tensor? c, " + "str compiled_dims) -> ()" + ); + ops.impl("bf16_gemm_nn", torch::kCUDA, &deep_gemm_bf16_gemm_nn); + + ops.def( + "bf16_gemm_tn(Tensor a, Tensor b, Tensor d, Tensor? c, " + "str compiled_dims) -> ()" + ); + ops.impl("bf16_gemm_tn", torch::kCUDA, &deep_gemm_bf16_gemm_tn); + + ops.def( + "bf16_gemm_tt(Tensor a, Tensor b, Tensor d, Tensor? c, " + "str compiled_dims) -> ()" + ); + ops.impl("bf16_gemm_tt", torch::kCUDA, &deep_gemm_bf16_gemm_tt); + + // M-grouped BF16 GEMM ops (CUDA dispatch) + ops.def( + "m_grouped_bf16_gemm_nt_contiguous(" + "Tensor a, Tensor b, Tensor d, Tensor grouped_layout, " + "str compiled_dims, bool use_psum_layout, " + "int expected_m_for_psum_layout, bool has_expected_m_for_psum_layout) -> ()" + ); + ops.impl("m_grouped_bf16_gemm_nt_contiguous", torch::kCUDA, + &deep_gemm_m_grouped_bf16_gemm_nt_contiguous); + + ops.def( + "m_grouped_bf16_gemm_nn_contiguous(" + "Tensor a, Tensor b, Tensor d, Tensor grouped_layout, " + "str compiled_dims, bool use_psum_layout) -> ()" + ); + ops.impl("m_grouped_bf16_gemm_nn_contiguous", torch::kCUDA, + &deep_gemm_m_grouped_bf16_gemm_nn_contiguous); + + ops.def( + "m_grouped_bf16_gemm_nt_masked(" + "Tensor a, Tensor b, Tensor d, Tensor masked_m, " + "int expected_m, str compiled_dims) -> ()" + ); + ops.impl("m_grouped_bf16_gemm_nt_masked", torch::kCUDA, + &deep_gemm_m_grouped_bf16_gemm_nt_masked); + + // K-grouped BF16 GEMM ops (CUDA dispatch) + ops.def( + "k_grouped_bf16_gemm_tn_contiguous(" + "Tensor a, Tensor b, Tensor d, Tensor ks_tensor, Tensor? c, " + "str compiled_dims) -> ()" + ); + ops.impl("k_grouped_bf16_gemm_tn_contiguous", torch::kCUDA, + &deep_gemm_k_grouped_bf16_gemm_tn_contiguous); + + // cuBLASLt GEMM ops (CUDA dispatch) + ops.def( + "cublaslt_gemm_nt(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()" + ); + ops.impl("cublaslt_gemm_nt", torch::kCUDA, &deep_gemm_cublaslt_gemm_nt); + + ops.def( + "cublaslt_gemm_nn(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()" + ); + ops.impl("cublaslt_gemm_nn", torch::kCUDA, &deep_gemm_cublaslt_gemm_nn); + + ops.def( + "cublaslt_gemm_tn(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()" + ); + ops.impl("cublaslt_gemm_tn", torch::kCUDA, &deep_gemm_cublaslt_gemm_tn); + + ops.def( + "cublaslt_gemm_tt(Tensor a, Tensor b, Tensor d, Tensor? c) -> ()" + ); + ops.impl("cublaslt_gemm_tt", torch::kCUDA, &deep_gemm_cublaslt_gemm_tt); + + // Attention ops (CUDA dispatch) + ops.def( + "fp8_gemm_nt_skip_head_mid(" + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, " + "int head_split_left, int head_split_mid, int head_split_right, " + "int recipe_0, int recipe_1, int recipe_2, bool has_recipe, " + "str compiled_dims, bool disable_ue8m0_cast) -> ()" + ); + ops.impl("fp8_gemm_nt_skip_head_mid", torch::kCUDA, + &deep_gemm_fp8_gemm_nt_skip_head_mid); + + ops.def( + "fp8_mqa_logits(" + "Tensor q, Tensor kv_data, Tensor kv_sf, " + "Tensor weights, Tensor cu_seq_len_k_start, Tensor cu_seq_len_k_end, " + "bool clean_logits, int max_seqlen_k) -> Tensor" + ); + ops.impl("fp8_mqa_logits", torch::kCUDA, &deep_gemm_fp8_mqa_logits); + + ops.def( + "get_paged_mqa_logits_metadata(" + "Tensor context_lens, int block_kv, int num_sms) -> Tensor" + ); + ops.impl("get_paged_mqa_logits_metadata", torch::kCUDA, + &deep_gemm_get_paged_mqa_logits_metadata); + + ops.def( + "fp8_paged_mqa_logits(" + "Tensor q, Tensor fused_kv_cache, " + "Tensor weights, Tensor context_lens, " + "Tensor block_table, Tensor schedule_meta, " + "int max_context_len, bool clean_logits) -> Tensor" + ); + ops.impl("fp8_paged_mqa_logits", torch::kCUDA, + &deep_gemm_fp8_paged_mqa_logits); + + // Einsum ops (CUDA dispatch) + ops.def( + "einsum(str expr, Tensor a, Tensor b, Tensor d, " + "Tensor? c, bool use_cublaslt) -> ()" + ); + ops.impl("einsum", torch::kCUDA, &deep_gemm_einsum); + + ops.def( + "fp8_einsum(str expr, " + "Tensor a_data, Tensor a_sf, Tensor b_data, Tensor b_sf, " + "Tensor d, Tensor? c, " + "int recipe_0, int recipe_1, int recipe_2) -> ()" + ); + ops.impl("fp8_einsum", torch::kCUDA, &deep_gemm_fp8_einsum); + + // Hyperconnection ops (CUDA dispatch) + ops.def( + "tf32_hc_prenorm_gemm(" + "Tensor a, Tensor b, Tensor d, Tensor sqr_sum, " + "int num_splits, bool has_num_splits) -> ()" + ); + ops.impl("tf32_hc_prenorm_gemm", torch::kCUDA, + &deep_gemm_tf32_hc_prenorm_gemm); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/deep-gemm/torch-ext/torch_binding.h b/deep-gemm/torch-ext/torch_binding.h new file mode 100644 index 00000000..82bc9012 --- /dev/null +++ b/deep-gemm/torch-ext/torch_binding.h @@ -0,0 +1,256 @@ +#pragma once + +#include +#include +#include + +using Tensor = at::Tensor; + +// ============================================================================ +// Runtime ops +// ============================================================================ + +void deep_gemm_init(const std::string& path, const std::string& cuda_home); + +void deep_gemm_set_num_sms(int64_t num_sms); +int64_t deep_gemm_get_num_sms(); + +void deep_gemm_set_tc_util(int64_t tc_util); +int64_t deep_gemm_get_tc_util(); + +// ============================================================================ +// Layout ops +// ============================================================================ + +int64_t deep_gemm_get_mk_alignment_for_contiguous_layout(); + +Tensor deep_gemm_get_tma_aligned_size(int64_t mn, int64_t element_size); + +Tensor deep_gemm_get_mn_major_tma_aligned_tensor(const Tensor& sf); + +Tensor deep_gemm_get_mn_major_tma_aligned_packed_ue8m0_tensor(const Tensor& sf); + +Tensor deep_gemm_get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + const Tensor& sf, const Tensor& ks_tensor, const Tensor& ks_int_tensor); + +Tensor deep_gemm_transform_sf_into_required_layout( + const Tensor& sf, int64_t mn, int64_t k, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_ab_0, int64_t recipe_ab_1, bool has_recipe_ab, + int64_t num_groups, bool has_num_groups, + bool is_sfa, bool disable_ue8m0_cast); + +// ============================================================================ +// GEMM ops - FP8/FP4 +// ============================================================================ + +void deep_gemm_fp8_fp4_gemm_nt( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast); + +void deep_gemm_fp8_fp4_gemm_nn( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast); + +void deep_gemm_fp8_fp4_gemm_tn( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast); + +void deep_gemm_fp8_fp4_gemm_tt( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast); + +// ============================================================================ +// GEMM ops - M-grouped FP8/FP4 +// ============================================================================ + +void deep_gemm_m_grouped_fp8_fp4_gemm_nt_contiguous( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& grouped_layout, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast, + bool use_psum_layout, int64_t expected_m_for_psum_layout, + bool has_expected_m_for_psum_layout); + +void deep_gemm_m_grouped_fp8_fp4_gemm_nn_contiguous( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& grouped_layout, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast, + bool use_psum_layout); + +void deep_gemm_m_grouped_fp8_fp4_gemm_nt_masked( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& masked_m, int64_t expected_m, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + int64_t recipe_a_0, int64_t recipe_a_1, bool has_recipe_a, + int64_t recipe_b_0, int64_t recipe_b_1, bool has_recipe_b, + const std::string& compiled_dims, bool disable_ue8m0_cast); + +// ============================================================================ +// GEMM ops - K-grouped FP8 +// ============================================================================ + +void deep_gemm_k_grouped_fp8_gemm_tn_contiguous( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& ks_tensor, + const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, + const std::string& compiled_dims); + +void deep_gemm_k_grouped_fp8_gemm_nt_contiguous( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const Tensor& ks_tensor, + const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, + const std::string& compiled_dims); + +// ============================================================================ +// GEMM ops - BF16 +// ============================================================================ + +void deep_gemm_bf16_gemm_nt( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, const std::string& compiled_dims); + +void deep_gemm_bf16_gemm_nn( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, const std::string& compiled_dims); + +void deep_gemm_bf16_gemm_tn( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, const std::string& compiled_dims); + +void deep_gemm_bf16_gemm_tt( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, const std::string& compiled_dims); + +// ============================================================================ +// GEMM ops - M-grouped BF16 +// ============================================================================ + +void deep_gemm_m_grouped_bf16_gemm_nt_contiguous( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& grouped_layout, const std::string& compiled_dims, + bool use_psum_layout, int64_t expected_m_for_psum_layout, + bool has_expected_m_for_psum_layout); + +void deep_gemm_m_grouped_bf16_gemm_nn_contiguous( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& grouped_layout, const std::string& compiled_dims, + bool use_psum_layout); + +void deep_gemm_m_grouped_bf16_gemm_nt_masked( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& masked_m, int64_t expected_m, + const std::string& compiled_dims); + +// ============================================================================ +// GEMM ops - K-grouped BF16 +// ============================================================================ + +void deep_gemm_k_grouped_bf16_gemm_tn_contiguous( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& ks_tensor, const std::optional& c, + const std::string& compiled_dims); + +// ============================================================================ +// GEMM ops - cuBLASLt +// ============================================================================ + +void deep_gemm_cublaslt_gemm_nt( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c); + +void deep_gemm_cublaslt_gemm_nn( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c); + +void deep_gemm_cublaslt_gemm_tn( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c); + +void deep_gemm_cublaslt_gemm_tt( + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c); + +// ============================================================================ +// Attention ops +// ============================================================================ + +void deep_gemm_fp8_gemm_nt_skip_head_mid( + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, + int64_t head_split_left, int64_t head_split_mid, int64_t head_split_right, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2, bool has_recipe, + const std::string& compiled_dims, bool disable_ue8m0_cast); + +Tensor deep_gemm_fp8_mqa_logits( + const Tensor& q, + const Tensor& kv_data, const Tensor& kv_sf, + const Tensor& weights, + const Tensor& cu_seq_len_k_start, const Tensor& cu_seq_len_k_end, + bool clean_logits, int64_t max_seqlen_k); + +Tensor deep_gemm_get_paged_mqa_logits_metadata( + const Tensor& context_lens, int64_t block_kv, int64_t num_sms); + +Tensor deep_gemm_fp8_paged_mqa_logits( + const Tensor& q, const Tensor& fused_kv_cache, + const Tensor& weights, const Tensor& context_lens, + const Tensor& block_table, const Tensor& schedule_meta, + int64_t max_context_len, bool clean_logits); + +// ============================================================================ +// Einsum ops +// ============================================================================ + +void deep_gemm_einsum( + const std::string& expr, + const Tensor& a, const Tensor& b, const Tensor& d, + const std::optional& c, bool use_cublaslt); + +void deep_gemm_fp8_einsum( + const std::string& expr, + const Tensor& a_data, const Tensor& a_sf, + const Tensor& b_data, const Tensor& b_sf, + const Tensor& d, const std::optional& c, + int64_t recipe_0, int64_t recipe_1, int64_t recipe_2); + +// ============================================================================ +// Hyperconnection ops +// ============================================================================ + +void deep_gemm_tf32_hc_prenorm_gemm( + const Tensor& a, const Tensor& b, const Tensor& d, + const Tensor& sqr_sum, int64_t num_splits, bool has_num_splits); diff --git a/scripts/check_kernel_freshness.py b/scripts/check_kernel_freshness.py index a6fbb669..a2cc7b03 100644 --- a/scripts/check_kernel_freshness.py +++ b/scripts/check_kernel_freshness.py @@ -45,6 +45,7 @@ "mlx-quantization-metal-kernels": "https://github.com/ml-explore/mlx", "mlx-rmsnorm": "https://github.com/ml-explore/mlx", "sage-attention": "https://github.com/thu-ml/SageAttention", + "deep-gemm": "https://github.com/deepseek-ai/DeepGEMM", "bitsandbytes-mps": "", }