diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index 8582691d..44d5b898 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -49,7 +49,7 @@ auto create_args(int argc, char* argv[]) .insert("s_k", "0", "seqlen_k, 0 means equal to s") .insert("d", "128", "head dim for q, k") .insert("d_v", "0", "head dim for v, 0 means equal to d") - .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(seqlen)") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") .insert("descale_q", "1", "scale factor for fp8 quantization") .insert("descale_k", "1", "scale factor for fp8 quantization") .insert("descale_v", "1", "scale factor for fp8 quantization") @@ -68,6 +68,7 @@ auto create_args(int argc, char* argv[]) "'g:y,x', generic attention mask coordinate with y/x size\n") .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") + .insert("kname", "0", "if set to 1 will print kernel name") .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") .insert("seed", "11939", @@ -157,8 +158,9 @@ bool run(const ArgParser& arg_parser) int stream_warmup = env_get_int("CK_WARMUP", 5); int stream_repeat = env_get_int("CK_REPEAT", 20); + bool kname = arg_parser.get_bool("kname"); - StreamConfig stream_config{nullptr, true, 0, stream_warmup, stream_repeat}; + StreamConfig stream_config{nullptr, true, kname ? 1 : 0, stream_warmup, stream_repeat}; const auto [seqlens_q, seqstart_q_host] = generate_seqlens_seqstarts_q(mode, batch, seqlen_q); const std::vector seqstart_k_host = @@ -296,9 +298,15 @@ bool run(const ArgParser& arg_parser) << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias << ", lse:" << lse << ", mask:" << mask << ", v:" << vlayout << std::flush; - auto fmha_traits = fmha_fwd_traits{ - hdim_q, data_type, mode == mode_enum::group, is_v_rowmajor, mask.type, use_bias, lse}; - auto fmha_args = fmha_fwd_args{q_buf.GetDeviceBuffer(), + auto fmha_traits = fmha_fwd_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + is_v_rowmajor, + mask.type, + use_bias, + lse}; + auto fmha_args = fmha_fwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), bias_buf.GetDeviceBuffer(), @@ -440,11 +448,11 @@ bool run(const ArgParser& arg_parser) auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck::utils::check_err( - o_host_result, o_host_ref, std::string("O Error: Incorrect results!"), rtol, atol); + o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); pass &= cur_pass; if(!cur_pass) { - std::cerr << "O mismatch found at batch: " << wb << std::endl + std::cerr << "OUT mismatch found at batch: " << wb << std::endl << "\tseqlen_q: " << real_seqlen_q << std::endl << "\tseqlen_k: " << real_seqlen_k << std::endl << "\tseqstart_q: " << seqstart_q_host << std::endl diff --git a/example/91_tile_program/fmha/fmha_fwd.hpp b/example/91_tile_program/fmha/fmha_fwd.hpp index a6db9439..f6a07849 100644 --- a/example/91_tile_program/fmha/fmha_fwd.hpp +++ b/example/91_tile_program/fmha/fmha_fwd.hpp @@ -290,23 +290,43 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.mask_x); } -// this is internal API, will be generated across different files to speedup compile +// this is used to pattern-match internl kernel implementation, not to instantiate kernel template + bool kStoreLse_, + bool kPadS_, + bool kPadSK_, + bool kPadD_, + bool kPadDv_> struct fmha_fwd_traits_ { - static constexpr ck::index_t HDim = HDim_; - using DataType = ck::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; - using FmhaMask = ck::remove_cvref_t; - static constexpr bool kHasBias = kHasBias_; - static constexpr bool kStoreLse = kStoreLse_; + static constexpr ck::index_t HDim = HDim_; + using DataType = ck::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck::index_t kM0 = kM0_; + static constexpr ck::index_t kN0 = kN0_; + static constexpr ck::index_t kK0 = kK0_; + static constexpr ck::index_t kN1 = kN1_; + static constexpr ck::index_t kK1 = kK1_; + static constexpr ck::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + using FmhaMask = ck::remove_cvref_t; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; }; template @@ -315,12 +335,14 @@ float fmha_fwd_(const StreamConfig&, fmha_fwd_args); // This is the public API, will be generated by script struct fmha_fwd_traits { - int hdim; + int hdim_q; + int hdim_v; std::string data_type; bool is_group_mode; bool is_v_rowmajor; mask_enum mask_type; bool has_bias; bool has_lse; + // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const StreamConfig&); diff --git a/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp b/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp index 6c5e6e86..84cba479 100644 --- a/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp @@ -7,19 +7,23 @@ #include "ck/tile_program/tile/store_tile.hpp" #include "ck/tile_program/tile/tile_elementwise.hpp" -template +template struct FmhaFwdEpilogueProblem { - using OaccDataType = ck::remove_cvref_t; - using ODataType = ck::remove_cvref_t; + using OaccDataType = ck::remove_cvref_t; + using ODataType = ck::remove_cvref_t; + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; }; template struct FmhaFwdEpilogue { - using Problem = ck::remove_cvref_t; - using OaccDataType = ck::remove_cvref_t; - using ODataType = ck::remove_cvref_t; + using Problem = ck::remove_cvref_t; + using OaccDataType = ck::remove_cvref_t; + using ODataType = ck::remove_cvref_t; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; __host__ __device__ static constexpr ck::index_t GetSmemSize() { return 0; } @@ -29,6 +33,15 @@ struct FmhaFwdEpilogue using namespace ck; using namespace ck::tile_program; - store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + // TODO: this is ugly + if constexpr(kPadSeqLenQ || kPadHeadDimV) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + buffer_store_fence(); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } } }; diff --git a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp index 4b2c6d09..435f527a 100644 --- a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include "ck/utility/common_header.hpp" #include "ck/tensor/tensor_view.hpp" @@ -44,6 +45,46 @@ struct FmhaFwdKernel using FmhaMask = ck::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + + "r" + _TS_(gbr::At(ck::Number<0>{})) + "x" + _TS_(gbr::At(ck::Number<1>{})) + "x" + _TS_(gbr::At(ck::Number<2>{})) + "_" + + "w" + _TS_(gwt::At(ck::Number<0>{})) + "x" + _TS_(gwt::At(ck::Number<1>{})) + "x" + _TS_(gwt::At(ck::Number<2>{})) + "_" + + "o" + _TS_(kBlockPerCu) + "_" + _SS_(FmhaPipeline::name) + "_" + + "v" + (ck::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + + (kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + template // to avoid duplicated base class prblem, introduce an template arg struct FmhaFwdEmptyKargs { @@ -447,7 +488,7 @@ struct FmhaFwdKernel q_ptr, make_tuple(kargs.seqlen_q, kargs.hdim_q), make_tuple(kargs.stride_q, 1), - Number<32>{}, + Number{}, Number<1>{}); if constexpr(FmhaPipeline::kQLoadOnce) { @@ -469,7 +510,7 @@ struct FmhaFwdKernel k_ptr, make_tuple(kargs.seqlen_k, kargs.hdim_q), make_tuple(kargs.stride_k, 1), - Number<32>{}, + Number{}, Number<1>{}); return pad_tensor_view( @@ -484,7 +525,7 @@ struct FmhaFwdKernel v_ptr, make_tuple(kargs.seqlen_k, kargs.hdim_v), make_tuple(kargs.stride_v, 1), - Number<32>{}, + Number{}, Number<1>{}); const auto v_dram_transposed = @@ -505,7 +546,7 @@ struct FmhaFwdKernel v_ptr, make_tuple(kargs.hdim_v, kargs.seqlen_k), make_tuple(kargs.stride_v, 1), - Number<32>{}, + Number{}, Number<1>{}); return pad_tensor_view( @@ -551,7 +592,7 @@ struct FmhaFwdKernel bias_ptr, make_tuple(kargs.seqlen_q, kargs.seqlen_k), make_tuple(kargs.stride_bias, 1), - Number<32>{}, + Number{}, Number<1>{}); return pad_tensor_view(bias_dram_naive, @@ -636,7 +677,7 @@ struct FmhaFwdKernel o_ptr, make_tuple(kargs.seqlen_q, kargs.hdim_v), make_tuple(kargs.stride_o, 1), - Number<32>{}, + Number{}, Number<1>{}); return pad_tensor_view( diff --git a/example/91_tile_program/fmha/generate.py b/example/91_tile_program/fmha/generate.py index f4594639..9142d08e 100644 --- a/example/91_tile_program/fmha/generate.py +++ b/example/91_tile_program/fmha/generate.py @@ -15,6 +15,14 @@ "fp8" : "ck::f8_t" } +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + MASK_MAP = { "no" : "FmhaMasks::NoMask", "causal" : "FmhaMasks::CausalMask", @@ -96,19 +104,24 @@ using fmha_epilogue_{F_idx} = FmhaFwdEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType>>; + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = FmhaFwdKernel, fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}>; +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include template<> float fmha_fwd_(const StreamConfig& s, fmha_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck::index_t kBlockPerCu = k_::kBlockPerCu; @@ -126,17 +139,12 @@ """ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ - switch (t.hdim){{ {F_hdim_case} - default: - break; - }} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" case {F_hdim}: {{ +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ {F_inner_dispatch} - }} - break; + }} """ MASK_CHECK_MAP = { "no" : "t.mask_type == mask_enum::no_mask", @@ -144,26 +152,107 @@ "generic" : "t.mask_type == mask_enum::window_generic", } -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}>; - return fmha_fwd_(s, a); - }} +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_fwd_(s, a); + }} """ @dataclass class FmhaFwdApiTrait: + pipeline_tag : str # sync with fmha_fwd_traits<>, to generate fallback calls hdim : str dtype : str # data type mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0blen : int vlayout : str mask : str bias : str # true/false lse : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0' + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0' + else : return f'a.hdim_q % {self.bk0blen} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0' + else : return f'a.hdim_v % {self.bk0blen} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_mask : str # value from MASK_MAP @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.vlayout}-{self.mask}-{self.bias}-{self.lse}' + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias == 't' : n += '_bias' + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + return n class FmhaFwdApiPool: def __init__(self): @@ -183,18 +272,21 @@ def api(self) -> str: per_dtypes=str() for i, dtype in enumerate(self.pool.keys()): per_hdim_case=str() - for hdim in self.pool[dtype].keys(): + for j, hdim in enumerate(self.pool[dtype].keys()): traits=self.pool[dtype][hdim] inners=str() - for j, trait in enumerate(traits): - if0 = 'if' if j == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if0, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask], - F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) - - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_hdim=hdim, F_inner_dispatch=inners) - if1 = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if1, F_dtype=dtype, F_hdim_case=per_hdim_case) - + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask], + F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) @dataclass @@ -223,17 +315,9 @@ class FmhaFwdKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim F_dtype : str # data type - F_tile : FmhaFwdTileSize - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_bias : str # true/false - F_lse : str # - F_mask : str # value from MASK_MAP F_mode : str # value from MODE_MAP - F_pipeline : str # value from PIPIELINE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline @property def template(self) -> str: @@ -254,48 +338,59 @@ def template(self) -> str: F_wm = self.F_tile.F_wm, F_wn = self.F_tile.F_wn, F_wk = self.F_tile.F_wk, - F_vlayout = LAYOUT_MAP[self.F_vlayout], - F_spad = BOOL_MAP[self.F_spad], - F_skpad = BOOL_MAP[self.F_skpad], - F_dpad = BOOL_MAP[self.F_dpad], - F_dvpad = BOOL_MAP[self.F_dvpad], - F_bias = BOOL_MAP[self.F_bias], - F_lse = BOOL_MAP[self.F_lse], + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BOOL_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_occupancy = self.F_tile.F_occupancy , - F_mask = MASK_MAP[self.F_mask], + F_mask = MASK_MAP[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f"_v{self.F_vlayout[0]}" +\ - f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\ - f"_{BOOL_MAP[self.F_bias][0]}_m{self.F_mask[0]}_l{BOOL_MAP[self.F_lse][0]}_{self.F_pipeline}" + return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\ + self.F_tile.name + '_' + self.F_pipeline.name @property def filename(self) -> str: return self.name + ".cpp" def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait(hdim=str(self.F_hdim), + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), dtype=self.F_dtype, mode=self.F_mode, - vlayout=self.F_vlayout, - mask=self.F_mask, - bias=self.F_bias, - lse=self.F_lse) + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) # TODO: design a more practical way to do it -# this is current supported tile size. +# this is current supported tile size per hdim def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: if direction == 'fwd': if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, 2), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, 3), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, 2), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, 1), + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, 2), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, 3), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, 2), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, 1), } elif dtype == 'fp8' or dtype == 'bf8': return { @@ -309,25 +404,26 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ def get_blobs() -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - def get_vlayout(dtype, hdim): - if dtype in ['fp16', 'bf16']: - return 'row' - elif dtype in ['fp8', 'bf8']: - return 'col' - else: - assert Fasle - def get_pipeline(dtype, hdim): + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + pipelines = [] if dtype in ['fp16', 'bf16']: - if hdim == 256: - return 'qr' - else: - return 'qr_async' + for mask, bias, lse in itertools.product(MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + if hdim == 256: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask)) elif dtype in ['fp8', 'bf8']: - return 'qr_fp8' + # no need lse kernels + for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask)) else: assert Fasle - def get_pad(dtype, hdim): - return 'f' + return pipelines gen = list() api_pool = FmhaFwdApiPool() @@ -336,17 +432,14 @@ def get_pad(dtype, hdim): d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) if d == None: continue - for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) - if dtype in ['fp8', 'bf8'] and lse == "t": - continue - k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, F_vlayout=get_vlayout(dtype, hdim), - F_spad=get_pad(dtype, hdim), F_skpad=get_pad(dtype, hdim), F_dpad=get_pad(dtype, hdim), - F_dvpad=get_pad(dtype, hdim), F_bias=bias, F_lse=lse, F_mask=mask, F_mode=mode, - F_pipeline=get_pipeline(dtype, hdim)) - api_pool.register_traits(k.api_trait()) - gen.append(k) + for pipeline in get_pipelines(dtype, hdim): + k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline) + api_pool.register_traits(k.api_trait()) + gen.append(k) return (api_pool, gen) diff --git a/example/91_tile_program/fmha/script/smoke_test.sh b/example/91_tile_program/fmha/script/smoke_test.sh index 40e17fd8..a81af8e2 100644 --- a/example/91_tile_program/fmha/script/smoke_test.sh +++ b/example/91_tile_program/fmha/script/smoke_test.sh @@ -2,22 +2,26 @@ # TODO: run this script from CK root BUILD=build EXE=$BUILD/bin/example_fmha_fwd +KNAME=1 for prec in "fp16" "bf16" ; do for perm in 0 1 ; do -for hdim in 128 64 256 ; do +for vlayout in "r" "c" ; do +for hdim in 128 64 256 32 ; do +for lse in 0 1 ; do for bias in 0 1 ; do -$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -iperm=$perm -operm=$perm -v=1 -$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=256 -bias=$bias -iperm=$perm -operm=$perm -v=1 -$EXE -prec=$prec -b=2 -h=2 -h_k=1 -d=$hdim -s=512 -s_k=256 -bias=$bias -iperm=$perm -operm=$perm -v=1 -$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=256 -s_k=512 -bias=$bias -iperm=$perm -operm=$perm -v=1 -$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -iperm=$perm -operm=$perm -mask=1 -v=1 -$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -iperm=$perm -operm=$perm -mask=2 -v=1 -$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -iperm=$perm -operm=$perm -mask=g:128,32 -v=1 - +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -v=1 -kname=$KNAME +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -v=1 -kname=$KNAME +$EXE -prec=$prec -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -v=1 -kname=$KNAME +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=100 -s_k=512 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -v=1 -kname=$KNAME +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -v=1 -kname=$KNAME +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -v=1 -kname=$KNAME +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=g:128,32 -vlayout=$vlayout -v=1 -kname=$KNAME done done done done +done +done diff --git a/include/ck/tensor/tensor_view.hpp b/include/ck/tensor/tensor_view.hpp index 03a8fcad..01f74440 100644 --- a/include/ck/tensor/tensor_view.hpp +++ b/include/ck/tensor/tensor_view.hpp @@ -53,29 +53,35 @@ struct TensorView // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::type, typename scalar_type>::type>, bool>::type = false> __host__ __device__ constexpr remove_cvref_t - GetVectorizedElements(const TensorCoord& coord, bool_constant = {}) const + GetVectorizedElements(const TensorCoord& coord, bool_constant = {}) const { return buf_.template Get( coord.GetOffset(), coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), - bool_constant{}); + bool_constant{}); } // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::type, typename scalar_type>::type>, bool>::type = false> - __host__ __device__ void GetVectorizedElementsRaw(remove_cvref_t& dst, - const TensorCoord& coord) const + __host__ __device__ void + GetVectorizedElementsRaw(remove_cvref_t& dst, + const TensorCoord& coord, + bool_constant = {}) const { - return buf_.template GetRaw(dst, coord.GetOffset()); + return buf_.template GetRaw( + dst, + coord.GetOffset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); } template >::type, typename scalar_type>::type>, bool>::type = false> - __host__ __device__ constexpr void SetVectorizedElements(const TensorCoord& coord, const X& x) + __host__ __device__ constexpr void SetVectorizedElements( + const TensorCoord& coord, const X& x, bool_constant = {}) { - buf_.template Set(coord.GetOffset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), - x); + buf_.template Set( + coord.GetOffset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + + template >::type, + typename scalar_type>::type>, + bool>::type = false> + __host__ __device__ constexpr void SetVectorizedElementsRaw( + const TensorCoord& coord, const X& x, bool_constant = {}) + { + buf_.template SetRaw( + coord.GetOffset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); } __host__ __device__ void Print() const diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 445004dc..4cfd5760 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -310,6 +310,18 @@ struct LeftPad is_known_at_compile_time::value; } + // MUST be static function + template + __host__ __device__ static constexpr auto + CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + // TODO: we allow pass through this vector length. If one need per-pixel check, + // should change the guaranteed vector length while creating the tensor view. + // It's up to runtime to check the padding length should be multiple of vector length + return make_tuple(low_vector_lengths, low_vector_strides); + } + __host__ __device__ void Print() const { printf("LeftPad{"); @@ -397,9 +409,21 @@ struct RightPad : public BaseTransform<1, 1> is_known_at_compile_time::value; } + // MUST be static function + template + __host__ __device__ static constexpr auto + CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + // TODO: we allow pass through this vector length. If one need per-pixel check, + // should change the guaranteed vector length while creating the tensor view. + // It's up to runtime to check the padding length should be multiple of vector length + return make_tuple(low_vector_lengths, low_vector_strides); + } + __host__ __device__ void Print() const { - printf("LeftPad{"); + printf("RightPad{"); // printf("up_lengths_: "); diff --git a/include/ck/tile_program/block_tile/block_masking.hpp b/include/ck/tile_program/block_tile/block_masking.hpp index 1e01310d..a794fbbd 100644 --- a/include/ck/tile_program/block_tile/block_masking.hpp +++ b/include/ck/tile_program/block_tile/block_masking.hpp @@ -57,7 +57,15 @@ namespace block { y = seq_q, x = seq_k -> no mask */ +namespace impl { + template struct MaskName; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mc"; }; + template<> struct MaskName { static constexpr const char * name = "mg"; }; +} // clang-format on + template struct GenericAttentionMask { @@ -65,6 +73,8 @@ struct GenericAttentionMask static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask, // else only upper-right could have mask + static constexpr const char* name = impl::MaskName::name; + __host__ __device__ GenericAttentionMask(index_t y_total_, index_t x_total_) : GenericAttentionMask(0, 0, y_total_, x_total_) { diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index c1c48a58..1c09224d 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -63,6 +63,26 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kHasBias = Problem::kHasBias; static constexpr bool kStoreLSE = Problem::kStoreLSE; + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(ck::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr const char* name = "qr"; + __host__ __device__ static constexpr ck::index_t GetSmemSize() { return Policy::template GetSmemSize(); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index b8a488ab..e0eecab5 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -56,17 +56,37 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kHasBias = Problem::kHasBias; static constexpr bool kStoreLSE = Problem::kStoreLSE; + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(ck::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + #if CK_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / math::log2e_v; #endif + static constexpr const char* name = "qr_async"; + __host__ __device__ static constexpr ck::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -165,7 +185,9 @@ struct BlockFmhaPipelineQRKSVSAsync // TODO: we use async Copy for K, which is inline asm // a side effect is we have to use inline asm for q as well - auto q = load_tile_raw(q_dram_window); + auto q = decltype(load_tile(q_dram_window)){}; + clear_tile(q); + load_tile_raw(q, q_dram_window); __builtin_amdgcn_sched_barrier(0); using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); @@ -289,7 +311,7 @@ struct BlockFmhaPipelineQRKSVSAsync __builtin_amdgcn_s_barrier(); const auto bias_tile = load_tile(bias_dram_window); // load bias tile - auto v_buf = load_tile(v_dram_window); + auto v_buf = load_tile(v_dram_window, bool_constant{}); __builtin_amdgcn_sched_barrier(0); { // tail gemm_0(s_acc, @@ -382,7 +404,11 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - store_tile(v_lds_window, + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, tile_elementwise_in(v_element_func, v_buf)); // store the prefetch } @@ -391,7 +417,7 @@ struct BlockFmhaPipelineQRKSVSAsync move_tile_window( v_dram_window, {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile(v_dram_window); // load next v_buf + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf } __builtin_amdgcn_sched_barrier(0); @@ -474,7 +500,7 @@ struct BlockFmhaPipelineQRKSVSAsync static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) { - v_buf = load_tile(v_dram_window); // load next v_buf + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf } block_sync_lds(); gemm_1(o_acc, @@ -500,7 +526,11 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - store_tile(v_lds_window, + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, tile_elementwise_in(v_element_func, v_buf)); // store next v_buf } if constexpr(i_k1 < k1_loops - 1) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 3e74e405..6afc7514 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -63,6 +63,26 @@ struct BlockFmhaPipelineQRKSVSFp8 static constexpr bool kHasBias = Problem::kHasBias; static constexpr bool kStoreLSE = Problem::kStoreLSE; + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(ck::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr const char* name = "qr_fp8"; + __host__ __device__ static constexpr ck::index_t GetSmemSize() { return Policy::template GetSmemSize(); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 3c94597a..5a9e5772 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -63,6 +63,8 @@ struct BlockFmhaPipelineQSKSVS static constexpr bool kHasBias = Problem::kHasBias; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr const char* name = "qs"; + __host__ __device__ static constexpr ck::index_t GetSmemSize() { return Policy::template GetSmemSize(); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 4bb59d79..c28b5bc8 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -42,6 +42,17 @@ struct BlockFmhaPipelineQXCustomPolicy return 0; } + // TODO: GetAlignment*() currently didn't consider if need padding or not + // so in pipeline still need check padding requirement + template + __host__ __device__ static constexpr auto GetAlignmentQ() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + } + template __host__ __device__ static constexpr auto MakeQDramTileDistribution() { @@ -136,6 +147,13 @@ struct BlockFmhaPipelineQXCustomPolicy return q_smem_size; } + template + __host__ __device__ static constexpr auto GetAlignmentQ() + { + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + template __host__ __device__ static constexpr auto MakeQDramTileDistribution() { @@ -309,6 +327,20 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + __host__ __device__ static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + if constexpr(AsyncCopyK) + { + return 4 / sizeof(KDataType); + } + else + { + return 16 / sizeof(KDataType); + } + } + template __host__ __device__ static constexpr auto GetSmemKPackV() { @@ -316,20 +348,53 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; return 16 / sizeof(VDataType); } + template - __host__ __device__ static constexpr auto GetVectorloadV() + __host__ __device__ static constexpr auto GetAlignmentV() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; + if constexpr(ck::is_same_v) + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - // TODO: not correct! - if constexpr(total_pixels > 4) - return 4; + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } else - return 2; + { + return 16 / sizeof(VDataType); + } + } + + template + __host__ __device__ static constexpr auto GetAlignmentBias() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.GetYs2DDescriptor().GetLengths().At(Number{}); + return vec; + } + + template + __host__ __device__ static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.GetYs2DDescriptor().GetLengths().At(Number{}); + return vec; } template @@ -348,8 +413,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds - constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; static_assert(warpSize * KVector >= kKPerBlock && @@ -366,7 +431,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; constexpr index_t Banks = 32; // TODO: need change based on arch constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t kKPack = GetSmemKPackK(); static_assert(PixelsPerRow % kKPack == 0); constexpr index_t NPerRow = PixelsPerRow / kKPack; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; @@ -380,13 +445,6 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy - __host__ __device__ static constexpr auto GetVectorloadK() - { - using KDataType = remove_cvref_t; - return 4 / sizeof(KDataType); // TODO: this is for async copy - } - template __host__ __device__ static constexpr auto MakeQRegBlockDescriptor() { @@ -425,7 +483,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); + constexpr index_t kKPack = GetSmemKPackK(); constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(Number{}, Number{}, Number{}), @@ -454,8 +512,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds - constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed @@ -509,8 +567,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds - constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; // for async-copy, this pad is between warps static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); @@ -556,8 +614,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for lds - constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load constexpr index_t kPad = KPack; // for async-copy, this pad is between warps static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); @@ -687,7 +745,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); // this is for global load + constexpr index_t KVector = GetAlignmentK(); // this is for global load static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave @@ -714,8 +772,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy __device__ static constexpr auto MakeVDramTileDistribution() { - using VDataType = remove_cvref_t; - using VLayout = remove_cvref_t; + using VLayout = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; @@ -723,7 +780,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy) { - constexpr index_t N1 = GetVectorloadV(); + constexpr index_t N1 = GetAlignmentV(); constexpr index_t N0 = kNPerBlock / N1; // P constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -764,7 +821,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); constexpr index_t K0 = kKPerBlock / K1; constexpr index_t N2 = get_warp_size() / K0; constexpr index_t N1 = kBlockSize / get_warp_size(); @@ -823,7 +880,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); + constexpr index_t N1 = GetAlignmentV(); constexpr index_t N0 = kNPerBlock / N1; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; static_assert(total_pixels % N1 == 0); // TODO: this is not always true? diff --git a/include/ck/tile_program/tile/load_tile.hpp b/include/ck/tile_program/tile/load_tile.hpp index 1aae8ad3..40eb3ddf 100644 --- a/include/ck/tile_program/tile/load_tile.hpp +++ b/include/ck/tile_program/tile/load_tile.hpp @@ -20,26 +20,31 @@ namespace tile_program { template + index_t NumCoord, + bool oob_conditional_check = true> __device__ auto load_tile(const TileWindowWithStaticDistribution& tile_window) + NumCoord>& tile_window, + bool_constant = {}) { - return tile_window.Load(); + return tile_window.Load(bool_constant{}); } -// This version use inline asm to do loading. -template -__device__ auto load_tile_raw(const TileWindowWithStaticDistribution +__device__ auto load_tile_raw(T& tile, + const TileWindowWithStaticDistribution& tile_window) + NumCoord>& tile_window, + bool_constant = {}) { - return tile_window.Load(bool_constant{}); + tile_window.LoadRaw(tile, bool_constant{}); } template +__device__ void +store_tile_raw(TileWindowWithStaticDistribution& tile_window, + const StaticDistributedTensor& dstr_tensor) +{ + tile_window.StoreRaw(dstr_tensor); +} + } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/tile/store_tile_impl_static_lengths.hpp b/include/ck/tile_program/tile/store_tile_impl_static_lengths.hpp index 5d25c5e4..4ce80f0f 100644 --- a/include/ck/tile_program/tile/store_tile_impl_static_lengths.hpp +++ b/include/ck/tile_program/tile/store_tile_impl_static_lengths.hpp @@ -38,5 +38,28 @@ store_tile(TileWindowWithStaticLengths& tile_ tile_window.Store(dstr_tensor); } +template +__device__ void +store_tile_raw(TileWindowWithStaticLengths& tile_window_tmp, + const StaticDistributedTensor& dstr_tensor) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.GetBottomTensorView(), + tile_window_tmp.GetWindowLengths(), + tile_window_tmp.GetWindowOrigin(), + tile_dstr); + + tile_window.StoreRaw(dstr_tensor); +} + } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp index 5e67dd4f..11306c49 100644 --- a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp +++ b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp @@ -263,8 +263,8 @@ struct TileWindowWithStaticDistribution __device__ constexpr auto GetNumAccess() const { return LoadStoreTraits::NumAccess; } - template - __device__ auto Load(bool_constant = {}) const + template + __device__ auto Load(bool_constant = {}) const { using Traits = LoadStoreTraits; @@ -291,7 +291,7 @@ struct TileWindowWithStaticDistribution // read from bottom tensor const vector_t vec_value = GetBottomTensorView().template GetVectorizedElements( - bottom_tensor_thread_coord, bool_constant{}); + bottom_tensor_thread_coord, bool_constant{}); const vector_type_t vec{vec_value}; @@ -327,9 +327,70 @@ struct TileWindowWithStaticDistribution return dst_tensor; } + __device__ auto MakeLoadTile() + { + constexpr auto tile_dstr = TileDstr{}; + return make_static_distributed_tensor(tile_dstr); + } + + template + __device__ void LoadRaw(DstTile& dst_tensor, bool_constant = {}) const + { + using Traits = LoadStoreTraits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + static constexpr index_t YElementSize = + TileDstr{}.GetYs2DDescriptor().GetElementSpaceSize(); + static_assert(YElementSize % Traits::ScalarPerVector == 0); + using vectorized_tbuf = StaticBuffer; + + constexpr auto tile_dstr = TileDstr{}; + + auto& dst_vec_tbuf = reinterpret_cast(dst_tensor.GetThreadBuffer()); + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = Number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess); + constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys_start); + static_assert(d % Traits::ScalarPerVector == 0); + + GetBottomTensorView().template GetVectorizedElementsRaw( + dst_vec_tbuf.template At(), + bottom_tensor_thread_coord, + bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(Array{0}, idx_diff_ys); + + MoveWindowAdaptorAndBottomTensorThreadCoordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + // TODO: currently async load only implemented in inline asm - template - __device__ auto AsyncLoad(LdsTileWindow_&& lds_tile, bool_constant = {}) const + template + __device__ auto AsyncLoad(LdsTileWindow_&& lds_tile, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; // using LdsTensorView = typename LdsTileWindow::BottomTensorView; @@ -397,7 +458,9 @@ struct TileWindowWithStaticDistribution }); } - __device__ void Store(const StaticDistributedTensor& dstr_tensor) const + template + __device__ void Store(const StaticDistributedTensor& dstr_tensor, + bool_constant = {}) const { using Traits = LoadStoreTraits; @@ -440,7 +503,69 @@ struct TileWindowWithStaticDistribution // write into bottom tensor GetBottomTensorView().template SetVectorizedElements( - bottom_tensor_thread_coord, vec_value); + bottom_tensor_thread_coord, vec_value, bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(Array{0}, idx_diff_ys); + + MoveWindowAdaptorAndBottomTensorThreadCoordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + __device__ void StoreRaw(const StaticDistributedTensor& dstr_tensor) const + { + using Traits = LoadStoreTraits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + static constexpr bool oob_conditional_check = true; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = Number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess); + + // read from distributed tensor + vector_type_t vec; + + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + Number{}); + + constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys); + + vec.template AsType()(j) = + dstr_tensor.GetThreadBuffer().template At(); + }); + + const vector_t vec_value = vec.template AsType().template At<0>(); + + // write into bottom tensor + GetBottomTensorView() + .template SetVectorizedElementsRaw( + bottom_tensor_thread_coord, vec_value); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 5220d8dd..16cbdbad 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -155,11 +155,375 @@ struct buffer_load<1> } }; +template +struct buffer_load_if; + +template <> +struct buffer_load_if<16> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 16); + auto saved_exec = __builtin_amdgcn_read_exec(); + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<8> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 8); + auto saved_exec = __builtin_amdgcn_read_exec(); + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<4> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 4); + auto saved_exec = __builtin_amdgcn_read_exec(); + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<2> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 2); + auto saved_exec = __builtin_amdgcn_read_exec(); + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<1> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 1); + auto saved_exec = __builtin_amdgcn_read_exec(); + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template +struct buffer_store; + +template <> +struct buffer_store<16> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 16); + asm volatile("buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<8> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 8); + asm volatile("buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<4> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 4); + asm volatile("buffer_store_dword %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<2> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 2); + asm volatile("buffer_store_short %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<1> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 1); + asm volatile("buffer_store_byte %0, %1, %2, %3 offen offset:%4" + : + : "v"(value), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template +struct buffer_store_if; + +template <> +struct buffer_store_if<16> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 16); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<8> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 8); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<4> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 4); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<2> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 2); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<1> +{ + template + __device__ void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 1); + auto save_exec = __builtin_amdgcn_read_exec(); + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(value), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + __device__ void buffer_load_fence(index_t cnt = 0) { asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } +__device__ void buffer_store_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + // buffer load i8 __device__ int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, @@ -537,8 +901,7 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence> __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) @@ -554,15 +917,7 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(use_inline_asm) - { - using type = typename vector_type::type; - type tmp; - buffer_load{}( - tmp, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - return tmp; - } - else if constexpr(is_same::value) // fp32 + if constexpr(is_same::value) // fp32 { if constexpr(N == 1) { @@ -714,30 +1069,29 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + bool oob_conditional_check = true> __device__ void amd_buffer_load_raw_impl(typename vector_type::type& dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, - index_t src_wave_addr_offset) + index_t src_wave_addr_offset, + index_t flag = 0) { - static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), - "wrong! not implemented"); -#if BUFFER_LOAD_USE_INLINEASM + constexpr index_t bytes = sizeof(T) * N; + static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, + "wrong! not supported by buffer_load instruction"); + using type = typename vector_type::type; - buffer_load{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); -#else - (void)dst; - (void)src_wave_buffer_resource; - (void)src_thread_addr_offset; - (void)src_wave_addr_offset; -#endif + if constexpr(oob_conditional_check) + { + buffer_load_if{}( + dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + } + else + { + buffer_load{}( + dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + } } template ::type src } } +template +__device__ void amd_buffer_store_raw_impl(const typename vector_type::type dst_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset, + index_t is_valid_element = 1) +{ + constexpr index_t bytes = sizeof(T) * N; + static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, + "wrong! not supported by buffer_store instruction"); + + using type = typename vector_type::type; + if constexpr(oob_conditional_check) + { + buffer_store_if{}(dst_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0, + is_valid_element); + } + else + { + buffer_store{}(dst_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } +} + template __device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, @@ -1247,10 +1635,11 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type::typ // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. +// oob_conditional_check : dynamic check if out-of-bound template + bool oob_conditional_check = true> __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, @@ -1268,15 +1657,21 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; - return amd_buffer_load_impl( + uint32_t src_addr_shift = [&]() { + if constexpr(oob_conditional_check) + return src_thread_element_valid ? 0 : 0x80000000; + else + return 0; + }(); + return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); - #else - - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); - return src_thread_element_valid ? tmp : vector_t(0); + if constexpr(oob_conditional_check) + return src_thread_element_valid ? tmp : vector_t(0); + else + return tmp; #endif } @@ -1287,7 +1682,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, template + bool oob_conditional_check = true> __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, @@ -1305,19 +1700,24 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); - return src_thread_element_valid ? tmp : vector_t(customized_value); + if constexpr(oob_conditional_check) + return src_thread_element_valid ? tmp : vector_t(customized_value); + else + return tmp; } template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + bool oob_conditional_check = true> __device__ void amd_buffer_load_raw(typename vector_type_maker::type::type& dst, const T* p_src_wave, index_t src_thread_element_offset, - index_t src_element_space_size) + index_t src_element_space_size, + index_t is_valid_element = 0) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size); @@ -1329,8 +1729,8 @@ __device__ void amd_buffer_load_raw(typename vector_type_maker::type::type constexpr index_t vector_size = scalar_type::vector_size; - amd_buffer_load_raw_impl( - dst, src_wave_buffer_resource, src_thread_addr_offset, 0); + amd_buffer_load_raw_impl( + dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); } // unfortunately async copy can not make sure invalid data is zero inside LDS @@ -1360,7 +1760,8 @@ __device__ void amd_async_buffer_load_with_oob(T* smem, // It is user's responsibility to make sure that is true. template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + bool oob_conditional_check = true> __device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, @@ -1377,11 +1778,24 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + uint32_t dst_addr_shift = [&]() { + if constexpr(oob_conditional_check) + return dst_thread_element_valid ? 0 : 0x80000000; + else + return 0; + }(); amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else - if(dst_thread_element_valid) + if constexpr(oob_conditional_check) + { + if(dst_thread_element_valid) + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } + } + else { amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); @@ -1389,6 +1803,34 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #endif } +template +__device__ void +amd_buffer_store_raw(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + + amd_buffer_store_raw_impl( + src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + 0, + dst_thread_element_valid); +} + // buffer_atomic_add requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. diff --git a/include/ck/utility/buffer_view_impl_generic.hpp b/include/ck/utility/buffer_view_impl_generic.hpp index 1b88bf5c..f547a021 100644 --- a/include/ck/utility/buffer_view_impl_generic.hpp +++ b/include/ck/utility/buffer_view_impl_generic.hpp @@ -60,12 +60,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> __device__ constexpr auto - Get(index_t i, bool is_valid_element, bool_constant = {}) const + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; diff --git a/include/ck/utility/buffer_view_impl_global.hpp b/include/ck/utility/buffer_view_impl_global.hpp index 62150940..156bcb2a 100644 --- a/include/ck/utility/buffer_view_impl_global.hpp +++ b/include/ck/utility/buffer_view_impl_global.hpp @@ -63,12 +63,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> __device__ constexpr auto - Get(index_t i, bool is_valid_element, bool_constant = {}) const + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -93,15 +93,16 @@ struct BufferView, t_per_x, Coherence, - use_inline_asm>( + oob_conditional_check>( p_data_, i, is_valid_element, buffer_size_); } else { - return amd_buffer_load_invalid_element_return_customized_value, - t_per_x, - Coherence, - use_inline_asm>( + return amd_buffer_load_invalid_element_return_customized_value< + remove_cvref_t, + t_per_x, + Coherence, + oob_conditional_check>( p_data_, i, is_valid_element, buffer_size_, invalid_element_value_); } } @@ -135,10 +136,11 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> - __device__ constexpr auto GetRaw(remove_cvref_t& dst, index_t i) const + __device__ constexpr auto GetRaw(remove_cvref_t& dst, index_t i, bool is_valid_element) const { constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -149,7 +151,8 @@ struct BufferView, t_per_x, Coherence>(dst, p_data_, i, buffer_size_); + amd_buffer_load_raw, t_per_x, Coherence, oob_conditional_check>( + dst, p_data_, i, buffer_size_, is_valid_element); } // i is offset of T, not X. i should be aligned to X @@ -205,6 +208,7 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> @@ -246,6 +250,27 @@ struct BufferView>::type, + typename scalar_type>::type>::value, + bool>::type = false> + __device__ void SetRaw(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + amd_buffer_store_raw, t_per_x, Coherence, oob_conditional_check>( + x, p_data_, i, is_valid_element, buffer_size_); + } + template >::type, typename scalar_type>::type>::value, diff --git a/include/ck/utility/buffer_view_impl_lds.hpp b/include/ck/utility/buffer_view_impl_lds.hpp index a4620330..f809e8b0 100644 --- a/include/ck/utility/buffer_view_impl_lds.hpp +++ b/include/ck/utility/buffer_view_impl_lds.hpp @@ -56,12 +56,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> __device__ constexpr auto - Get(index_t i, bool is_valid_element, bool_constant = {}) const + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; diff --git a/include/ck/utility/buffer_view_impl_vgpr.hpp b/include/ck/utility/buffer_view_impl_vgpr.hpp index 15bdf135..a7efd7a5 100644 --- a/include/ck/utility/buffer_view_impl_vgpr.hpp +++ b/include/ck/utility/buffer_view_impl_vgpr.hpp @@ -60,12 +60,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> __device__ constexpr auto - Get(index_t i, bool is_valid_element, bool_constant = {}) const + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size;