Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
441de82
async copy code
carlushuang Dec 1, 2023
f90c80a
modify stream config
carlushuang Dec 1, 2023
94e7723
mofidy some internal API
carlushuang Dec 1, 2023
3ab658e
remove some useleff code
carlushuang Dec 1, 2023
3380eeb
support MQA/GQA
carlushuang Dec 4, 2023
f810f41
rename some code
carlushuang Dec 6, 2023
149a242
seperate async to different pipeline
carlushuang Dec 6, 2023
72b5a47
rename radio->ratio
carlushuang Dec 7, 2023
8f3a9ad
merge feature/fmha-pad-support aece827
carlushuang Dec 10, 2023
5999fd9
merge main
carlushuang Dec 10, 2023
4790257
Merge branch 'main' into fmha_attemp_async_copy_unify
carlushuang Dec 10, 2023
ced4670
add missing bf16 type
carlushuang Dec 11, 2023
cfcc7e7
Fix loop counter update logics
poyenc Dec 11, 2023
6048487
Merge branch 'fmha_attemp_async_copy_unify' of github.com:asroy/ck_ti…
poyenc Dec 11, 2023
5a24af3
Disable exp() and log() overloading for half_t to support xformers C+…
qianfengz Dec 1, 2023
d180391
Add include/ck/config.h to support xformers c++ extension building
qianfengz Dec 1, 2023
c1814f9
refactor mask in async copy pipeline
carlushuang Dec 11, 2023
7fab8b0
Make sure RNG data for MaskUpperTriangleFromBottomRightPredicate is v…
poyenc Dec 11, 2023
dc9ba2e
Use std::make_tuple() to construct temp std::tuple<>
poyenc Dec 11, 2023
b7e3f3b
Add bhalf2_t, bhalf4_t inner_product
qianfengz Dec 13, 2023
3ffae93
Merge pull request #55 from asroy/fmah_attemp_async_copy_unify_innerprod
carlushuang Dec 13, 2023
d205bc5
Choose constant according precision
poyenc Dec 20, 2023
7ddff7b
Avoid inefficient instruction
poyenc Dec 20, 2023
81f1b0f
Merge pull request #57 from asroy/feature/speed-up-bias-mode
carlushuang Dec 20, 2023
6b888b6
Remove sched_barrier() for non-bias mode
poyenc Dec 20, 2023
3913a40
WIP add generic masking (#59)
carlushuang Dec 29, 2023
afea739
add __device__ to make_generic_attention_mask_coordinates_from_lr_window
carlushuang Jan 3, 2024
b556a44
Re-organize example directories (#60)
poyenc Jan 5, 2024
0c1bf34
modify bench script
carlushuang Jan 5, 2024
33a2ee1
Fix in block_masking.hpp
qianfengz Jan 5, 2024
62b17b9
Merge pull request #62 from asroy/fmha_attemp_async_copy_unify_maskin…
poyenc Jan 5, 2024
65c8f98
Fix inconsistent mask creation logics (#63)
poyenc Jan 6, 2024
539f967
support non-broadcast in block reduce sync
carlushuang Jan 6, 2024
f188b80
Fix wrong data type used for bias tensor
poyenc Jan 11, 2024
1787c23
Flexible head dimension (#66)
poyenc Jan 11, 2024
cd4c060
Fix complation error
poyenc Jan 12, 2024
bf427ce
Extract distributed indices convertion logics as function
poyenc Jan 12, 2024
6cbea7d
Flash attention fwd store LSE (#65)
ltqin Jan 18, 2024
73166db
Support head dim = 256 for fMHA (#70)
poyenc Jan 19, 2024
bcb6592
F8 enablement (#71)
carlushuang Jan 23, 2024
97997b6
Fix wrong arg order of transform_tensor_view()
poyenc Jan 23, 2024
f8c746b
Merge pull request #73 from asroy/feature/fix-wrong-trans-desc
poyenc Jan 24, 2024
5b6b5df
Validate m values before use them (#75)
poyenc Jan 29, 2024
9a302e6
Rename & separate TileFmhaTraits<> padding flags for better comprehen…
poyenc Jan 29, 2024
52b621e
Feature fMHA generic mask / bias issues (#80)
poyenc Jan 30, 2024
1bed0e7
Fallback changes for init=0
poyenc Jan 30, 2024
0d231a7
Add back sched_barrier() in pipeline
poyenc Jan 30, 2024
e914fa2
Fix README.md wording (#78)
poyenc Jan 30, 2024
d1adca3
restore init dist
carlushuang Jan 31, 2024
eb53e23
Check padding boundary in GenericAttentionMask<>::IsEdgeTile() (#81)
poyenc Feb 1, 2024
3bda955
Allow infinity reference value while checking LSE (#83)
poyenc Feb 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ set(version 1.1.0)
# Check support for CUDA/HIP in Cmake
project(composable_kernel VERSION ${version})

find_package(Python3 3.7 COMPONENTS Interpreter REQUIRED)

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")

if (DTYPES)
Expand Down
20 changes: 12 additions & 8 deletions example/91_tile_program/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
add_example_executable(example_im2col im2col.cpp)
add_example_executable(example_gemm gemm.cpp)
add_example_executable(example_gemm_gemm gemm_gemm.cpp)
add_example_executable(example_reduce reduce.cpp)
add_example_executable(example_softmax softmax.cpp)
add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp)
add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp)
add_example_executable(example_fmha_fwd fmha_fwd.cpp)
include_directories(AFTER
${CMAKE_CURRENT_LIST_DIR}
)

add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(fmha)
add_subdirectory(gemm)
add_subdirectory(gemm_gemm)
add_subdirectory(gemm_softmax_gemm)
add_subdirectory(im2col)
add_subdirectory(reduce)
add_subdirectory(softmax)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp)
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

#include "reference_batched_gemm.hpp"
#include "reference_batched_softmax.hpp"
#include "reference/reference_batched_gemm.hpp"
#include "reference/reference_batched_softmax.hpp"
#include "batched_gemm_softmax_gemm.hpp"

int main(int argc, char* argv[])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"

#include "gemm_softmax_gemm_impl.hpp"
#include "gemm_softmax_gemm/gemm_softmax_gemm_impl.hpp"

// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
Expand Down
182 changes: 182 additions & 0 deletions example/91_tile_program/common/arg_parser.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <string>

#include <iomanip>
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <unordered_map>
#include <vector>

/*
* arg parser for
* -[key0]=[value0] -[key1]=[value1] ...
*/
class ArgParser
{
public:
class Arg
{
public:
std::string name;
std::string value;
std::string help_text;
};

ArgParser() {}
ArgParser& insert(const std::string& _name,
const std::string& _default_value,
const std::string& _help_text)
{
Arg in;
in.name = _name;
in.value = _default_value;
in.help_text = _help_text;

if(input_map.count(_name) != 0)
{
printf("arg:%s already exist\n", _name.c_str());
}
else
{
input_map[_name] = in;
keys.push_back(_name);
}
return *this;
}
void print()
{
printf("args:\n");
for(auto& key : keys)
{
auto value = input_map[key];
std::vector<std::string> help_text_lines;
size_t pos = 0;
for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;)
{
help_text_lines.push_back(std::string(value.help_text.begin() + pos,
value.help_text.begin() + next_pos++));
pos = next_pos;
next_pos = value.help_text.find('\n', pos);
}
help_text_lines.push_back(
std::string(value.help_text.begin() + pos, value.help_text.end()));

std::string default_value = std::string("(default:") + value.value + std::string(")");

std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key
<< std::setw(4) << " " << help_text_lines[0] << " " << default_value
<< std::endl;

for(auto help_next_line = std::next(help_text_lines.begin());
help_next_line != help_text_lines.end();
++help_next_line)
{
std::cout << std::setw(17) << " " << *help_next_line << std::endl;
}
}
}
bool parse(int argc, char* argv[], int start_index = 1)
{
if(argc < start_index)
{
printf("not enough args\n");
return false;
}
for(int i = start_index; i < argc; i++)
{
char* cur_arg = argv[i];
if(cur_arg[0] != '-')
{
printf("illegal input\n");
print();
return false;
}
else
{
std::string text(cur_arg + 1);
if(text == "?")
{
print();
return false;
}
auto pos = text.find('=');
if(pos == std::string::npos)
{
printf("arg should be [key]=[value] pair, here:%s\n", text.c_str());
return false;
}
if(pos >= (text.size() - 1))
{
printf("cant find value after \"=\", here:%s\n", text.c_str());
return false;
}
auto key = text.substr(0, pos);
auto value = text.substr(pos + 1);
if(input_map.count(key) == 0)
{
printf("no such arg:%s\n", key.c_str());
return false;
}
input_map[key].value = value;
}
}
return true;
}

std::string get_str(const std::string& name) const
{
std::string value = input_map.at(name).value;
return value;
}

int get_int(const std::string& name) const
{
int value = atoi(input_map.at(name).value.c_str());
return value;
}

uint32_t get_uint32(const std::string& name) const
{
uint32_t value = strtoul(input_map.at(name).value.c_str(), nullptr, 10);
return value;
}

uint64_t get_uint64(const std::string& name) const
{
uint64_t value = strtoull(input_map.at(name).value.c_str(), nullptr, 10);
return value;
}

bool get_bool(const std::string& name) const
{
auto v = input_map.at(name).value;
if(v.compare("t") == 0 || v.compare("true") == 0)
return true;
if(v.compare("f") == 0 || v.compare("false") == 0)
return false;
int value = atoi(v.c_str());
return value == 0 ? false : true;
}

float get_float(const std::string& name) const
{
double value = atof(input_map.at(name).value.c_str());
return static_cast<float>(value);
}

double get_double(const std::string& name) const
{
double value = atof(input_map.at(name).value.c_str());
return value;
}

private:
std::unordered_map<std::string, Arg> input_map;
std::vector<std::string> keys;
};
41 changes: 41 additions & 0 deletions example/91_tile_program/fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt
)

# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory
# as current cmake list, otherwise will not figure out the dependency properly
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS)

add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
)

set(EXAMPLE_FMHA_FWD "example_fmha_fwd")
add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})

# NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
set(FMHA_FWD_FAST_EXP2 true)
endif()

set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)

# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=0)
endif()

# Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)

target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
90 changes: 90 additions & 0 deletions example/91_tile_program/fmha/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# fused multi-head attention

This folder contains example for fmha(fused multi-head attention) using ck tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast.

## build
```
# in the root of ck
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make example_fmha_fwd -j
```
This will result in an executable `build/bin/example_fmha_fwd`

## kernel
The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck. We may still have an implementation under ck's include path (in the future) for the kernel template.

There are 3 template parameters for this kernel template.
* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose.
* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support.

## codegen
To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable.

## executable
`example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/example_fmha_fwd -?` to list all supported args
```
args:
-v weather do CPU validation or not (default:1)
-mode kernel mode. 0:batch, 1:group (default:0)
-b batch size (default:2)
-h num of head, for q (default:8)
-h_k num of head, for k/v, 0 means equal to h (default:0)
if not equal to h, then this is GQA/MQA case
-s seqlen_q (default:3328)
-s_k seqlen_k, 0 means equal to s (default:0)
-d head dim for q, k (default:128)
-d_v head dim for v, 0 means equal to d (default:0)
-scale scale factor. 0 means equal to 1/sqrt(seqlen) (default:0)
-iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1)
-bias add bias or not (default:0)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-mask 0: no mask, 1: top-left, 2:bottom-right (default:0)
't:l,r', top-left local-attn with left right size
'b:l,r', bottom-r local-attn with left right size
'g:y,x', generic attention mask coordinate with y/x size
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0)
-init init method. 0:random int, 1:random float, 2:trig float (default:1)
```
Example: `./bin/example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.

## support features
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.

### hdim
Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. We may consider optimize other hdim performance if have more request. We also have an experimental support for arbitrary hdim(even odd number), one can change the return value of `get_pad()` inside `generate.py` to achieve this. (Note: we may change the method or optimize arbitrary hdim support in the future)

### group/batch mode
Currently we support both batch and group mode, by setting `-mode` = `0` or `1`, where in group mode we support each batch can have different seqlen

### MQA/GQA
By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers.

### input/output permute, and `b*s*3*h*d`
If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`.

### attention bias
Attention bias is supported with the layout of `b*h*s*s` and bias value in float number.

### lse
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`

### vlayout
We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts.

### generic attention mask coordinate
We unify the mask expression into generic attention mask coordinate, providing an uniformed approach to describe causal top-left, causal bottom-right, local attention.
![](misc/gamc.png)

(more description to be added)

### dropout
TBD

## FP8 experimental support
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `example_fmha_fwd`, on a gfx940/941/942 machine and ROCm 6.0+. Currently if you not explicitly setting `-v=0`(which will disable CPU verification), it will printout an error as much as `0.05`. We are still WIP to tune the kernel performance as well as the precision, so stay tuned for the updated performance(pipeline)
Currently we only support `-vlayout=c` for fp8, which is `hdim*seqlen` for V matrix. row major for V matrix support will come later.
Loading