diff --git a/CMakeLists.txt b/CMakeLists.txt index e780c1565..5eefd68b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,8 @@ set(version 1.1.0) # Check support for CUDA/HIP in Cmake project(composable_kernel VERSION ${version}) +find_package(Python3 3.7 COMPONENTS Interpreter REQUIRED) + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") if (DTYPES) diff --git a/example/91_tile_program/CMakeLists.txt b/example/91_tile_program/CMakeLists.txt index 27d3c67ad..123eabe3e 100644 --- a/example/91_tile_program/CMakeLists.txt +++ b/example/91_tile_program/CMakeLists.txt @@ -1,8 +1,12 @@ -add_example_executable(example_im2col im2col.cpp) -add_example_executable(example_gemm gemm.cpp) -add_example_executable(example_gemm_gemm gemm_gemm.cpp) -add_example_executable(example_reduce reduce.cpp) -add_example_executable(example_softmax softmax.cpp) -add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp) -add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp) -add_example_executable(example_fmha_fwd fmha_fwd.cpp) +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) + +add_subdirectory(batched_gemm_softmax_gemm) +add_subdirectory(fmha) +add_subdirectory(gemm) +add_subdirectory(gemm_gemm) +add_subdirectory(gemm_softmax_gemm) +add_subdirectory(im2col) +add_subdirectory(reduce) +add_subdirectory(softmax) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm/CMakeLists.txt b/example/91_tile_program/batched_gemm_softmax_gemm/CMakeLists.txt new file mode 100644 index 000000000..69fae2e10 --- /dev/null +++ b/example/91_tile_program/batched_gemm_softmax_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm.cpp b/example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.cpp similarity index 98% rename from example/91_tile_program/batched_gemm_softmax_gemm.cpp rename to example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.cpp index f785ffcf9..8c6cbbf45 100644 --- a/example/91_tile_program/batched_gemm_softmax_gemm.cpp +++ b/example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.cpp @@ -13,8 +13,8 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_batched_gemm.hpp" -#include "reference_batched_softmax.hpp" +#include "reference/reference_batched_gemm.hpp" +#include "reference/reference_batched_softmax.hpp" #include "batched_gemm_softmax_gemm.hpp" int main(int argc, char* argv[]) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm.hpp b/example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.hpp similarity index 98% rename from example/91_tile_program/batched_gemm_softmax_gemm.hpp rename to example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.hpp index f396222fe..179440a89 100644 --- a/example/91_tile_program/batched_gemm_softmax_gemm.hpp +++ b/example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.hpp @@ -17,7 +17,7 @@ #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp" -#include "gemm_softmax_gemm_impl.hpp" +#include "gemm_softmax_gemm/gemm_softmax_gemm_impl.hpp" // S[M0, N0] = Q[M0, K0] * K[N0, K0] // P[M0, N0] = Softmax(S[M0, N0]) diff --git a/example/91_tile_program/common/arg_parser.hpp b/example/91_tile_program/common/arg_parser.hpp new file mode 100644 index 000000000..58155d078 --- /dev/null +++ b/example/91_tile_program/common/arg_parser.hpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +/* + * arg parser for + * -[key0]=[value0] -[key1]=[value1] ... + */ +class ArgParser +{ + public: + class Arg + { + public: + std::string name; + std::string value; + std::string help_text; + }; + + ArgParser() {} + ArgParser& insert(const std::string& _name, + const std::string& _default_value, + const std::string& _help_text) + { + Arg in; + in.name = _name; + in.value = _default_value; + in.help_text = _help_text; + + if(input_map.count(_name) != 0) + { + printf("arg:%s already exist\n", _name.c_str()); + } + else + { + input_map[_name] = in; + keys.push_back(_name); + } + return *this; + } + void print() + { + printf("args:\n"); + for(auto& key : keys) + { + auto value = input_map[key]; + std::vector help_text_lines; + size_t pos = 0; + for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;) + { + help_text_lines.push_back(std::string(value.help_text.begin() + pos, + value.help_text.begin() + next_pos++)); + pos = next_pos; + next_pos = value.help_text.find('\n', pos); + } + help_text_lines.push_back( + std::string(value.help_text.begin() + pos, value.help_text.end())); + + std::string default_value = std::string("(default:") + value.value + std::string(")"); + + std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key + << std::setw(4) << " " << help_text_lines[0] << " " << default_value + << std::endl; + + for(auto help_next_line = std::next(help_text_lines.begin()); + help_next_line != help_text_lines.end(); + ++help_next_line) + { + std::cout << std::setw(17) << " " << *help_next_line << std::endl; + } + } + } + bool parse(int argc, char* argv[], int start_index = 1) + { + if(argc < start_index) + { + printf("not enough args\n"); + return false; + } + for(int i = start_index; i < argc; i++) + { + char* cur_arg = argv[i]; + if(cur_arg[0] != '-') + { + printf("illegal input\n"); + print(); + return false; + } + else + { + std::string text(cur_arg + 1); + if(text == "?") + { + print(); + return false; + } + auto pos = text.find('='); + if(pos == std::string::npos) + { + printf("arg should be [key]=[value] pair, here:%s\n", text.c_str()); + return false; + } + if(pos >= (text.size() - 1)) + { + printf("cant find value after \"=\", here:%s\n", text.c_str()); + return false; + } + auto key = text.substr(0, pos); + auto value = text.substr(pos + 1); + if(input_map.count(key) == 0) + { + printf("no such arg:%s\n", key.c_str()); + return false; + } + input_map[key].value = value; + } + } + return true; + } + + std::string get_str(const std::string& name) const + { + std::string value = input_map.at(name).value; + return value; + } + + int get_int(const std::string& name) const + { + int value = atoi(input_map.at(name).value.c_str()); + return value; + } + + uint32_t get_uint32(const std::string& name) const + { + uint32_t value = strtoul(input_map.at(name).value.c_str(), nullptr, 10); + return value; + } + + uint64_t get_uint64(const std::string& name) const + { + uint64_t value = strtoull(input_map.at(name).value.c_str(), nullptr, 10); + return value; + } + + bool get_bool(const std::string& name) const + { + auto v = input_map.at(name).value; + if(v.compare("t") == 0 || v.compare("true") == 0) + return true; + if(v.compare("f") == 0 || v.compare("false") == 0) + return false; + int value = atoi(v.c_str()); + return value == 0 ? false : true; + } + + float get_float(const std::string& name) const + { + double value = atof(input_map.at(name).value.c_str()); + return static_cast(value); + } + + double get_double(const std::string& name) const + { + double value = atof(input_map.at(name).value.c_str()); + return value; + } + + private: + std::unordered_map input_map; + std::vector keys; +}; diff --git a/example/91_tile_program/fmha/CMakeLists.txt b/example/91_tile_program/fmha/CMakeLists.txt new file mode 100644 index 000000000..a2255a025 --- /dev/null +++ b/example/91_tile_program/fmha/CMakeLists.txt @@ -0,0 +1,41 @@ +# generate a list of kernels, but not actually emit files at config stage +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt +) + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory +# as current cmake list, otherwise will not figure out the dependency properly +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${FMHA_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --output_dir ${CMAKE_CURRENT_BINARY_DIR} +) + +set(EXAMPLE_FMHA_FWD "example_fmha_fwd") +add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp) +target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) + +# NOTE: this is dangerous since will change the whole kernel to flush denormals +# WIP with compiler team for an exp2 intrinsic..., then remove this +if(NOT DEFINED FMHA_FWD_FAST_EXP2) + set(FMHA_FWD_FAST_EXP2 true) +endif() + +set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +# ... because they are auto-generated +if(FMHA_FWD_FAST_EXP2) +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) +else() +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=0) +endif() + +# Allow comparing floating points directly in order to check sentinel values +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) + +target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) diff --git a/example/91_tile_program/fmha/README.md b/example/91_tile_program/fmha/README.md new file mode 100644 index 000000000..8b0d2521a --- /dev/null +++ b/example/91_tile_program/fmha/README.md @@ -0,0 +1,90 @@ +# fused multi-head attention + +This folder contains example for fmha(fused multi-head attention) using ck tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. + +## build +``` +# in the root of ck +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make example_fmha_fwd -j +``` +This will result in an executable `build/bin/example_fmha_fwd` + +## kernel +The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck. We may still have an implementation under ck's include path (in the future) for the kernel template. + +There are 3 template parameters for this kernel template. +* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. +* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). +* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. + +## codegen +To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. + +## executable +`example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/example_fmha_fwd -?` to list all supported args +``` +args: + -v weather do CPU validation or not (default:1) + -mode kernel mode. 0:batch, 1:group (default:0) + -b batch size (default:2) + -h num of head, for q (default:8) + -h_k num of head, for k/v, 0 means equal to h (default:0) + if not equal to h, then this is GQA/MQA case + -s seqlen_q (default:3328) + -s_k seqlen_k, 0 means equal to s (default:0) + -d head dim for q, k (default:128) + -d_v head dim for v, 0 means equal to d (default:0) + -scale scale factor. 0 means equal to 1/sqrt(seqlen) (default:0) + -iperm permute input (default:1) + if true, will be b*h*s*d, else b*s*h*d + -operm permute output (default:1) + -bias add bias or not (default:0) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -mask 0: no mask, 1: top-left, 2:bottom-right (default:0) + 't:l,r', top-left local-attn with left right size + 'b:l,r', bottom-r local-attn with left right size + 'g:y,x', generic attention mask coordinate with y/x size + -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) + -lse 0 not store lse, 1 store lse (default:0) + -init init method. 0:random int, 1:random float, 2:trig float (default:1) +``` +Example: `./bin/example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. + +## support features +Currently we are still in rapid development stage, so more features/optimizations will be coming soon. + +### hdim +Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. We may consider optimize other hdim performance if have more request. We also have an experimental support for arbitrary hdim(even odd number), one can change the return value of `get_pad()` inside `generate.py` to achieve this. (Note: we may change the method or optimize arbitrary hdim support in the future) + +### group/batch mode +Currently we support both batch and group mode, by setting `-mode` = `0` or `1`, where in group mode we support each batch can have different seqlen + +### MQA/GQA +By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers. + +### input/output permute, and `b*s*3*h*d` +If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`. + +### attention bias +Attention bias is supported with the layout of `b*h*s*s` and bias value in float number. + +### lse +For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` + +### vlayout +We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts. + +### generic attention mask coordinate +We unify the mask expression into generic attention mask coordinate, providing an uniformed approach to describe causal top-left, causal bottom-right, local attention. +![](misc/gamc.png) + +(more description to be added) + +### dropout +TBD + +## FP8 experimental support +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `example_fmha_fwd`, on a gfx940/941/942 machine and ROCm 6.0+. Currently if you not explicitly setting `-v=0`(which will disable CPU verification), it will printout an error as much as `0.05`. We are still WIP to tune the kernel performance as well as the precision, so stay tuned for the updated performance(pipeline) +Currently we only support `-vlayout=c` for fp8, which is `hdim*seqlen` for V matrix. row major for V matrix support will come later. diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp new file mode 100644 index 000000000..1de17cfa8 --- /dev/null +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -0,0 +1,531 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor/tensor_view.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/common_header.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +#include "common/arg_parser.hpp" +#include "fmha_fwd.hpp" +#include "mask.hpp" +#include "reference/reference_batched_elementwise.hpp" +#include "reference/reference_batched_gemm.hpp" +#include "reference/reference_batched_masking.hpp" +#include "reference/reference_batched_softmax.hpp" +#include "utils.hpp" + +auto create_args(int argc, char* argv[]) +{ + ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "0", + "num of head, for k/v, 0 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", "3328", "seqlen_q") + .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(seqlen)") + .insert("descale_q", "1", "scale factor for fp8 quantization") + .insert("descale_k", "1", "scale factor for fp8 quantization") + .insert("descale_v", "1", "scale factor for fp8 quantization") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", "0", "add bias or not") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("mask", + "0", + "0: no mask, 1: top-left, 2:bottom-right\n" + "'t:l,r', top-left local-attn with left right size\n" + "'b:l,r', bottom-r local-attn with left right size\n" + "'g:y,x', generic attention mask coordinate with y/x size\n") + .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") + .insert("lse", "0", "0 not store lse, 1 store lse") + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 to use " + "non-deterministic random number as seed"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(int /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(int init_method) +{ + if(init_method == 0) + { + double rtol = 1e-2; + double atol = 1e-2; + return ck::make_tuple(rtol, atol); + } + else + { + double rtol = 3e-3; + double atol = 3e-3; + return ck::make_tuple(rtol, atol); + } +} + +template +bool run(const ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck::index_t batch = arg_parser.get_int("b"); + ck::index_t nhead = arg_parser.get_int("h"); + ck::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k == 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + ck::index_t seqlen_q = arg_parser.get_int("s"); + ck::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k == 0) + seqlen_k = seqlen_q; + ck::index_t hdim_q = arg_parser.get_int("d"); + ck::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v == 0) + hdim_v = hdim_q; + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale = arg_parser.get_float("scale"); + if(scale == .0f) + scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + + float descale_q = arg_parser.get_float("descale_q"); + float descale_k = arg_parser.get_float("descale_k"); + float descale_v = arg_parser.get_float("descale_v"); + + std::string vlayout = arg_parser.get_str("vlayout"); + bool use_bias = arg_parser.get_bool("bias"); + bool lse = arg_parser.get_bool("lse"); + + mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + + int init_method = arg_parser.get_int("init"); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + int stream_warmup = env_get_int("CK_WARMUP", 5); + int stream_repeat = env_get_int("CK_REPEAT", 20); + + StreamConfig stream_config{nullptr, true, 0, stream_warmup, stream_repeat}; + + const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + + using TypeConfig = FmhaFwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + using namespace ck::literals; + + flop += nhead * (2_uz * real_seqlen_q * real_seqlen_k * hdim_q + + 2_uz * real_seqlen_q * hdim_v * real_seqlen_k); + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k + + sizeof(ODataType) * real_seqlen_q * hdim_v); + } + } + + auto get_lengths = [&](bool permute, + ck::index_t b /*batch*/, + ck::index_t h /*nhead*/, + ck::index_t s /*seqlen*/, + ck::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + bool is_v_rowmajor = vlayout == std::string("r"); + + // host memory for storing all the tensor elements + const ck::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + const ck::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + + Tensor q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + Tensor k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + Tensor v_host( + is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); + // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host + // will not be used for verification at all (but will be copied to device anyway). + Tensor bias_host( + use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] + Tensor lse_host( + lse ? std::array{shape_batch, nhead, shape_seqlen_q} + : std::array{1, 1, 1} /* dummy shape for simplifying code */); + + Tensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + + if(init_method == 0) + { + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + } + else if(init_method == 1) + { + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + } + else if(init_method == 2) + { + ck::utils::FillTrigValue{}(q_host); + ck::utils::FillTrigValue{}(k_host); + ck::utils::FillTrigValue{}(v_host); + ck::utils::FillTrigValue{}(bias_host); + } + + DeviceMem q_buf(q_host.GetElementSpaceSizeInBytes()); + DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes()); + DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes()); + DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes()); + DeviceMem lse_buf(lse_host.GetElementSpaceSizeInBytes()); + DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes()); + DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if (iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias + << ", lse:" << lse << ", mask:" << mask << ", v:" << vlayout << std::flush; + + auto fmha_traits = fmha_fwd_traits{ + hdim_q, data_type, mode == mode_enum::group, is_v_rowmajor, mask.type, use_bias, lse}; + auto fmha_args = fmha_fwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + batch, + nhead, + nhead_k, + shape_seqlen_q, + shape_seqlen_k, + hdim_q, + hdim_v, + max_seqlen_q, + scale, + descale_q * descale_k, + descale_v, + i_perm, + o_perm, + mask.y, + mask.x}; + + float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); + + if(ave_time < 0) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::flush << std::endl; + return true; + } + + o_buf.FromDevice(o_host.data()); + lse_buf.FromDevice(lse_host.data()); + + bool pass = true; + + for(ck::index_t wb = 0; wb < batch; ++wb) + { + const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + const auto v_host_ref_lengths = std::array{nhead, hdim_v, real_seqlen_k}; + const auto v_host_ref_strides = + is_v_rowmajor ? std::array{hdim_v * real_seqlen_k, 1, hdim_v} + : std::array{hdim_v * real_seqlen_k, real_seqlen_k, 1}; + + Tensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + Tensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + Tensor v_host_ref(v_host_ref_lengths, v_host_ref_strides); + Tensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + Tensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + Tensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + Tensor lse_host_ref({nhead, real_seqlen_q}); + + ck::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + if (is_v_rowmajor) { + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + } + else { + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); + } + // clang-format on + + // reference + reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck::identity{}, + ck::identity{}, + [&](SaccDataType x) { return scale * x; }); + + if(use_bias) + { + Tensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + reference_batched_masking( + s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + } + else + { + reference_batched_masking( + s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + } + if(lse) + { + reference_batched_softmax( + s_host_ref, p_host_ref, lse_host_ref); + } + else + { + reference_batched_softmax( + s_host_ref, p_host_ref); + } + + reference_batched_gemm( + p_host_ref, v_host_ref, o_host_ref); + + Tensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // clang-format off + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool cur_pass = ck::utils::check_err( + o_host_result, o_host_ref, std::string("O Error: Incorrect results!"), rtol, atol); + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "O mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + + if(lse) + { + Tensor lse_host_result({nhead, real_seqlen_q}); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b, idx[0], idx[1] + query_offset); + }); + + bool lse_pass = ck::utils::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); + + pass &= lse_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/91_tile_program/fmha/fmha_fwd.hpp b/example/91_tile_program/fmha/fmha_fwd.hpp new file mode 100644 index 000000000..a6db9439c --- /dev/null +++ b/example/91_tile_program/fmha/fmha_fwd.hpp @@ -0,0 +1,326 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +#include "ck/tile_program/block_tile/block_masking.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp" +#include "ck/tile_program/tile/tile_fmha_shape.hpp" +#include "ck/tile_program/tile/tile_fmha_traits.hpp" + +#include "fmha_fwd_epilogue.hpp" +#include "fmha_fwd_kernel.hpp" +#include "fmha_fwd_tile_partitioner.hpp" +#include "mask.hpp" + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck::half_t; + using KDataType = ck::half_t; + using VDataType = ck::half_t; + using BiasDataType = ck::half_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck::bhalf_t; + using KDataType = ck::bhalf_t; + using VDataType = ck::bhalf_t; + using BiasDataType = ck::bhalf_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::bhalf_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::bhalf_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck::f8_t; + using KDataType = ck::f8_t; + using VDataType = ck::f8_t; + using BiasDataType = float; // TODO: fix me + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::f8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::f8_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck::bf8_t; + using KDataType = ck::bf8_t; + using VDataType = ck::bf8_t; + using BiasDataType = ck::bf8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::bf8_t; +}; + +struct FmhaMasks +{ + using NoMask = ck::tile_program::block::GenericAttentionMask; + using GenericMask = ck::tile_program::block::GenericAttentionMask; + using CausalMask = ck::tile_program::block::GenericAttentionMask; +}; + +// internal API, don't use this directly +template +auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck::index_t batch, + ck::index_t nhead, + ck::index_t nhead_k, + ck::index_t seqlen_q, + ck::index_t seqlen_k, + ck::index_t hdim_q, + ck::index_t hdim_v, + ck::index_t max_seqlen_q, + float scale, + float descale_qk, + float descale_sv, + bool i_perm, + bool o_perm, + ck::index_t mask_y, + ck::index_t mask_x) +{ + constexpr bool is_v_rowmajor = + ck::is_same_v; + + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias' + /// are 0. + // setup stride_* arguments + const ck::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck::index_t stride_v = [&]() { + if constexpr(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? seqlen_k : nhead_k * seqlen_k; + }(); + const ck::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k); + const ck::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q); + const ck::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q); + const ck::index_t nhead_stride_v = [&]() { + if constexpr(is_v_rowmajor) + return i_perm ? seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_k : seqlen_k; + }(); + const ck::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k); + const ck::index_t nhead_stride_lse = (seqlen_q * 1); + const ck::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck::index_t batch_stride_q = (nhead * seqlen_q * hdim_q); + const ck::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q); + const ck::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k); + const ck::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k); + const ck::index_t batch_stride_lse = (nhead * seqlen_q * 1); + const ck::index_t batch_stride_o = (nhead * seqlen_q * hdim_v); + + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + nhead / nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + mask_y, + mask_x, + descale_qk, + descale_sv); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + nhead / nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_lse, + batch_stride_o, + mask_y, + mask_x, + descale_qk, + descale_sv); + } + }(); + + dim3 grids = FmhaKernel::GridSize(batch, nhead, max_seqlen_q, hdim_v); + return ck::make_tuple(kargs, grids); +} + +// This is the args from caller to underneath API, different from the kernel +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + void* lse_ptr; + void* o_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck::index_t batch; + ck::index_t nhead; + ck::index_t nhead_k; + ck::index_t seqlen_q; + ck::index_t seqlen_k; + ck::index_t hdim_q; + ck::index_t hdim_v; + ck::index_t max_seqlen_q; + float scale; + float descale_qk; + float descale_sv; + bool i_perm; + bool o_perm; + ck::index_t mask_y; + ck::index_t mask_x; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + return fmha_fwd_create_kargs_and_grids(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.batch, + args.nhead, + args.nhead_k, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.max_seqlen_q, + args.scale, + args.descale_qk, + args.descale_sv, + args.i_perm, + args.o_perm, + args.mask_y, + args.mask_x); +} + +// this is internal API, will be generated across different files to speedup compile +template +struct fmha_fwd_traits_ +{ + static constexpr ck::index_t HDim = HDim_; + using DataType = ck::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + using FmhaMask = ck::remove_cvref_t; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLse = kStoreLse_; +}; + +template +float fmha_fwd_(const StreamConfig&, fmha_fwd_args); + +// This is the public API, will be generated by script +struct fmha_fwd_traits +{ + int hdim; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bool has_bias; + bool has_lse; +}; +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const StreamConfig&); diff --git a/example/91_tile_program/fmha_fwd_epilogue.hpp b/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp similarity index 88% rename from example/91_tile_program/fmha_fwd_epilogue.hpp rename to example/91_tile_program/fmha/fmha_fwd_epilogue.hpp index 94f4b17e5..6c5e6e861 100644 --- a/example/91_tile_program/fmha_fwd_epilogue.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp @@ -29,7 +29,6 @@ struct FmhaFwdEpilogue using namespace ck; using namespace ck::tile_program; - const auto o = tile_elementwise_in(type_convert, o_acc_tile); - store_tile(o_dram_window_tmp, o); + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); } }; diff --git a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp new file mode 100644 index 000000000..4b2c6d09f --- /dev/null +++ b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp @@ -0,0 +1,655 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor/tensor_view.hpp" +#include "ck/tile_program/tile/tile_window.hpp" + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] + +template +struct FmhaFwdKernel +{ + using TilePartitioner = ck::remove_cvref_t; + using FmhaPipeline = ck::remove_cvref_t; + using EpiloguePipeline = ck::remove_cvref_t; + static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + + using QDataType = ck::remove_cvref_t; + using KDataType = ck::remove_cvref_t; + using VDataType = ck::remove_cvref_t; + using BiasDataType = ck::remove_cvref_t; + using LSEDataType = ck::remove_cvref_t; + using ODataType = ck::remove_cvref_t; + static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8; + + using VLayout = ck::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + using FmhaMask = ck::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + template // to avoid duplicated base class prblem, introduce an template arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck::index_t seqlen_q; + ck::index_t seqlen_k; + ck::index_t hdim_q; + ck::index_t hdim_v; + + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck::index_t nhead_ratio_qk; + float scale; + + ck::index_t stride_q; + ck::index_t stride_k; + ck::index_t stride_v; + ck::index_t stride_o; + + ck::index_t nhead_stride_q; + ck::index_t nhead_stride_k; + ck::index_t nhead_stride_v; + ck::index_t nhead_stride_o; + }; + + struct FmhaFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck::index_t stride_bias = 0; + ck::index_t nhead_stride_bias = 0; + }; + + struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs + { + ck::index_t batch_stride_bias = 0; + }; + + struct FmhaFwdMaskKargs + { + ck::index_t mask_y, mask_x; + }; + + struct FmhaFwdFP8Kargs + { + float descale_qk; // q*k + float descale_sv; // s*v + // float * o_amax_ptr; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck::index_t nhead_stride_lse = 0; + }; + + struct FmhaFwdBatchModeLSEKargs : FmhaFwdCommonLSEKargs + { + ck::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck::index_t batch_stride_q; + ck::index_t batch_stride_k; + ck::index_t batch_stride_v; + ck::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + ck::index_t seqlen_q, + ck::index_t seqlen_k, + ck::index_t hdim_q, + ck::index_t hdim_v, + ck::index_t nhead_ratio_qk, + float scale, + ck::index_t stride_q, + ck::index_t stride_k, + ck::index_t stride_v, + ck::index_t stride_bias, + ck::index_t stride_o, + ck::index_t nhead_stride_q, + ck::index_t nhead_stride_k, + ck::index_t nhead_stride_v, + ck::index_t nhead_stride_bias, + ck::index_t nhead_stride_lse, + ck::index_t nhead_stride_o, + ck::index_t batch_stride_q, + ck::index_t batch_stride_k, + ck::index_t batch_stride_v, + ck::index_t batch_stride_bias, + ck::index_t batch_stride_lse, + ck::index_t batch_stride_o, + ck::index_t mask_y, + ck::index_t mask_x, + float descale_qk, + float descale_sv) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + nhead_ratio_qk, +#if CK_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck::math::log2e_v<>), +#else + scale, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8 args + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + if constexpr(kHasMask) + { + kargs.mask_y = mask_y; + kargs.mask_x = mask_x; + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kIsFp8) + { + kargs.descale_qk = descale_qk; + kargs.descale_sv = descale_sv; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck::index_t hdim_q, + ck::index_t hdim_v, + ck::index_t nhead_ratio_qk, + float scale, + ck::index_t stride_q, + ck::index_t stride_k, + ck::index_t stride_v, + ck::index_t stride_bias, + ck::index_t stride_o, + ck::index_t nhead_stride_q, + ck::index_t nhead_stride_k, + ck::index_t nhead_stride_v, + ck::index_t nhead_stride_bias, + ck::index_t nhead_stride_lse, + ck::index_t nhead_stride_o, + ck::index_t mask_y, + ck::index_t mask_x, + float descale_qk, + float descale_sv) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + nhead_ratio_qk, +#if CK_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck::math::log2e_v<>), +#else + scale, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8 args + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + if constexpr(kHasMask) + { + kargs.mask_y = mask_y; + kargs.mask_x = mask_x; + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + if constexpr(kIsFp8) + { + kargs.descale_qk = descale_qk; + kargs.descale_sv = descale_sv; + } + + return kargs; + } + + __host__ static constexpr auto GridSize(ck::index_t batch_size_, + ck::index_t nhead_, + ck::index_t seqlen_q_, + ck::index_t hdim_v_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + return ck::math::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + __device__ void operator()(Kargs kargs) const + { + using namespace ck; + using namespace ck::tile_program; + using namespace ck::tile_program::block; + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = + TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(ck::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(kHasBias) + { + batch_offset_bias = query_start * kargs.stride_bias + key_start; + } + else + { + batch_offset_bias = key_start; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(kHasBias) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + Number<32>{}, + Number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + Number<32>{}, + Number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(ck::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + Number<32>{}, + Number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return pad_tensor_view( + v_dram_transposed, + make_tuple(Number{}, Number{}), + Sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + Number<32>{}, + Number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(Number{}, + Number{}); + else + return make_tuple(Number{}, Number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(Number{}, Number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(Number{}, Number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables is not supported. Remove + /// following copy capture of the 'i_nhead' + /// if compiled in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(Number{}, Number{}); + if constexpr(kHasBias) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + Number<32>{}, + Number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + Sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(Number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = + make_naive_tensor_view(lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + Number<1>{}, + Number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, Sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k}; + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto o_acc_tile = [&]() { + if constexpr(kIsFp8) + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + kargs.scale, + kargs.descale_qk, + kargs.descale_sv, + smem_ptr); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + kargs.scale, + smem_ptr); + } + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + Number<32>{}, + Number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(Number{}, Number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; diff --git a/example/91_tile_program/fmha_fwd_tile_partitioner.hpp b/example/91_tile_program/fmha/fmha_fwd_tile_partitioner.hpp similarity index 81% rename from example/91_tile_program/fmha_fwd_tile_partitioner.hpp rename to example/91_tile_program/fmha/fmha_fwd_tile_partitioner.hpp index c01cadc4e..b6194716f 100644 --- a/example/91_tile_program/fmha_fwd_tile_partitioner.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_tile_partitioner.hpp @@ -24,7 +24,10 @@ struct FmhaFwdTilePartitioner ck::index_t hdim_v_) { // TODO: this may need tuning - return dim3((seqlen_q_ / kM0) * (hdim_v_ / kN1), batch_size_, nhead_); + return dim3(ck::math::integer_divide_ceil(seqlen_q_, kM0) * + ck::math::integer_divide_ceil(hdim_v_, kN1), + nhead_, + batch_size_); } __device__ auto operator()(ck::index_t /*seqlen_q*/, ck::index_t hdim_v) @@ -32,11 +35,11 @@ struct FmhaFwdTilePartitioner using namespace ck; // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = hdim_v / kN1; + const index_t num_tile_n1 = ck::math::integer_divide_ceil(hdim_v, kN1); const index_t i_block = blockIdx.x; - const index_t i_batch = blockIdx.y; - const index_t i_nhead = blockIdx.z; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; const auto f = [](index_t dividend, index_t divisor) { index_t quotient = dividend / divisor; diff --git a/example/91_tile_program/fmha/generate.py b/example/91_tile_program/fmha/generate.py new file mode 100644 index 000000000..f4594639b --- /dev/null +++ b/example/91_tile_program/fmha/generate.py @@ -0,0 +1,402 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +import itertools +from pathlib import Path +from typing import List, Optional, Tuple +from dataclasses import dataclass +import copy + +DTYPE_MAP = { + "fp16": "ck::half_t", + "bf16": "ck::bhalf_t", + "fp8" : "ck::f8_t" +} + +MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck::tile_program::block::BlockFmhaPipelineQRKSVS", + "qr_fp8" : "ck::tile_program::block::BlockFmhaPipelineQRKSVSFp8", + "qr_async" : "ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false" +} + +MASKS = ["no", "causal", "generic"] +DIRECTIONS = ["fwd"] +GEN_DIR = "" # in Cmake, have to generate files in same folder + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck::Sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps_{F_idx} = ck::Sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile_{F_idx} = ck::Sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape_{F_idx} = ck::tile_program::TileFmhaShape; + +using fmha_trait_{F_idx} = ck::tile_program::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_lse}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck::tile_program::block::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + FmhaFwdEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType>>; + +using fmha_kernel_{F_idx} = + FmhaFwdKernel, + fmha_pipeline_{F_idx}, + fmha_epilogue_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}>; + +template<> +float fmha_fwd_(const StreamConfig& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck::index_t kBlockPerCu = k_::kBlockPerCu; + return launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const StreamConfig& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ + switch (t.hdim){{ +{F_hdim_case} + default: + break; + }} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" case {F_hdim}: {{ +{F_inner_dispatch} + }} + break; +""" +MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::causal_top_left || t.mask_type == mask_enum::causal_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}>; + return fmha_fwd_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + vlayout : str + mask : str + bias : str # true/false + lse : str # + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.vlayout}-{self.mask}-{self.bias}-{self.lse}' + +class FmhaFwdApiPool: + def __init__(self): + self.pool = dict() + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for hdim in self.pool[dtype].keys(): + traits=self.pool[dtype][hdim] + inners=str() + for j, trait in enumerate(traits): + if0 = 'if' if j == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if0, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask], + F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_hdim=hdim, F_inner_dispatch=inners) + if1 = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if1, F_dtype=dtype, F_hdim_case=per_hdim_case) + + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along qk seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm : int # number of warps along q seqlen (block warps) + F_rn : int # number of warps along k seqlen(not used) + F_rk : int # number of warps along gemm-k(not used) + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn0}x{self.F_bk1}x{self.F_bk0blen}" +\ + f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" + +@dataclass +class FmhaFwdKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaFwdTileSize + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_pipeline : str # value from PIPIELINE_MAP + + @property + def template(self) -> str: + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_vlayout], + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_bias = BOOL_MAP[self.F_bias], + F_lse = BOOL_MAP[self.F_lse], + F_occupancy = self.F_tile.F_occupancy , + F_mask = MASK_MAP[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f"_v{self.F_vlayout[0]}" +\ + f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\ + f"_{BOOL_MAP[self.F_bias][0]}_m{self.F_mask[0]}_l{BOOL_MAP[self.F_lse][0]}_{self.F_pipeline}" + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait(hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + vlayout=self.F_vlayout, + mask=self.F_mask, + bias=self.F_bias, + lse=self.F_lse) + +# TODO: design a more practical way to do it +# this is current supported tile size. +def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: + if direction == 'fwd': + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, 2), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, 3), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, 2), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, 1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, 2) + } + else: + return None + else: + return None + +def get_blobs() -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_vlayout(dtype, hdim): + if dtype in ['fp16', 'bf16']: + return 'row' + elif dtype in ['fp8', 'bf8']: + return 'col' + else: + assert Fasle + def get_pipeline(dtype, hdim): + if dtype in ['fp16', 'bf16']: + if hdim == 256: + return 'qr' + else: + return 'qr_async' + elif dtype in ['fp8', 'bf8']: + return 'qr_fp8' + else: + assert Fasle + def get_pad(dtype, hdim): + return 'f' + + gen = list() + api_pool = FmhaFwdApiPool() + + for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): + d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) + if d == None: + continue + for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + tile = d[hdim_str] + hdim = int(hdim_str) + if dtype in ['fp8', 'bf8'] and lse == "t": + continue + k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, F_vlayout=get_vlayout(dtype, hdim), + F_spad=get_pad(dtype, hdim), F_skpad=get_pad(dtype, hdim), F_dpad=get_pad(dtype, hdim), + F_dvpad=get_pad(dtype, hdim), F_bias=bias, F_lse=lse, F_mask=mask, F_mode=mode, + F_pipeline=get_pipeline(dtype, hdim)) + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir: Optional[str]) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) / GEN_DIR + + output_dir.mkdir(parents=True, exist_ok=True) + api_pool, kernels = get_blobs() + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_api(api_pool, output_dir) + +# list all the files that will be generated +def list_blobs(output_file: Optional[str]) -> None: + assert output_file is not None + file_path = Path(output_file) + with file_path.open('a') as f: + _, kernels = get_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen api for CK fmha kernel", + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="write all the blobs into a directory" + ) + parser.add_argument( + "-l", + "--list_blobs", + required=False, + help="list all the kernels to a file" + ) + args = parser.parse_args() + if args.list_blobs is not None: + list_blobs(args.list_blobs) + else: + write_blobs(args.output_dir) diff --git a/example/91_tile_program/fmha/mask.hpp b/example/91_tile_program/fmha/mask.hpp new file mode 100644 index 000000000..e64df0ba0 --- /dev/null +++ b/example/91_tile_program/fmha/mask.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tile_program/block_tile/block_masking.hpp" + +enum class mask_enum +{ + no_mask = 0, + causal_top_left, + causal_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck::index_t y, x; + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::causal_top_left) + os << "tl"; + else if(type == mask_enum::causal_bottom_right) + os << "br"; + else + { + os << "g(" << y << "/" << x << ")"; + } + } + static mask_info decode(std::string str, ck::index_t seqlen_q, ck::index_t seqlen_k) + { + ck::index_t x_total = seqlen_k; + ck::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.At(ck::Number<0>{}); + tmp.x = r.At(ck::Number<1>{}); + } + else if(t == "b") + { + auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.At(ck::Number<0>{}); + tmp.x = r.At(ck::Number<1>{}); + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + else + { + // should be 0, 1, 2 + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::causal_top_left) + { + tmp.y = seqlen_q; + tmp.x = 1; + } + else if(tmp.type == mask_enum::causal_bottom_right) + { + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + } + } + return tmp; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); +}; + +inline std::ostream& operator<<(std::ostream& os, const mask_info& mi) +{ + mi.serialize(os); + return os; +} diff --git a/example/91_tile_program/fmha/misc/gamc.png b/example/91_tile_program/fmha/misc/gamc.png new file mode 100644 index 000000000..2c96951f3 Binary files /dev/null and b/example/91_tile_program/fmha/misc/gamc.png differ diff --git a/example/91_tile_program/fmha/script/benchmark.sh b/example/91_tile_program/fmha/script/benchmark.sh new file mode 100644 index 000000000..a8f3a8202 --- /dev/null +++ b/example/91_tile_program/fmha/script/benchmark.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/example_fmha_fwd +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 64 128 256 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 + +done +done +done diff --git a/example/91_tile_program/fmha/script/smoke_test.sh b/example/91_tile_program/fmha/script/smoke_test.sh new file mode 100644 index 000000000..40e17fd88 --- /dev/null +++ b/example/91_tile_program/fmha/script/smoke_test.sh @@ -0,0 +1,23 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/example_fmha_fwd + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 128 64 256 ; do +for bias in 0 1 ; do + +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -iperm=$perm -operm=$perm -v=1 +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=256 -bias=$bias -iperm=$perm -operm=$perm -v=1 +$EXE -prec=$prec -b=2 -h=2 -h_k=1 -d=$hdim -s=512 -s_k=256 -bias=$bias -iperm=$perm -operm=$perm -v=1 +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=256 -s_k=512 -bias=$bias -iperm=$perm -operm=$perm -v=1 +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -iperm=$perm -operm=$perm -mask=1 -v=1 +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -iperm=$perm -operm=$perm -mask=2 -v=1 +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -iperm=$perm -operm=$perm -mask=g:128,32 -v=1 + + +done +done +done +done diff --git a/example/91_tile_program/fmha/utils.hpp b/example/91_tile_program/fmha/utils.hpp new file mode 100644 index 000000000..5a8ef1042 --- /dev/null +++ b/example/91_tile_program/fmha/utils.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/utility/span.hpp" + +enum class mode_enum +{ + batch = 0, + group +}; + +std::ostream& operator<<(std::ostream& stream, mode_enum mode) +{ + return stream << (mode == mode_enum::batch ? "batch" : "group"); +} + +std::vector to_seqstarts(ck::span seqlens) +{ + std::vector seqstarts = {0}; + for(int32_t seqlen : seqlens) + { + seqstarts.push_back(seqstarts.back() + seqlen); + } + assert(seqstarts.size() == seqlens.size() + 1); + return seqstarts; +} + +std::vector generate_seqlens(mode_enum mode, + unsigned count, + int32_t seqlens_sum, + std::optional seed = std::nullopt) +{ + assert(0 < count); + + std::vector seqlens(count, seqlens_sum); + + if(mode == mode_enum::group && 1 < count) + { + using size_type = std::vector::size_type; + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); + + for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + // make sure each elements of seqlens is always greater than 0 + if(seqlens[to_decrease] == 1) + { + continue; + } + + const size_type to_increase = (to_decrease + next_step()) % count; + + --seqlens[to_decrease]; + ++seqlens[to_increase]; + } + } + + return seqlens; +} + +std::vector generate_seqstarts(mode_enum mode, + unsigned count, + int32_t seqlens_sum, + std::optional seed = std::nullopt) +{ + return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); +} + +int env_get_int(const char* var_name, int default_int) +{ + char* v = getenv(var_name); + int r = default_int; + if(v) + r = atoi(v); + return r; +} diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp deleted file mode 100644 index 7b0934278..000000000 --- a/example/91_tile_program/fmha_fwd.cpp +++ /dev/null @@ -1,365 +0,0 @@ -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/tensor/tensor_view.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" - -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" -#include "ck/tile_program/tile/tile_fmha_shape.hpp" - -#include "reference_batched_gemm.hpp" -#include "reference_batched_softmax.hpp" -#include "fmha_fwd_kernel.hpp" -#include "fmha_fwd_tile_partitioner.hpp" -#include "fmha_fwd_epilogue.hpp" - -#if 1 -using QDataType = ck::half_t; -using KDataType = ck::half_t; -using VDataType = ck::half_t; -using SaccDataType = float; // data type for first gemm accumulation -using SMPLComputeDataType = float; // data type for reduction, softmax -using PDataType = ck::half_t; // data type for A matrix of second gemm -using OaccDataType = float; // data type for second gemm accumulation -using ODataType = ck::half_t; -#else -using QDataType = ck::bhalf_t; -using KDataType = ck::bhalf_t; -using VDataType = ck::bhalf_t; -using SaccDataType = float; // data type for first gemm accumulation -using SMPLComputeDataType = float; // data type for reduction, softmax -using PDataType = ck::bhalf_t; // data type for A matrix of second gemm -using OaccDataType = float; // data type for second gemm accumulation -using ODataType = ck::bhalf_t; -#endif - -// M0 N0 K0 N1 K1 K0L -// using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>; -// using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>; -using VLayout = ck::tensor_layout::gemm::RowMajor; // (bs, nhead) seqlen * hdim -// using VLayout = ck::tensor_layout::gemm::ColumnMajor; // (bs, nhead) hdim * seqlen - -using FmhaBlockTileHdim64 = ck::Sequence<128, 64, 32, 64, 32, 64>; -using FmhaBlockTileHdim128 = ck::Sequence<128, 128, 32, 128, 32, 128>; -using FmhaBlockWarps = ck::Sequence<4, 1, 1>; -using FmhaWarpTile = ck::Sequence<32, 32, 16>; -using FmhaShapeHDim64 = ck::tile_program::TileFmhaShape; -using FmhaShapeHDim128 = ck::tile_program::TileFmhaShape; - -using FmhaTilePartitionerHDim64 = FmhaFwdTilePartitioner; -using FmhaTilePartitionerHDim128 = FmhaFwdTilePartitioner; -using FmhaPipelineProblemHDim64 = - ck::tile_program::block::BlockFmhaPipelineProblem; -using FmhaPipelineProblemHDim128 = - ck::tile_program::block::BlockFmhaPipelineProblem; -// using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS; -using FmhaPipelineHDim64 = - ck::tile_program::block::BlockFmhaPipelineQRKSVS; -using FmhaPipelineHDim128 = - ck::tile_program::block::BlockFmhaPipelineQRKSVS; - -using FmhaEpilogue = FmhaFwdEpilogue>; -using FmhaKernelHDim64 = FmhaFwdKernel; -using FmhaKernelHDim128 = - FmhaFwdKernel; - -template -float invoker_fmha_kernel(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - void* o_ptr, - ck::index_t batch, - ck::index_t nhead, - ck::index_t seqlen_q, - ck::index_t seqlen_k, - ck::index_t hdim_q, - ck::index_t hdim_v, - float scale, - bool i_perm, - bool o_perm) -{ - dim3 kGridSize = FmhaKernel::GridSize(batch, nhead, seqlen_q, hdim_v); - constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); - - constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD - constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize; - constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; - - constexpr bool is_v_rowmajor = - ck::is_same_v; - - // batch * nhead * seqlen * hdim or batch * seqlen * nhead * hdim - auto kargs = FmhaKernel::MakeKargs( - q_ptr, - k_ptr, - v_ptr, - o_ptr, - seqlen_q, // seqlen_q - seqlen_k, // seqlen_k - hdim_q, // hdim_q - hdim_v, // hdim_v - scale, - i_perm ? hdim_q : nhead * hdim_q, // stride_q - i_perm ? hdim_q : nhead * hdim_q, // stride_k - [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? hdim_v : nhead * hdim_v; - else - return i_perm ? seqlen_k : nhead * seqlen_k; - }(), // stride_v - o_perm ? hdim_v : nhead * hdim_v, // stride_o - i_perm ? seqlen_q * hdim_q : hdim_q, // nhead_stride_q - i_perm ? seqlen_k * hdim_q : hdim_q, // nhead_stride_k - [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * seqlen_k : seqlen_k; - }(), // nhead_stride_v - o_perm ? seqlen_q * hdim_v : hdim_v, // nhead_stride_o - nhead * seqlen_q * hdim_q, // batch_stride_q - nhead * seqlen_k * hdim_q, // batch_stride_k - nhead * hdim_v * seqlen_k, // batch_stride_v - nhead * seqlen_q * hdim_v); // batch_stride_o - - float ave_time = launch_kernel(StreamConfig{nullptr, true}, - FmhaKernel{}, - kGridSize, - kBlockSize, - 0, - kargs); // BatchStrideO - return ave_time; -} - -int main(int argc, char* argv[]) -{ - int do_validation = 1; - ck::index_t batch = 2; - ck::index_t nhead = 8; - ck::index_t seqlen_q = 3328; - ck::index_t seqlen_k = 4096; - ck::index_t hdim_q = 128; - ck::index_t hdim_v = 128; - - float scale = .0f; - - bool i_perm = true; // if true, will be batch * nhead * seqlen * hdim - bool o_perm = true; // if false, will be batch * seqlen * nhead * hdim - - if(argc >= 2) - do_validation = std::stoi(argv[1]); - - if(argc >= 8) - { - batch = std::stoi(argv[2]); - nhead = std::stoi(argv[3]); - seqlen_q = std::stoi(argv[4]); - seqlen_k = std::stoi(argv[5]); - hdim_q = std::stoi(argv[6]); - hdim_v = std::stoi(argv[7]); - } - if(argc >= 9) - scale = std::stof(argv[8]); - if(argc >= 10) - i_perm = static_cast(std::stoi(argv[9])); - if(argc >= 11) - o_perm = static_cast(std::stoi(argv[10])); - - if(scale == .0f) - scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); // TODO: q ? v ? - - auto get_lengths = [&](bool permute, - ck::index_t b /*batch*/, - ck::index_t h /*nhead*/, - ck::index_t s /*seqlen*/, - ck::index_t d /*hdim*/) { - if(permute) - return std::array{b, h, s, d}; - else - return std::array{b, s, h, d}; - }; - - constexpr bool is_v_rowmajor = - ck::is_same_v; - - // host verify - Tensor q_host(get_lengths(i_perm, batch, nhead, seqlen_q, hdim_q)); - Tensor k_host(get_lengths(i_perm, batch, nhead, seqlen_k, hdim_q)); - Tensor v_host(is_v_rowmajor ? get_lengths(i_perm, batch, nhead, seqlen_k, hdim_v) - : get_lengths(i_perm, batch, nhead, hdim_v, seqlen_k)); - Tensor o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v)); - -#if 0 - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); -#else - ck::utils::FillUniformDistribution{0.f, 1.f}(q_host); - ck::utils::FillUniformDistribution{0.f, 1.f}(k_host); - ck::utils::FillUniformDistribution{-.5f, .5f}(v_host); -#endif - - DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize()); - DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize()); - DeviceMem v_buf(sizeof(VDataType) * v_host.GetElementSpaceSize()); - DeviceMem o_buf(sizeof(ODataType) * o_host.GetElementSpaceSize()); - - q_buf.ToDevice(q_host.mData.data()); - k_buf.ToDevice(k_host.mData.data()); - v_buf.ToDevice(v_host.mData.data()); - - std::cout << "batch:" << batch << ", nhead:" << nhead << ", seqlen_q:" << seqlen_q - << ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v - << ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm - << ", v:" << std::string(FmhaKernelHDim64::VLayout::name) << std::flush << std::endl; - - float ave_time = 0; - if(hdim_q == hdim_v && hdim_q == 64) - ave_time = invoker_fmha_kernel(q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - batch, - nhead, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - scale, - i_perm, - o_perm); - else if(hdim_q == hdim_v && hdim_q == 128) - ave_time = invoker_fmha_kernel(q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - batch, - nhead, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - scale, - i_perm, - o_perm); - else - { - std::cout << "not support hdim, will not run" << std::endl; - return -1; - } - - std::size_t flop = std::size_t(2) * batch * nhead * seqlen_q * seqlen_k * hdim_q + - std::size_t(2) * batch * nhead * seqlen_q * hdim_v * seqlen_k; - - std::size_t num_btype = sizeof(QDataType) * batch * nhead * seqlen_q * hdim_q + - sizeof(KDataType) * batch * nhead * seqlen_k * hdim_q + - sizeof(VDataType) * batch * nhead * hdim_v * seqlen_k + - sizeof(ODataType) * batch * nhead * seqlen_q * hdim_v; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - if(do_validation) - { - Tensor q_host_ref({batch * nhead, seqlen_q, hdim_q}); - Tensor k_host_ref({batch * nhead, seqlen_k, hdim_q}); - const auto v_lengths = std::array{batch * nhead, hdim_v, seqlen_k}; - const auto v_strides = is_v_rowmajor - ? std::array{hdim_v * seqlen_k, 1, hdim_v} - : std::array{hdim_v * seqlen_k, seqlen_k, 1}; - Tensor v_host_ref(v_lengths, v_strides); - Tensor o_host_ref({batch * nhead, seqlen_q, hdim_v}); - Tensor o_host_result_ref(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v)); - - Tensor s_host_ref({batch * nhead, seqlen_q, seqlen_k}); - Tensor p_host_ref({batch * nhead, seqlen_q, seqlen_k}); - - // clang-format off - // permute - if(i_perm) q_host.ForEach([&](auto& self, auto idx) { q_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); }); - else q_host.ForEach([&](auto& self, auto idx) { q_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); }); - - if(i_perm) k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); }); - else k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); }); - - if constexpr (is_v_rowmajor) { - // v_host :b, h, s, d, v_host_ref : batch*hdim*seq - if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[3], idx[2]) = self(idx); }); - // v_host : b, s, h, d, v_host_ref : batch*hdim*seq - else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[3], idx[1]) = self(idx); }); - } - else { - if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); }); - else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); }); - } - - // reference - reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, - [](const QDataType& x) { return x; }, - [](const KDataType& x) { return x; }, - [&scale](const SaccDataType& x) { return scale * x; }); - reference_batched_softmax(s_host_ref, - p_host_ref); - reference_batched_gemm( - p_host_ref, v_host_ref, o_host_ref); - - // permute - if(o_perm) o_host_result_ref.ForEach([&](auto& self, auto idx) { self(idx) = o_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]); }); - else o_host_result_ref.ForEach([&](auto& self, auto idx) { self(idx) = o_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]); }); - // clang-format on - - o_buf.FromDevice(o_host.mData.data()); - return !ck::utils::check_err(o_host, o_host_result_ref); - } - else - { - return 0; - } -} diff --git a/example/91_tile_program/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha_fwd_kernel.hpp deleted file mode 100644 index b447de1db..000000000 --- a/example/91_tile_program/fmha_fwd_kernel.hpp +++ /dev/null @@ -1,214 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/common_header.hpp" -#include "ck/tensor/tensor_view.hpp" -#include "ck/tile_program/tile/tile_window.hpp" - -// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q] -// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) -// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] - -#define C_LOG2E 1.44269504088896340736 // log2(e) - -template -struct FmhaFwdKernel -{ - using TilePartitioner = ck::remove_cvref_t; - using FmhaPipeline = ck::remove_cvref_t; - using EpiloguePipeline = ck::remove_cvref_t; - static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize; - - using QDataType = ck::remove_cvref_t; - using KDataType = ck::remove_cvref_t; - using VDataType = ck::remove_cvref_t; - using ODataType = ck::remove_cvref_t; - - using VLayout = ck::remove_cvref_t; - - struct Kargs - { - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - void* o_ptr; - ck::index_t seqlen_q; - ck::index_t seqlen_k; - ck::index_t hdim_q; - ck::index_t hdim_v; - - float scale; - - ck::index_t stride_q; - ck::index_t stride_k; - ck::index_t stride_v; - ck::index_t stride_o; - - ck::index_t nhead_stride_q; - ck::index_t nhead_stride_k; - ck::index_t nhead_stride_v; - ck::index_t nhead_stride_o; - - ck::index_t batch_stride_q; - ck::index_t batch_stride_k; - ck::index_t batch_stride_v; - ck::index_t batch_stride_o; - }; - - __host__ static constexpr Kargs MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - void* o_ptr, - ck::index_t seqlen_q, - ck::index_t seqlen_k, - ck::index_t hdim_q, - ck::index_t hdim_v, - float scale, - ck::index_t stride_q, - ck::index_t stride_k, - ck::index_t stride_v, - ck::index_t stride_o, - ck::index_t nhead_stride_q, - ck::index_t nhead_stride_k, - ck::index_t nhead_stride_v, - ck::index_t nhead_stride_o, - ck::index_t batch_stride_q, - ck::index_t batch_stride_k, - ck::index_t batch_stride_v, - ck::index_t batch_stride_o) - { - return Kargs{q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, - seqlen_k, hdim_q, hdim_v, scale, stride_q, - stride_k, stride_v, stride_o, nhead_stride_q, nhead_stride_k, - nhead_stride_v, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, - batch_stride_o}; - } - - __host__ static constexpr auto GridSize(ck::index_t batch_size_, - ck::index_t nhead_, - ck::index_t seqlen_q_, - ck::index_t hdim_v_) - { - return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); - } - - __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } - - __host__ __device__ static constexpr ck::index_t GetSmemSize() - { - return ck::math::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); - } - - __device__ void operator()(Kargs kargs) const - { - using namespace ck; - using namespace ck::tile_program; - using namespace ck::tile_program::block; - - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; - - // divide problem - const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = - TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); - - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - - // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - i_nhead * kargs.nhead_stride_q + i_batch * kargs.batch_stride_q; - const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - i_nhead * kargs.nhead_stride_k + i_batch * kargs.batch_stride_k; - const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - i_nhead * kargs.nhead_stride_v + i_batch * kargs.batch_stride_v; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + - i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o; - - // Q/K/V DRAM and DRAM window - const auto q_dram = make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - Number<32>{}, - Number<1>{}); - - const auto k_dram = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - Number<32>{}, - Number<1>{}); - - const auto v_dram = [&]() { - if constexpr(ck::is_same_v) - { - const auto v_dram_tmp = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), - Number<32>{}, - Number<1>{}); - return transform_tensor_view( - v_dram_tmp, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - return make_naive_tensor_view( - v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen_k), - make_tuple(kargs.stride_v, 1), - Number<32>{}, - Number<1>{}); - } - }(); - - auto q_dram_window = make_tile_window( - q_dram, - [&]() { - if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(Number{}, - Number{}); - else - return make_tuple(Number{}, Number{}); - }(), - {i_m0, 0}); - - auto k_dram_window = make_tile_window( - k_dram, make_tuple(Number{}, Number{}), {0, 0}); - - auto v_dram_window = - make_tile_window(v_dram, - make_tuple(Number{}, Number{}), - {i_n1, 0}); - - auto o_acc_tile = FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - kargs.scale, - kargs.seqlen_k / FmhaPipeline::kN0, - kargs.hdim_q / FmhaPipeline::kK0, - smem_ptr); - - // O DRAM and O DRAM window - auto o_dram = make_naive_tensor_view( - o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), - Number<32>{}, - Number<1>{}); - - auto o_dram_window = - make_tile_window(o_dram, - make_tuple(Number{}, Number{}), - {i_m0, i_n1}); - - EpiloguePipeline{}(o_dram_window, o_acc_tile); - } -}; diff --git a/example/91_tile_program/gemm/CMakeLists.txt b/example/91_tile_program/gemm/CMakeLists.txt new file mode 100644 index 000000000..a6e8f1ef5 --- /dev/null +++ b/example/91_tile_program/gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm gemm.cpp) diff --git a/example/91_tile_program/gemm.cpp b/example/91_tile_program/gemm/gemm.cpp similarity index 99% rename from example/91_tile_program/gemm.cpp rename to example/91_tile_program/gemm/gemm.cpp index 67e8479ea..7e8ad59c1 100644 --- a/example/91_tile_program/gemm.cpp +++ b/example/91_tile_program/gemm/gemm.cpp @@ -13,7 +13,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_gemm.hpp" +#include "reference/reference_gemm.hpp" #include "gemm.hpp" // elementwise lambda diff --git a/example/91_tile_program/gemm.hpp b/example/91_tile_program/gemm/gemm.hpp similarity index 100% rename from example/91_tile_program/gemm.hpp rename to example/91_tile_program/gemm/gemm.hpp diff --git a/example/91_tile_program/gemm_gemm/CMakeLists.txt b/example/91_tile_program/gemm_gemm/CMakeLists.txt new file mode 100644 index 000000000..0034ade28 --- /dev/null +++ b/example/91_tile_program/gemm_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_gemm gemm_gemm.cpp) diff --git a/example/91_tile_program/gemm_gemm.cpp b/example/91_tile_program/gemm_gemm/gemm_gemm.cpp similarity index 99% rename from example/91_tile_program/gemm_gemm.cpp rename to example/91_tile_program/gemm_gemm/gemm_gemm.cpp index ccbea2369..e65eaf40a 100644 --- a/example/91_tile_program/gemm_gemm.cpp +++ b/example/91_tile_program/gemm_gemm/gemm_gemm.cpp @@ -13,7 +13,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_gemm.hpp" +#include "reference/reference_gemm.hpp" #include "gemm_gemm.hpp" int main(int argc, char* argv[]) diff --git a/example/91_tile_program/gemm_gemm.hpp b/example/91_tile_program/gemm_gemm/gemm_gemm.hpp similarity index 100% rename from example/91_tile_program/gemm_gemm.hpp rename to example/91_tile_program/gemm_gemm/gemm_gemm.hpp diff --git a/example/91_tile_program/gemm_softmax_gemm/CMakeLists.txt b/example/91_tile_program/gemm_softmax_gemm/CMakeLists.txt new file mode 100644 index 000000000..8ce4b41fd --- /dev/null +++ b/example/91_tile_program/gemm_softmax_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp) diff --git a/example/91_tile_program/gemm_softmax_gemm.cpp b/example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.cpp similarity index 98% rename from example/91_tile_program/gemm_softmax_gemm.cpp rename to example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.cpp index b887b8ab9..4dddfaa75 100644 --- a/example/91_tile_program/gemm_softmax_gemm.cpp +++ b/example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.cpp @@ -13,8 +13,8 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_gemm.hpp" -#include "reference_softmax.hpp" +#include "reference/reference_gemm.hpp" +#include "reference/reference_softmax.hpp" #include "gemm_softmax_gemm.hpp" int main(int argc, char* argv[]) diff --git a/example/91_tile_program/gemm_softmax_gemm.hpp b/example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.hpp similarity index 100% rename from example/91_tile_program/gemm_softmax_gemm.hpp rename to example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.hpp diff --git a/example/91_tile_program/gemm_softmax_gemm_impl.hpp b/example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm_impl.hpp similarity index 100% rename from example/91_tile_program/gemm_softmax_gemm_impl.hpp rename to example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm_impl.hpp diff --git a/example/91_tile_program/im2col/CMakeLists.txt b/example/91_tile_program/im2col/CMakeLists.txt new file mode 100644 index 000000000..7a72732bc --- /dev/null +++ b/example/91_tile_program/im2col/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_im2col im2col.cpp) diff --git a/example/91_tile_program/im2col.cpp b/example/91_tile_program/im2col/im2col.cpp similarity index 88% rename from example/91_tile_program/im2col.cpp rename to example/91_tile_program/im2col/im2col.cpp index 83a8ba55f..d0744cfef 100644 --- a/example/91_tile_program/im2col.cpp +++ b/example/91_tile_program/im2col/im2col.cpp @@ -24,55 +24,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -template -void reference_im2col(Tensor& in_mtx_host_ref, - const Tensor& in_host, - int /*N*/, - int /*K*/, - int C, - int /*Y*/, - int X, - int Hi, - int Wi, - int Ho, - int Wo, - int ConvStrideH, - int ConvStrideW, - int ConvDilationH, - int ConvDilationW, - int InLeftPadH, - int InLeftPadW, - int /*InRightPadH*/, - int /*InRightPadW*/) -{ - int GemmM = in_mtx_host_ref.GetLengths()[0]; - int GemmK = in_mtx_host_ref.GetLengths()[1]; - - for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m) - { - int mtmp = gemm_m; - int n = mtmp / (Ho * Wo); - mtmp -= n * Ho * Wo; - int ho = mtmp / Wo; - int wo = mtmp - ho * Wo; - - for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k) - { - int ktmp = gemm_k; - int y = ktmp / (X * C); - ktmp -= y * X * C; - int x = ktmp / C; - int c = ktmp - x * C; - - int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH; - int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW; - - bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi); - - in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0; - } - } -} +#include "reference/reference_im2col.hpp" template -void reference_reduce(const Tensor& a_m_n, Tensor& b_m) -{ - auto f = [&](auto m) { - const int N = a_m_n.mDesc.GetLengths()[1]; - - AccDataType v_acc = 0; - - for(int n = 0; n < N; ++n) - { - const ADataType v_a = a_m_n(m, n); - - v_acc += v_a; - } - - b_m(m) = ck::type_convert(v_acc); - }; - - make_ParallelTensorFunctor(f, b_m.mDesc.GetLengths()[0])(std::thread::hardware_concurrency()); -} - int main(int argc, char* argv[]) { using ADataType = ck::half_t; diff --git a/example/91_tile_program/reduce.hpp b/example/91_tile_program/reduce/reduce.hpp similarity index 100% rename from example/91_tile_program/reduce.hpp rename to example/91_tile_program/reduce/reduce.hpp diff --git a/example/91_tile_program/reference/reference_batched_elementwise.hpp b/example/91_tile_program/reference/reference_batched_elementwise.hpp new file mode 100644 index 000000000..cf5beec2d --- /dev/null +++ b/example/91_tile_program/reference/reference_batched_elementwise.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template > +void reference_batched_elementwise(const Tensor& a_b_m_n, + const Tensor& b_b_m_n, + Tensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const BinaryElementOp& binary_element_op = {}) +{ + const ck::index_t N = c_b_m_n.mDesc.GetLengths()[2]; + + const bool broadcast_a_dim_b = (a_b_m_n.GetLengths()[0] == 1); + const bool broadcast_a_dim_m = (a_b_m_n.GetLengths()[1] == 1); + const bool broadcast_a_dim_n = (a_b_m_n.GetLengths()[2] == 1); + + const bool broadcast_b_dim_b = (b_b_m_n.GetLengths()[0] == 1); + const bool broadcast_b_dim_m = (b_b_m_n.GetLengths()[1] == 1); + const bool broadcast_b_dim_n = (b_b_m_n.GetLengths()[2] == 1); + + auto f = [&](auto batch, auto m) { + for(ck::index_t n = 0; n < N; ++n) + { + AccDataType v_a{}; + { + ck::index_t i_b = (broadcast_a_dim_b ? 0 : batch); + ck::index_t i_m = (broadcast_a_dim_m ? 0 : m); + ck::index_t i_n = (broadcast_a_dim_n ? 0 : n); + + v_a = ck::type_convert(a_element_op(a_b_m_n(i_b, i_m, i_n))); + } + + AccDataType v_b{}; + { + ck::index_t i_b = (broadcast_b_dim_b ? 0 : batch); + ck::index_t i_m = (broadcast_b_dim_m ? 0 : m); + ck::index_t i_n = (broadcast_b_dim_n ? 0 : n); + + v_b = ck::type_convert(b_element_op(b_b_m_n(i_b, i_m, i_n))); + } + + c_b_m_n(batch, m, n) = ck::type_convert(binary_element_op(v_a, v_b)); + } + }; + + make_ParallelTensorFunctor(f, c_b_m_n.mDesc.GetLengths()[0], c_b_m_n.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); +} diff --git a/example/91_tile_program/reference_batched_gemm.hpp b/example/91_tile_program/reference/reference_batched_gemm.hpp similarity index 61% rename from example/91_tile_program/reference_batched_gemm.hpp rename to example/91_tile_program/reference/reference_batched_gemm.hpp index a29af3e30..f4e03fcc9 100644 --- a/example/91_tile_program/reference_batched_gemm.hpp +++ b/example/91_tile_program/reference/reference_batched_gemm.hpp @@ -10,15 +10,15 @@ template + typename AElementOp = ck::identity, + typename BElementOp = ck::identity, + typename ACCElementOp = ck::identity> void reference_batched_gemm(const Tensor& a_b_m_k, const Tensor& b_b_n_k, Tensor& c_b_m_n, - const AElementOp& a_element_op, - const BElementOp& b_element_op, - const ACCElementOp& acc_element_op) + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const int N = b_b_n_k.mDesc.GetLengths()[1]; const int K = b_b_n_k.mDesc.GetLengths()[2]; @@ -43,17 +43,3 @@ void reference_batched_gemm(const Tensor& a_b_m_k, make_ParallelTensorFunctor(f, c_b_m_n.mDesc.GetLengths()[0], c_b_m_n.mDesc.GetLengths()[1])( std::thread::hardware_concurrency()); } - -template -void reference_batched_gemm(const Tensor& a_b_m_k, - const Tensor& b_b_n_k, - Tensor& c_b_m_n) -{ - reference_batched_gemm( - a_b_m_k, - b_b_n_k, - c_b_m_n, - [](const ADataType& x) { return x; }, - [](const BDataType& x) { return x; }, - [](const AccDataType& x) { return x; }); -} diff --git a/example/91_tile_program/reference/reference_batched_masking.hpp b/example/91_tile_program/reference/reference_batched_masking.hpp new file mode 100644 index 000000000..5fc54457d --- /dev/null +++ b/example/91_tile_program/reference/reference_batched_masking.hpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/tile_program/block_tile/block_masking.hpp" + +template +void reference_batched_masking(Tensor& c_b_m_n, const MaskingType& mask) +{ + const int M = c_b_m_n.mDesc.GetLengths()[1]; + const int N = c_b_m_n.mDesc.GetLengths()[2]; + + auto f = [&](auto batch) { + for(int n = 0; n < N; ++n) + { + for(int m = 0; m < M; ++m) + { + if(mask.IsOutOfBound(m, n)) + c_b_m_n(batch, m, n) = -ck::NumericLimits::Infinity(); + } + } + }; + + make_ParallelTensorFunctor(f, + c_b_m_n.mDesc.GetLengths()[0])(std::thread::hardware_concurrency()); +} diff --git a/example/91_tile_program/reference_batched_softmax.hpp b/example/91_tile_program/reference/reference_batched_softmax.hpp similarity index 60% rename from example/91_tile_program/reference_batched_softmax.hpp rename to example/91_tile_program/reference/reference_batched_softmax.hpp index a9fa3f103..ae6c861a4 100644 --- a/example/91_tile_program/reference_batched_softmax.hpp +++ b/example/91_tile_program/reference/reference_batched_softmax.hpp @@ -3,16 +3,23 @@ #pragma once +#include +#include +#include + #include "ck/utility/common_header.hpp" #include "ck/library/utility/host_tensor.hpp" template -void reference_batched_softmax(const Tensor& a_b_m_n, Tensor& b_b_m_n) +void reference_batched_softmax( + const Tensor& a_b_m_n, + Tensor& b_b_m_n, + std::optional>> lse_b_m = std::nullopt) { const int N = a_b_m_n.mDesc.GetLengths()[2]; auto f = [&](auto batch, auto m) { - CompDataType v_max = ck::NumericLimits::Lowest(); + CompDataType v_max = -ck::NumericLimits::Infinity(); // max for(int n = 0; n < N; ++n) @@ -23,6 +30,11 @@ void reference_batched_softmax(const Tensor& a_b_m_n, Tensor(0.f); + } // sum for(int n = 0; n < N; ++n) @@ -32,13 +44,21 @@ void reference_batched_softmax(const Tensor& a_b_m_n, Tensor(a_b_m_n(batch, m, n)); b_b_m_n(batch, m, n) = - ck::type_convert(ck::math::exp(v_a - v_max) / v_exp_sum); + ck::type_convert(ck::math::exp(v_a - v_max) * inv_sum); + } + // lse + if(lse_b_m) + { + lse_b_m->get()(batch, m) = v_max + ck::math::log(v_exp_sum); } }; diff --git a/example/91_tile_program/reference_gemm.hpp b/example/91_tile_program/reference/reference_gemm.hpp similarity index 53% rename from example/91_tile_program/reference_gemm.hpp rename to example/91_tile_program/reference/reference_gemm.hpp index a558e5719..1972214b9 100644 --- a/example/91_tile_program/reference_gemm.hpp +++ b/example/91_tile_program/reference/reference_gemm.hpp @@ -6,10 +6,19 @@ #include "ck/utility/common_header.hpp" #include "ck/library/utility/host_tensor.hpp" -template +template void reference_gemm(const Tensor& a_m_k, const Tensor& b_n_k, - Tensor& c_m_n) + Tensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const int N = b_n_k.mDesc.GetLengths()[0]; const int K = b_n_k.mDesc.GetLengths()[1]; @@ -21,13 +30,13 @@ void reference_gemm(const Tensor& a_m_k, for(int k = 0; k < K; ++k) { - ADataType v_a = a_m_k(m, k); - BDataType v_b = b_n_k(n, k); + ADataType v_a = a_element_op(a_m_k(m, k)); + BDataType v_b = b_element_op(b_n_k(n, k)); v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); } - c_m_n(m, n) = ck::type_convert(v_acc); + c_m_n(m, n) = ck::type_convert(acc_element_op(v_acc)); } }; diff --git a/example/91_tile_program/reference/reference_im2col.hpp b/example/91_tile_program/reference/reference_im2col.hpp new file mode 100644 index 000000000..44ecab29f --- /dev/null +++ b/example/91_tile_program/reference/reference_im2col.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template +void reference_im2col(Tensor& in_mtx_host_ref, + const Tensor& in_host, + int /*N*/, + int /*K*/, + int C, + int /*Y*/, + int X, + int Hi, + int Wi, + int Ho, + int Wo, + int ConvStrideH, + int ConvStrideW, + int ConvDilationH, + int ConvDilationW, + int InLeftPadH, + int InLeftPadW, + int /*InRightPadH*/, + int /*InRightPadW*/) +{ + int GemmM = in_mtx_host_ref.GetLengths()[0]; + int GemmK = in_mtx_host_ref.GetLengths()[1]; + + for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m) + { + int mtmp = gemm_m; + int n = mtmp / (Ho * Wo); + mtmp -= n * Ho * Wo; + int ho = mtmp / Wo; + int wo = mtmp - ho * Wo; + + for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k) + { + int ktmp = gemm_k; + int y = ktmp / (X * C); + ktmp -= y * X * C; + int x = ktmp / C; + int c = ktmp - x * C; + + int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH; + int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW; + + bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi); + + in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0; + } + } +} diff --git a/example/91_tile_program/reference/reference_reduce.hpp b/example/91_tile_program/reference/reference_reduce.hpp new file mode 100644 index 000000000..a4e0941f3 --- /dev/null +++ b/example/91_tile_program/reference/reference_reduce.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template +void reference_reduce(const Tensor& a_m_n, Tensor& b_m) +{ + auto f = [&](auto m) { + const int N = a_m_n.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_m_n(m, n); + + v_acc += v_a; + } + + b_m(m) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f, b_m.mDesc.GetLengths()[0])(std::thread::hardware_concurrency()); +} diff --git a/example/91_tile_program/reference_softmax.hpp b/example/91_tile_program/reference/reference_softmax.hpp similarity index 100% rename from example/91_tile_program/reference_softmax.hpp rename to example/91_tile_program/reference/reference_softmax.hpp diff --git a/example/91_tile_program/softmax/CMakeLists.txt b/example/91_tile_program/softmax/CMakeLists.txt new file mode 100644 index 000000000..da580fbff --- /dev/null +++ b/example/91_tile_program/softmax/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_softmax softmax.cpp) diff --git a/example/91_tile_program/softmax.cpp b/example/91_tile_program/softmax/softmax.cpp similarity index 98% rename from example/91_tile_program/softmax.cpp rename to example/91_tile_program/softmax/softmax.cpp index f78d609f2..93d1279d8 100644 --- a/example/91_tile_program/softmax.cpp +++ b/example/91_tile_program/softmax/softmax.cpp @@ -13,7 +13,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_softmax.hpp" +#include "reference/reference_softmax.hpp" #include "softmax.hpp" int main(int argc, char* argv[]) diff --git a/example/91_tile_program/softmax.hpp b/example/91_tile_program/softmax/softmax.hpp similarity index 100% rename from example/91_tile_program/softmax.hpp rename to example/91_tile_program/softmax/softmax.hpp diff --git a/include/ck/config.h b/include/ck/config.h new file mode 100644 index 000000000..dbf4a9597 --- /dev/null +++ b/include/ck/config.h @@ -0,0 +1,109 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_CONFIG_H_IN +#define CK_CONFIG_H_IN + +// clang-format off +// +// DataType supports in the current CK build +// +#ifndef DTYPES +/* #undef DTYPES */ +#endif +// if DTYPES is not defined, enable all datatypes in headerfiles +#ifndef CK_ENABLE_ALL_DTYPES +#define CK_ENABLE_ALL_DTYPES ON +#if defined(CK_ENABLE_ALL_DTYPES) +#ifndef CK_ENABLE_INT8 +#define CK_ENABLE_INT8 "ON" +#endif +#ifndef CK_ENABLE_FP8 +#define CK_ENABLE_FP8 "ON" +#endif +#ifndef CK_ENABLE_BF8 +#define CK_ENABLE_BF8 "ON" +#endif +#ifndef CK_ENABLE_FP16 +#define CK_ENABLE_FP16 "ON" +#endif +#ifndef CK_ENABLE_BF16 +#define CK_ENABLE_BF16 "ON" +#endif +#ifndef CK_ENABLE_FP32 +#define CK_ENABLE_FP32 "ON" +#endif +#ifndef CK_ENABLE_FP64 +#define CK_ENABLE_FP64 "ON" +#endif +#endif +#endif +// if DTYPES are selectively enabled +#ifndef CK_ENABLE_INT8 +/* #undef CK_ENABLE_INT8 */ +#endif + +#ifndef CK_ENABLE_FP8 +/* #undef CK_ENABLE_FP8 */ +#endif + +#ifndef CK_ENABLE_BF8 +/* #undef CK_ENABLE_BF8 */ +#endif + +#ifndef CK_ENABLE_FP16 +/* #undef CK_ENABLE_FP16 */ +#endif + +#ifndef CK_ENABLE_BF16 +/* #undef CK_ENABLE_BF16 */ +#endif + +#ifndef CK_ENABLE_FP32 +/* #undef CK_ENABLE_FP32 */ +#endif + +#ifndef CK_ENABLE_FP64 +/* #undef CK_ENABLE_FP64 */ +#endif + +// +// Legacy DL kernel supports in the current CK build +// by default DL kernels are turned OFF +// +#ifndef CK_ENABLE_DL_KERNELS +/* #undef CK_ENABLE_DL_KERNELS */ +#endif + +// +// Instances supports in the current CK build +// +#ifndef CK_ENABLE_INSTANCES_ONLY +/* #undef CK_ENABLE_INSTANCES_ONLY */ +#endif + +// clang-format on + +#endif // CK_CONFIG_H_IN diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index 3e44faecb..ca1f91ad8 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -4,6 +4,8 @@ #pragma once #include +#include + #include // To be removed, which really does not tell the location of failed HIP functional call diff --git a/include/ck/host_utility/io.hpp b/include/ck/host_utility/io.hpp index 55734bab2..7578537be 100644 --- a/include/ck/host_utility/io.hpp +++ b/include/ck/host_utility/io.hpp @@ -13,15 +13,33 @@ template std::ostream& operator<<(std::ostream& os, const std::vector& v) { - std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); - return os; + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; } template std::ostream& operator<<(std::ostream& os, const std::array& v) { - std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); - return os; + os << "["; + for(std::size_t idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; } template diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index e0d6e32ec..a58eca557 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -54,11 +54,11 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #endif hipEvent_t start, stop; - hip_check_error(hipEventCreate(&start)); - hip_check_error(hipEventCreate(&stop)); + HIP_CHECK_ERROR(hipEventCreate(&start)); + HIP_CHECK_ERROR(hipEventCreate(&stop)); - hip_check_error(hipDeviceSynchronize()); - hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipEventRecord(start, stream_config.stream_id_)); for(int i = 0; i < nrepeat; ++i) { @@ -66,12 +66,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config, hip_check_error(hipGetLastError()); } - hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); - hip_check_error(hipEventSynchronize(stop)); + HIP_CHECK_ERROR(hipEventRecord(stop, stream_config.stream_id_)); + HIP_CHECK_ERROR(hipEventSynchronize(stop)); float total_time = 0; - hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); return total_time / nrepeat; } @@ -125,11 +125,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #endif hipEvent_t start, stop; - hip_check_error(hipEventCreate(&start)); - hip_check_error(hipEventCreate(&stop)); + HIP_CHECK_ERROR(hipEventCreate(&start)); + HIP_CHECK_ERROR(hipEventCreate(&stop)); - hip_check_error(hipDeviceSynchronize()); - hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipEventRecord(start, stream_config.stream_id_)); for(int i = 0; i < nrepeat; ++i) { @@ -138,12 +138,12 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, hip_check_error(hipGetLastError()); } - hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); - hip_check_error(hipEventSynchronize(stop)); + HIP_CHECK_ERROR(hipEventRecord(stop, stream_config.stream_id_)); + HIP_CHECK_ERROR(hipEventSynchronize(stop)); float total_time = 0; - hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); return total_time / nrepeat; } diff --git a/include/ck/tensor/tensor_view.hpp b/include/ck/tensor/tensor_view.hpp index 0ecfcfa0a..03a8fcadf 100644 --- a/include/ck/tensor/tensor_view.hpp +++ b/include/ck/tensor/tensor_view.hpp @@ -53,15 +53,39 @@ struct TensorView // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::type, typename scalar_type>::type>, bool>::type = false> __host__ __device__ constexpr remove_cvref_t - GetVectorizedElements(const TensorCoord& coord) const + GetVectorizedElements(const TensorCoord& coord, bool_constant = {}) const { return buf_.template Get( coord.GetOffset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); + } + + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::type, + typename scalar_type>::type>, + bool>::type = false> + __host__ __device__ void GetVectorizedElementsRaw(remove_cvref_t& dst, + const TensorCoord& coord) const + { + return buf_.template GetRaw(dst, coord.GetOffset()); + } + + template >::type, + typename scalar_type>::type>, + bool>::type = false> + __host__ __device__ constexpr void AsyncGetVectorizedElements(remove_cvref_t* smem, + const TensorCoord& coord) const + { + return buf_.template AsyncGet(smem, coord.GetOffset(), true /*not used*/); } // X is vector of DataType. @@ -98,6 +122,11 @@ struct TensorView TensorDesc desc_; }; +// placeholder type if we want to opt-out a tile view parameter +struct NullTensorView +{ +}; + template @@ -168,4 +197,47 @@ __host__ __device__ constexpr auto transform_tensor_view(const OldTensorView& ol old_tensor_view.buf_, new_desc}; } +template + typename DoPads> // Sequence +__host__ __device__ constexpr auto +pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads) +{ + constexpr index_t num_dim = DoPads::Size(); + + static_assert(num_dim == TileLengths::Size() && num_dim == TensorView::GetNumOfDimension(), + "wrong! inconsistent # of dimensions"); + + // transforms + const auto transforms = generate_tuple( + [&](auto idim) { + const auto old_length = tensor_view.GetTensorDescriptor().GetLength(idim); + + const auto tile_length = tile_lengths[idim]; + + const auto new_length = + math::integer_divide_ceil(old_length, tile_length) * tile_length; + + const auto pad_length = new_length - old_length; + + constexpr bool DoPad = DoPads::At(idim); + + const auto transform = + conditional_expr(make_right_pad_transform(old_length, pad_length), + make_pass_through_transform(old_length)); + + return transform; + }, + Number{}); + + // lower dimension Id + const auto lower_dimss = + generate_tuple([&](auto idim) { return Sequence{}; }, Number{}); + + // upper dimension Id + const auto upper_dimss = lower_dimss; + + return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss); +} + } // namespace ck diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 5c166b9b6..445004dc6 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -18,6 +18,7 @@ enum struct IndexTransformEnum UnMerge, Replicate, Xor, + Offset, }; template @@ -1401,4 +1402,88 @@ struct Xor : public BaseTransform<2, 2> } }; +template +struct Offset : public BaseTransform<1, 1> +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + OffsetLength offset_length_; + + __host__ __device__ constexpr Offset() = default; + + __host__ __device__ constexpr Offset(const LowLength& low_length, + const OffsetLength& offset_length) + : up_lengths_{make_tuple(low_length)}, offset_length_{offset_length} + { + } + + __host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::Offset; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] + offset_length_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Offset{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("offset_length_: "); + print(offset_length_); + + printf("}"); + } +}; + } // namespace ck diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index 649a36a04..7ce003670 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -119,4 +119,11 @@ __host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_leng return Xor{low_lengths, right_shift}; } +template +__host__ __device__ constexpr auto make_offset_transform(const LowLength& low_length, + const OffsetLength& offset_length) +{ + return Offset{low_length, offset_length}; +} + } // namespace ck diff --git a/include/ck/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp index c49cc91c0..54f9f80b4 100644 --- a/include/ck/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -83,6 +83,59 @@ make_naive_tensor_descriptor(const Tuple& lengths, GuaranteedVectorStrides>{transforms, element_space_size}; } +// tensor descriptor with offset, the offset will not be added into element space size +// only have an information of the starting offset, and will impact on offset calculation +template ::type = false> +__host__ __device__ constexpr auto +make_naive_tensor_descriptor_with_offset(const Tuple& lengths, + const Tuple& strides, + const Offset& offset, + Number = Number<-1>{}, + Number = Number<-1>{}) +{ + const auto desc_0 = [&]() { + const auto element_space_size = detail::calculate_element_space_size_impl( + lengths, strides, Number<0>{}, LongNumber<1>{}); + + const auto transforms = make_tuple(make_offset_transform(element_space_size, offset)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = make_tuple(Sequence<1>{}); + + constexpr auto visible_dim_hidden_ids = Sequence<1>{}; + + using GuaranteedVectorLengths = + typename sequence_merge::type, + Sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, + Sequence>::type; + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; + }(); + + constexpr index_t N = sizeof...(Lengths); + + return transform_tensor_descriptor( + desc_0, + make_tuple(make_embed_transform(lengths, strides)), + make_tuple(Sequence<0>{}), + make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{})); +} + // Lengths... could be: // 1) index_t, which is known at run-time, or // 2) Number<>, which is known at compile-time @@ -123,6 +176,53 @@ make_naive_tensor_descriptor_packed(const Tuple& lengths, GuaranteedVectorStrides>{transforms, element_space_size}; } +template ::type = false> +__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed_with_offset( + const Tuple& lengths, + const Offset& offset, + Number = Number<-1>{}) +{ + const auto desc_0 = [&]() { + const auto element_space_size = + container_reduce(lengths, math::multiplies{}, LongNumber<1>{}); + + const auto transforms = make_tuple(make_offset_transform(element_space_size, offset)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = make_tuple(Sequence<1>{}); + + constexpr auto visible_dim_hidden_ids = Sequence<1>{}; + + using GuaranteedVectorLengths = + typename sequence_merge::type, + Sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, Sequence<1>>::type; + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; + }(); + + constexpr index_t N = sizeof...(Lengths); + + return transform_tensor_descriptor( + desc_0, + make_tuple(make_unmerge_transform(lengths)), + make_tuple(Sequence<0>{}), + make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{})); +} + // Lengths... could be: // 1) index_t, which is known at run-time, or // 2) Number<>, which is known at compile-time diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index 8fc0b03a5..d4e35ee82 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -20,6 +20,10 @@ template // # of scalars per access in each dimension struct SpaceFillingCurve { + static constexpr index_t TensorSize = + reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}); + static_assert(0 < TensorSize, "SpaceFillingCurve should be used to access a non-empty tensor"); + static constexpr index_t nDim = TensorLengths::Size(); using Index = MultiIndex; diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp new file mode 100644 index 000000000..ad6193cc9 --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV2 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + __device__ void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.GetLengths()[Number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.GetLengths()[Number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.GetThreadBuffer() = a_block_tensor_tmp.GetThreadBuffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + b_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using Array will cause register spill + Array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + StaticallyIndexedArray, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // check C-block-distribution + static_assert(is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths()); + constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetYSlicedThreadData( + merge_sequences(Sequence{}, a_warp_y_index_zeros), + merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetYSlicedThreadData( + merge_sequences(Sequence{}, c_warp_y_index_zeros), + merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.SetYSlicedThreadData( + merge_sequences(Sequence{}, c_warp_y_index_zeros), + merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.GetThreadBuffer()); + }); + }); + }); + } + + __device__ constexpr auto MakeCBlockTile() const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + __device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp new file mode 100644 index 000000000..842b0ce38 --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +template +struct BlockGemmARegBSmemCRegV2CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::At(Number<0>{}); + static constexpr index_t kNWarps = BlockWarps::At(Number<1>{}); + static constexpr index_t kKWarps = BlockWarps::At(Number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() + { + using namespace ck::tile_program::warp; + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_default_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_default_policy.hpp new file mode 100644 index 000000000..f7306e67a --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_default_policy.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// Default policy for BlockGemmARegBSmemCRegV2 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmARegBSmemCRegV2DefaultPolicy +{ + template + __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() + { + using namespace ck::tile_program::warp; + +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + // FIXME + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#endif + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp index ff3c44db7..c87023485 100644 --- a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp @@ -177,22 +177,10 @@ struct BlockGemmASmemBSmemCRegV1 }); } - // C = A * B - template - __device__ auto operator()(const ABlockWindowTmp& a_block_window_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + __device__ constexpr auto MakeCBlockTile() const { - static_assert(is_same_v && - is_same_v, - "wrong!"); - - constexpr index_t MPerBlock = ABlockWindowTmp{}.GetWindowLengths()[Number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}]; - constexpr index_t KPerBlock = ABlockWindowTmp{}.GetWindowLengths()[Number<1>{}]; - - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -203,86 +191,7 @@ struct BlockGemmASmemBSmemCRegV1 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - - constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; - constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() % NWarp; - - // construct A-warp-window - auto a_warp_window_tmp = make_tile_window( - a_block_window_tmp.GetBottomTensorView(), - make_tuple(Number{}, Number{}), - a_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - -#if 0 // FIXME: using Array will cause register spill - Array, MIterPerWarp> a_warp_windows{ - {a_warp_window_tmp}}; - - for(index_t mIter = 0; mIter < MIterPerWarp; mIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - StaticallyIndexedArray, - MIterPerWarp> - a_warp_windows; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window_tmp.GetBottomTensorView(), - make_tuple(Number{}, Number{}), - b_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iNWarp * WG::kN, 0}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - -#if 0 // FIXME: using Array will cause register spill - Array, NIterPerWarp> b_warp_windows{ - {b_warp_window_tmp}}; - - for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - StaticallyIndexedArray, - NIterPerWarp> - b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - static_assert(is_same_v, "wrong!"); - - // Construct C-Block-Tensor constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding< Sequence<>, Tuple, Sequence>, @@ -297,51 +206,16 @@ struct BlockGemmASmemBSmemCRegV1 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } - using CWarpDstr = typename WG::CWarpDstr; - using CWarpTensor = typename WG::CWarpTensor; - - constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - // warp GEMM - if constexpr(KIterPerWarp == 0) - { - // c = a * b - c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor); - } - else - { - // c += a * b - c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetYSlicedThreadData( - merge_sequences(Sequence{}, c_warp_y_index_zeros), - merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths)); - - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } - - // write C warp tensor into C block tensor - c_block_tensor.SetYSlicedThreadData( - merge_sequences(Sequence{}, c_warp_y_index_zeros), - merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.GetThreadBuffer()); - }); - }); - }); - + // C = A * B + template + __device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); return c_block_tensor; } }; diff --git a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index d22d702cb..dfe545bd3 100644 --- a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -23,34 +23,26 @@ template + typename WarpGemm_> struct BlockGemmASmemBSmemCRegV1CustomPolicy { using AType = remove_cvref_t; using BType = remove_cvref_t; using CType = remove_cvref_t; - using BlockWarps = remove_cvref_t; - using WarpTile = remove_cvref_t; - static constexpr index_t BlockMWarps = BlockWarps::At(Number<0>{}); - static constexpr index_t BlockNWarps = BlockWarps::At(Number<1>{}); - static constexpr index_t BlockKWarps = BlockWarps::At(Number<2>{}); + using BlockWarps = remove_cvref_t; - static constexpr index_t MPerWarp = WarpTile::At(Number<0>{}); - static constexpr index_t NPerWarp = WarpTile::At(Number<1>{}); - static constexpr index_t KPerWarp = WarpTile::At(Number<2>{}); + static constexpr index_t kMWarps = BlockWarps::At(Number<0>{}); + static constexpr index_t kNWarps = BlockWarps::At(Number<1>{}); + static constexpr index_t kKWarps = BlockWarps::At(Number<2>{}); - static constexpr bool TranposeC = TranposeC_; - - using WarpGemm = ck::tile_program::warp:: - WarpGemmMfmaDispatcher; + using WarpGemm = remove_cvref_t; template __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() { using namespace ck::tile_program::warp; - return make_tuple(WarpGemm{}, BlockMWarps, BlockNWarps); + return make_tuple(WarpGemm{}, kMWarps, kNWarps); } }; diff --git a/include/ck/tile_program/block_tile/block_masking.hpp b/include/ck/tile_program/block_tile/block_masking.hpp new file mode 100644 index 000000000..1e01310d8 --- /dev/null +++ b/include/ck/tile_program/block_tile/block_masking.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// clang-format off +/* Generic Attention Mask Coordinate + use x(horizontal axis), y(vertical axis) to describe mask. + top-left corner is origin + + x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask) + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + 1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + l=7,-1/r=0(tl) l=7,-1/r=0(br) + + x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2 + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 + * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1 + l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl) + l=4/r=0(br) l=4/r=2(br) l=4/r=4(br) + + x=4/y=-1 x=6/y=-1 x=8/y=-1 + * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 + * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 + * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1 + * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1 + + x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r) + * * * * * * * * 1 * * * * * * * + * * * * * * * * 1 1 * * 1 * * * + * * * * * * * * 1 1 1 * 1 1 * * + 1 * * * * * * * 1 1 1 1 1 1 1 * + 1 1 * * * * * * 1 1 1 1 1 1 1 1 + + Validations: + x + y > 1 (x + y >= 2) + + Note: + y = seq_q, x = 1 -> top-left + y = seq_q, x = seq_k - seq_q + 1 -> bottom-right + y < seq_q, x < seq_k -> local-attn + y = seq_q, x = seq_k -> no mask + +*/ +// clang-format on +template +struct GenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask, + // else only upper-right could have mask + + __host__ __device__ GenericAttentionMask(index_t y_total_, index_t x_total_) + : GenericAttentionMask(0, 0, y_total_, x_total_) + { + } + + __host__ __device__ + GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + { + } + template + __host__ __device__ GenericAttentionMask(const MaskCoordinates& mask_coord) + : y(mask_coord.At(Number<0>{})), + x(mask_coord.At(Number<1>{})), + y_total(mask_coord.At(Number<2>{})), + x_total(mask_coord.At(Number<3>{})) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + __host__ __device__ constexpr auto + GetTileRangeAlongX(index_t i_y, Number, Number) const + { + if constexpr(!IsMasking) + { + return ck::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + if constexpr(IsLocal) + { + index_t tmp = math::max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + } + else + { + return 0; + } + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = math::min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck::make_tuple(x_start, x_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + __host__ __device__ constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + return i_x >= x_total; + } + else + { + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + i_y + 1; + index_t x_end = math::min(i_y + x, x_total); + + if constexpr(IsLocal) + { + return i_x < x_start || i_x >= x_end; + } + else + { + return i_x >= x_end; + } + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // can be used as a fast-path to decide if do per-pixel check or not + template + __host__ __device__ constexpr auto + IsEdgeTile(index_t i_tile_top, index_t i_tile_left, Number, Number) const + { + if constexpr(IsLocal) + { + // check top-right corner > x or left-borrom corner < x + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = i_tile_top + TileHeight; + index_t x_end = math::min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > (i_tile_top + x); + bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); + bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + } + else + { + // only need to check top-right corner > x + index_t i_tile_right = i_tile_left + TileWidth; + index_t x_end = math::min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > x_end; + return top_right_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + +} // namespace block +} // namespace tile_program + +// TODO: prefer use this function in host code +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +__host__ __device__ constexpr auto +make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + index_t x = 0, y = 0; + + if(is_top_left) + { + if(left_size < 0) + left_size = y_total - 1; + if(right_size < 0) + right_size = x_total - 1; + + x = 1 + right_size; + y = left_size + 1; + } + else + { + if(left_size < 0) + left_size = x_total - 1; + if(right_size < 0) + right_size = y_total - 1; + + x = x_total - y_total + 1 + right_size; + y = y_total - x_total + 1 + left_size; + } + + return ck::make_tuple(y, x, y_total, x_total); +} +} // namespace ck diff --git a/include/ck/tile_program/block_tile/block_reduce.hpp b/include/ck/tile_program/block_tile/block_reduce.hpp index 1cba690c7..08a8f8a42 100644 --- a/include/ck/tile_program/block_tile/block_reduce.hpp +++ b/include/ck/tile_program/block_tile/block_reduce.hpp @@ -14,9 +14,10 @@ namespace tile_program { namespace block { // synchronize reduce result (cross lane reduction and broadcast on replicated dimension) -template +template __device__ void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, - const ReduceFunc& reduce_func) + const ReduceFunc& reduce_func, + bool_constant = {}) { using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using DstrEncode = typename Dstr::DstrEncode; @@ -67,40 +68,43 @@ __device__ void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, } }); - // cross-lane broadcast for replication - // only broadcast on R dimension correspond to lane - // (lane id maps to this R dimension) - static_for<0, NDimR, 1>{}([&](auto idim_r) { - // FIXME: nasty to use does_p_own_r_ - if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) - { - const index_t r_id = rs_idx[idim_r]; + if constexpr(WithBroadcast) + { + // cross-lane broadcast for replication + // only broadcast on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + const index_t r_id = rs_idx[idim_r]; - constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; - constexpr index_t lid_over_rid_derivative = - DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; - static_assert(math::is_power_of_two_integer(r_length), - "wrong! only support power of 2 reduction"); + static_assert(math::is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); - constexpr index_t nstage = math::integer_log2_floor(r_length); + constexpr index_t nstage = math::integer_log2_floor(r_length); - // broadcast sweep backward - static_for<0, nstage, 1>{}([&](auto istage) { - // do I hold reduced data? - const bool do_i_hold_reduced_data = r_id < (1 << istage); + // broadcast sweep backward + static_for<0, nstage, 1>{}([&](auto istage) { + // do I hold reduced data? + const bool do_i_hold_reduced_data = r_id < (1 << istage); - constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); + constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); - // pull data from remote lane - const auto v_remote = warp_shuffle_up(v_local, lid_delta); + // pull data from remote lane + const auto v_remote = warp_shuffle_up(v_local, lid_delta); - // decide whether to update local data with remote data - v_local = do_i_hold_reduced_data ? v_local : v_remote; - }); - } - }); + // decide whether to update local data with remote data + v_local = do_i_hold_reduced_data ? v_local : v_remote; + }); + } + }); + } acc_tensor.GetThreadBuffer()(i) = v_local; }); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp index 7804ff259..a00163fe9 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp @@ -4,7 +4,9 @@ #pragma once #include "ck/ck.hpp" +#include "ck/utility/get_id.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/data_type.hpp" namespace ck { namespace tile_program { @@ -15,11 +17,15 @@ template + typename BlockFmhaShape_, + bool kIsGroupMode_, + typename FmhaMask_, + typename Traits_> struct BlockFmhaPipelineProblem { using QDataType = remove_cvref_t; @@ -27,12 +33,33 @@ struct BlockFmhaPipelineProblem using VDataType = remove_cvref_t; using SaccDataType = remove_cvref_t; using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasBias = Traits::kHasBias; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr bool kIsFp8 = + (is_same_v || is_same_v)&&( + is_same_v || + is_same_v)&&(is_same_v || + is_same_v)&&is_same_v && + is_same_v; }; } // namespace block diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp deleted file mode 100644 index 934c5c90f..000000000 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp +++ /dev/null @@ -1,348 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" - -#include "ck/tile_program/tile/tile_distribution.hpp" -#include "ck/tile_program/tile/load_tile.hpp" -#include "ck/tile_program/tile/store_tile.hpp" -#include "ck/tile_program/tile/tile_elementwise.hpp" -#include "ck/tile_program/tile/tile_gemm_shape.hpp" -#include "ck/tile_program/tile/slice_tile.hpp" -#include "ck/tile_program/warp_tile/warp_gemm.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp" -#include "ck/tile_program/block_tile/block_reduce.hpp" - -namespace ck { -namespace tile_program { -namespace block { - -// This pipeline is qkv all located in LDS -template -struct BlockFmhaPipelineQKVS -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - - using BlockFmhaShape = remove_cvref_t; - static constexpr bool kQLoadOnce = false; // if q load whole block length (hdim) at once - - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - - __host__ __device__ static constexpr ck::index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - __host__ __device__ auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, - float scale, - index_t num_total_loop, - index_t num_sub_loop_qk, - void* smem_ptr) const - { - static_assert( - is_same_v> && - is_same_v> && - is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && - kK0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && - kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && - kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && - kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && - kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], - "wrong!"); - - // Q tile in LDS - auto q_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_window = - make_tile_window(q_lds, make_tuple(Number{}, Number{}), {0, 0}); - - // K tile in LDS - KDataType* k_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); - auto k_lds = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_window = - make_tile_window(k_lds, make_tuple(Number{}, Number{}), {0, 0}); - - // V tile in LDS - auto v_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_window = - make_tile_window(v_lds, make_tuple(Number{}, Number{}), {0, 0}); - - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - - auto s_acc = decltype(gemm_0(q_lds_window, k_lds_window)){}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = - decltype(tile_elementwise_in(type_convert, s_acc)); - - using PBlockTileType = - decltype(tile_elementwise_in(type_convert, s_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); - - using OaccBlockTileType = decltype(gemm_1( - get_slice_tile(PBlockTileType{}, Sequence<0, 0>{}, Sequence{}), - v_lds_window)); - - // init Oacc, M, L - auto o_acc = OaccBlockTileType{}; - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; - - tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); - tile_elementwise_inout([](auto& e) { e = NumericLimits::Lowest(); }, - m); - tile_elementwise_inout([](auto& e) { e = 0; }, l); - - auto k_dram_block_window = k_dram_block_window_tmp; - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), - v_dram_block_window_tmp.GetWindowLengths(), - v_dram_block_window_tmp.GetWindowOrigin(), - Policy::template MakeVDramTileDistribution()); - - index_t i_total_loops = 0; - do - { - // STAGE 1, QK gemm - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp.GetBottomTensorView(), - q_dram_block_window_tmp.GetWindowLengths(), - q_dram_block_window_tmp.GetWindowOrigin(), - Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for - // load - - auto k_dram_window = make_tile_window( - k_dram_block_window.GetBottomTensorView(), - k_dram_block_window.GetWindowLengths(), - k_dram_block_window.GetWindowOrigin(), - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load - - auto q_block_tile = load_tile(q_dram_window); // prefetch, global read 0 - auto k_block_tile = load_tile(k_dram_window); - { - move_tile_window(q_dram_window, {0, kK0}); // move to 1 - move_tile_window(k_dram_window, {0, kK0}); - - tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); // Initialize C - - store_tile(q_lds_window, - tile_elementwise_in(q_element_func, q_block_tile)); // LDS write 0 - q_block_tile = load_tile(q_dram_window); // global read 1 - store_tile(k_lds_window, - tile_elementwise_in(k_element_func, k_block_tile)); // LDS write 0 - k_block_tile = load_tile(k_dram_window); // global read 1 - } - - index_t i_k0_loops = num_sub_loop_qk - 2; - do - { - block_sync_lds(); - gemm_0(s_acc, q_lds_window, k_lds_window); // GEMM i - block_sync_lds(); - - move_tile_window(q_dram_window, {0, kK0}); // move to i + 2 - move_tile_window(k_dram_window, {0, kK0}); - - store_tile(q_lds_window, - tile_elementwise_in(q_element_func, q_block_tile)); // LDS write i + 1 - q_block_tile = load_tile(q_dram_window); // global read i + 2 - store_tile(k_lds_window, - tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 - k_block_tile = load_tile(k_dram_window); // global read i + 2 - - i_k0_loops--; - } while(i_k0_loops > 0); - - const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile - - { // tail - block_sync_lds(); - gemm_0(s_acc, q_lds_window, k_lds_window); // GEMM num_loop - 2 - block_sync_lds(); - - store_tile( - q_lds_window, - tile_elementwise_in(q_element_func, q_block_tile)); // LDS write num_loop - 1 - store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); - block_sync_lds(); - - gemm_0(s_acc, q_lds_window, k_lds_window); // GEMM num_loop - 1 - } - - // STAGE 2, scale softmax - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); - - const auto s = - tile_elementwise_in(type_convert, s_acc); // S{j} - auto m_local = block_tile_reduce( - s, - Sequence<1>{}, - f_max, - NumericLimits::Lowest()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s.GetTileDistribution()); // Pcompute{j} - - constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); - sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync(rowsum_p, f_sum); - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); - sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - // FIXME: this use different equation from FA v2 paper, - // but produce correc result. - // Is the equation wrong? - o_acc(i_j_idx) *= tmp; - }); - }); - - block_sync_lds(); - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch - move_tile_window(v_dram_window, {0, kK1}); - - const auto p = - tile_elementwise_in(type_convert, p_compute); - - // STAGE 3, KV gemm - constexpr index_t k1_loops = kN0 / kK1; - if constexpr(k1_loops > 1) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, Sequence<0, i_k1 * kK1>{}, Sequence{}), - v_lds_window); - block_sync_lds(); - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v)); // store next v - move_tile_window(v_dram_window, {0, kK1}); - }); - } - // tail - { - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), - v_lds_window); - block_sync_lds(); - } - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - - i_total_loops++; - } while(i_total_loops < num_total_loop); - - // finally, O - constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); - - sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = 1 / l[i_idx]; - sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - }); - - return o_acc; - } - - template - __host__ __device__ auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - float scale, - index_t num_total_loop, - index_t num_sub_loop_qk, - void* smem_ptr) const - { - return operator()( - q_dram_block_window_tmp, - [](const QDataType& x) { return x; }, - k_dram_block_window_tmp, - [](const KDataType& x) { return x; }, - v_dram_block_window_tmp, - [](const VDataType& x) { return x; }, - scale, - num_total_loop, - num_sub_loop_qk, - smem_ptr); - } -}; - -} // namespace block -} // namespace tile_program -} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp deleted file mode 100644 index 833e0787d..000000000 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp +++ /dev/null @@ -1,264 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" - -#include "ck/tile_program/tile/tile_distribution.hpp" -#include "ck/tile_program/tile/tile_elementwise.hpp" -#include "ck/tile_program/tile/tile_gemm_shape.hpp" -#include "ck/tile_program/warp_tile/warp_gemm.hpp" -#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" -#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" -#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp" - -namespace ck { -namespace tile_program { -namespace block { - -// This pipeline is qkv all located in LDS -struct BlockFmhaPipelineQKVSDefaultPolicy -{ - // 3d + padding - template - __host__ __device__ static constexpr auto MakeQLdsBlockDescriptor() - { - using namespace ck; - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number<8>{}), - make_tuple(Number<(kMPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), - Number<8>{}, - Number<1>{}); - - constexpr auto q_lds_block_desc = transform_tensor_descriptor( - q_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return q_lds_block_desc; - } - - // 3d + padding - template - __host__ __device__ static constexpr auto MakeKLdsBlockDescriptor() - { - using namespace ck; - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number<8>{}), - make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), - Number<8>{}, - Number<1>{}); - - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return k_lds_block_desc; - } - - // 3d + padding - template - __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kPad = 1; - constexpr index_t kK1 = 8; - - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number{}), - make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number{}, Number<1>{}), - Number{}, - Number<1>{}); - - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return v_lds_block_desc; - } - - template - __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() - { - constexpr index_t lds_alignment = 16; // optional - constexpr index_t q_smem_size = - ck::math::integer_divide_ceil( - sizeof(typename Problem::QDataType) * - MakeQLdsBlockDescriptor().GetElementSpaceSize(), - lds_alignment) * - lds_alignment; - return q_smem_size; - } - - template - __host__ __device__ static constexpr ck::index_t GetSmemSize() - { - constexpr index_t smem_size_gemm_0 = - GetSmemSizeQ() + sizeof(typename Problem::KDataType) * - MakeKLdsBlockDescriptor().GetElementSpaceSize(); - constexpr index_t smem_size_gemm_1 = - MakeVLdsBlockDescriptor().GetElementSpaceSize() * - sizeof(typename Problem::VDataType); - - // TODO: consider shuffle requirement - return math::max(smem_size_gemm_0, smem_size_gemm_1); - } - - template - __host__ __device__ static constexpr auto MakeQDramTileDistribution() - { - using QDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t M0 = kBlockSize / get_warp_size(); - constexpr index_t M1 = kMPerBlock / (M2 * M0); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<1, 1>>{}); -#endif - } - - template - __host__ __device__ static constexpr auto MakeKDramTileDistribution() - { - using KDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t K1 = 16 / sizeof(KDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t N0 = kBlockSize / get_warp_size(); - constexpr index_t N1 = kNPerBlock / (N2 * N0); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<1, 1>>{}); -#endif - } - - template - __device__ static constexpr auto MakeVDramTileDistribution() - { - using VDataType = remove_cvref_t; - ; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr index_t K1 = 16 / sizeof(VDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); - } - - template - __host__ __device__ static constexpr auto GetQKBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; - - return BlockGemmASmemBSmemCRegV1{}; - } - - template - __host__ __device__ static constexpr auto GetKVBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; - using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy; - - return BlockGemmARegBSmemCRegV1{}; - } -}; - -} // namespace block -} // namespace tile_program -} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 9224469e6..1400a2b20 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -32,15 +32,21 @@ struct BlockFmhaPipelineQRKSVS using VDataType = remove_cvref_t; using SaccDataType = remove_cvref_t; using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; - static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; - static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kN0 = BlockFmhaShape::kN0; @@ -49,6 +55,14 @@ struct BlockFmhaPipelineQRKSVS static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + __host__ __device__ static constexpr ck::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -57,9 +71,13 @@ struct BlockFmhaPipelineQRKSVS template + typename VElementFunction, + typename BiasElementFunction, + typename LSEElementFunction> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -67,9 +85,12 @@ struct BlockFmhaPipelineQRKSVS const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + FmhaMask mask, float scale, - index_t num_total_loop, - index_t /*num_sub_loop_qk*/, // in this pipeline, the 1st gemm loop must be static void* smem_ptr) const { static_assert( @@ -82,7 +103,9 @@ struct BlockFmhaPipelineQRKSVS kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && - kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], "wrong!"); // K tile in LDS @@ -97,8 +120,8 @@ struct BlockFmhaPipelineQRKSVS auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_window = - make_tile_window(v_lds, make_tuple(Number{}, Number{}), {0, 0}); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().GetLengths(), {0, 0}); // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); @@ -110,50 +133,84 @@ struct BlockFmhaPipelineQRKSVS q_dram_block_window_tmp.GetWindowOrigin(), Policy::template MakeQDramTileDistribution()); - auto q = load_tile(q_dram_window); // persistent q register tile + auto q = load_tile(q_dram_window); - auto s_acc = decltype(gemm_0(get_slice_tile(tile_elementwise_in(q_element_func, q), - Sequence<0, 0>{}, - Sequence{}), - k_lds_window)){}; + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = - decltype(tile_elementwise_in(type_convert, s_acc)); - - using PBlockTileType = - decltype(tile_elementwise_in(type_convert, s_acc)); + using SBlockTileType = decltype(cast_tile(s_acc)); using MLBlockTileType = decltype(block_tile_reduce( SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); - using OaccBlockTileType = decltype(gemm_1( - get_slice_tile(PBlockTileType{}, Sequence<0, 0>{}, Sequence{}), - v_lds_window)); + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); // init Oacc, M, L auto o_acc = OaccBlockTileType{}; auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; - tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); - tile_elementwise_inout([](auto& e) { e = NumericLimits::Lowest(); }, - m); - tile_elementwise_inout([](auto& e) { e = 0; }, l); + clear_tile(o_acc); + set_tile(m, -NumericLimits::Infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.GetWindowOrigin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.At(Number<0>{}), Number{}, Number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + set_tile(lse, -NumericLimits::Infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.GetBottomTensorView(), + k_dram_block_window_tmp.GetWindowLengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.GetWindowOrigin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.GetBottomTensorView(), + bias_dram_block_window_tmp.GetWindowLengths(), + {bias_origin.At(Number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); - auto k_dram_block_window = k_dram_block_window_tmp; auto v_dram_window = make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), v_dram_block_window_tmp.GetWindowLengths(), - v_dram_block_window_tmp.GetWindowOrigin(), + {0, seqlen_k_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); - auto q_tile = tile_elementwise_in(q_element_func, q); - index_t i_total_loops = 0; + auto q_tile = tile_elementwise_in(q_element_func, q); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); do { // STAGE 1, QK gemm @@ -167,16 +224,22 @@ struct BlockFmhaPipelineQRKSVS auto k_block_tile = load_tile(k_dram_window); { move_tile_window(k_dram_window, {0, kK0}); - - tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); // Initialize C - - store_tile(k_lds_window, - tile_elementwise_in(k_element_func, k_block_tile)); // LDS write 0 - k_block_tile = load_tile(k_dram_window); // global read 1 + clear_tile(s_acc); // Initialize C + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); } - // index_t i_k0_loops = num_sub_loop_qk - 2; - constexpr index_t k0_loops = kK0BlockLength / kK0; + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } if constexpr(k0_loops > 2) { @@ -217,17 +280,53 @@ struct BlockFmhaPipelineQRKSVS k_lds_window); } - // STAGE 2, scale softmax - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + // STAGE 2, scale, add bias, mask, softmax + if constexpr(kHasBias) + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert(bias_element_func(y)); +#else + x = scale * x + math::log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), + k_origin.At(Number<0>{}), + Number{}, + Number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } - const auto s = - tile_elementwise_in(type_convert, s_acc); // S{j} + const auto s = cast_tile(s_acc); // S{j} auto m_local = block_tile_reduce( s, Sequence<1>{}, f_max, - NumericLimits::Lowest()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max); + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} tile_elementwise_inout( @@ -236,25 +335,68 @@ struct BlockFmhaPipelineQRKSVS auto p_compute = make_static_distributed_tensor( s.GetTileDistribution()); // Pcompute{j} + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * get_validated_m(m[i_idx]); +#endif sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif }); }); auto rowsum_p = block_tile_reduce( p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - block_tile_reduce_sync(rowsum_p, f_sum); + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); // l{j}, Oacc{j} constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(kHasBias) + { + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); // FIXME: this use different equation from FA v2 paper, @@ -281,11 +423,9 @@ struct BlockFmhaPipelineQRKSVS } move_tile_window(v_dram_window, {0, kK1}); - const auto p = - tile_elementwise_in(type_convert, p_compute); + const auto p = cast_tile(p_compute); // STAGE 3, KV gemm - constexpr index_t k1_loops = kN0 / kK1; if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { @@ -315,7 +455,6 @@ struct BlockFmhaPipelineQRKSVS } // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); - i_total_loops++; // tail { block_sync_lds(); @@ -324,14 +463,46 @@ struct BlockFmhaPipelineQRKSVS v_lds_window); block_sync_lds(); } - } while(i_total_loops < num_total_loop); + } while(++i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + constexpr auto lse_spans = decltype(lse)::GetDistributedSpans(); + sweep_tile_span(lse_spans[Number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } // finally, O constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - const auto tmp = 1 / l[i_idx]; + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); o_acc(i_j_idx) *= tmp; @@ -343,27 +514,32 @@ struct BlockFmhaPipelineQRKSVS template + typename VDramBlockWindowTmp, + typename BiasDramBlockWindowTmp, + typename LSEDramBlockWindowTmp> __host__ __device__ auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, float scale, - index_t num_total_loop, - index_t num_sub_loop_qk, void* smem_ptr) const { - return operator()( - q_dram_block_window_tmp, - [](const QDataType& x) { return x; }, - k_dram_block_window_tmp, - [](const KDataType& x) { return x; }, - v_dram_block_window_tmp, - [](const VDataType& x) { return x; }, - scale, - num_total_loop, - num_sub_loop_qk, - smem_ptr); + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + mask, + scale, + smem_ptr); } }; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp new file mode 100644 index 000000000..eb33aea53 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -0,0 +1,629 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/load_tile.hpp" +#include "ck/tile_program/tile/store_tile.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/tile/slice_tile.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck/tile_program/block_tile/block_reduce.hpp" +#include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaPipelineQRKSVSAsync +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + +#if CK_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / math::log2e_v; +#endif + + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + FmhaMask mask, + float scale, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).GetLengths(), + {0, 0, 0}); + }, + Number{}); + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + auto k_lds_load = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), + Policy::template MakeKLdsLoadBlockDescriptor(i_buf).GetLengths(), + {0, 0}); + }, + Number{}); +#else + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().GetLengths(), + {0, 0}); +#endif + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().GetLengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.GetBottomTensorView(), + q_dram_block_window_tmp.GetWindowLengths(), + q_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeQDramTileDistribution()); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = load_tile_raw(q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -NumericLimits::Infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.GetWindowOrigin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.At(Number<0>{}), Number{}, Number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + set_tile(lse, -NumericLimits::Infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.GetBottomTensorView(), + k_dram_block_window_tmp.GetWindowLengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + const auto bias_origin = bias_dram_block_window_tmp.GetWindowOrigin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.GetBottomTensorView(), + bias_dram_block_window_tmp.GetWindowLengths(), + {bias_origin.At(Number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), + v_dram_block_window_tmp.GetWindowLengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.GetNumAccess()); + + auto q_tile = tile_elementwise_in(q_element_func, q); + __builtin_amdgcn_sched_barrier(0); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // Initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(Number{})>{}), + k_dram_window); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.GetNumAccess()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, i_k0 * kK0>{}, + Sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[Number{})>{}]); + +#else + get_slice_tile(k_lds_load, + Sequence<(LdsSeq.At(Number{})) * kN0, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN0, kK0>{})); +#endif + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, (k0_loops - 1) * kK0>{}, + Sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[Number{})>{}]); + +#else + get_slice_tile( + k_lds_load, + Sequence<(LdsSeq.At(Number{})) * kN0, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN0, kK0>{})); +#endif + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale, add bias, mask, softmax + if constexpr(kHasBias) + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert(bias_element_func(y)); +#else + x = scale * x + math::log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), + k_origin.At(Number<0>{}), + Number{}, + Number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + Sequence<1>{}, + f_max, + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.GetTileDistribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile(v_dram_window); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); + sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(kHasBias) + { + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile(v_dram_window); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, Sequence<0, i_k1 * kK1>{}, Sequence{}), + get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{})); + + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window = + make_tile_window(k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); + + if constexpr(k1_loops >= 2 && + LdsSeq.At(Number<0>{}) == LdsSeq.At(Number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), + get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + constexpr auto lse_spans = decltype(lse)::GetDistributedSpans(); + sweep_tile_span(lse_spans[Number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + math::log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale * R_LOG2E + math::log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale, + void* smem_ptr) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + mask, + scale, + smem_ptr); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp new file mode 100644 index 000000000..29bcde926 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +using BlockFmhaPipelineQRKSVSAsyncDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp index b5680f78f..28d7ba2da 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -3,396 +3,19 @@ #pragma once -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" - -#include "ck/tile_program/tile/tile_distribution.hpp" -#include "ck/tile_program/tile/tile_elementwise.hpp" -#include "ck/tile_program/tile/tile_gemm_shape.hpp" -#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp" -#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" -#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" -#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" namespace ck { namespace tile_program { namespace block { // This pipeline is qkv all located in LDS -struct BlockFmhaPipelineQRKSVSDefaultPolicy -{ - template - __host__ __device__ static constexpr auto GetSmemKPackK() - { - // TODO: this is for 3d layout - using KDataType = remove_cvref_t; - return 16 / sizeof(KDataType); - } - - template - __host__ __device__ static constexpr auto GetSmemKPackV() - { - // TODO: this is for 3d layout - using VDataType = remove_cvref_t; - return 16 / sizeof(VDataType); - } - template - __host__ __device__ static constexpr auto GetTransposedVectorloadV() - { - return 4; // TODO: fix me - } - - template - __host__ __device__ static constexpr auto MakeQRegBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; - - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template At<1>(); - constexpr index_t NWarp = config.template At<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto q_block_outer_dstr_encoding = StaticTileDistributionEncoding< - Sequence, - Tuple, Sequence>, - Tuple>, - Tuple>, - Sequence<1, 2>, - Sequence<0, 0>>{}; - - constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); - - return q_block_dstr; - } - - // 3d + padding - template - __host__ __device__ static constexpr auto MakeKLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kKPack = GetSmemKPackV(); - - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number{}), - make_tuple(Number<(kNPerBlock + 1) * kKPack>{}, Number{}, Number<1>{}), - Number<8>{}, - Number<1>{}); - - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return k_lds_block_desc; - } - - // 3d + padding - template - __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() - { -#if 0 - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kPad = 1; - constexpr index_t kKPack = GetSmemKPackV(); - - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number{}), - make_tuple(Number<(kNPerBlock + kPad) * kKPack>{}, Number{}, Number<1>{}), - Number{}, - Number<1>{}); - - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return v_lds_block_desc; -#else - using VDataType = remove_cvref_t; - constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(PixelsPerRow % kKPack == 0); - constexpr index_t NPerRow = PixelsPerRow / kKPack; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - static_assert(kNPerBlock % NPerRow == 0); - static_assert(kKPerBlock % kKPack == 0); - - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}), - make_tuple(Number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, - Number{}, - Number{}, - Number<1>{}), - Number{}, - Number<1>{}); - - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple(Number{}, Number{})), - make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<1, 2>{}, Sequence<0, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return v_lds_block_desc; -#endif - } - - template - __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() - { - return 0; - } - - template - __host__ __device__ static constexpr ck::index_t GetSmemSize() - { - constexpr index_t smem_size_gemm_0 = - GetSmemSizeQ() + sizeof(typename Problem::KDataType) * - MakeKLdsBlockDescriptor().GetElementSpaceSize(); - constexpr index_t smem_size_gemm_1 = - MakeVLdsBlockDescriptor().GetElementSpaceSize() * - sizeof(typename Problem::VDataType); - - // TODO: consider shuffle requirement - return math::max(smem_size_gemm_0, smem_size_gemm_1); - } - - template - __host__ __device__ static constexpr auto MakeQDramTileDistribution() - { - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template At<1>(); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; - - constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<2, 1>>, - Tuple, Sequence<1, 2>>, - Sequence<2, 1, 2>, - Sequence<0, 0, 2>>{}); - } - - template - __host__ __device__ static constexpr auto MakeKDramTileDistribution() - { - using KDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t K1 = 16 / sizeof(KDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t N0 = kBlockSize / get_warp_size(); - constexpr index_t N1 = kNPerBlock / (N2 * N0); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<1, 1>>{}); -#endif - } - - template - __device__ static constexpr auto MakeVDramTileDistribution() - { - using VDataType = remove_cvref_t; - using VLayout = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - if constexpr(ck::is_same_v) - { - constexpr index_t N1 = GetTransposedVectorloadV(); - constexpr index_t N0 = kNPerBlock / N1; // P - - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<2, 1, 2>>, - Tuple, Sequence<1, 0, 2>>, - Sequence<2, 1>, - Sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = 16 / sizeof(VDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); - } - } - - template - __host__ __device__ static constexpr auto MakeShuffledVRegBlockDescriptor() - { - // This descriptor only used when V layout is seqlen * hdim - using VLayout = remove_cvref_t; - static_assert(ck::is_same_v); - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr index_t N1 = GetTransposedVectorloadV(); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<2, 1, 2>>, - Tuple, Sequence<1, 0, 2>>, - Sequence<1, 2>, - Sequence<1, 3>>{}); - } - - template - __host__ __device__ static constexpr auto GetQKBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; - - constexpr auto warp_gemm = []() { - if constexpr(is_same_v && - is_same_v && - is_same_v) - { - return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; - } - else if constexpr(is_same_v && - is_same_v && - is_same_v) - { - return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; - } - }(); - - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV1CustomPolicy; - - return BlockGemmARegBSmemCRegV1{}; - } - - template - __host__ __device__ static constexpr auto GetKVBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; - - using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<0>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<2>{}), - true>; - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV1CustomPolicy; - return BlockGemmARegBSmemCRegV1{}; - } -}; +using BlockFmhaPipelineQRKSVSDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; } // namespace block } // namespace tile_program diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp new file mode 100644 index 000000000..3e74e4058 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -0,0 +1,475 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/load_tile.hpp" +#include "ck/tile_program/tile/store_tile.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/tile/slice_tile.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck/tile_program/block_tile/block_reduce.hpp" +#include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQRKSVSFp8 +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported + FmhaMask mask, + float scale, + float descale_qk, + float descale_sv, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().GetLengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.GetBottomTensorView(), + q_dram_block_window_tmp.GetWindowLengths(), + q_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeQDramTileDistribution()); + + auto q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -NumericLimits::Infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.GetWindowOrigin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.At(Number<0>{}), Number{}, Number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.GetBottomTensorView(), + k_dram_block_window_tmp.GetWindowLengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.GetWindowOrigin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.GetBottomTensorView(), + bias_dram_block_window_tmp.GetWindowLengths(), + {bias_origin.At(Number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), + v_dram_block_window_tmp.GetWindowLengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // auto q_tile = tile_elementwise_in(q_element_func, q); + auto q_tile = q; + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + + scale = scale * descale_qk; + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // Initialize C + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, i_k0 * kK0>{}, + Sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile(k_lds_window, + k_block_tile); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, (k0_loops - 2) * kK0>{}, + Sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, k_block_tile); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, (k0_loops - 1) * kK0>{}, + Sequence{}), + k_lds_window); + } + + // STAGE 2, scale, add bias, mask, softmax + if constexpr(kHasBias) + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert((y)); +#else + x = scale * x + + math::log2e_v * type_convert((y)); +#endif + }, + s_acc, + bias_tile); + } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), + k_origin.At(Number<0>{}), + Number{}, + Number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + Sequence<1>{}, + f_max, + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.GetTileDistribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); + sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(kHasBias) + { + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_prefetch); + store_tile(v_lds_window, + v_shuffle_tmp); // store the prefetch + } + else + { + store_tile(v_lds_window, + v_prefetch); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, Sequence<0, i_k1 * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v); + store_tile(v_lds_window, v_shuffle_tmp); + } + else + { + store_tile(v_lds_window, v); + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + // finally, O + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + tmp = tmp * descale_sv; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp new file mode 100644 index 000000000..87332d5fc --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -0,0 +1,558 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/load_tile.hpp" +#include "ck/tile_program/tile/store_tile.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/tile/slice_tile.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" +#include "ck/tile_program/block_tile/block_reduce.hpp" +#include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQSKSVS +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = false; + static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() + { + return Policy::template GetSmemSizeQ(); + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + FmhaMask mask, + float scale, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + // Q tile in LDS + auto q_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().GetLengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -NumericLimits::Infinity()); + clear_tile(l); + + const auto q_origin = q_dram_block_window_tmp.GetWindowOrigin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.At(Number<0>{}), Number{}, Number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + set_tile(lse, -NumericLimits::Infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.GetBottomTensorView(), + k_dram_block_window_tmp.GetWindowLengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.GetWindowOrigin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.GetBottomTensorView(), + bias_dram_block_window_tmp.GetWindowLengths(), + {bias_origin.At(Number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), + v_dram_block_window_tmp.GetWindowLengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + do + { + // STAGE 1, QK gemm + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.GetBottomTensorView(), + q_dram_block_window_tmp.GetWindowLengths(), + q_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeQDramTileDistribution()); + + auto k_dram_window = + make_tile_window(k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); + + auto q_block_tile = load_tile(q_dram_window); + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {0, kK0}); + + clear_tile(s_acc); // Initialize C + + store_tile(q_lds_window, tile_elementwise_in(q_element_func, q_block_tile)); + q_block_tile = load_tile(q_dram_window); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(kHasBias) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto) { + block_sync_lds(); + gemm_0(s_acc, q_lds_window, k_lds_window); + block_sync_lds(); + + move_tile_window(q_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + q_lds_window, + tile_elementwise_in(q_element_func, q_block_tile)); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, q_lds_window, k_lds_window); + block_sync_lds(); + + store_tile(q_lds_window, tile_elementwise_in(q_element_func, q_block_tile)); + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, q_lds_window, k_lds_window); + } + + // STAGE 2, scale, add bias, mask, softmax + if constexpr(kHasBias) + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert(bias_element_func(y)); +#else + x = scale * x + math::log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), + k_origin.At(Number<0>{}), + Number{}, + Number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + Sequence<1>{}, + f_max, + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.GetTileDistribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); + sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(kHasBias) + { + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, Sequence<0, i_k1 * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + constexpr auto lse_spans = decltype(lse)::GetDistributedSpans(); + sweep_tile_span(lse_spans[Number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(kHasBias) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale, + void* smem_ptr) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + mask, + scale, + smem_ptr); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp new file mode 100644 index 000000000..8fdf2c0b1 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +using BlockFmhaPipelineQSKSVSDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp new file mode 100644 index 000000000..4bb59d79f --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -0,0 +1,916 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp" +#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +// TODO: remove this +#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0 + +namespace ck { +namespace tile_program { +namespace block { + +template +struct BlockFmhaPipelineQXCustomPolicy; + +template <> +struct BlockFmhaPipelineQXCustomPolicy +{ + static constexpr bool QLoadOnce = true; + + template + __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() + { + return 0; + } + + template + __host__ __device__ static constexpr auto MakeQDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template At<1>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + + constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K0 = kKPerBlock / (K1 * K2); + + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1>>, + Tuple, Sequence<1, 2>>, + Sequence<1, 2, 2>, + Sequence<0, 0, 2>>{}); + } + + template + __host__ __device__ static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(Problem::kIsFp8) + { + constexpr index_t swizzle_factor = 4; // TODO: hard coded here + return warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< + warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< + typename Problem::QDataType, + typename Problem::KDataType>, + 2, + swizzle_factor>>{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + return BlockGemmARegBSmemCRegV2{}; + } +}; + +template <> +struct BlockFmhaPipelineQXCustomPolicy +{ + static constexpr bool QLoadOnce = false; + + template + __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() + { + constexpr index_t lds_alignment = 16; // optional + constexpr index_t q_smem_size = + ck::math::integer_divide_ceil( + sizeof(typename Problem::QDataType) * + MakeQLdsBlockDescriptor().GetElementSpaceSize(), + lds_alignment) * + lds_alignment; + return q_smem_size; + } + + template + __host__ __device__ static constexpr auto MakeQDramTileDistribution() + { + using QDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + + // 3d + padding + template + __host__ __device__ static constexpr auto MakeQLdsBlockDescriptor() + { + using QDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPack = 16 / sizeof(QDataType); + + constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number{}), + make_tuple(Number<(kMPerBlock + 1) * kKPack>{}, Number{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto q_lds_block_desc = transform_tensor_descriptor( + q_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return q_lds_block_desc; + } + + template + __host__ __device__ static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(Problem::kIsFp8) + { + constexpr index_t swizzle_factor = 4; // TODO: hard coded here + return warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< + warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< + typename Problem::QDataType, + typename Problem::KDataType>, + 2, + swizzle_factor>>{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy +{ + static constexpr bool AsyncCopyK = AsyncCopyK_; + static constexpr bool AsyncCopyV = AsyncCopyV_; // TODO: this not supported yet + + static constexpr index_t NumPrefetchK = NumPrefetchK_; + static constexpr index_t NumPrefetchV = NumPrefetchK_; + + using QXPolicy = BlockFmhaPipelineQXCustomPolicy; + + template + struct LdsBufferSequence + { + static constexpr auto Make() + { + return transform_sequences( + [&](auto i) { + if(i < k_loops_) + return i % k_prefetches_; + return (i - k_loops_) % v_prefetches_; + }, + typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{}); + }; + + using type = remove_cvref_t; + }; + // clang-format off + template<> struct + LdsBufferSequence<3, 3, 4, 4> { using type = Sequence<1, 2, 0, 1, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 4, 2> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 4> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 3, 3> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 2> { using type = Sequence<1, 2, 1, 0>;}; + // clang-format on + + template + __host__ __device__ static constexpr auto GetLdsBufferSequence() + { + using BlockFmhaShape = remove_cvref_t; + + constexpr index_t kN0 = BlockFmhaShape::kN0; + constexpr index_t kK0 = BlockFmhaShape::kK0; + constexpr index_t kK1 = BlockFmhaShape::kK1; + constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + return typename LdsBufferSequence::type{}; + } + + template + __host__ __device__ static constexpr auto GetSmemKPackK() + { + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + __host__ __device__ static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + template + __host__ __device__ static constexpr auto GetVectorloadV() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + __host__ __device__ static constexpr auto GetSingleSmemElementSpaceSize() + { + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + if constexpr(!AsyncCopyK) + { + return MakeKLdsBlockDescriptor().GetElementSpaceSize(); + } + else + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(warpSize * KVector >= kKPerBlock && + warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = warpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (warpSize * KVector + kPad); + } + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return math::max(SingleKSize, SingleVSize); + } + + template + __host__ __device__ static constexpr auto GetVectorloadK() + { + using KDataType = remove_cvref_t; + return 4 / sizeof(KDataType); // TODO: this is for async copy + } + + template + __host__ __device__ static constexpr auto MakeQRegBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto q_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + + // TODO: this is used for non async copy desc. unify in the future + template + __host__ __device__ static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPack = GetSmemKPackV(); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number{}), + make_tuple(Number<(kNPerBlock + 1) * kKPack>{}, Number{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return k_lds_block_desc; + } + + template + __host__ __device__ static constexpr auto + MakeKLdsStoreBlockDescriptor(Number = Number<0>{}) + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = + KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + warpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(Number{}, // n0 + Number{}, // n1 + Number{}, // n2 + Number{}, // k0 + Number{}), // k1 + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{}), + Number()>{}, + Number{}, + Number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + template + __host__ __device__ static constexpr auto + MakeKLdsLoadBlockDescriptor(Number = Number<0>{}) + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; // for async-copy, this pad is between warps + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(Number{}, // n0 + Number{}, // n2 + Number{}, // n1 + Number{}, // k0 + Number{}), // k1 + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{}), + Number()>{}, + Number{}, + Number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0, 2, 1>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return k_lds_block_desc; + } +#else + template + __host__ __device__ static constexpr auto MakeKLdsLoadBlockDescriptor() + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; // for async-copy, this pad is between warps + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + // constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad); + // constexpr index_t SingleVSize = MakeVLdsBlockDescriptor().GetElementSpaceSize(); + constexpr index_t BufferSize = + GetSingleSmemElementSpaceSize(); // math::max(SingleKSize, SingleVSize); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(Number{}, // num_buffers + Number{}, // n0 + Number{}, // n2 + Number{}, // n1 + Number{}, // k0 + Number{}), // k1 + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{}), + Number{}, + Number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0, 1, 3, 2>{}, Sequence<4, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return k_lds_block_desc; + } +#endif + + // 3d + padding + template + __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() + { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}), + make_tuple(Number()>{}, + Number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + Number{}, + Number{}, + Number<1>{}), + Number{}, + Number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return v_lds_block_desc; + } + + template + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + // TODO: assume Q is in register + // TODO: assume K/V has same data type + constexpr index_t single_smem_size = + GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + + return QXPolicy::template GetSmemSizeQ() + + single_smem_size * math::max(NumPrefetchK, NumPrefetchV); + } + + template + __host__ __device__ static constexpr auto MakeKDramTileDistribution() + { + if constexpr(!AsyncCopyK) + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t K1 = 16 / sizeof(KDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + else + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KVector = GetVectorloadK(); // this is for global load + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<1, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + } + + template + __device__ static constexpr auto MakeVDramTileDistribution() + { + using VDataType = remove_cvref_t; + using VLayout = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(ck::is_same_v) + { + constexpr index_t N1 = GetVectorloadV(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + StaticTileDistributionEncoding< + Sequence<1>, + Tuple, Sequence>, + Tuple, Sequence<2, 1, 2>>, + Tuple, Sequence<1, 0, 2>>, + Sequence<2, 1>, + Sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + StaticTileDistributionEncoding< + Sequence<1>, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<0, 2>>, + Sequence<2, 1>, + Sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = 16 / sizeof(VDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + static_assert(N0 != 0); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + } + + template + __host__ __device__ static constexpr auto MakeBiasDramTileDistribution() + { + constexpr index_t MPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t NPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + return c_block_dstr; + } + + template + __host__ __device__ static constexpr auto MakeShuffledVRegBlockDescriptor() + { + // This descriptor only used when V layout is seqlen * hdim + using VLayout = remove_cvref_t; + static_assert(ck::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t N1 = GetVectorloadV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1, 2>>, + Tuple, Sequence<1, 0, 2>>, + Sequence<1, 2>, + Sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<0, 2>>, + Sequence<1, 2>, + Sequence<1, 3>>{}); + } + } + + template + __host__ __device__ static constexpr auto GetKVBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + auto warp_gemm = [&]() { + if constexpr(Problem::kIsFp8) + { + return warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< + warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< + typename Problem::PDataType, + typename Problem::VDataType>, + 2>>{}; + // return + // warp::WarpGemmImpl>>{}; + } + else + { + return ck::tile_program::warp::WarpGemmMfmaDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<2>{}), + true>{}; + } + }(); + + using WarpGemm = remove_cvref_t; + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + return BlockGemmARegBSmemCRegV2{}; + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/load_tile.hpp b/include/ck/tile_program/tile/load_tile.hpp index b2ea61afa..1aae8ad3a 100644 --- a/include/ck/tile_program/tile/load_tile.hpp +++ b/include/ck/tile_program/tile/load_tile.hpp @@ -11,6 +11,7 @@ #include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/tile_window.hpp" +#include "ck/tile_program/tile/null_tensor.hpp" #include "ck/tile_program/tile/static_distributed_tensor.hpp" namespace ck { @@ -28,5 +29,43 @@ __device__ auto load_tile(const TileWindowWithStaticDistribution +__device__ auto load_tile_raw(const TileWindowWithStaticDistribution& tile_window) +{ + return tile_window.Load(bool_constant{}); +} + +template +__device__ auto async_load_tile_raw(LdsTileWindow_&& lds_tile, + const TileWindowWithStaticDistribution& tile_window) +{ + return tile_window.AsyncLoad(lds_tile); +} + +__device__ auto async_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +template +__device__ auto load_tile(const NullTileWindow&) +{ + return NullTensor{}; +} + } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/tile/null_tensor.hpp b/include/ck/tile_program/tile/null_tensor.hpp new file mode 100644 index 000000000..50f1efa17 --- /dev/null +++ b/include/ck/tile_program/tile/null_tensor.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tile_program { + +struct NullTensor +{ +}; + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/null_tile_window.hpp b/include/ck/tile_program/tile/null_tile_window.hpp new file mode 100644 index 000000000..2d873bcfc --- /dev/null +++ b/include/ck/tile_program/tile/null_tile_window.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor/tensor_view.hpp" +#include "ck/utility/common_header.hpp" +#include + +namespace ck { +namespace tile_program { + +// placeholder type if we want to opt-out a tile window parameter +template +struct NullTileWindow +{ + using BottomTensorView = NullTensorView; + using WindowLengths = remove_cvref_t; + + using BottomTensorIndex = Array; + + __device__ constexpr NullTileWindow() = default; + + __device__ constexpr NullTileWindow(const WindowLengths& window_lengths) + : window_lengths_{window_lengths} + { + } + + __device__ constexpr auto GetWindowLengths() const { return window_lengths_; } + + __device__ constexpr auto GetBottomTensorView() const { return NullTensorView{}; } + + __device__ constexpr auto GetWindowOrigin() const { return BottomTensorIndex{}; } + + WindowLengths window_lengths_; +}; + +// utility to check if this is a Null Tile Window +namespace impl { +template +struct IsNullTileWindow : public std::false_type +{ +}; + +template +struct IsNullTileWindow> : public std::true_type +{ +}; +} // namespace impl + +template +__device__ constexpr auto is_null_tile_window(const T&) +{ + return impl::IsNullTileWindow>::value; +} + +template +__device__ constexpr auto make_null_tile_window(const WindowLengths& window_lengths) +{ + static_assert(is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return NullTileWindow>{window_lengths}; +} + +template +__device__ constexpr auto make_tile_window(NullTensorView, + const WindowLengths& window_lengths, + const MultiIndex& /*origin*/, + Ts&&...) +{ + static_assert(is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return NullTileWindow>{window_lengths}; +} + +template +__device__ void move_tile_window(NullTileWindow&, + const typename NullTileWindow::BottomTensorIndex&) +{ +} + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/slice_tile.hpp b/include/ck/tile_program/tile/slice_tile.hpp index e7999f26a..7de77db49 100644 --- a/include/ck/tile_program/tile/slice_tile.hpp +++ b/include/ck/tile_program/tile/slice_tile.hpp @@ -3,64 +3,5 @@ #pragma once -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" -#include "ck/tensor_description/tensor_space_filling_curve.hpp" - -#include "ck/tile_program/tile/tile_distribution.hpp" -#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" -#include "ck/tile_program/tile/tile_window.hpp" -#include "ck/tile_program/tile/static_distributed_tensor.hpp" - -namespace ck { -namespace tile_program { - -template -__host__ __device__ constexpr auto get_slice_tile(const StaticDistributedTensor_& tile, - Sequence slice_begins, - Sequence slice_ends) -{ - using Distribution = decltype(StaticDistributedTensor_::GetTileDistribution()); - using DataType = typename StaticDistributedTensor_::DataType; - - constexpr auto sliced_dstr_yidx_ylen = - detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends); - - constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); - constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); - constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); - - auto sliced_tensor = make_static_distributed_tensor(sliced_dstr); - - sliced_tensor.GetThreadBuffer() = tile.GetYSlicedThreadData(sliced_y_origins, sliced_y_lengths); - - return sliced_tensor; -} - -template -__host__ __device__ constexpr auto set_slice_tile(DstStaticDistributedTensor_& dst_tile, - const SrcStaticDistributedTensor_& src_tile, - Sequence slice_begins, - Sequence slice_ends) -{ - using DstDistribution = decltype(DstStaticDistributedTensor_::GetTileDistribution()); - - constexpr auto sliced_dstr_yidx_ylen = - detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); - - constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); - constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); - constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); - - static_assert(is_same_v, "wrong!"); - - dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer()); -} - -} // namespace tile_program -} // namespace ck +#include "ck/tile_program/tile/slice_tile_impl_distributed_tensor.hpp" +#include "ck/tile_program/tile/slice_tile_impl_static_lengths.hpp" diff --git a/include/ck/tile_program/tile/slice_tile_impl_distributed_tensor.hpp b/include/ck/tile_program/tile/slice_tile_impl_distributed_tensor.hpp new file mode 100644 index 000000000..2b31b449b --- /dev/null +++ b/include/ck/tile_program/tile/slice_tile_impl_distributed_tensor.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" +#include "ck/tile_program/tile/tile_window.hpp" +#include "ck/tile_program/tile/static_distributed_tensor.hpp" + +namespace ck { +namespace tile_program { + +template +__host__ __device__ constexpr auto +get_slice_tile(const StaticDistributedTensor& tile, + Sequence slice_begins, + Sequence slice_ends) +{ + using DataType = remove_cvref_t; + using Distribution = remove_cvref_t; + + constexpr auto sliced_dstr_yidx_ylen = + detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends); + + constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); + constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); + constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); + + auto sliced_tensor = make_static_distributed_tensor(sliced_dstr); + + sliced_tensor.GetThreadBuffer() = tile.GetYSlicedThreadData(sliced_y_origins, sliced_y_lengths); + + return sliced_tensor; +} + +template +__host__ __device__ constexpr auto +set_slice_tile(StaticDistributedTensor& dst_tile, + const StaticDistributedTensor& src_tile, + Sequence slice_begins, + Sequence slice_ends) +{ + using DstDistribution = remove_cvref_t; + + constexpr auto sliced_dstr_yidx_ylen = + detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); + + constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); + constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); + constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); + + static_assert(is_same_v, "wrong!"); + + dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer()); +} + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/slice_tile_impl_static_lengths.hpp b/include/ck/tile_program/tile/slice_tile_impl_static_lengths.hpp new file mode 100644 index 000000000..de7aa8f03 --- /dev/null +++ b/include/ck/tile_program/tile/slice_tile_impl_static_lengths.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" +#include "ck/tile_program/tile/tile_window.hpp" +#include "ck/tile_program/tile/tile_window_impl_static_lengths.hpp" + +namespace ck { +namespace tile_program { + +template +__host__ __device__ constexpr auto +get_slice_tile(const TileWindowWithStaticLengths& tile, + Sequence slice_begins, + Sequence slice_ends) +{ + using TileWindow = TileWindowWithStaticLengths; + // NOTE: This API will override the origin of the tile window! + static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds)); + static_assert(sizeof...(SliceBegins) == TileWindow::GetNumOfDimension()); + + constexpr auto slice_lengths = slice_ends - slice_begins; + + return make_tile_window(tile.GetBottomTensorView(), + sequence_to_tuple_of_number(slice_lengths), + to_multi_index(slice_begins)); +} + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/static_distributed_tensor.hpp b/include/ck/tile_program/tile/static_distributed_tensor.hpp index 19f420aa4..d0a0355be 100644 --- a/include/ck/tile_program/tile/static_distributed_tensor.hpp +++ b/include/ck/tile_program/tile/static_distributed_tensor.hpp @@ -4,10 +4,12 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor_coordinate.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" #include "ck/tile_program/tile/tile_distribution.hpp" namespace ck { @@ -176,5 +178,44 @@ __host__ __device__ constexpr auto make_static_distributed_tensor(const StaticTi remove_cvref_t>{}; } +// get X indices from tuple of TileDistributedIndex<> +template +__host__ __device__ constexpr auto +get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices) +{ + const auto partition_index = detail::get_partition_index(tile_distribution); + constexpr auto y_indices = + tile_distribution.GetYIndicesFromDistributedIndices(distributed_indices); + + const auto x_coord = make_tensor_adaptor_coordinate( + tile_distribution.GetPsYs2XsAdaptor(), + container_concat(partition_index, to_array(y_indices))); + + return x_coord.GetBottomIndex(); +} + +template +__host__ __device__ void +set_tile_if(StaticDistributedTensor& out_tensor, + DataType value, + XIndicesPredicate predicate) +{ + constexpr auto out_spans = + StaticDistributedTensor::GetDistributedSpans(); + sweep_tile_span(out_spans[Number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[Number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{}, + distributed_indices); + + if(predicate(x_indices)) + { + out_tensor(distributed_indices) = value; + } + }); + }); +} + } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/tile/tile_elementwise.hpp b/include/ck/tile_program/tile/tile_elementwise.hpp index dedfc5961..4b3ea237f 100644 --- a/include/ck/tile_program/tile/tile_elementwise.hpp +++ b/include/ck/tile_program/tile/tile_elementwise.hpp @@ -9,13 +9,17 @@ #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/null_tensor.hpp" #include "ck/tile_program/tile/static_distributed_tensor.hpp" namespace ck { namespace tile_program { // TODO: support tensors with different distribution -template +template , NullTensor>>...>>> __device__ void tile_elementwise_inout(const InOutElementFunc& inout_element_func, InOutDstrTensors&... inout_dstr_tensors) { @@ -29,7 +33,10 @@ __device__ void tile_elementwise_inout(const InOutElementFunc& inout_element_fun [&](auto i) { inout_element_func(inout_dstr_tensors.GetThreadBuffer().At(i)...); }); } -template +template >...>>> __device__ auto tile_elementwise_in(const InElementFunc& in_element_func, const InDstrTensors&... in_dstr_tensors) { @@ -52,5 +59,102 @@ __device__ auto tile_elementwise_in(const InElementFunc& in_element_func, return out_dstr_tensor; } +template +__device__ void set_tile(DstrTensors& dstr_tensor, const T& value) +{ + tile_elementwise_inout( + [&value](auto& x) { + x = type_convert>(value); + }, + dstr_tensor); +} + +template +__device__ void clear_tile(DstrTensors& dstr_tensor) +{ + set_tile(dstr_tensor, 0); +} + +template +__device__ auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // This API is designed to use the _pk_ serious of function + constexpr auto in_tile_dstr = InDstrTensors::GetTileDistribution(); + + constexpr index_t thread_buffer_size = InDstrTensors::GetThreadBufferSize(); + static_assert(thread_buffer_size % 4 == 0); + constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4; + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wuninitialized" + // __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and + // will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA + // so we prepare an uninitialized variable purposely, and turn off the warning + int dummy_old; + static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) { + uint32_t x = + __builtin_amdgcn_cvt_pk_fp8_f32(in_dstr_tensors.GetThreadBuffer()[Number<4 * i + 0>{}], + in_dstr_tensors.GetThreadBuffer()[Number<4 * i + 1>{}], + dummy_old, + false); // false -> WORD0 + + uint32_t y = + __builtin_amdgcn_cvt_pk_fp8_f32(in_dstr_tensors.GetThreadBuffer()[Number<4 * i + 2>{}], + in_dstr_tensors.GetThreadBuffer()[Number<4 * i + 3>{}], + dummy_old, + false); // false -> WORD0 + + constexpr int32_t m0 = 0x05040100; + using vec_t = typename vector_type::type; + + vec_t d = bit_cast(__builtin_amdgcn_perm(y, x, m0)); + out_dstr_tensor.GetThreadBuffer().template SetAsType(Number<4 * i>{}, d); + }); +#pragma clang diagnostic pop + + return out_dstr_tensor; +#else + // fallback + return tile_elementwise_in(type_convert, + in_dstr_tensors); +#endif +} + +template +__device__ auto cast_tile(const SrcDstrTensors& src_tensor) +{ + if constexpr((ck::is_same_v || + ck::is_same_v)&&ck::is_same_v && + (SrcDstrTensors::GetThreadBufferSize() % 4 == 0)) + { + return cast_tile_pk_fp8x4(src_tensor); + } + else + return tile_elementwise_in(type_convert, + src_tensor); +} + +// no-op function for NullTensor arguments +template , NullTensor>...>>> +__device__ void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...) +{ +} + +// no-op function for NullTensor arguments +template , NullTensor>...>>> +__device__ auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...) +{ + return NullTensor{}; +} + } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/tile/tile_fmha_shape.hpp b/include/ck/tile_program/tile/tile_fmha_shape.hpp index c577a3943..88d3c0a2b 100644 --- a/include/ck/tile_program/tile/tile_fmha_shape.hpp +++ b/include/ck/tile_program/tile/tile_fmha_shape.hpp @@ -5,6 +5,8 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/math.hpp" namespace ck { namespace tile_program { @@ -14,7 +16,7 @@ template + bool IsVLayoutRowMajor_> struct TileFmhaShape { using BlockTile = remove_cvref_t; @@ -23,6 +25,12 @@ struct TileFmhaShape using Gemm1BlockWarps = remove_cvref_t; using Gemm1WarpTile = remove_cvref_t; + static constexpr index_t NumWarps = + reduce_on_sequence(Gemm0BlockWarps{}, math::multiplies{}, Number<1>{}); + + static_assert(NumWarps == + reduce_on_sequence(Gemm1BlockWarps{}, math::multiplies{}, Number<1>{})); + static constexpr index_t kM0 = BlockTile::At(Number<0>{}); // tile size along q seqlen static constexpr index_t kN0 = BlockTile::At(Number<1>{}); // tile size along k seqlen static constexpr index_t kK0 = BlockTile::At(Number<2>{}); // tile size along qk gemm unroll @@ -31,8 +39,13 @@ struct TileFmhaShape static constexpr index_t kK0BlockLength = BlockTile::At(Number<5>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) + static_assert(kK0BlockLength % kK0 == 0, "kK0BlockLength should be divisible by kK0"); - using VLayout = remove_cvref_t; // rowmajor : seqlen*hdim, colmajor : hdim*seqlen + // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen + static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; + using VLayout = std::conditional_t; }; } // namespace tile_program diff --git a/include/ck/tile_program/tile/tile_fmha_traits.hpp b/include/ck/tile_program/tile/tile_fmha_traits.hpp new file mode 100644 index 000000000..2bced1feb --- /dev/null +++ b/include/ck/tile_program/tile/tile_fmha_traits.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { +namespace tile_program { + +template +struct TileFmhaTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/tile_window.hpp b/include/ck/tile_program/tile/tile_window.hpp index a07b52800..d3df06d24 100644 --- a/include/ck/tile_program/tile/tile_window.hpp +++ b/include/ck/tile_program/tile/tile_window.hpp @@ -7,6 +7,7 @@ #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor_coordinate.hpp" +#include "ck/tile_program/tile/null_tile_window.hpp" #include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/tile_window_impl_static_distribution.hpp" #include "ck/tile_program/tile/tile_window_impl_static_lengths.hpp" diff --git a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp index f67fa6a28..5e67dd4ff 100644 --- a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp +++ b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp @@ -261,7 +261,10 @@ struct TileWindowWithStaticDistribution get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); } - __device__ auto Load() const + __device__ constexpr auto GetNumAccess() const { return LoadStoreTraits::NumAccess; } + + template + __device__ auto Load(bool_constant = {}) const { using Traits = LoadStoreTraits; @@ -288,7 +291,7 @@ struct TileWindowWithStaticDistribution // read from bottom tensor const vector_t vec_value = GetBottomTensorView().template GetVectorizedElements( - bottom_tensor_thread_coord); + bottom_tensor_thread_coord, bool_constant{}); const vector_type_t vec{vec_value}; @@ -324,6 +327,76 @@ struct TileWindowWithStaticDistribution return dst_tensor; } + // TODO: currently async load only implemented in inline asm + template + __device__ auto AsyncLoad(LdsTileWindow_&& lds_tile, bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + // using LdsTensorView = typename LdsTileWindow::BottomTensorView; + using LdsDataType = typename LdsTileWindow::DataType; + // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc; + + // issues * warps * lanes + static_assert(LdsTileWindow::GetNumOfDimension() == 3); // TODO: hard coded + + const index_t size_per_buf = + lds_tile.GetBottomTensorView().GetTensorDescriptor().CalculateOffset( + make_tuple(Number<0>{}, Number<0>{}, Number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.GetBottomTensorView().GetTensorDescriptor().CalculateOffset( + make_tuple(Number<0>{}, Number<1>{}, Number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.GetBottomTensorView().GetTensorDescriptor().CalculateOffset( + make_tuple(Number<1>{}, Number<0>{}, Number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + m0_set_with_memory(m0_init_value); // This should be wave independent + + using Traits = LoadStoreTraits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + + LdsDataType* smem = lds_tile.GetBottomTensorView().GetBufferView().p_data_; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + // TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = Number{}; + + // read from bottom tensor + GetBottomTensorView().template AsyncGetVectorizedElements( + smem, bottom_tensor_thread_coord); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(Array{0}, idx_diff_ys); + + MoveWindowAdaptorAndBottomTensorThreadCoordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + m0_inc_with_memory(size_per_issue); + } + }); + }); + } + __device__ void Store(const StaticDistributedTensor& dstr_tensor) const { using Traits = LoadStoreTraits; diff --git a/include/ck/tile_program/warp_tile/warp_gemm.hpp b/include/ck/tile_program/warp_tile/warp_gemm.hpp index f08e24631..112b9d622 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm.hpp @@ -81,6 +81,31 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; +// fp8 +using WarpGemmMfma_f32_32x32x16_fp8_fp8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_fp8_bf8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_bf8_fp8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_bf8_bf8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + } // namespace warp } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp index 6e98e9115..85cabac37 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp @@ -207,6 +207,67 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution } }; +template +struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::BDataType; + using BDataType = typename Impl::ADataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::BVecType; + using BVecType = typename Impl::AVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK; + + using AWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<2>, + Sequence<1>>; + + using BWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2>, + Sequence<1>>; + + using CWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2, 2>, + Sequence<0, 2>>; + + // c_vec += a_vec * b_vec + __device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + // swap A and B + Impl{}(c_vec, b_vec, a_vec); + } + + // c_vec = a_vec * b_vec + __device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + // swap A and B + return Impl{}(b_vec, a_vec); + } +}; + template struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution { @@ -287,7 +348,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution } }; -template +template struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB { using Impl = remove_cvref_t; @@ -301,9 +362,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB using BVecType = typename vector_type_maker::type::type; using CVecType = typename Impl::CVecType; - static constexpr index_t kM = Impl::kN; - static constexpr index_t kN = Impl::kM; - static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together using AWarpDstrEncoding = StaticTileDistributionEncoding< Sequence<>, @@ -312,7 +374,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB Tuple>, Sequence<2>, Sequence<1>>; - +#if 0 using BWarpDstrEncoding = StaticTileDistributionEncoding< Sequence<>, Tuple>, Sequence<2, 2>, Sequence<0, 2>>; +#else + using BWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2>, + Sequence<1>>; + using CWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2, 2>, + Sequence<0, 2>>; +#endif // c_vec += a_vec * b_vec __device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { diff --git a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp index 72431c802..0a2badda6 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp @@ -159,6 +159,88 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 } }; +// FP8 +template +struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base +{ + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = typename vector_type::type; + using BVecType = typename vector_type::type; + using CVecType = typename vector_type::type; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + __device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); +#else + vector_type a_(a_vec); + vector_type b_(b_vec); + + static_for<0, 8, 1>{}([&](auto k) { + float a_f32 = type_convert(a_.template AsType()[Number{}]); + float b_f32 = type_convert(b_.template AsType()[Number{}]); + + c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); + }); +#endif + } + + // c_vec = a_vec * b_vec + __device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + } +}; + +using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + } // namespace warp } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp b/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp index 68f2255b5..2c9e3089d 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp @@ -40,6 +40,17 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; + +// fp8 +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; + // clang-format on } // namespace impl diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 952174c7d..5220d8dd3 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -49,6 +49,117 @@ __device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) return wave_buffer_resource.content; } +// TODO: glc/slc/... +template +struct buffer_load; + +template <> +struct buffer_load<16> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 16); + // using dummy_vector = vector_type; + // using dummy_vector = StaticallyIndexedArray; + asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" + // : "+v"(reinterpret_cast(value)) + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<8> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 8); + // using dummy_vector = vector_type; + using dummy_vector = float __attribute__((ext_vector_type(2))); + asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<4> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 4); + // using dummy_vector = vector_type; + using dummy_vector = float __attribute__((ext_vector_type(1))); + asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<2> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 2); + asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<1> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 1); + asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +__device__ void buffer_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + // buffer load i8 __device__ int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, @@ -286,6 +397,24 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); +__device__ void async_buffer_load_dword(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0) +{ + asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) + : "memory"); +} + +__device__ void async_buffer_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + // memory coherency bit for buffer store/load instruction // check ISA manual for each GFX target // e.g. for @@ -402,9 +531,14 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, } } +#ifndef BUFFER_LOAD_USE_INLINEASM +#define BUFFER_LOAD_USE_INLINEASM 0 +#endif + template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + bool use_inline_asm = false> __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) @@ -420,7 +554,15 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) // fp32 + if constexpr(use_inline_asm) + { + using type = typename vector_type::type; + type tmp; + buffer_load{}( + tmp, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return tmp; + } + else if constexpr(is_same::value) // fp32 { if constexpr(N == 1) { @@ -461,6 +603,36 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w return tmp.AsType()(Number<0>{}); } + else if constexpr(N == 16) + { + vector_type tmp; + + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + tmp.AsType()(Number<2>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(float), + static_cast(coherence)); + + tmp.AsType()(Number<3>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(float), + static_cast(coherence)); + + return tmp.AsType()(Number<0>{}); + } } else if constexpr(is_same::value) // fp16 { @@ -540,6 +712,52 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w } } +template +__device__ void amd_buffer_load_raw_impl(typename vector_type::type& dst, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); +#if BUFFER_LOAD_USE_INLINEASM + using type = typename vector_type::type; + buffer_load{}( + dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + (void)dst; + (void)src_wave_buffer_resource; + (void)src_thread_addr_offset; + (void)src_wave_addr_offset; +#endif +} + +template +__device__ void amd_async_buffer_load_impl(T* smem, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset, + index_t src_immediate_addr_offset = 0) +{ + static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + + async_buffer_load_dword(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset); +} + template __device__ void amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, @@ -1031,7 +1249,8 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type::typ // It is user's responsibility to make sure that is true. template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + bool use_inline_asm = false> __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, @@ -1050,12 +1269,12 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; - return amd_buffer_load_impl( + return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(0); #endif @@ -1067,7 +1286,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, // It is user's responsibility to make sure that is true. template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + bool use_inline_asm> __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, @@ -1085,12 +1305,55 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(customized_value); } +template +__device__ void amd_buffer_load_raw(typename vector_type_maker::type::type& dst, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + + amd_buffer_load_raw_impl( + dst, src_wave_buffer_resource, src_thread_addr_offset, 0); +} + +// unfortunately async copy can not make sure invalid data is zero inside LDS +// ... unless people manually write zero to LDS at the proper address. +// so not support invalid_element check for now. +// buffer_load OOB still working. +template +__device__ void amd_async_buffer_load_with_oob(T* smem, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_async_buffer_load_impl( + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); +} + // buffer_store requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 43baa817d..2a43e2b57 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -367,5 +367,17 @@ __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, f #endif } +// TODO: we have "memory" clobber here because this inline asm is used for async copy +__device__ void m0_set_with_memory(index_t v) +{ + asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory"); +} + +// NOTE: this is an immediate value +__device__ void m0_inc_with_memory(index_t v) +{ + asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory"); +} + } // namespace ck #endif diff --git a/include/ck/utility/buffer_view_declare.hpp b/include/ck/utility/buffer_view_declare.hpp index 747f1ab63..42b31954d 100644 --- a/include/ck/utility/buffer_view_declare.hpp +++ b/include/ck/utility/buffer_view_declare.hpp @@ -5,9 +5,11 @@ #pragma once #include "ck/ck.hpp" +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/amd_buffer_addressing.hpp" +#include "ck/utility/c_style_pointer_cast.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/enable_if.hpp" -#include "ck/utility/c_style_pointer_cast.hpp" namespace ck { diff --git a/include/ck/utility/buffer_view_impl_generic.hpp b/include/ck/utility/buffer_view_impl_generic.hpp index 78c7b8e9a..1b88bf5c4 100644 --- a/include/ck/utility/buffer_view_impl_generic.hpp +++ b/include/ck/utility/buffer_view_impl_generic.hpp @@ -60,10 +60,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> - __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __device__ constexpr auto + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; diff --git a/include/ck/utility/buffer_view_impl_global.hpp b/include/ck/utility/buffer_view_impl_global.hpp index f8d716ca5..621509407 100644 --- a/include/ck/utility/buffer_view_impl_global.hpp +++ b/include/ck/utility/buffer_view_impl_global.hpp @@ -63,10 +63,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> - __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __device__ constexpr auto + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -90,14 +92,16 @@ struct BufferView, t_per_x, - Coherence>( + Coherence, + use_inline_asm>( p_data_, i, is_valid_element, buffer_size_); } else { return amd_buffer_load_invalid_element_return_customized_value, t_per_x, - Coherence>( + Coherence, + use_inline_asm>( p_data_, i, is_valid_element, buffer_size_, invalid_element_value_); } } @@ -129,6 +133,46 @@ struct BufferView>::type, + typename scalar_type>::type>::value, + bool>::type = false> + __device__ constexpr auto GetRaw(remove_cvref_t& dst, index_t i) const + { + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_load_raw, t_per_x, Coherence>(dst, p_data_, i, buffer_size_); + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __device__ constexpr auto + AsyncGet(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const + { + // X is vector of T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_async_buffer_load_with_oob, t_per_x, Coherence>( + smem, p_data_, i, buffer_size_); + } + // i is offset of T, not X. i should be aligned to X template >::type, typename scalar_type>::type>::value, bool>::type = false> - __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __device__ constexpr auto + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; diff --git a/include/ck/utility/buffer_view_impl_vgpr.hpp b/include/ck/utility/buffer_view_impl_vgpr.hpp index 4c3e94884..15bdf1354 100644 --- a/include/ck/utility/buffer_view_impl_vgpr.hpp +++ b/include/ck/utility/buffer_view_impl_vgpr.hpp @@ -60,10 +60,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> - __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __device__ constexpr auto + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 58ebac9d8..31e3f5140 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" #include "ck/utility/static_assert.hpp" +#include "ck/utility/static_switch.hpp" #include "ck/utility/remove_cvref.hpp" #include "ck/utility/is_static.hpp" #include "ck/utility/bit_cast.hpp" diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index 91797d240..e1bab6f59 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -3,6 +3,8 @@ #pragma once +#include + #include "ck/utility/integral_constant.hpp" #include "ck/utility/type.hpp" @@ -128,4 +130,13 @@ constexpr auto conditional_expr(X&& x, Y&& y) } } +struct identity +{ + template + __host__ __device__ constexpr T&& operator()(T&& arg) const noexcept + { + return std::forward(arg); + } +}; + } // namespace ck diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index 65efaf388..c1943520c 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -152,6 +152,48 @@ __device__ void inner_product(const half8_t& a, const h c); } +template <> +__device__ void +inner_product(const bhalf2_t& a, const bhalf2_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const bhalf4_t& a, const bhalf4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + template <> __device__ void inner_product(const int8_t& a, const int8_t& b, int32_t& c) { diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 04c07bc4c..9e8d3771b 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -50,4 +50,7 @@ __host__ __device__ constexpr auto operator%(integral_constant, integral_ return integral_constant{}; } +template +using bool_constant = integral_constant; + } // namespace ck diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index b8e7380f0..e2f12216f 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -45,7 +45,7 @@ struct MagicDivision32BitRange } // integral_constant - template + template > __host__ __device__ static constexpr auto CalculateMagicNumbers(integral_constant) { diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index e654f7dfd..c4039bbcb 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -251,5 +251,27 @@ __host__ __device__ constexpr bool is_power_of_two_integer(int32_t x) return x == (1 << integer_log2_floor(x)); } +#ifndef C_LOG2E +#define C_LOG2E 1.44269504088896340736 // log2(e) +#endif + +template +struct log2e; + +template <> +struct log2e +{ + static constexpr double value = C_LOG2E; +}; + +template <> +struct log2e +{ + static constexpr float value = C_LOG2E; +}; + +template +inline constexpr T log2e_v = log2e::value; + } // namespace math } // namespace ck diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 082fa7baa..594123097 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -312,11 +312,13 @@ inline __device__ float exp(float x) return __expf(x); } +/* template <> inline __device__ half_t exp(half_t x) { return hexp(x); }; +*/ template <> inline __device__ double exp(double x) @@ -346,11 +348,13 @@ inline __device__ T log(T x) return ck::type_convert(__logf(ck::type_convert(x))); }; +/* template <> inline __device__ half_t log(half_t x) { return hlog(x); }; +*/ template <> inline __device__ float log(float x) diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index 0ccebd476..f79e774a8 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -14,6 +14,8 @@ template // TODO remove this bool, no longer needed struct StaticBuffer : public StaticallyIndexedArray, N> { + static_assert(0 < N, "StaticBuffer should not be empty"); + using S = remove_cvref_t; using type = S; using base = StaticallyIndexedArray; diff --git a/include/ck/utility/static_switch.hpp b/include/ck/utility/static_switch.hpp new file mode 100644 index 000000000..9ddfed6a0 --- /dev/null +++ b/include/ck/utility/static_switch.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#define BOOL_SWITCH(COND1, CONST_NAME1, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + __VA_ARGS__(); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + __VA_ARGS__(); \ + } \ + }() + +#define BOOL_SWITCH_2(COND1, CONST_NAME1, COND2, CONST_NAME2, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \ + } \ + }() + +#define BOOL_SWITCH_3(COND1, CONST_NAME1, COND2, CONST_NAME2, COND3, CONST_NAME3, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \ + } \ + }() + +#define BOOL_SWITCH_4( \ + COND1, CONST_NAME1, COND2, CONST_NAME2, COND3, CONST_NAME3, COND4, CONST_NAME4, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH_3( \ + COND2, CONST_NAME2, COND3, CONST_NAME3, COND4, CONST_NAME4, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH_3( \ + COND2, CONST_NAME2, COND3, CONST_NAME3, COND4, CONST_NAME4, ##__VA_ARGS__); \ + } \ + }() diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index eb2995872..cf8b0229d 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -171,14 +171,14 @@ struct Tuple : detail::TupleImpl + template > __host__ __device__ constexpr const auto& operator[](Number i) const { return At(i); } // write access - template + template > __host__ __device__ constexpr auto& operator()(Number i) { return At(i); diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index a3df884ee..33bc06ea0 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -31,9 +31,10 @@ typename std::enable_if< bool>::type check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-5, - double atol = 3e-6) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 3e-6, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -42,6 +43,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -51,7 +59,7 @@ check_err(const Range& out, const double o = *std::next(std::begin(out), i); const double r = *std::next(std::begin(ref), i); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; @@ -81,9 +89,10 @@ typename std::enable_if< bool>::type check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -92,6 +101,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -102,7 +118,7 @@ check_err(const Range& out, const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; @@ -132,9 +148,10 @@ typename std::enable_if< bool>::type check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -143,6 +160,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -152,7 +176,7 @@ check_err(const Range& out, const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; @@ -236,9 +260,10 @@ std::enable_if_t<(std::is_same_v, ranges::range_val bool> check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -247,6 +272,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -256,7 +288,7 @@ check_err(const Range& out, const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; @@ -281,9 +313,10 @@ std::enable_if_t<(std::is_same_v, ranges::range_val bool> check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -292,6 +325,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -301,7 +341,7 @@ check_err(const Range& out, const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index 4e075df43..e6666fa51 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -20,11 +21,12 @@ struct FillUniformDistribution { float a_{-5.f}; float b_{5.f}; + std::optional seed_{11939}; template void operator()(ForwardIter first, ForwardIter last) const { - std::mt19937 gen(11939); + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); } @@ -40,6 +42,32 @@ struct FillUniformDistribution } }; +template +struct FillNormalDistribution +{ + float mean_{0.f}; + float variance_{1.f}; + std::optional seed_{11939}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::normal_distribution dis(mean_, std::sqrt(variance_)); + std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + // Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. // However this produces segfaults in std::mt19937 which look like inifite loop. // template @@ -64,11 +92,12 @@ struct FillUniformDistributionIntegerValue { float a_{-5.f}; float b_{5.f}; + std::optional seed_{11939}; template void operator()(ForwardIter first, ForwardIter last) const { - std::mt19937 gen(11939); + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate( first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); @@ -85,6 +114,33 @@ struct FillUniformDistributionIntegerValue } }; +template +struct FillNormalDistributionIntegerValue +{ + float mean_{0.f}; + float variance_{1.f}; + std::optional seed_{11939}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::normal_distribution dis(mean_, std::sqrt(variance_)); + std::generate( + first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + template struct FillMonotonicSeq { @@ -133,5 +189,46 @@ struct FillConstant } }; +template +struct FillTrigValue +{ + template + struct LinearTrigGen + { + int i{0}; + auto operator()() + { + float v = 0; + if constexpr(UseCos_) + { + v = cos(i); + } + else + { + v = sin(i); + } + if constexpr(UseAbs_) + v = abs(v); + i++; + return static_cast(v); + } + }; + template + void operator()(ForwardIter first, ForwardIter last) const + { + LinearTrigGen gen; + std::generate(first, last, gen); + } + + template + auto operator()(ForwardRange&& range) const -> std::void_t< + decltype(std::declval()(std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + } // namespace utils } // namespace ck diff --git a/library/src/utility/device_memory.cpp b/library/src/utility/device_memory.cpp index 61b6326b5..c95619ecd 100644 --- a/library/src/utility/device_memory.cpp +++ b/library/src/utility/device_memory.cpp @@ -1,23 +1,25 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include + #include "ck/host_utility/hip_check_error.hpp" #include "ck/library/utility/device_memory.hpp" DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) { - hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); } void DeviceMem::Realloc(std::size_t mem_size) { if(mpDeviceBuf) { - hip_check_error(hipFree(mpDeviceBuf)); + HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); } mMemSize = mem_size; - hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); } void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; } @@ -28,7 +30,7 @@ void DeviceMem::ToDevice(const void* p) const { if(mpDeviceBuf) { - hip_check_error( + HIP_CHECK_ERROR( hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); } else @@ -39,14 +41,14 @@ void DeviceMem::ToDevice(const void* p) const void DeviceMem::ToDevice(const void* p, const std::size_t cpySize) const { - hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); } void DeviceMem::FromDevice(void* p) const { if(mpDeviceBuf) { - hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); } else { @@ -56,14 +58,14 @@ void DeviceMem::FromDevice(void* p) const void DeviceMem::FromDevice(void* p, const std::size_t cpySize) const { - hip_check_error(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); } void DeviceMem::SetZero() const { if(mpDeviceBuf) { - hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); + HIP_CHECK_ERROR(hipMemset(mpDeviceBuf, 0, mMemSize)); } } @@ -71,6 +73,13 @@ DeviceMem::~DeviceMem() { if(mpDeviceBuf) { - hip_check_error(hipFree(mpDeviceBuf)); + try + { + HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); + } + catch(std::runtime_error& re) + { + std::cerr << re.what() << std::endl; + } } } diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 51d6f7a30..b85ed239c 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -5,13 +5,19 @@ rm -rf CMakeFiles MY_PROJECT_SOURCE=$1 +if [ $# -ge 2 ] ; then + GPU_TARGETS=$2 +else + GPU_TARGETS="gfx908;gfx90a;gfx940" +fi + cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ --D GPU_TARGETS="gfx908;gfx90a;gfx940" \ +-D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE} diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index 787eabbf9..25ccb5c79 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -5,13 +5,19 @@ rm -rf CMakeFiles MY_PROJECT_SOURCE=$1 +if [ $# -ge 2 ] ; then + GPU_TARGETS=$2 +else + GPU_TARGETS="gfx908;gfx90a;gfx940" +fi + cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_FLAGS="-O3" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=OFF \ --D GPU_TARGETS="gfx908;gfx90a;gfx940" \ +-D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE}