diff --git a/example/ck_tile/42_unified_attention/CMakeLists.txt b/example/ck_tile/42_unified_attention/CMakeLists.txt new file mode 100644 index 00000000000..45f67f3e0d6 --- /dev/null +++ b/example/ck_tile/42_unified_attention/CMakeLists.txt @@ -0,0 +1,228 @@ +# Commented out: FMHA fwd/bwd instance generation and codegen commands not used by unified_attention +# +# set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) +# # Currently only gfx9 archs are supported by FMHA +# list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") +# if(NOT INST_TARGETS) +# message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +# return() +# endif() +# +# # validate user-specified fmha_fwd API list +# set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") +# set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING +# "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") +# if(BUILD_TESTING) +# # Build instances of all APIs for tests +# set(FMHA_FWD_ENABLE_APIS "all") +# endif() +# if(FMHA_FWD_ENABLE_APIS STREQUAL "all") +# set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) +# endif() +# +# foreach(api ${FMHA_FWD_ENABLE_APIS}) +# if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) +# message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.") +# endif() +# endforeach() +# +# # "fwd" is a must-have api for the fmha_fwd example, add it if not specified +# if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) +# list(PREPEND FMHA_FWD_ENABLE_APIS "fwd") +# endif() +# +# file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS +# ${CMAKE_CURRENT_LIST_DIR}/generate.py +# ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py +# ) +# set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") +# +# string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +# set(FMHA_FWD_CODE_GEN_COMMON_ARGS +# ${CMAKE_CURRENT_LIST_DIR}/generate.py +# --api ${FMHA_FWD_APIS} +# --optdim 32,64,128,256 +# ) +# set(FMHA_BWD_CODE_GEN_COMMON_ARGS +# ${CMAKE_CURRENT_LIST_DIR}/generate.py +# --api bwd +# --receipt 3 +# --optdim 32,64,96,128,256 +# ) +# +# if(BUILD_TESTING) +# list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) +# endif() +# +# execute_process( +# COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} +# --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt +# RESULT_VARIABLE ret +# ) +# if(ret AND NOT ret EQUAL 0) +# message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") +# endif() +# +# execute_process( +# COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} +# --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt +# RESULT_VARIABLE ret +# ) +# if(ret AND NOT ret EQUAL 0) +# message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.") +# endif() +# +# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) +# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) +# +# add_custom_command( +# OUTPUT ${FMHA_FWD_GEN_BLOBS} +# COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} +# --output_dir ${CMAKE_CURRENT_BINARY_DIR} +# DEPENDS ${CODE_GEN_SCRIPTS} +# ) +# +# add_custom_command( +# OUTPUT ${FMHA_BWD_GEN_BLOBS} +# COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} +# --output_dir ${CMAKE_CURRENT_BINARY_DIR} +# DEPENDS ${CODE_GEN_SCRIPTS} +# ) +# +# set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") +# set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") +# +# message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}") +# add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) +# target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +# target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +# set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +# set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) +# +# message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}") +# add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) +# target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +# target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS}) +# set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +# set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) +# +# set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS) +# set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS) +# set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS) +# set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS) +# +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) +# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) +# +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) +# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) +# +# if(NOT DEFINED FMHA_FWD_FAST_EXP2) +# set(FMHA_FWD_FAST_EXP2 ON) +# endif() +# +# if(FMHA_FWD_FAST_EXP2) +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) +# else() +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0) +# endif() +# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero) +# +# if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) +# else() +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) +# endif() +# +# if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS) +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) +# else() +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) +# endif() +# +# if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) +# else() +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) +# endif() +# +# if(CK_USE_OCP_FP8) +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +# endif() +# +# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +# list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +# +# target_compile_options(${FMHA_FWD_INSTANCES} +# PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS} +# INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS}) +# target_compile_options(${FMHA_BWD_INSTANCES} +# PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS} +# INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS}) +# +# set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") +# set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") +# +# message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") +# add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp) +# target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES}) +# target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +# +# message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") +# add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) +# target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) +# target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +# +# set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) + +# --- Unified Attention target (kept) --- + +# +set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) +# Currently only gfx9 archs are supported by FMHA +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") +if(NOT INST_TARGETS) + message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() +endif() + +set(EXAMPLE_UNIFIED_ATTENTION "tile_example_unified_attention") +message(DEBUG "adding example ${EXAMPLE_UNIFIED_ATTENTION}") + +add_executable(${EXAMPLE_UNIFIED_ATTENTION} EXCLUDE_FROM_ALL example_unified_attention.cpp) +target_include_directories(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +file(GLOB UNIFIED_ATTENTION_INSTANCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" +) +target_sources(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE + unified_attention.cpp + ${UNIFIED_ATTENTION_INSTANCES} +) + +set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS) +list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS + -fgpu-flush-denormals-to-zero + -Wno-undefined-func-template + --save-temps +) +set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS) + +check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) +if(HAS_DISABLE_PACKED_FP32) + list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS + -mllvm --amdgpu-disable-packed-fp32=1 + ) + list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS + -DCK_TILE_DISABLE_PACKED_FP32=1 + ) +endif() + +target_compile_options(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS}) +target_compile_definitions(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/example/ck_tile/42_unified_attention/README.md b/example/ck_tile/42_unified_attention/README.md new file mode 100644 index 00000000000..bcc10f7e657 --- /dev/null +++ b/example/ck_tile/42_unified_attention/README.md @@ -0,0 +1,161 @@ +# fused multi-head attention + +This folder contains examples for unified attention (fused multi-head attention) using the ck_tile tile-programming implementation. The examples demonstrate the usage of the tile-programming API, as well as the new approach to constructing kernel templates and instantiating them. + +## build + +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +../script/cmake-ck-dev.sh ../ +make tile_example_unified_attention -j +``` +This will result in an executable `build/bin/tile_example_unified_attention` + +## kernel + +The kernel template is `unified_attention.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. + +There are 2 template parameters for this kernel template. + +* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). +* `EpiloguePipeline` is the last stage of the pipeline. It modifies and stores the result. Post-fusion can be done at this stage though the example only returns the result. + +## codegen +To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. + +## executable +`tile_example_unified_attention` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_unified_attention -?` to list all the arguments. Below is an example of the output (may subject to change) +``` +args: + -v weather do CPU validation or not (default:1) + -mode kernel mode. 0:batch, 1:group (default:0) + -b batch size (default:2) + -h num of head, for q (default:8) + -h_k num of head, for k/v, -1 means equal to h (default:-1) + if not equal to h, then this is GQA/MQA case + -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) + total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary + also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) + -s_k seqlen_k (including new key/value), -1 means equal to s (default:-1) + also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode) + -s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1) + Provide positive strides per-batch to simulate physical padding on Q + -s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1) + for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride + along seqlen, instead of packed, same as xformer kv_padding, + must be greater than or equal to s_k + -d head dim for q, k (default:128) + -d_v head dim for v, -1 means equal to d (default:-1) + -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) + note when squant=1, this value will be modified by range_q/k + -range_q per-tensor quantization range of q. used if squant=1. (default:16) + -range_k per-tensor quantization range of k. used if squant=1. (default:16) + -range_v per-tensor quantization range of v. used if squant=1. (default:16) + -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) + -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) + -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto) + 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. + calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o + -iperm permute input (default:1) + if true, will be b*h*s*d, else b*s*h*d + -operm permute output (default:1) + -bias n or 0, no bias (default:n) + e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s + a(libi) or 2, alibi with 1*h. a:1, b*h + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) + 't', top-left causal mask, 'b', bottom-r causal mask + 't:l,r', top-left sliding window attn(swa) with FA style left right size + 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size + 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa + 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa + 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) + -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) + -lse 0 not store lse, 1 store lse (default:0) + -kname if set to 1 will print kernel name (default:0) + -init init method. ui, uniform random int, ni, normalized random int (default:uf) + uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization + -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) + -drop_seed seed for random number generator (default:1) +-drop_offset offset for random number generator (default:0) + -drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0) + -num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1) + -warmup number of iterations before benchmark the kernel (default:5) + -repeat number of iterations to benchmark the kernel (default:20) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:fmha_fwd.json) + -q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override +-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override +``` +Example 1: `./bin/tile_example_unified_attention -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 2: `./bin/tile_example_unified_attention -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with + batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case + +## Padding Examples +Example 3 (Group mode with padding): `./bin/tile_example_unified_attention -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively. + +Example 4 (Batch mode with effective lengths): `./bin/tile_example_unified_attention -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively. + +## support features +Currently we are still in rapid development stage, so more features/optimizations will be coming soon. + +### hdim +Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default) + +### group/batch mode +Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below) + +### MQA/GQA +By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers. + +### input/output permute, and `b*s*3*h*d` +If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`. + +### attention bias +Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. + +### alibi +alibi is supported + +### lse +For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` + +### vlayout +We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts. + +### attention mask +we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right. +Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out. +![](misc/gamc.png) + +Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline. + +| mask case| cmdline | FA style | xformer style | +|----------|:-------------:|:-------------:|:-------------:| +| no mask | `-mask=0`(default) | | | +| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` | +| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` | +| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` | +| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` | + +Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right. + +### dropout +TBD + +### sequence padding and variable length support +We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths. + +**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment. + +**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste. + +Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. + +## FP8 experimental support +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_unified_attention`, on a gfx942 machine and ROCm 6.0+. + +Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. diff --git a/example/ck_tile/42_unified_attention/example_unified_attention.cpp b/example/ck_tile/42_unified_attention/example_unified_attention.cpp new file mode 100644 index 00000000000..e43a8df76ef --- /dev/null +++ b/example/ck_tile/42_unified_attention/example_unified_attention.cpp @@ -0,0 +1,680 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "unified_attention.hpp" +#include "mask.hpp" + +// const ck_tile::index_t page_blk_size = 32; +const ck_tile::index_t num_queries_per_kv = 1; + +auto parse_cmd_args(int argc, char* argv[]) -> std::pair +{ + ck_tile::ArgParser arg_parser; + arg_parser + .insert("prec", "bf16", "data type. fp16/bf16") + // .insert("b", "3", "batch size") + .insert("h_k", + "8", + "num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) + + " times this") + .insert("s", "3328", "max seqlen_q") + .insert("s_k", "-1", "max seqlen_k, -1 means equal to s") + .insert("nb", "1024", "num_blks") + .insert("b", "3", "batch") + .insert("d", "128", "head dim for q & k") + .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") + // TODO scale factors + .insert("scale", "1", "") + .insert("scale_k", "1", "") + .insert("scale_v", "1", "") + .insert("scale_out", "1", "") + .insert("iperm", + "0", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "0", "permute output") + .insert("causal", "0", "0: no mask, 1: causal mask") + .insert("verify", "1", "0:no verify, 1:verify") + .insert("varlen", "1", "0: fixed length, 1: variable length") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "30", "number of iterations to benchmark the kernel") + .insert("page_blk_size", "128", "page block size of kv cache") + // Optional effective seqlen override (exclude PAD) for batch mode + .insert("query_lens", + "", + "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override.") + .insert("kv_lens", + "", + "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override."); + + bool result = arg_parser.parse(argc, argv); + return std::make_pair(result, arg_parser); +} + +auto seqlen_preprocess(ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t max_seqlen_kv, + const std::vector& query_lens_input, + const std::vector& kv_lens_input, + bool varlen) -> std::pair, std::vector> +{ + // If both query_lens and kv_lens are provided, return them directly + if(!query_lens_input.empty() && !kv_lens_input.empty()) + { + return std::make_pair(query_lens_input, kv_lens_input); + } + + std::vector query_lens; + std::vector kv_lens; + + if(!varlen) + { + // Fixed length mode: fill with max seqlen + query_lens.assign(batch, max_seqlen_q); + kv_lens.assign(batch, max_seqlen_kv); + } + else + { + // Variable length mode: generate random lengths up to max + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution q_dist(1, max_seqlen_q); + std::uniform_int_distribution kv_dist(1, max_seqlen_kv); + + query_lens.resize(batch); + kv_lens.resize(batch); + + for(ck_tile::index_t i = 0; i < batch; ++i) + { + query_lens[i] = q_dist(gen); + kv_lens[i] = kv_dist(gen); + } + } + + return std::make_pair(query_lens, kv_lens); +} + +struct Problem +{ + explicit Problem(const ck_tile::ArgParser& args) + { + data_type = args.get_str("prec") == "fp16" + ? ck_tile::unified_attention_args::data_type_enum::fp16 + : ck_tile::unified_attention_args::data_type_enum::bf16; + num_blks = args.get_int("nb"); + nhead_kv = args.get_int("h_k"); + // TODO: support other GQA/MQA cases than just 4x + nhead_q = nhead_kv * num_queries_per_kv; + + ck_tile::index_t max_seqlen_q = args.get_int("s"); + ck_tile::index_t max_seqlen_kv = args.get_int("s_k"); + + if(max_seqlen_kv == -1) + { + max_seqlen_kv = max_seqlen_q; + } + + hdim = args.get_int("d"); + query_lens = args.get_int_vec("query_lens"); + kv_lens = args.get_int_vec("kv_lens"); + assert(query_lens.size() == kv_lens.size() && + "query_lens and kv_lens must have the same length b"); + batch = args.get_int("b"); + page_blk_size = args.get_int("page_blk_size"); + + bool varlen = args.get_bool("varlen"); + auto [query_lens_, kv_lens_] = + seqlen_preprocess(batch, max_seqlen_q, max_seqlen_kv, query_lens, kv_lens, varlen); + + query_lens = query_lens_; + kv_lens = kv_lens_; + batch = query_lens.size(); + + // Calculate scale_s + scale_s = args.get_float("scale_s"); + if(scale_s == 0.0f) + scale_s = 1.0f / ck_tile::sqrt(static_cast(hdim)); + + // Initialize other scales + scale = args.get_float("scale"); + scale_k = args.get_float("scale_k"); + scale_v = args.get_float("scale_v"); + num_tokens = 0; + for(const auto& len : query_lens) + { + num_tokens += len; + } + } + + std::vector get_query_shape() const { return {num_tokens, nhead_q, hdim}; } + + std::vector get_key_shape() const + { + return {num_blks, page_blk_size, nhead_kv, hdim}; + } + + std::vector get_value_shape() const + { + return {num_blks, page_blk_size, nhead_kv, hdim}; + } + + std::vector get_output_shape() const { return {num_tokens, nhead_q, hdim}; } + + ck_tile::unified_attention_args::data_type_enum data_type; + ck_tile::index_t batch; + ck_tile::index_t num_blks; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_kv; + ck_tile::index_t hdim; + ck_tile::index_t page_blk_size; + ck_tile::index_t num_tokens; + float scale_s; + float scale; + float scale_k; + float scale_v; + mask_info mask; + std::vector query_lens; + std::vector kv_lens; +}; + +struct RunConfig +{ + explicit RunConfig(const ck_tile::ArgParser& args) + { + seed = args.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + kernel_warmup = args.get_int("warmup"); + kernel_repeat = args.get_int("repeat"); + verify = args.get_bool("verify"); + } + + std::optional seed; + int kernel_warmup; + int kernel_repeat; + bool verify; +}; + +template +auto generate_qkv(const Problem& problem, + [[maybe_unused]] std::optional seed = std::nullopt) + -> std::tuple, + ck_tile::HostTensor, + ck_tile::HostTensor> +{ + ck_tile::HostTensor q(problem.get_query_shape()); + ck_tile::HostTensor k(problem.get_key_shape()); + ck_tile::HostTensor v(problem.get_value_shape()); + + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v); + + return std::make_tuple(q, k, v); +} + +namespace host { +template +CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, + const ck_tile::HostTensor& k_bshd, + const ck_tile::HostTensor& v_bshd, + // const mask_info& mask, + ck_tile::HostTensor& o_bshd, + const QElementOp& q_element_op = {}, + const KElementOp& k_element_op = {}, + const VElementOp& v_element_op = {}, + const SAccElementOp& s_acc_element_op = {}) +{ + const int batch_size = q_bshd.mDesc.get_lengths()[0]; + const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; + const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; + const int nhead_q = q_bshd.mDesc.get_lengths()[2]; + const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; + const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; + const int hdim_v = v_bshd.mDesc.get_lengths()[3]; + + const int nr = nhead_q / nhead_kv; + + ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); + ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); + ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); + ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); + ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); + // do computation for each batch + for(int b = 0; b < batch_size; ++b) + { + // copy per-batch data from input tensors + // clang-format off + q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , + idx[2]); }); + k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], + idx[0] / nr, idx[2]); }); + v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = + v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + // clang-format on + ck_tile::reference_batched_gemm( + q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); + + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + -1, 0, seqlen_q, seqlen_kv, 1, false)); + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, ck_tile::identity{}); + ck_tile::reference_batched_gemm( + p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); + + // copy resulting per-batch data to the output tensor + o_host_ref.ForEach( + [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); + } +} +} // namespace host + +template +bool run_impl(const Problem& problem, const RunConfig& run_config) +{ + auto [q, k, v] = generate_qkv(problem, run_config.seed); + + ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes()); + /// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v + ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q.data()); + k_buf.ToDevice(k.data()); + v_buf.ToDevice(v.data()); + // Ensure output buffer is zero-initialized so padded regions compare cleanly + o_buf.SetZero(); + + ck_tile::unified_attention_args args{}; + + args.scale_s = problem.scale_s; + args.data_type = problem.data_type; + args.num_seqs = problem.batch; + args.num_head_q = problem.nhead_q; + args.num_queries_per_kv = num_queries_per_kv; + args.page_blk_size = problem.page_blk_size; + args.mask_type = 2; + args.hdim = problem.hdim; + + args.num_blks = problem.num_blks; + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.query_stride_0 = problem.hdim * problem.nhead_q; + args.query_stride_1 = problem.hdim; + + args.k_ptr = k_buf.GetDeviceBuffer(); + + args.stride_k_cache_0 = problem.hdim * problem.nhead_kv * problem.page_blk_size; + args.stride_k_cache_1 = problem.hdim * problem.nhead_kv; + args.stride_k_cache_2 = problem.hdim; + args.stride_k_cache_3 = 1; + + args.v_ptr = v_buf.GetDeviceBuffer(); + args.stride_v_cache_0 = args.stride_k_cache_0; + args.stride_v_cache_1 = args.stride_k_cache_1; + args.stride_v_cache_2 = args.stride_k_cache_2; + args.stride_v_cache_3 = args.stride_k_cache_3; + + args.o_ptr = o_buf.GetDeviceBuffer(); + args.output_stride_0 = args.query_stride_0; + args.output_stride_1 = args.query_stride_1; + + // Optional cumulative seqlen overrides (exclude PAD) + auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { + std::vector eff; + if(!opt_vec.empty() && opt_vec[0] != -1) + { + eff.assign(opt_vec.begin(), opt_vec.end()); + if(eff.size() < static_cast(problem.batch)) + { + eff.resize(problem.batch, eff.back()); + } + } + else + { + eff.assign(problem.batch, fallback); + } + return eff; + }; + + const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024); + const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024); + + args.num_tokens = std::accumulate(eff_query_lens.begin(), eff_query_lens.end(), 0); + + // Calculate cumulative sums for kernel arguments if varlen is used + std::vector cu_query_lens; + + auto calculate_cumulative = [&](const std::vector& per_batch_vec, + std::vector& cum_vec) { + cum_vec.resize(per_batch_vec.size() + 1); + cum_vec[0] = 0; + for(std::size_t i = 0; i < per_batch_vec.size(); ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + }; + calculate_cumulative(eff_query_lens, cu_query_lens); + + ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size() * sizeof(ck_tile::index_t)); + + seq_lens_buf.ToDevice(eff_kv_lens.data()); + query_start_len_buf.ToDevice(cu_query_lens.data()); + + args.seq_lens_ptr = reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); + args.query_start_len_ptr = + reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); + + auto max_element = [&](const std::vector& opt_vec) { + ck_tile::index_t max = opt_vec[0]; + for(ck_tile::index_t i : opt_vec) + { + if(i > max) + { + max = i; + } + } + return max; + }; + + ck_tile::index_t max_kv_len = max_element(eff_kv_lens); + + ck_tile::index_t max_num_blocks_per_seq = + (max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size; + + // Create block_tables + ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * + sizeof(ck_tile::index_t)); + + // Allocate host memory for block_tables + std::vector block_tables_host(problem.batch * max_num_blocks_per_seq); + + // Fill block_tables with random integers between 0 and num_blocks-1 + std::mt19937 rng(run_config.seed ? *run_config.seed : std::random_device{}()); + std::uniform_int_distribution dist(0, problem.num_blks - 1); + for(size_t i = 0; i < block_tables_host.size(); ++i) + { + block_tables_host[i] = dist(rng); + } + + // Copy to device + block_tables_buf.ToDevice(block_tables_host.data()); + + // Set pointer in args + args.block_tables_ptr = + reinterpret_cast(block_tables_buf.GetDeviceBuffer()); + args.block_table_stride = max_num_blocks_per_seq; + + ck_tile::stream_config stream_config{nullptr, + true, + /*log_level=*/0, + run_config.kernel_warmup, + run_config.kernel_repeat}; + + auto [result, time] = ck_tile::unified_attention(args, stream_config); + + if(!result) + { + std::cerr << "faild to run unified_attention()" << std::endl; + return false; + } + + std::size_t flop = [&] { + long flop_result = 0; + + for(size_t b = 0; b < eff_query_lens.size(); ++b) + { + long query_lens = eff_query_lens[b]; + long kv_lens = eff_kv_lens[b]; + long valid_out_elements = 0; + + // Causal logic for valid output elements + if(query_lens > kv_lens) + { + valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2; + } + else + { + valid_out_elements = + query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2); + } + + flop_result += 2 * problem.nhead_q * valid_out_elements * (problem.hdim + problem.hdim); + } + return flop_result; + }(); + // TODO fix this + // std::size_t flop = 1; + float tflops = static_cast(flop) / 1.e9 / time; + long mem = 0; + + mem += problem.num_tokens * problem.nhead_q * problem.hdim * 2 * 2; // q and o, fp16 + // Count unique block indices used in block_tables_host + std::unordered_set unique_blocks(block_tables_host.begin(), + block_tables_host.end()); + mem += unique_blocks.size() * problem.nhead_kv * problem.hdim * 2 * 2; // k and v, fp16 + mem += problem.batch * max_num_blocks_per_seq * 4; // int32 block table + mem += problem.batch * 4; // int32 seq_lens_ptr + + std::cout << "[" << problem.data_type << "|"; + std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv + << ", d:" << problem.hdim << ", scale_s:" << problem.scale_s << ", query_lens:["; + for(size_t i = 0; i < problem.query_lens.size(); ++i) + { + std::cout << problem.query_lens[i]; + if(i < problem.query_lens.size() - 1) + std::cout << ","; + } + std::cout << "], kv_lens:["; + for(size_t i = 0; i < problem.kv_lens.size(); ++i) + { + std::cout << problem.kv_lens[i]; + if(i < problem.kv_lens.size() - 1) + std::cout << ","; + } + std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time + << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) + << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; + + if(!run_config.verify) + { + return true; + } + + // variable lengths are provided -> compute per-batch references + // with the effective lengths; else compute a single full reference. + // Variable-length aware verification: zero-fill padded region and only compute valid part. + ck_tile::HostTensor o_ref(problem.get_output_shape()); + o_ref.SetZero(); + + for(int b = 0; b < problem.batch; ++b) + { + const ck_tile::index_t seqlen_q_eff = eff_query_lens[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_lens[b]; + + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; + + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::index_t seq_q_off = cu_query_lens[b]; + + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { + // kv cache is paged + ck_tile::index_t table_col = int(idx[1] / problem.page_blk_size); + ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col; + ck_tile::index_t block_idx = block_tables_host[block_table_offset]; + + self(idx) = k(block_idx, idx[1] % problem.page_blk_size, idx[2], idx[3]); + }); + v_b.ForEach([&](auto& self, auto idx) { + ck_tile::index_t table_col = int(idx[1] / problem.page_blk_size); + ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col; + ck_tile::index_t block_idx = block_tables_host[block_table_offset]; + + self(idx) = v(block_idx, idx[1] % problem.page_blk_size, idx[2], idx[3]); + }); + // v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + // problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.scale_s}); + + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d); + } + } + } + } + + ck_tile::HostTensor o(problem.get_output_shape()); + o_buf.FromDevice(o.data()); + + const auto [rtol, atol] = [&] { + if constexpr(std::is_same_v) + return std::make_tuple(1e-3, 1e-3); + else + return std::make_tuple(1e-2, 1e-2); + }(); + + size_t total = static_cast(problem.num_tokens) * static_cast(problem.nhead_q) * + static_cast(problem.hdim); + + size_t nonzero = 0; + + for(int tok = 0; tok < problem.num_tokens; ++tok) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + if(static_cast(o(tok, h, d)) != 0.0f) + { + nonzero++; + } + } + } + } + + float percent = + (total > 0) ? (100.0f * static_cast(nonzero) / static_cast(total)) : 0.0f; + + std::cout << "\nNon-zero elements in output tensor o: " << nonzero << " / " << total << " (" + << percent << "%)\n"; + + // std::cout << "\n=== Complete Output Tensor (o) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o(tok, h, d)) << " "; + // } + // std::cout << "\n"; + // } + // } + + // std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o_ref(tok, h, d)) << " "; + // } + // std::cout << "\n"; + // } + // } + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); +} + +int main(int argc, char* argv[]) +{ + + auto [parse_result, args] = parse_cmd_args(argc, argv); + + if(!parse_result) + { + std::cerr << "failed to parse command line arguments" << std::endl; + } + + Problem problem(args); + RunConfig run_config(args); + + const auto run = [&] { + if(problem.data_type == ck_tile::unified_attention_args::data_type_enum::fp16) + { + return run_impl(problem, run_config); + } + else + { + return run_impl(problem, run_config); + } + }; + + return !run(); +} diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp new file mode 100644 index 00000000000..72717026bc5 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp new file mode 100644 index 00000000000..391103891a9 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp new file mode 100644 index 00000000000..f2cc00f8356 --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp new file mode 100644 index 00000000000..6a2a9984d1f --- /dev/null +++ b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/mask.hpp b/example/ck_tile/42_unified_attention/mask.hpp new file mode 100644 index 00000000000..33f9bf72a9b --- /dev/null +++ b/example/ck_tile/42_unified_attention/mask.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/unified_attention.hpp" + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + tmp.seqlen_q = seqlen_q; + tmp.seqlen_k = seqlen_k; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = std::stoi(v); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else if(t == "t" || t == "b" || t == "g") + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + throw std::invalid_argument("invalid mask value: " + str); + } + ck_tile::index_t v0 = std::stoi(v.substr(0, found_1)); + ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1)); + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.type = mask_enum::window_generic; + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + } + else + { + throw std::invalid_argument("invalid mask value: " + str); + } + } + else if(str == "0") + { + tmp.type = mask_enum::no_mask; + } + else if(str == "1" || str == "t") + { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + } + else if(str == "2" || str == "b") + { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + } + else + { + throw std::invalid_argument("invalid mask value: " + str); + } + return tmp; + } + + ck_tile::index_t get_unmaskarea() const + { + if(type == mask_enum::no_mask) + return seqlen_q * seqlen_k; + ck_tile::index_t area = 0; + for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) + { + ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0)); + ck_tile::index_t x_end = std::min(i_y + x, seqlen_k); + if(x_end > x_start) + { + area += (x_end - x_start); + } + } + return area; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/42_unified_attention/misc/gamc.png b/example/ck_tile/42_unified_attention/misc/gamc.png new file mode 100644 index 00000000000..2c96951f30f Binary files /dev/null and b/example/ck_tile/42_unified_attention/misc/gamc.png differ diff --git a/example/ck_tile/42_unified_attention/rotary.hpp b/example/ck_tile/42_unified_attention/rotary.hpp new file mode 100644 index 00000000000..346f2a5e7ee --- /dev/null +++ b/example/ck_tile/42_unified_attention/rotary.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include + +// keep sync with RotaryEmbeddingEnum +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + +template +std::tuple, ck_tile::HostTensor> +generate_rotary_cos_sin(ck_tile::index_t seqlen, + ck_tile::index_t rotary_dim, + std::optional seed = std::nullopt) +{ + // return dummy tensors if we won't apply RoPE at all + if(rotary_dim <= 0) + { + ck_tile::HostTensor dummy({1, 1}); + return std::make_tuple(dummy, dummy); + } + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_real_distribution generator(0.0f, 1.0f); + + const ck_tile::index_t num_rows = seqlen * 2; + const ck_tile::index_t num_cols = rotary_dim / 2; + + using std::begin, std::end; + + ck_tile::HostTensor angle({num_rows, num_cols}); + std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; }); + + ck_tile::HostTensor cos({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) { + return ck_tile::type_convert(std::cos(origin_value)); + }); + + ck_tile::HostTensor sin({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) { + return ck_tile::type_convert(std::sin(origin_value)); + }); + + return std::make_tuple(cos, sin); +} + +template +std::tuple, ck_tile::HostTensor> +slice_rotary_cos_sin(const ck_tile::HostTensor& cos, + const ck_tile::HostTensor& sin, + ck_tile::index_t seqlen_offset, + ck_tile::index_t seqlen) +{ + assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2); + assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1)); + + assert(static_cast(seqlen_offset + seqlen) <= cos.get_length(0)); + + const ck_tile::index_t num_rows = seqlen; + const ck_tile::index_t num_cols = cos.get_length(1); + + ck_tile::HostTensor cos_pt({num_rows, num_cols}); + cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); }); + + ck_tile::HostTensor sin_pt({num_rows, num_cols}); + sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); }); + + return std::make_tuple(cos_pt, sin_pt); +} diff --git a/example/ck_tile/42_unified_attention/script/benchmark_fwd.sh b/example/ck_tile/42_unified_attention/script/benchmark_fwd.sh new file mode 100755 index 00000000000..3a3b9389002 --- /dev/null +++ b/example/ck_tile/42_unified_attention/script/benchmark_fwd.sh @@ -0,0 +1,53 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_unified_attention -type f | head -n 1)" +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 64 128 256 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done + +#Padding Benchmarks: batch mode (baseline vs low/med/high pad) +prec="fp16" +base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no pad) +$EXE $base_batch_args + +# low pad (≈90–95% effective) +$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + +# Padding Benchmarks: group mode (baseline vs low/med/high physical pad) +seqlens_q="1024,768,512,256" +seqlens_k="1024,768,512,256" +base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no physical pad) +$EXE $base_group_args + +# low physical pad +$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + +# medium physical pad +$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + +# high physical pad +$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 diff --git a/example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx90a.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx942.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx950.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx90a.txt b/example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx90a.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx942.txt b/example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx942.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx950.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/42_unified_attention/script/run_full_test.sh b/example/ck_tile/42_unified_attention/script/run_full_test.sh new file mode 100755 index 00000000000..5c2a5a4b3d0 --- /dev/null +++ b/example/ck_tile/42_unified_attention/script/run_full_test.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/ +# +# run the script as "./run_full_test.sh +# input arguments: +# environment tag : a string describing the specifics of your test environment +# branch name : name of the branch in git repo (git status | grep -e 'On branch') +# host name : $hostname +# gpu architecture: e.g., gfx90a, or gfx942, etc. + +set -euo pipefail + +#get the command line arguments: +export env_type=$1 +echo 'Environment type: ' $env_type +export branch=$2 +echo 'Branch name: ' $branch +export host_name=$3 +echo 'Host name: ' $host_name +export GPU_arch=$4 +echo 'GPU_arch: ' $GPU_arch + +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + #get GPU_arch and number of compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + +#run verification tests +time example/ck_tile/01_fmha/script/smoke_test_fwd.sh +time example/ck_tile/01_fmha/script/smoke_test_bwd.sh + +#run performance benchmarks +export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" +print_log_header $fmha_fwd_log $env_type $branch $host_name +time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log + +export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log" +print_log_header $fmha_bwd_log $env_type $branch $host_name +time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log + diff --git a/example/ck_tile/42_unified_attention/script/smoke_test_bwd.sh b/example/ck_tile/42_unified_attention/script/smoke_test_bwd.sh new file mode 100755 index 00000000000..cd51dde2d4e --- /dev/null +++ b/example/ck_tile/42_unified_attention/script/smoke_test_bwd.sh @@ -0,0 +1,90 @@ +#!/bin/bash +# TODO: run this script from CK root or build directory +set -euo pipefail + +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_bwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" +KNAME=1 +GPU_arch=${GPU_arch:-""} +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi + +export CK_WARMUP=0 +export CK_REPEAT=1 + +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"} + +COMMON_ARGS='-v=1' + +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + +test_h_s_mask() { + run_exe -b=1 -h=4 -h_k=2 -s=259 $@ + run_exe -b=2 -h=2 -s=516 -s_k=253 $@ + run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@ + run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@ + run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@ + run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@ +} + +set -x +# main tests +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 256 ; do +for mode in 0 1 ; do +for bias in "n" "a" ; do +for dbias in 0 ; do +for p_drop in 0.0 0.2 ; do +for deterministic in 0 ; do +test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +done +done +done +done +done +done +done +done + +# additional cases +for hdim in 40 48 72 96 ; do +test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +done +set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) diff --git a/example/ck_tile/42_unified_attention/script/smoke_test_fwd.sh b/example/ck_tile/42_unified_attention/script/smoke_test_fwd.sh new file mode 100755 index 00000000000..fca6b8d0cd3 --- /dev/null +++ b/example/ck_tile/42_unified_attention/script/smoke_test_fwd.sh @@ -0,0 +1,281 @@ +#!/bin/bash +# TODO: run this script from CK root or build directory +set -euo pipefail + +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_fwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" +KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi + +export CK_WARMUP=0 +export CK_REPEAT=1 + +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"} + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 + +TEST_SPLITKV=0 +TEST_APPENDKV=0 +# options: +# -s: run splitkv tests +# -a: run appendkv tests +while getopts ":sa" opt; do + case "${opt}" in + s) + TEST_SPLITKV=1 + ;; + a) + TEST_APPENDKV=1 + ;; + *) + ;; + esac +done + +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + +run_fp16_bf16_tests() { + local NUM_SPLITS="1" + local PAGE_BLOCK_SIZE="0" + local CACHE_BATCH_IDX="0" + + if [ $TEST_SPLITKV -eq 1 ] ; then + NUM_SPLITS="$NUM_SPLITS 2 3" + PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" + CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" + fi + + for prec in "fp16" "bf16" ; do + for mode in 1 0 ; do + for perm in 0 1 ; do + for hdim in 32 64 128 256 ; do + for lse in 0 1 ; do + for bias in "n" "e" "a" ; do + for p_drop in 0.0 0.2 ; do + for num_splits in $NUM_SPLITS ; do + for page_block_size in $PAGE_BLOCK_SIZE ; do + for cache_batch_idx in $CACHE_BATCH_IDX ; do + + # run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done ; done ; done +} + +run_fp8_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8bf16_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8fp32_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp16_appendkv_tests() { + for s in $(seq 63 1 65) ; do + for s_k in 65 129 ; do + for s_knew in 0 64 $s_k ; do + for hdim in 32 64 128 256 ; do + for ri in 0 1 ; do + for rdim in 0 16 32 $hdim ; do + for page_block_size in 0 128 ; do + for cache_batch_idx in 0 1 ; do + + run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done +} + +run_padding_smoke_tests() { + # Padding-only smoke tests for batch/group mode using COMMON_ARGS + local prec="fp16" + + # Batch mode: padding via effective lengths (exclude PAD) + # Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches + local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low pad (≈90–95% effective) + $EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + # medium pad (≈60–75% effective) + $EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + # high pad (≈30–40% effective) + $EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + + # Group mode: padding via physical stride along seqlen + local seqlens_q="1024,768,512,256" + local seqlens_k="1024,768,512,256" + local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low physical pad + $EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + # medium physical pad + $EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + # high physical pad + $EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 +} + +run_padding_basic_boundary_tests() { + # Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh) + local prec + local perm + + # Group mode: Q&K padded with per-batch different strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \ + -s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # slightly larger, uneven padding strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \ + -s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # only K padded; Q unpadded + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \ + -s=55 -s_k=256 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # use cu_seqlen overrides to skip tail PAD + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \ + -q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \ + -q_eff_lens=55,60 -kv_eff_lens=200,256 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # no padding (equal), mixed Q/KV, all len=1 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # highly variable logical lengths + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \ + -s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \ + -s=256,129 -s_k=256,129 -s_kpad=256 \ + -bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \ + -kname=$KNAME $COMMON_ARGS + done +} + +set -x + +run_fp16_bf16_tests +run_padding_smoke_tests +run_padding_basic_boundary_tests +run_fp8_tests +run_fp8bf16_tests +run_fp8fp32_tests + +if [ $TEST_APPENDKV -eq 1 ] ; then + run_fp16_appendkv_tests +fi + +set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp new file mode 100644 index 00000000000..fb3e37e1e06 --- /dev/null +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" +#include "mask.hpp" + +namespace ck_tile { + +std::ostream& operator<<(std::ostream& stream, + const unified_attention_args::data_type_enum& data_type) +{ + switch(data_type) + { + case unified_attention_args::data_type_enum::fp16: return stream << "fp16"; + case unified_attention_args::data_type_enum::bf16: return stream << "bf16"; + default: return stream << "unknown"; + } +} + +std::pair unified_attention(const unified_attention_args& args, + const stream_config& config) +{ + if(args.data_type == unified_attention_args::data_type_enum::fp16) + { + if(args.mask_type == static_cast(mask_enum::no_mask)) + { + using kernel_traits = + unified_attention_kernel_traits; + + return unified_attention_kernel_dispatch(args, config); + } + else + { + using kernel_traits = + unified_attention_kernel_traits; + + return unified_attention_kernel_dispatch(args, config); + } + } + else if(args.data_type == unified_attention_args::data_type_enum::bf16) + { + if(args.mask_type == static_cast(mask_enum::no_mask)) + { + using kernel_traits = + unified_attention_kernel_traits; + + return unified_attention_kernel_dispatch(args, config); + } + else + { + using kernel_traits = + unified_attention_kernel_traits; + + return unified_attention_kernel_dispatch(args, config); + } + } + + return std::make_pair(false, -1.f); +} + +} // namespace ck_tile diff --git a/example/ck_tile/42_unified_attention/unified_attention.hpp b/example/ck_tile/42_unified_attention/unified_attention.hpp new file mode 100644 index 00000000000..64f340c5562 --- /dev/null +++ b/example/ck_tile/42_unified_attention/unified_attention.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/host/stream_config.hpp" +#include "ck_tile/ops/unified_attention.hpp" + +namespace ck_tile { + +struct unified_attention_args +{ + enum class data_type_enum + { + fp16, + bf16 + }; + + data_type_enum data_type; + // bool is_varlen; + index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and + // window_size_right == 0). + + index_t num_tokens; // total number of tokens in query + index_t num_blks; + index_t num_head_q; + index_t num_queries_per_kv; + index_t page_blk_size; + // index_t BLOCK_SIZE; + + index_t hdim; + // TODO window + float scale_s; + float scale; + float scale_k; + float scale_v; + float scale_out; + + const void* q_ptr; + index_t query_stride_0; + index_t query_stride_1; + + const void* k_ptr; // [num_blks, blk_size, num_kv_heads, head_size] + index_t stride_k_cache_0; + index_t stride_k_cache_1; + index_t stride_k_cache_2; + index_t stride_k_cache_3; + + const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size] + index_t stride_v_cache_0; + index_t stride_v_cache_1; + index_t stride_v_cache_2; + index_t stride_v_cache_3; + + void* o_ptr; + index_t output_stride_0; + index_t output_stride_1; + + const int32_t* block_tables_ptr; + index_t block_table_stride; + const int32_t* seq_lens_ptr; // seq len in each batch + const int32_t* query_start_len_ptr; // [num_seqs+1] + + index_t num_seqs; // number of batches for q +}; + +std::ostream& operator<<(std::ostream& stream, + const unified_attention_args::data_type_enum& data_type); + +// return value: +// first = whether the kernel was launched (true = launched, false = skipped) +// second = elapsed time (ms) of the kernel launch, valid only if first == true +std::pair unified_attention(const unified_attention_args& args, + const stream_config& config); + +} // namespace ck_tile + +struct UnifiedAttentionMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp new file mode 100644 index 00000000000..8087c4b8e64 --- /dev/null +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" +#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp" + +#include "unified_attention.hpp" +#include "mask.hpp" + +#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \ + template <> \ + std::pair unified_attention_kernel_dispatch( \ + const unified_attention_args& args, const stream_config& config) \ + { \ + return std::make_pair( \ + true, unified_attention_kernel_launch(args, config)); \ + } + +namespace ck_tile { + +template +struct unified_attention_problem_traits; + +template <> +struct unified_attention_problem_traits +{ + using qkvp_dtype = ck_tile::half_t; + using acc_dtype = float; + using o_dtype = ck_tile::half_t; + using lse_dtype = float; +}; + +template <> +struct unified_attention_problem_traits +{ + using qkvp_dtype = ck_tile::bf16_t; + using acc_dtype = float; + using o_dtype = ck_tile::bf16_t; + using lse_dtype = float; +}; + +template +struct unified_attention_kernel_traits +{ + static constexpr auto date_type = DataType; + static constexpr bool is_masking = IsMasking; + + static constexpr index_t kBlockM = 256; + static constexpr index_t BLOCK_SIZE = 32; + static constexpr index_t HEAD_SIZE = 128; + + // TODO please fix this to support also other num_queries_per_kv + static constexpr index_t num_queries_per_kv = 1; + static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; + + // kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE + using unified_attention_block_tile = sequence; + using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; + // need to have 8 warps per workgroup to have warp specialization + using unified_attention_block_warps = sequence<8, 1, 1>; + + using unified_attention_shape = TileUnifiedAttentionShape; + + using unified_attention_traits = TileUnifiedAttentionTraits; + + using unified_attention_mask = GenericAttentionMask; + + using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem< + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::lse_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::o_dtype, + unified_attention_shape, + unified_attention_mask, + unified_attention_traits>; + + using unified_attention_pipeline = UnifiedAttentionPipeline; + + using epilogue = Default2DEpilogue< + Default2DEpilogueProblem::acc_dtype, + typename unified_attention_problem_traits::o_dtype, + true, // kPadM + true, // kPadM + true // UseRawStore + >>; + + using kernel = UnifiedAttentionKernel; +}; + +template +float unified_attention_kernel_launch(const unified_attention_args& args, + const stream_config& config) +{ + index_t kBlockQ = Kernel::kBlockQ; + assert(args.num_queries_per_kv == Kernel::num_queries_per_kv && + "argument num_queries_per_kv must equal compiled num_queries_per_kv"); + assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && + "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); + assert(kBlockQ == kBlockM / args.num_queries_per_kv && + "kBlockQ must equal kBlockM / num_queries_per_kv"); + index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs; + auto kargs = Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.o_ptr, + args.num_blks, + args.num_head_q, + args.num_queries_per_kv, + args.scale_s, + args.scale, + args.scale_k, + args.scale_v, + args.scale_out, + args.page_blk_size, + total_num_q_blocks, + args.query_stride_0, + args.query_stride_1, + args.stride_k_cache_0, + args.stride_k_cache_1, + args.stride_k_cache_2, + args.stride_k_cache_3, + args.stride_v_cache_0, + args.stride_v_cache_1, + args.stride_v_cache_2, + args.stride_v_cache_3, + args.output_stride_0, + args.output_stride_1, + args.block_tables_ptr, + args.block_table_stride, + args.seq_lens_ptr, + args.query_start_len_ptr, + args.num_seqs); + + dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; + + return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} + +// return value: +// first = whether the kernel was launched (true = launched, false = skipped) +// second = elapsed time (ms) of the kernel launch, valid only if first == true +template +std::pair unified_attention_kernel_dispatch(const unified_attention_args& args, + const stream_config& config); + +} // namespace ck_tile diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 215525878b8..ca3fe67867e 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -4,7 +4,6 @@ include_directories(AFTER ${CMAKE_CURRENT_LIST_DIR} ) - add_subdirectory(01_fmha) add_subdirectory(02_layernorm2d) add_subdirectory(03_gemm) @@ -30,4 +29,4 @@ add_subdirectory(36_pooling) add_subdirectory(38_block_scale_gemm) add_subdirectory(40_streamk_gemm) add_subdirectory(41_batched_contraction) - +add_subdirectory(42_unified_attention) diff --git a/include/ck_tile/ops/unified_attention.hpp b/include/ck_tile/ops/unified_attention.hpp new file mode 100644 index 00000000000..53ca1da684d --- /dev/null +++ b/include/ck_tile/ops/unified_attention.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" +#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp new file mode 100644 index 00000000000..33ca84d2c5c --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -0,0 +1,300 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +enum struct GenericAttentionMaskEnum +{ + NO_MASK = 0, + + // below enum could be causal, or sliding window + MASK_FROM_TOP_LEFT = 1, + MASK_FROM_BOTTOM_RIGHT = 2, + + // this enum maybe not used by xformer/FA, since it's hard to + // specify left/right window for varlen case. put it here for + // debug purpose + MASK_GENERIC, +}; + +// clang-format off +/* generic Attention Mask Coordinate + use x(horizontal axis), y(vertical axis) to describe mask. + top-left corner is origin + + x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask) + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + 1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + l=7,-1/r=0(tl) l=7,-1/r=0(br) + + x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2 + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 + * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1 + l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl) + l=4/r=0(br) l=4/r=2(br) l=4/r=4(br) + + x=4/y=-1 x=6/y=-1 x=8/y=-1 + * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 + * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 + * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1 + * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1 + + x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r) + * * * * * * * * 1 * * * * * * * + * * * * * * * * 1 1 * * 1 * * * + * * * * * * * * 1 1 1 * 1 1 * * + 1 * * * * * * * 1 1 1 1 1 1 1 * + 1 1 * * * * * * 1 1 1 1 1 1 1 1 + + Validations: + x + y > 1 (x + y >= 2) + + Note: + y = seq_q, x = 1 -> top-left + y = seq_q, x = seq_k - seq_q + 1 -> bottom-right + y < seq_q, x < seq_k -> local-attn + y = seq_q, x = seq_k -> no mask + +*/ +namespace impl { + template struct MaskName; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mc"; }; + template<> struct MaskName { static constexpr const char * name = "mg"; }; +} +// clang-format on + +template +struct GenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask, + // else only upper-right could have mask + + static constexpr const char* name = impl::MaskName::name; + + // New constructor accepting repeat_idx with default value 1 + CK_TILE_HOST_DEVICE + GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) + : GenericAttentionMask(0, 0, y_total_, x_total_, repeat_idx_) + { + } + + CK_TILE_HOST_DEVICE + GenericAttentionMask( + index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_), repeat_idx(repeat_idx_) + { + } + + template + CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord, + index_t repeat_idx_ = 1) + : y(mask_coord.at(number<0>{})), + x(mask_coord.at(number<1>{})), + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})), + repeat_idx(repeat_idx_) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + // Transform the y index according to repeat_idx + index_t y_eff = i_y / repeat_idx; + + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assuming we loop over along X tile by tile + index_t x_start = [&]() { + if constexpr(IsLocal) + { + index_t tmp = max(-y + y_eff + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + } + else + { + return 0; + } + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(y_eff + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // Note: this function does not take a dynamic y index so no transform is needed + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assuming we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + // Transform the y index according to repeat_idx + index_t y_eff = i_y / repeat_idx; + + if constexpr(!IsMasking) + { + return i_x >= x_total; + } + else + { + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + y_eff + 1; + index_t x_end = min(y_eff + x, x_total); + + if constexpr(IsLocal) + { + return i_x < x_start || i_x >= x_end; + } + else + { + return i_x >= x_end || y_eff >= y_total; + } + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the index passed in this function is within range of GetTileRangeAlongX/Y() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number, number) const + { + // Transform the y index according to repeat_idx + index_t y_eff = i_tile_top / repeat_idx; + + if constexpr(!IsMasking) + { + // TODO: no need to check begin + return (i_tile_left + TileWidth) > x_total; + } + else + { + if constexpr(IsLocal) + { + // check top-right corner > x or left-bottom corner < x + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = y_eff + TileHeight; + index_t x_end = min(y_eff + x, x_total); + + bool top_right_edge = i_tile_right > (y_eff + x); + bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); + bool is_partial_out_of_bound = + i_tile_right > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + } + else + { + // only need to check top-right corner > x + index_t i_tile_right = i_tile_left + TileWidth; + index_t x_end = min(y_eff + x, x_total); + + bool top_right_edge = i_tile_right > x_end; + return top_right_edge; + } + } + } + + private: + index_t y, x; + index_t y_total, x_total; + index_t repeat_idx; +}; + +// TODO: prefer use this function in host code +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + // TODO: below should all use sgpr arithmetic + index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1; + index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1; + + left_size = left_size < 0 ? left_size_tmp : left_size; + right_size = right_size < 0 ? right_size_tmp : right_size; + + index_t x_tmp = is_top_left ? 0 : x_total - y_total; + index_t y_tmp = is_top_left ? 0 : y_total - x_total; + + index_t x = 1 + right_size + x_tmp; + index_t y = 1 + left_size + y_tmp; + + return ck_tile::make_tuple(y, x, y_total, x_total); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + index_t repeat_idx = 1, + bool is_top_left = true) +{ + auto r = make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, is_top_left); + return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total, repeat_idx}; +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp new file mode 100644 index 00000000000..1a69afad201 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -0,0 +1,460 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" +#include "ck_tile/core/numeric/math.hpp" + +#include +#include +#include +#include + +namespace ck_tile { + +template +struct UnifiedAttentionKernel +{ + using UnifiedAttentionPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = UnifiedAttentionPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + static constexpr bool kPadSeqLenK = UnifiedAttentionPipeline::kPadSeqLenK; + static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ; + static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV; + + static constexpr index_t kHeadDim = UnifiedAttentionPipeline::kHeadDim; + static constexpr index_t kHeadDimPadded = UnifiedAttentionPipeline::kHeadDimPadded; + + // kBlockQ = kBlockM // num_queries_per_kv + // kBlockQ is the block size for q seqlen + /// static constexpr index_t kBlockQ = UnifiedAttentionPipeline::kBlockQ; + static constexpr index_t kBlockM = UnifiedAttentionPipeline::kBlockM; + static constexpr index_t kBlockQ = UnifiedAttentionPipeline::kBlockQ; + // BLOCK size for K seqlen + static constexpr index_t kPageBlockSize = UnifiedAttentionPipeline::kPageBlockSize; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + // The attention is default causal + struct UnifiedAttentionCommonKargs + { + const void* q_ptr; + const void* k_ptr; // [num_blks, page_size, num_kv_heads, head_size] + const void* v_ptr; // [num_blks, page_size, num_kv_heads, head_size] + void* o_ptr; + + ck_tile::index_t num_blks; + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + const ck_tile::index_t num_queries_per_kv; + // scales + float scale_s; + float scale; + float scale_k; + float scale_v; + float scale_out; + + ck_tile::index_t page_size; + + ck_tile::index_t total_num_q_blocks; + ck_tile::index_t query_stride_0; + ck_tile::index_t query_stride_1; + ck_tile::index_t stride_k_cache_0; + ck_tile::index_t stride_k_cache_1; + ck_tile::index_t stride_k_cache_2; + ck_tile::index_t stride_k_cache_3; + ck_tile::index_t stride_v_cache_0; + ck_tile::index_t stride_v_cache_1; + ck_tile::index_t stride_v_cache_2; + ck_tile::index_t stride_v_cache_3; + ck_tile::index_t output_stride_0; + ck_tile::index_t output_stride_1; + }; + + struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs + { + const int32_t* block_tables_ptr; + ck_tile::index_t block_table_stride; + const int32_t* seq_lens_ptr; // seq len in each batch + const int32_t* query_start_len_ptr; // [num_seqs+1] + + ck_tile::index_t num_seqs; // number of batches for q + }; + + using Kargs = UnifiedAttentionVarlenKargs; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* o_ptr, + ck_tile::index_t num_blks, + ck_tile::index_t num_head_q, + const ck_tile::index_t num_queries_per_kv, + float scale_s, + float scale, + float scale_k, + float scale_v, + float scale_out, + ck_tile::index_t page_size, + ck_tile::index_t total_num_q_blocks, + ck_tile::index_t query_stride_0, + ck_tile::index_t query_stride_1, + ck_tile::index_t stride_k_cache_0, + ck_tile::index_t stride_k_cache_1, + ck_tile::index_t stride_k_cache_2, + ck_tile::index_t stride_k_cache_3, + ck_tile::index_t stride_v_cache_0, + ck_tile::index_t stride_v_cache_1, + ck_tile::index_t stride_v_cache_2, + ck_tile::index_t stride_v_cache_3, + ck_tile::index_t output_stride_0, + ck_tile::index_t output_stride_1, + const int32_t* block_tables_ptr, + ck_tile::index_t block_table_stride, + const int32_t* seq_lens_ptr, + const int32_t* query_start_len_ptr, + ck_tile::index_t num_seqs) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + num_blks, + num_head_q, + num_queries_per_kv, + static_cast(scale_s * ck_tile::log2e_v<>), + scale, + scale_k, + scale_v, + scale_out, + page_size, + total_num_q_blocks, + query_stride_0, + query_stride_1, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + output_stride_0, + output_stride_1}, + block_tables_ptr, + block_table_stride, + seq_lens_ptr, + query_start_len_ptr, + num_seqs}; + + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads, + ck_tile::index_t total_num_q_blocks) + { + return dim3(num_kv_heads * total_num_q_blocks); + } + + // Binary search to find the sequence index for a given target index + CK_TILE_DEVICE static constexpr ck_tile::index_t + find_seq_idx(const int32_t* query_start_len_ptr, + ck_tile::index_t target_idx, + ck_tile::index_t num_seqs, + ck_tile::index_t block_q, + bool use_q_block_mode) + { + ck_tile::index_t left = 0; + ck_tile::index_t right = num_seqs; + + while(left < right) + { + ck_tile::index_t mid = (left + right) / 2; + ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]); + ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val; + + if(mid_val <= target_idx) + { + left = mid + 1; + } + else + { + right = mid; + } + } + + return left - 1; + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, + const Kargs& kargs) + { + using namespace ck_tile; + + ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv; + + return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv); + } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(), + EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + using namespace ck_tile; + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + ck_tile::index_t pid = blockIdx.x; + + const index_t num_queries_per_kv = kargs.num_queries_per_kv; + + assert(kBlockM / num_queries_per_kv == kBlockQ); + + // for simplicity, batch stride we just modify the pointer + // const index_t num_head_q = kargs.num_head_q; + + // divide problem + const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs); + + // grid size is (num_kv_heads, total_num_q_blocks) + // total_num_q_blocks = q.shape[0] // kBlockQ + num_seqs + // q.shape[0] is total number of query tokens across all batches + // one q_block spans kBlockQ = kBlockM // num_queries_per_kv number of query token groups. + // One query token group shares one kv token + + const index_t seq_idx = find_seq_idx(kargs.query_start_len_ptr, + q_block_global_idx, + kargs.num_seqs, + kBlockQ, + true); // which batch + + const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / kBlockQ + seq_idx; + + const index_t q_block_local_idx = + amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + + const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx]; + const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1]; + + const index_t cur_batch_query_len = + amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); + + // TODO check if we get the block size info from pipeline + if(q_block_local_idx * kBlockQ >= cur_batch_query_len) + { + return; + } + + const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ); + const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; + + const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); + + index_t _max_seq_prefix_len = amd_wave_read_first_lane( + (context_len + q_block_local_idx * kBlockQ + (kBlockM - 1) + 1)); + + if(seq_len < _max_seq_prefix_len) + { + _max_seq_prefix_len = seq_len; + } + + const auto max_seq_prefix_len = _max_seq_prefix_len; + const index_t num_blocks = + amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize); + + // TODO sliding window + const index_t num_blocks_start = 0; + index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2; + + // Q/K/V DRAM and DRAM window + index_t q_ptr_offset_0 = cur_batch_in_all_start_index * + kargs.query_stride_0; // move the pointer to the batch start + index_t q_ptr_offset_1 = + kv_head_idx * num_queries_per_kv * + kargs.query_stride_1; // move the pointer to the correct head group start + index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; + + index_t o_ptr_offset_0 = cur_batch_in_all_start_index * + kargs.output_stride_0; // move the pointer to the batch start + index_t o_ptr_offset_1 = + kv_head_idx * num_queries_per_kv * + kargs.output_stride_1; // move the pointer to the correct head group start + index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1; + index_t block_table_offset = seq_idx * kargs.block_table_stride; + + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + kv_head_offset; + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; + + index_t query_len_padded = + amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, kBlockQ) * kBlockQ); + // const bool is_query_len_padded = (cur_batch_query_len % kBlockQ == 0); + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_base = make_naive_tensor_view( + q_ptr, + make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim), + make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), + number{}, + number<1>{}); + + const auto q_dram_pad = + pad_tensor_view( // aling seqlen with kBlockQ and head dim with kHeadDimPadded + q_dram_base, + // block sizes + make_tuple(number{}, 1, kHeadDimPadded), + sequence{}); // pads to (seq_len_padded, num_head_q, + // kHeadDimPadded) + + const auto q_dram_merged = transform_tensor_view( + q_dram_pad, + make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), + make_pass_through_transform(kHeadDimPadded)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, + sequence<1>{})); // flattens the first two dims, head idx is the fastest + // changing dim in the merged dim + + return q_dram_merged; + }(); + // static_assert(q_dram.desc_[number<0>{}] == 0, + // "q_dram.get_bottom_tensor_view()[number<0>{}] == 0"); + + // Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim) + // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) + auto q_dram_window = + make_tile_window(q_dram, + make_tuple(number{}, number{}), + {query_pos * num_queries_per_kv, 0}); + + const auto k_dram = [&]() { + // HEAD dim is skipped as defined in the ptrs + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.num_blks * kargs.page_size, kHeadDim), + make_tuple(kargs.stride_k_cache_1, kargs.stride_k_cache_3), + number{}, + number<1>{}); + + const auto k_dram_pad = + pad_tensor_view(k_dram_naive, + // TODO can the kPageBlockSize_RAW needs padding? + make_tuple(kPageBlockSize, kHeadDimPadded), + sequence{}); + + return k_dram_pad; + }(); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.num_blks * kargs.page_size, kHeadDim), + make_tuple(kargs.stride_v_cache_1, kargs.stride_v_cache_3), + number{}, + number<1>{}); + + const auto v_dram_pad = pad_tensor_view(v_dram_naive, + make_tuple(kPageBlockSize, kHeadDimPadded), + sequence{}); + + return v_dram_pad; + }(); + + auto v_dram_window = make_tile_window( + v_dram, make_tuple(number{}, number{}), {0, 0}); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + -1, + 0, + cur_batch_query_len, // y_total + seq_len, // x_total + num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv + // times along x dim of the tile + false); + else + return FmhaMask{cur_batch_query_len, seq_len}; + }(); + + const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize; + assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size + + auto o_acc_tile = [&]() { + return UnifiedAttentionPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + num_blocks, + num_blocks_start, + kargs.block_tables_ptr, + block_table_offset, + kv_page_size_in_blocks, + mask, + kargs.scale_s, + smem_ptr); + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_base = make_naive_tensor_view( + o_ptr, + make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim), + make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1), + number{}, + number<1>{}); + + const auto o_dram_pad = + pad_tensor_view( // aling cu_seqlen with kBlockQ and head dim with kHeadDimPadded + o_dram_base, + // block sizes + make_tuple(kBlockQ, 1, kHeadDimPadded), + sequence{}); // pads to (seq_len_padded, num_head_q, + // kHeadDimPadded) + + const auto o_dram_merged = transform_tensor_view( + o_dram_pad, + make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), + make_pass_through_transform(kHeadDimPadded)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return o_dram_merged; + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {query_pos * num_queries_per_kv, 0}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp new file mode 100644 index 00000000000..68c53401cdf --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length() +{ + if constexpr(Headdim == 48) + return 48; + else if constexpr(Headdim == 96) + return 128; + else if constexpr(Headdim == 160) + return 256; + else if constexpr(Headdim == 192) + return 192; + else if constexpr(is_power_of_two_integer(Headdim)) + return Headdim; + else + static_assert(Headdim == 0, + "only Headdim of 48, 96, 160, 192 and power-of-two is supported"); +}; + +template +struct TileUnifiedAttentionShape +{ + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + + static constexpr index_t NumGemm0Warps = + reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); + static constexpr index_t NumGemm1Warps = + reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}); + static_assert(NumGemm1Warps % NumGemm0Warps == 0); + + static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); + + static constexpr index_t kBlockM = BlockTile::at( + number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) + static constexpr index_t kBlockQ = BlockTile::at( + number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) + // static constexpr index_t kBlockM = BlockTile::at(number<1>{}); // tile size along q seqlen * + // num_queries_per_kv (q_head//kv_head) + static constexpr index_t kPageBlockSize = + BlockTile::at(number<2>{}); // BLOCK size for K seqlen + static constexpr index_t kHeadDim = BlockTile::at(number<3>{}); // BLOCK size for K seqlen + + // BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at + // // once (or repeately load Q as a whole tile) + // static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); + + static constexpr index_t kHeadDimPadded = ceil_to_qualified_tile_length(); + + // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen + static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; + using VLayout = std::conditional_t; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp new file mode 100644 index 00000000000..8b01a5722da --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileUnifiedAttentionTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDim = kPadHeadDim_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp new file mode 100644 index 00000000000..74693460ec1 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -0,0 +1,1053 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" +#define ENABLE_ASM_MARKER 1 +#if ENABLE_ASM_MARKER +#define ASM_MARKER(marker) \ + __builtin_amdgcn_sched_barrier(0); \ + asm volatile("; [POYENC] " #marker); \ + __builtin_amdgcn_sched_barrier(0); +#else +#define ASM_MARKER(marker) +#endif + +#define ADD_SBARRIER_FOR_PHASE0 1 +#if !defined(CK_TILE_DISABLE_PACKED_FP32) +#define CK_TILE_DISABLE_PACKED_FP32 0 +#endif + +#define WARP_ID 0 +#define LANE_ID 0 + +#define ENABLE_DEBUG_STMTS 1 +#if ENABLE_DEBUG_STMTS +#define DEBUG_STMTS \ + if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID) +#else +#define DEBUG_STMTS if constexpr(false) +#endif + +namespace ck_tile { + +template +struct UnifiedAttentionPipeline +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + + static_assert(std::is_same_v, + "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); + + using UnifiedAttentionShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; + + static constexpr ck_tile::index_t kBlockM = UnifiedAttentionShape::kBlockM; + static constexpr ck_tile::index_t kBlockQ = UnifiedAttentionShape::kBlockQ; + + static constexpr ck_tile::index_t kPageBlockSize = UnifiedAttentionShape::kPageBlockSize; + static constexpr ck_tile::index_t kHeadDim = UnifiedAttentionShape::kHeadDim; + static constexpr ck_tile::index_t kHeadDimPadded = UnifiedAttentionShape::kHeadDimPadded; + + static_assert(kHeadDimPadded <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + // static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDim; + // static constexpr bool kStoreLSE = Problem::kStoreLSE; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr ck_tile::index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr ck_tile::index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr ck_tile::index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr ck_tile::index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr ck_tile::index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + return 2; + } + }(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // create another LDS buffer for p + return ck_tile::max(kBlockM * kHeadDimPadded * sizeof(PDataType), + Policy::template GetSmemSize() + + kBlockM * kPageBlockSize * sizeof(PDataType)); + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc() + { + using namespace ck_tile; + constexpr auto lds_block_desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + return lds_block_desc; + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D() + { + using namespace ck_tile; + constexpr auto lds_block_desc = make_naive_tensor_descriptor( + make_tuple(number{}), make_tuple(number<1>{}), number<1>{}, number<1>{}); + + return lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc) + { + using namespace ck_tile; + + auto tensor_view = + make_tensor_view(reinterpret_cast(base), desc); + return make_tile_window(tensor_view, desc.get_lengths(), {0, 0}); + } + + // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7 + template + CK_TILE_DEVICE static constexpr void s_waitcnt() + { + // vmcnt use bits {[15:14],[3:0]} + // expcnt use bits [6:4] + // lgkmcnt use bits [11:8] + __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) | + ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8)); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt() + { + s_waitcnt(); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt() + { + s_waitcnt<63, Lgkmcnt>(); + } + + template + CK_TILE_DEVICE auto operator()( + const QDramBlockWindowTmp& q_dram_block_window_tmp, // kBlockM * kHeadDimPadded tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // kPageBlockSize * kHeadDimPadded tile + [[maybe_unused]] const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // kHeadDimPadded * kPageBlockSize tile + [[maybe_unused]] const VElementFunction& v_element_func, + const index_t num_blocks, + const index_t num_blocks_start, + const void* block_tables_ptr, + index_t block_table_offset, + const index_t kv_page_size_in_blocks, + [[maybe_unused]] const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert( + kBlockM == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kPageBlockSize == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kHeadDimPadded == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kPageBlockSize == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kHeadDimPadded == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + static_assert(sizeof(SaccDataType) * kPageBlockSize * kBlockM <= GetSmemSize()); + auto s_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto s_lds_window = make_tile_window( + s_lds, make_tuple(number{}, number{}), {0, 0}); + + auto p_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto p_lds_window = make_tile_window( + p_lds, make_tuple(number{}, number{}), {0, 0}); + + auto o_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto o_lds_window = make_tile_window( + o_lds, make_tuple(number{}, number{}), {0, 0}); + + auto m_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc1D()); + [[maybe_unused]] auto m_lds_window = + make_tile_window(m_lds, make_tuple(number{}), {0}); + + const index_t warp_group_id = get_warp_id() / 4; + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); + + auto q_dram_window = make_tile_window_linear( + q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution()); + + // auto q_dram_window = q_dram_block_window_tmp; + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + auto k_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + auto v_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + statically_indexed_array( + nullptr, + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution())), + 2> + k_lds_window_load; + + statically_indexed_array( + nullptr, + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution())), + 2> + v_lds_window_load; + + decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())) q_tile; + + union kv_tile_type + { + CK_TILE_DEVICE kv_tile_type() {} + + decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile; + + decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile; + } kv_tile; + + union sp_compute_type + { + CK_TILE_DEVICE sp_compute_type() {} + + decltype(gemm_0.MakeCBlockTile()) sp_compute; + decltype(make_static_distributed_tensor( + Policy::template MakePRegTileDistribution())) p; + }; + statically_indexed_array sp; + + decltype(gemm_1.MakeCBlockTile()) o_acc; + constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd() + // instructions should we move to fmha_alu1() + static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); + + decltype(block_tile_reduce( + sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m; + decltype(m) l; + + // initialize k_lds_window and v_lds_window + static_for<0, 2, 1>{}([&](auto idx) { + k_lds_window_load(idx) = make_tile_window( + make_lds_tile_window( + static_cast(smem_ptr) + (idx)*Policy::template GetSmemSizeKV(), + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution()); + }); + + static_for<0, 2, 1>{}([&](auto idx) { + v_lds_window_load(idx) = + make_tile_window(make_lds_tile_window( + static_cast(smem_ptr) + + (idx + 2) * Policy::template GetSmemSizeKV(), + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution()); + }); + + { + auto origin_q = load_tile(q_dram_window); + auto transformed_q = tile_elementwise_in(q_element_func, origin_q); + + q_tile = transformed_q; + } + + clear_tile(o_acc); + set_tile(m, bit_cast(0xff7fffff)); // a bit larger than -infinity + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + + const auto num_total_loop = num_blocks; + index_t k_block_idx = 0; + index_t v_block_idx = 0; + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop - num_blocks_start <= 0) + { + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + // TODO check correctness of this + index_t i_total_loops = num_blocks_start; + const index_t PageSize = kv_page_size_in_blocks * kPageBlockSize; + const ck_tile::index_t* block_tables_ptr_ = + reinterpret_cast(block_tables_ptr); + assert(k_block_idx == v_block_idx); // because of the following line + block_table_offset += num_blocks_start; + index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx]; + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {kv_blk_idx_initial * PageSize, 0}, + Policy::template MakeKDramTileDistribution()); + k_dram_window.init_raw(); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {kv_blk_idx_initial * PageSize, 0}, + Policy::template MakeVDramTileDistribution()); + v_dram_window.init_raw(); + + // prefetch K tile + constexpr index_t k0_loops = 1; + constexpr index_t k1_loops = 1; + static_assert(1 == k0_loops); + static_assert(1 == k1_loops); + // static_assert(kPageBlockSize == kHeadDimPadded); + + constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup; + static_assert(NumWarpGroups == 2); + + [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) { + printf("[POYENC] %s (size=%d): %5.2f", + name, + decltype(dist_tensor.thread_buf_)::size(), + ck_tile::type_convert(dist_tensor.thread_buf_[0])); + static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) { + printf(", %5.2f", ck_tile::type_convert(dist_tensor.thread_buf_[i])); + }); + printf("\n"); + }; + + [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) { + const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{}); + const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + if constexpr(true || num_rows < num_cols) + { + for(int row = 0; row < num_rows; ++row) + { + int offset = desc.calculate_offset(make_tuple(row, 0)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + row, + ck_tile::type_convert(data[offset])); + for(int col = 1; col < num_cols; ++col) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + else + { + for(int col = 0; col < num_cols; ++col) + { + int offset = desc.calculate_offset(make_tuple(0, col)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + col, + ck_tile::type_convert(data[offset])); + for(int row = 1; row < num_rows; ++row) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + }; + + [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) { + const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + int offset = desc.calculate_offset(make_tuple(0)); + printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert(data[offset])); + for(int e = 1; e < num_elems; ++e) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(e)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + }; + + // K_mem_su_ld_insts = 1 for 32 x 128 + // V_mem_su_ld_insts = 1 for 128 x 32 + constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access(); + constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); + + // Page block index tracking + // const index_t kv_page_size_in_blocks = + // PageSize / kPageBlockSize; + // index_t kv_block_idx = 0; + // only for block 0 and thread + if(blockIdx.x == 0 && threadIdx.x == 0) {} + auto K_mem_load = [&](auto k_lds_write_idx) { + async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); + k_block_idx++; + + index_t k_page_blk_idx = + block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; + k_dram_window.set_window_origin( + {k_page_blk_idx * PageSize + + (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize, + 0}); + }; + + auto V_mem_load = [&](auto v_lds_write_idx) { + async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); + // prefetch next V tile (only if not at the end of loop) + v_block_idx++; + + index_t v_page_blk_idx = + block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)]; + v_dram_window.set_window_origin( + {v_page_blk_idx * PageSize + + (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize, + 0}); + // we assume that v load is always after k + }; + + auto K_lds_load = [&](auto k_lds_read_idx) { + kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx)); + }; + + auto V_lds_load = [&](auto v_lds_read_idx) { + kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + }; + + decltype(m) m_old; + SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd() + /// TODO: remove the sp_delta and use sp_compute directly + statically_indexed_array{}).sp_compute), 2> sp_delta; + + auto fmha_alu0 = [&](auto sp_reg_idx) { + m_old = m; // m{j-1} + static_assert(m.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowmax value"); + auto m_latest = block_tile_reduce( + sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(m_latest.thread_buf_[0]), + bit_cast(m_latest.thread_buf_[0]), + false, + false); + /// TODO: eliminate 2 redudant v_max_f32 instructions generated by the compiler + m_latest.thread_buf_[0] = f_max(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(m_latest, f_max, bool_constant{}); +#endif + m = m_latest; + + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( + sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx)); + }); + }); + /// TODO: move some fmha_alu1() code here if necessary + }; + + auto fmha_alu1 = [&](auto sp_reg_idx) { + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp(sp_reg_idx).sp_compute(i_j_idx) = + ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx)); + }); + }); + + auto rowsum_p = block_tile_reduce( + sp(sp_reg_idx).sp_compute, + sequence<1>{}, + f_sum, + SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + static_assert(rowsum_p.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowsum value"); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(rowsum_p.thread_buf_[0]), + bit_cast(rowsum_p.thread_buf_[0]), + false, + false); + rowsum_p.thread_buf_[0] = f_sum(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#endif + + // l{j} + /// Note: The compiler keeps moving the following instructions elsewhere because 'l' + /// is first consumed later. To anchor them here, we rewrite the final addition in + /// inline assembly to create a dependency, forcing the dependent instructions to + /// be emitted at this point. + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx])); + + l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); + }); + + // update partial o_acc [0, fmha_alu_D_reg_cnt) + static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) { + o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale); + }); + + /// Note: The compiler keeps sinking the conversion instructions because the + /// result 'p' is only consumed later. To anchor them here, we rewrite + /// the cast_tile() call as inline assembly, forcing the conversions to be + /// emitted at this point. + static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0); + static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) { + float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]); + float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]); + if constexpr(std::is_same_v) + { + auto casted = detail::cvt_pk_fp16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + else + { + auto casted = detail::cvt_pk_bf16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + }); + + /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly + /// can interfere with the behavior of sched_group_barrier(), so ending the phase here + /// avoids unintended reordering. + }; + + auto gemm = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kHeadDimPadded>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kHeadDimPadded>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{})); + } + }; + + auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kHeadDimPadded>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kHeadDimPadded>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{})); + fmha_alu0(number<1>{} - sp_reg_idx); + } + }; + + auto fmha_alu_D_upd = [&] { + o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0])); + + fp32x2_t pk_o_acc_scale; + pk_o_acc_scale.x = o_acc_scale; + pk_o_acc_scale.y = o_acc_scale; + + static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0); +#if CK_TILE_DISABLE_PACKED_FP32 + static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size()); + static_for{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); +#endif + + constexpr auto issued_D_reg_cnt = +#if CK_TILE_DISABLE_PACKED_FP32 + fmha_alu_D_reg_cnt + 2 +#else + fmha_alu_D_reg_cnt +#endif + ; + /// NOTICE: Use inline asm v_pk_mul_f32 to reduce latency. The fmha_alu_D_upd() call + /// should be placed at the end of a phase. + // update partial o_acc after [issued_D_reg_cnt] + static_for{}([&](auto idx) { + fp32x2_t input; + input.x = o_acc.thread_buf_[idx]; + input.y = o_acc.thread_buf_[idx + 1]; + + auto output = detail::pk_mul_f32(input, pk_o_acc_scale); + + o_acc.thread_buf_[idx] = output.x; + o_acc.thread_buf_[idx + 1] = output.y; + }); + }; + + auto fmha_mask = [&](auto sp_reg_idx) { + if constexpr(FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + i_total_loops * kPageBlockSize, + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(sp(sp_reg_idx).sp_compute, + -numeric::infinity(), + [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + i_total_loops * kPageBlockSize + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + }; + + auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) { + if constexpr(load_type == 0) + { + V_mem_load(mem_wr_idx); + K_lds_load(lds_rd_idx); + } + else + { + K_mem_load(mem_wr_idx); + V_lds_load(lds_rd_idx); + } + }; + + auto core_loop = [&](auto cl_p) { + auto gemm0 = number<0>{}; + auto gemm1 = number<1>{}; + + auto memV = number<0>{}; + auto memK = number<1>{}; + + using Scheduler = CoreLoopScheduler; + + auto iteration = [&](auto pi) { + auto xdl_SP_p01_reg_idx = number<1>{} - pi; + auto xdl_SP_p23_reg_idx = pi; + + auto K_w0_lds_wr_idx = number<1>{} - pi; + auto V_w0_lds_wr_idx = pi; + auto K_w0_lds_rd_idx = pi; + auto V_w0_lds_rd_idx = pi; + + auto K_w4_lds_wr_idx = number<1>{} - pi; + auto V_w4_lds_wr_idx = number<1>{} - pi; + auto K_w4_lds_rd_idx = number<1>{} - pi; + auto V_w4_lds_rd_idx = pi; + + bool result = true; + + if constexpr(cl_p == 0) + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave0-3 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave0-3 (pi=1)"); + } + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); + Scheduler::schedule(cl_p, number<1>{}); + fmha_mask(xdl_SP_p01_reg_idx); + + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave0-3"); + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 0"); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<2>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); + + Scheduler::schedule(cl_p, number<3>{}); + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + } + else + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave4-7 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave4-7 (pi=1)"); + } + cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<1>{}); + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave4-7"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx); + Scheduler::schedule(cl_p, number<2>{}); + fmha_mask(xdl_SP_p01_reg_idx); + + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<3>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + } + return result; + }; + return iteration(number<0>{}) && iteration(number<1>{}); + }; + + auto fmha_post_process = [&](auto d) { + auto ps_pi = number<1>{} - d; + auto V_lds_rd_idx = ps_pi; + + if(1 < num_total_loop) + { + s_waitcnt_vmcnt(); + } + else + { + s_waitcnt_vmcnt<0>(); + } + __builtin_amdgcn_s_barrier(); + + V_lds_load(V_lds_rd_idx); + fmha_alu1(ps_pi); + + s_waitcnt_lgkmcnt<0>(); + + auto xdl_SP_p23_reg_idx = ps_pi; + gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{}); + }; + + // pre-stage + { + ASM_MARKER("before pre-stage"); + // (1) load K0 to LDS & VGPR + K_mem_load(number<0>{}); // mem_K0 + + s_waitcnt_vmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + K_lds_load(number<0>{}); // lds_K0 + + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + // (2) prefetch K1 and V0 to LDS in parallel with GEMM0 + if(1 < num_total_loop) + { + K_mem_load(number<1>{}); // mem_K1 + } + V_mem_load(number<0>{}); // mem_V0 + + // (3) mfma (Q*K0) + softmax + gemm(number<0>{}, /*gemm_idx=*/number<0>{}); + + fmha_mask(number<0>{}); + /// TODO: find better way to map fmha_alu(0,96) call + fmha_alu0(number<0>{}); + fmha_alu_D_upd(); + + ++i_total_loops; + if(num_total_loop <= i_total_loops) + { + goto label_main_loops_exit; + } + + if(2 < num_total_loop) + { + K_mem_load(number<0>{}); // mem_K2 + + s_waitcnt_vmcnt(); + __builtin_amdgcn_s_barrier(); + } + + ASM_MARKER("end pre-stage"); + } + + if(1 < num_total_loop) + { + if(warp_group_id == 0) + { + V_mem_load(number<1>{}); // V1 + K_lds_load(number<1>{}); // K1 + + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<0>{})) + ; + } + if(warp_group_id != 0) + { + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<1>{})) + ; + } + } + label_main_loops_exit: + if(num_total_loop % 2) + { + fmha_post_process(number<1>{}); + } + if(!(num_total_loop % 2)) + { + fmha_post_process(number<0>{}); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_DEVICE auto operator()( + const QDramBlockWindowTmp& q_dram_block_window_tmp, // kBlockM * kHeadDimPadded tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // kPageBlockSize * kHeadDimPadded tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // kHeadDimPadded * kPageBlockSize tile + const index_t num_blocks, + const index_t num_blocks_start, + const void* block_tables_ptr, + index_t block_table_offset, + const index_t kv_page_size_in_blocks, + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + num_blocks, + num_blocks_start, + block_tables_ptr, + block_table_offset, + kv_page_size_in_blocks, + identity{}, + identity{}, + identity{}, + mask, + scale_s, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp new file mode 100644 index 00000000000..b0f8b26af68 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -0,0 +1,599 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" + +namespace ck_tile { + +struct UnifiedAttentionPipelineDefaultPolicy +{ + static constexpr ck_tile::index_t NumWarpPerGroup = 4; + static constexpr ck_tile::index_t NumThreadPerWarpGroup = + NumWarpPerGroup * ck_tile::get_warp_size(); + + // TODO: GetAlignment*() currently didn't consider if need padding or not + // so in pipeline still need check padding requirement + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentK() + { + using namespace ck_tile; + using KDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(KDataType); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentV() + { + using namespace ck_tile; + using VDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentK(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64 + + constexpr index_t KVector = GetAlignmentV(); // this is for global load + // 4 + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; // 8 + constexpr index_t N1 = LaneGroups; // 2 + constexpr index_t N2 = NumWarps; // 8 + constexpr index_t K0 = LanesPerK; // 32 + constexpr index_t K1 = KVector; // 4 + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + // compute the endcoding before transpose + constexpr auto v_block_dstr = + make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(v_block_dstr_encode), + typename Problem::VDataType>::TransposedDstrEncode{}); + + return v_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr auto GetQKBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, + typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>; + + using WarpGemm = + WarpGemmDispatcher{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false>; + + using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::SaccDataType, + typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, + WarpGemm, + GemmLoopOrder::MNK>; + + return BlockGemmARegBRegCRegV2{}; + } + + template + CK_TILE_DEVICE static constexpr auto GetPVBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::UnifiedAttentionShape::Gemm1BlockWarps, + typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>; + /// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass + /// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single + using WarpGemm = + WarpGemmDispatcher{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>; + + using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + typename Problem::UnifiedAttentionShape::Gemm1BlockWarps, + WarpGemm, + GemmLoopOrder::MNK>; + return BlockGemmARegBRegCRegV2{}; + } + + static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords + static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords + + template + CK_TILE_DEVICE static constexpr auto + MakeKLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = WarpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (WarpSize * KVector + kPad); + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return max(SingleKSize, SingleVSize); + } + + template + CK_TILE_DEVICE static constexpr auto + MakeVLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentV(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number<(IBuf + 2) * GetSingleSmemElementSpaceSize()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return v_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() + { + using namespace ck_tile; + + static_assert(MakeKLdsLoadBlockDescriptor().get_element_space_size() == + MakeKLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t k_element_space_size = + MakeKLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(MakeVLdsLoadBlockDescriptor().get_element_space_size() == + MakeVLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t v_element_space_size = + MakeVLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <= + GetSingleSmemElementSpaceSize()); + + /// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() & + /// MakeVLdsBlockDescriptor() + static_assert(std::is_same_v); + constexpr index_t kv_element_space_size_in_bytes = + GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + + return kv_element_space_size_in_bytes; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return 4 * GetSmemSizeKV(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp new file mode 100644 index 00000000000..2b655c74b3f --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct UnifiedAttentionPipelineProblem +{ + // TODO kM0 and KN1?? + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + // first gemm accumulation dtype + using SaccDataType = remove_cvref_t; + // Softmax dtype + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + // data type for A matrix of second gemm + using PDataType = remove_cvref_t; + // data type for second gemm accumulation + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using UnifiedAttentionShape = remove_cvref_t; + using Traits = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps; + static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size(); + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDim = Traits::kPadHeadDim; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; + static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; + static constexpr bool kHasDropout = Traits::kHasDropout; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; +} // namespace ck_tile diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 31d724deb66..455bcf1954d 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -20,7 +20,7 @@ else MY_PROJECT_SOURCE=".." fi -GPU_TARGETS="gfx908;gfx90a;gfx942" +GPU_TARGETS="gfx950" if [ $# -ge 1 ]; then case "$1" in