diff --git a/.gitignore b/.gitignore index d3b18b358..ab0fa3618 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ artifacts/ **/times.csv transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* +transformer_engine/lib/aiter diff --git a/3rdparty/aiter b/3rdparty/aiter index a64fa18e6..dd3bd2c93 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit a64fa18e60235994e4cbfd7059cc2f60d06e743f +Subproject commit dd3bd2c93ed2806f2f6ef88c80b025a510a7cf8a diff --git a/tools/check_aiter_mha_args_usage.py b/tools/check_aiter_mha_args_usage.py new file mode 100644 index 000000000..760c017dc --- /dev/null +++ b/tools/check_aiter_mha_args_usage.py @@ -0,0 +1,97 @@ +import argparse +import re +from pathlib import Path +from typing import List, Set +import sys + +def parse_with_skip_comments(buffer, line, regex, outputs): + # skip comments + stripped = line.strip() + if not stripped or stripped.startswith("//"): + return + line_no_comment = re.sub(r"//.*", "", line) + buffer[0] += " " + line_no_comment.strip() + if ";" not in line_no_comment: + return + match = regex.search(buffer[0]) + if match: + outputs.append(match.group(1)) + buffer[0] = "" + + +def extract_fields_from_header(text: str, struct_name: str) -> List[str]: + struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") + struct_end_re = re.compile(r"^\s*};\s*$") + + struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") + lines = text.splitlines() + in_struct = False + fields: List[str] = [] + buffer = [""] + for line in lines: + if not in_struct: + if struct_start_re.search(line): + in_struct = True + continue + if struct_end_re.search(line): + break + parse_with_skip_comments(buffer, line, struct_field_re, fields) + return fields + + +def extract_usage_from_source(text: str, var_name: str) -> Set[str]: + assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=") + assignments = [] + lines = text.splitlines() + buffer = [""] + for line in lines: + parse_with_skip_comments(buffer, line, assign_re, assignments) + return set(assignments) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") + parser.add_argument("--mode", choices=["fwd", "bwd", "both"], required=True, help="Mode: fwd, bwd, or both") + args = parser.parse_args() + modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode] + mismatch = 0 + for mode in modes: + header_path = Path(f"3rdparty/aiter/csrc/include/mha_{mode}.h") + source_path = Path(f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp") + header_text = header_path.read_text(encoding="utf-8") + source_text = source_path.read_text(encoding="utf-8") + + header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args") + header_set = set(header_fields) + used_fields = extract_usage_from_source(source_text, f"fmha_args") + + missing_in_usage = sorted(header_set - used_fields) + unknown_in_header = sorted(used_fields - header_set) + mismatch += len(missing_in_usage) + len(unknown_in_header) + + print(f"\nAnalyzing mha_{mode}_args\n") + print(f"mha_{mode}_args fields in header:", len(header_set)) + print(f"mha_{mode}_args fields referenced in source:", len(used_fields)) + + if missing_in_usage: + print("\nFields present in header but not referenced in source:") + for name in missing_in_usage: + print(f" - {name}") + else: + print("\nAll header fields are referenced in source.") + + if unknown_in_header: + print("\nFields referenced in source but not in header:") + for name in unknown_in_header: + print(f" - {name}") + else: + print("\nNo unknown fields referenced in source.") + + if mismatch: + print(f"\nTotal mismatched fields: {mismatch}") + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 3f51b96b6..e1453aba5 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -16,6 +16,17 @@ namespace ck_fused_attn{ +int ck_to_aiter_mask_type(mask_enum mask_type, ck_tile::index_t left, ck_tile::index_t right){ + if( + mask_type == mask_enum::no_mask || + mask_type == mask_enum::window_generic + ) return 0; + if(left == -1 && right == 0){ + return mask_type == mask_enum::mask_top_left ? 1 : 2; + } + return 3; +} + // TODO: unify with binary search in TE/common/fused_attn(rocm)/util // no device std::upper_bound // in an increasing array with given size len, search for the index that: @@ -45,12 +56,12 @@ __global__ void dk_dv_reduce( DataType *dv, //k,v, dk, dv guaranteed to have the same stride uint64_t stride_b_dkv, uint64_t stride_h_dkv, uint64_t stride_s_dkv){ - + uint64_t batch_idx = blockIdx.x; uint64_t seqlen_idx = blockIdx.y; uint64_t head_k_idx = blockIdx.z; uint64_t hdim_idx = threadIdx.x; - + // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -60,7 +71,7 @@ __global__ void dk_dv_reduce( assert(hdim_idx){ @@ -92,12 +103,12 @@ __global__ void dk_or_dv_reduce( DataType *dk_or_dv, //k,v, dk, dv guaranteed to have the same stride uint64_t stride_b_dk_or_dv, uint64_t stride_h_dk_or_dv, uint64_t stride_s_dk_or_dv){ - + uint64_t batch_idx = blockIdx.x; uint64_t seqlen_idx = blockIdx.y; uint64_t head_k_or_v_idx = blockIdx.z; uint64_t hdim_idx = threadIdx.x; - + // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -106,7 +117,7 @@ __global__ void dk_or_dv_reduce( assert(hdim_idx){ @@ -142,9 +153,9 @@ __global__ void dk_dv_reduce_thd( uint64_t seqlen_idx = blockIdx.x; uint64_t head_k_idx = blockIdx.y; uint64_t hdim_idx = threadIdx.x; - + assert(hdim_idx= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ return; } @@ -164,7 +175,7 @@ __global__ void dk_dv_reduce_thd( uint64_t read_idx = head_k_idx*head_idx_offset*stride_h_dkv_expanded + seqlen_idx*stride_s_dkv_expanded + hdim_idx; uint64_t write_idx = head_k_idx*stride_h_dkv + seqlen_idx* stride_s_dkv + hdim_idx; - + for(uint64_t ii = 0; ii < head_idx_offset; ii++){ // bf16 requires special casting in CK if constexpr (std::is_same_v){ @@ -202,7 +213,7 @@ __global__ void dk_or_dv_reduce_thd( uint64_t seqlen_idx = blockIdx.x; uint64_t head_k_or_v_idx = blockIdx.y; uint64_t hdim_idx = threadIdx.x; - + assert(hdim_idx= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ @@ -222,7 +233,7 @@ __global__ void dk_or_dv_reduce_thd( uint64_t read_idx = head_k_or_v_idx*head_idx_offset*stride_h_dk_or_dv_expanded + seqlen_idx*stride_s_dk_or_dv_expanded + hdim_idx; uint64_t write_idx = head_k_or_v_idx*stride_h_dk_or_dv + seqlen_idx* stride_s_dk_or_dv + hdim_idx; - + for(uint64_t ii = 0; ii < head_idx_offset; ii++){ // bf16 requires special casting in CK if constexpr (std::is_same_v){ @@ -248,7 +259,7 @@ __global__ void dbias_reduce_11ss( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ - + const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ @@ -278,7 +289,7 @@ __global__ void dbias_reduce_1hss( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ - + const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ @@ -308,7 +319,7 @@ __global__ void dbias_reduce_b1ss( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ - + const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ @@ -333,125 +344,118 @@ __global__ void dbias_reduce_b1ss( } // print the fmha_traits and args passed into ck apis -void log_bwd_config(const char* func_name, - const std::string data_type_str, - const bool is_group_mode, - const mask_enum mask_type, - const bias_enum bias_type, - const bool has_dbias, - const bool has_dropout, - const bool is_store_randval, - const bool is_deterministic, - const bool uses_bwd_v3, - const bool is_v3_atomic_fp32, - const int how_v3_bf16_cvt, - const fmha_bwd_args& fmha_args){ - - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } - if (ck_fused_attn_log_config) { - std::cout<::type>(mask_type)<::type>(bias_type)<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset)) + ); + log_value( + "dropout_offset_ptr", + std::get<1>(std::get>(fmha_args.drop_seed_offset)) + ); } void dump_bwd_timings(const char* dump_path, float average_runtime){ @@ -460,37 +464,42 @@ void dump_bwd_timings(const char* dump_path, float average_runtime){ file << average_runtime << "\n"; } -hipError_t ck_attn_bwd( +hipError_t _ck_attn_bwd_impl( DType dtype, - uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - const void* q_ptr, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, + uint64_t d_qk, uint64_t d_v, + uint64_t bias_b, uint64_t bias_h, + uint64_t max_tokens_q, uint64_t max_tokens_kv, + const void* q_ptr, uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, const void* bias_ptr, const void* alibi_slope_ptr, - const void* o_ptr, + const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, + const void* o_ptr, uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, - const void* lse_ptr, - const void* do_ptr, + const void* lse_ptr, + const void* do_ptr, uint64_t stride_b_do, uint64_t stride_h_do, uint64_t stride_s_do, float scaling_factor, float dropout_probability, void* philox_seed_ptr, void* philox_offset_ptr, BiasType attn_bias_type, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* dq_ptr, + void* dq_ptr, uint64_t stride_b_dq, uint64_t stride_h_dq, uint64_t stride_s_dq, void* dq_acc_ptr, void* dk_expanded_ptr, void* dv_expanded_ptr, uint64_t stride_b_dk_expanded, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, uint64_t stride_b_dv_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, - void* dk_ptr, + void* dk_ptr, uint64_t stride_b_dk, uint64_t stride_h_dk, uint64_t stride_s_dk, - void* dv_ptr, + void* dv_ptr, uint64_t stride_b_dv, uint64_t stride_h_dv, uint64_t stride_s_dv, void* dbias_expanded_ptr, void* dbias_ptr, @@ -499,10 +508,13 @@ hipError_t ck_attn_bwd( bool uses_bwd_v3, bool is_v3_atomic_fp32, int how_v3_bf16_cvt, + bool is_group_mode, + const char* func_name, + bool ck_log_config, hipStream_t stream){ bool has_dropout = (dropout_probability > 0.f); - bool has_dbias = dbias_ptr!=nullptr; + bool has_dbias = dbias_ptr != nullptr; bool is_mqa_gqa = (h > hg); /* CK input parameters */ @@ -518,184 +530,158 @@ hipError_t ck_attn_bwd( float scale_s = scaling_factor; float p_drop = dropout_probability; float p_undrop = 1.0 - p_drop; - bool is_group_mode = false; bool s_randval = false; - bias_enum bias_type; - BiasShape bias_shape; - std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + bias_enum bias_type = bias_enum::no_bias; + BiasShape bias_shape = BiasShape::k11SS; + if (!is_group_mode) { + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + } + ck_tile::index_t left, right; left = window_size_left; right = window_size_right; - mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } + const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; - - ck_tile::index_t shape_seqlen_q = seqlen_q; - ck_tile::index_t shape_seqlen_k = seqlen_k; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_log_config}; std::string data_type_str = get_data_type_str(dtype); - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias of shape (bias_b, bias_h, s_q, s_kv) - const ck_tile::index_t stride_bias = max_seqlen_k; - const ck_tile::index_t stride_o = stride_s_o; - const ck_tile::index_t stride_randval = max_seqlen_k; - const ck_tile::index_t stride_do = stride_s_do; - const ck_tile::index_t stride_dq = stride_s_dq; - const ck_tile::index_t stride_dk = stride_s_dk; - const ck_tile::index_t stride_dv = stride_s_dv; - const ck_tile::index_t stride_dk_expanded = stride_s_dk_expanded; - const ck_tile::index_t stride_dv_expanded = stride_s_dv_expanded; - const ck_tile::index_t stride_dq_acc = d_qk; //dq_acc of shape (nsplits, B, H, S, D) - // dbias is of the same shape as bias - // but ck only take dbias with BHSS - const ck_tile::index_t stride_dbias = max_seqlen_k; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - // bias input can be of different shapes (11SS, 1HSS, B1SS, and BHSS), but dbias must be of BHSS - const ck_tile::index_t nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; - const ck_tile::index_t nhead_stride_o = stride_h_o; - const ck_tile::index_t nhead_stride_randval = - shape_seqlen_q * max_seqlen_k; - const ck_tile::index_t nhead_stride_do = stride_h_do; - const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; - const ck_tile::index_t nhead_stride_dq = stride_h_dq; - const ck_tile::index_t nhead_stride_dk = stride_h_dk; - const ck_tile::index_t nhead_stride_dv = stride_h_dv; - const ck_tile::index_t nhead_stride_dk_expanded = stride_h_dk_expanded; - const ck_tile::index_t nhead_stride_dv_expanded = stride_h_dv_expanded; - // dbias can only be of BHSS - const ck_tile::index_t nhead_stride_dbias = max_seqlen_q * max_seqlen_k; - const ck_tile::index_t nhead_stride_dq_acc = s_q*d_qk; //dq_acc of shape (nsplits, B, H, S, D) - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = stride_b_q; - const ck_tile::index_t batch_stride_k = stride_b_k; - const ck_tile::index_t batch_stride_v = stride_b_v; - // bias input can be of different shapes (11SS, 1HSS, B1SS, and BHSS), but dbias must be of BHSS - // for B1SS and BHSS, batch stride for bias are both bias_h x s_q x s_kv (bias_h==1 for B1SS and bias_h == h for BHSS) - const ck_tile::index_t batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: bias_h* max_seqlen_q * max_seqlen_k; - const ck_tile::index_t batch_stride_o = stride_b_o; - const ck_tile::index_t batch_stride_randval = - nhead * shape_seqlen_q * max_seqlen_k; - const ck_tile::index_t batch_stride_do = stride_b_do; - const ck_tile::index_t batch_stride_lsed = nhead * max_seqlen_q; - const ck_tile::index_t batch_stride_dq = stride_b_dq; - const ck_tile::index_t batch_stride_dk = stride_b_dk; - const ck_tile::index_t batch_stride_dv = stride_b_dv; - const ck_tile::index_t batch_stride_dk_expanded = stride_b_dk_expanded; - const ck_tile::index_t batch_stride_dv_expanded = stride_b_dv_expanded; - // for dbias, use h since h can be different from bias_h - const ck_tile::index_t batch_stride_dbias = h* max_seqlen_q * max_seqlen_k; - const ck_tile::index_t batch_stride_dq_acc = h*s_q*d_qk; //dq_acc of shape (nsplits, B, H, S, D) - const ck_tile::index_t split_stride_dq_acc = b * h * s_q * d_qk; - - return fmha_bwd_args{q_ptr, - k_ptr, - v_ptr, - bias_type==bias_enum::no_bias? nullptr : (bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr), - o_ptr, - lse_ptr, - do_ptr, - lse_workspace_ptr, - nullptr, - dq_ptr, - is_mqa_gqa? dk_expanded_ptr:dk_ptr, - is_mqa_gqa? dv_expanded_ptr:dv_ptr, - has_dbias? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr): nullptr, - dq_acc_ptr, //dq_acc_buf - nullptr,//seqstart_q_ptr - nullptr,//seqstart_k_ptr - nullptr, /* seqlen_q_ptr */ - nullptr, /* seqlen_k_ptr */ - nullptr, //cu_seqlen_q_ptr - nullptr, //cu_seqlen_k_ptr - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - max_seqlen_k, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - stride_q, - stride_k, - stride_v, - bias_type==bias_enum::alibi? 0: stride_bias, - stride_o, - stride_randval, - stride_do, - stride_dq_acc,//stride_dq_acc - stride_dq,//stride_dq - is_mqa_gqa? stride_dk_expanded:stride_dk, - is_mqa_gqa? stride_dv_expanded:stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, //nhead_stride_dq_acc - nhead_stride_dq, - is_mqa_gqa? nhead_stride_dk_expanded:nhead_stride_dk, - is_mqa_gqa? nhead_stride_dv_expanded:nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_o, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, //batch_stride_dq_acc - batch_stride_dq, - is_mqa_gqa? batch_stride_dk_expanded:batch_stride_dk, - is_mqa_gqa? batch_stride_dv_expanded:batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - left, - right, - static_cast(mask_type), - p_drop, - p_undrop, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); + aiter::mha_bwd_args fmha_args{}; + fmha_args.mask_type = ck_to_aiter_mask_type(mask_type, left, right); + fmha_args.use_asm_v3 = uses_bwd_v3; + fmha_args.v3_atomic_fp32 = is_v3_atomic_fp32; + fmha_args.v3_bf16_cvt = how_v3_bf16_cvt; + fmha_args.v3_api_check = false; + + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.data_type = data_type_str; + fmha_args.is_group_mode = is_group_mode; + fmha_args.ck_mask_type = static_cast(mask_type); + fmha_args.bias_type = static_cast(bias_type); + fmha_args.has_dbias = (!is_group_mode) && has_dbias; + fmha_args.has_dropout = has_dropout; + fmha_args.is_store_randval = s_randval; + fmha_args.is_deterministic = deterministic; + + fmha_args.q_ptr = q_ptr; + fmha_args.k_ptr = k_ptr; + fmha_args.v_ptr = v_ptr; + fmha_args.bias_ptr = (bias_type==bias_enum::no_bias || is_group_mode) ? nullptr + : (bias_type==bias_enum::alibi? alibi_slope_ptr : bias_ptr); + fmha_args.o_ptr = o_ptr; + fmha_args.lse_ptr = lse_ptr; + fmha_args.do_ptr = do_ptr; + fmha_args.d_ptr = lse_workspace_ptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.dq_ptr = dq_ptr; + fmha_args.dk_ptr = is_mqa_gqa? dk_expanded_ptr:dk_ptr; + fmha_args.dv_ptr = is_mqa_gqa? dv_expanded_ptr:dv_ptr; + fmha_args.dbias_ptr = ((!is_group_mode) && has_dbias) + ? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr) + : nullptr; + fmha_args.dq_acc_ptr = dq_acc_ptr; + + if (is_group_mode) { + fmha_args.seqstart_q_ptr = cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr; + fmha_args.seqstart_k_ptr = cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr; + fmha_args.seqlen_q_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.cu_seqlen_q_ptr = cu_seqlen_q_ptr; + fmha_args.cu_seqlen_k_ptr = cu_seqlen_kv_ptr; + } else { + fmha_args.seqstart_q_ptr = nullptr; + fmha_args.seqstart_k_ptr = nullptr; + fmha_args.seqlen_q_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.cu_seqlen_q_ptr = nullptr; + fmha_args.cu_seqlen_k_ptr = nullptr; + } + + fmha_args.seqlen_q = is_group_mode ? max_seqlen_q : seqlen_q; + fmha_args.seqlen_k = is_group_mode ? max_seqlen_k : seqlen_k; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = max_seqlen_q; + fmha_args.max_seqlen_k = max_seqlen_k; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + fmha_args.scale = scale_s; + + // setup stride_* arguments + fmha_args.stride_q = stride_s_q; + fmha_args.stride_k = stride_s_k; + fmha_args.stride_v = stride_s_v; + fmha_args.stride_bias = (!is_group_mode && bias_type!=bias_enum::alibi) ? max_seqlen_k : 0; + fmha_args.stride_o = stride_s_o; + fmha_args.stride_randval = max_seqlen_k; + fmha_args.stride_do = stride_s_do; + fmha_args.stride_dq_acc = d_qk; + fmha_args.stride_dq = stride_s_dq; + fmha_args.stride_dk = is_mqa_gqa? stride_s_dk_expanded:stride_s_dk; + fmha_args.stride_dv = is_mqa_gqa? stride_s_dv_expanded:stride_s_dv; + fmha_args.stride_dbias = (!is_group_mode && bias_type!=bias_enum::alibi) ? max_seqlen_k : 0; + + // setup nhead_stride_* arguments + fmha_args.nhead_stride_q = stride_h_q; + fmha_args.nhead_stride_k = stride_h_k; + fmha_args.nhead_stride_v = stride_h_v; + fmha_args.nhead_stride_bias = (!is_group_mode && (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS)) + ? max_seqlen_q * max_seqlen_k + : 0; + fmha_args.nhead_stride_o = stride_h_o; + fmha_args.nhead_stride_randval = is_group_mode ? 0 : seqlen_q * max_seqlen_k; + fmha_args.nhead_stride_do = stride_h_do; + fmha_args.nhead_stride_lsed = is_group_mode ? max_tokens_q : max_seqlen_q; + fmha_args.nhead_stride_dq_acc = static_cast((is_group_mode ? max_tokens_q : s_q) * d_qk); + fmha_args.nhead_stride_dq = stride_h_dq; + fmha_args.nhead_stride_dk = is_mqa_gqa? stride_h_dk_expanded:stride_h_dk; + fmha_args.nhead_stride_dv = is_mqa_gqa? stride_h_dv_expanded:stride_h_dv; + fmha_args.nhead_stride_dbias = (!is_group_mode) ? max_seqlen_q * max_seqlen_k : 0; + + // setup batch_stride_* arguments + fmha_args.batch_stride_q = is_group_mode ? 0 : stride_b_q; + fmha_args.batch_stride_k = is_group_mode ? 0 : stride_b_k; + fmha_args.batch_stride_v = is_group_mode ? 0 : stride_b_v; + fmha_args.batch_stride_bias = (!is_group_mode && (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS)) + ? 0 + : (is_group_mode ? 0 : bias_h * max_seqlen_q * max_seqlen_k); + fmha_args.batch_stride_o = is_group_mode ? 0 : stride_b_o; + fmha_args.batch_stride_randval = is_group_mode ? 0 : nhead * seqlen_q * max_seqlen_k; + fmha_args.batch_stride_do = is_group_mode ? 0 : stride_b_do; + fmha_args.batch_stride_lsed = is_group_mode ? 0 : nhead * max_seqlen_q; + fmha_args.batch_stride_dq_acc = is_group_mode ? 0 : static_cast(h * s_q * d_qk); + fmha_args.batch_stride_dq = is_group_mode ? 0 : stride_b_dq; + fmha_args.batch_stride_dk = is_group_mode ? 0 : (is_mqa_gqa? stride_b_dk_expanded:stride_b_dk); + fmha_args.batch_stride_dv = is_group_mode ? 0 : (is_mqa_gqa? stride_b_dv_expanded:stride_b_dv); + fmha_args.batch_stride_dbias = is_group_mode ? 0 : h * max_seqlen_q * max_seqlen_k; + fmha_args.split_stride_dq_acc = static_cast(is_group_mode ? (max_tokens_q * h * d_qk) : (b * h * s_q * d_qk)); + + fmha_args.window_size_left = left; + fmha_args.window_size_right = right; + fmha_args.p_drop = p_drop; + fmha_args.p_undrop = p_undrop; + fmha_args.drop_seed_offset = std::pair{philox_seed_ptr, philox_offset_ptr}; + + // modify the max_seqlen_q for better performance in 0-length cases + // lse_thd_ptr used as buffer + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { + if(std::string(env_p) == "1"){ + if(ck_log_config){ + std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + } + fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); + fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); + } + } // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); - - float average_runtime = aiter::mha_bwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_dbias, - s_randval, - deterministic, - uses_bwd_v3, - is_v3_atomic_fp32, - how_v3_bf16_cvt); + log_bwd_config(func_name, fmha_args, ck_log_config); + + float average_runtime = aiter::mha_bwd(fmha_args, stream_config); if(dump_path){ dump_bwd_timings(dump_path, average_runtime); } @@ -703,11 +689,117 @@ hipError_t ck_attn_bwd( //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); } + return hipSuccess; +} +hipError_t ck_attn_bwd( + DType dtype, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, + const void* q_ptr, + uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, + const void* k_ptr, + uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, + const void* v_ptr, + uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, + const void* bias_ptr, + const void* alibi_slope_ptr, + const void* o_ptr, + uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, + const void* lse_ptr, + const void* do_ptr, + uint64_t stride_b_do, uint64_t stride_h_do, uint64_t stride_s_do, + float scaling_factor, float dropout_probability, + void* philox_seed_ptr, void* philox_offset_ptr, + BiasType attn_bias_type, + MaskType attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + void* dq_ptr, + uint64_t stride_b_dq, uint64_t stride_h_dq, uint64_t stride_s_dq, + void* dq_acc_ptr, + void* dk_expanded_ptr, + void* dv_expanded_ptr, + uint64_t stride_b_dk_expanded, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, + uint64_t stride_b_dv_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, + void* dk_ptr, + uint64_t stride_b_dk, uint64_t stride_h_dk, uint64_t stride_s_dk, + void* dv_ptr, + uint64_t stride_b_dv, uint64_t stride_h_dv, uint64_t stride_s_dv, + void* dbias_expanded_ptr, + void* dbias_ptr, + void* lse_workspace_ptr, + bool deterministic, + bool uses_bwd_v3, + bool is_v3_atomic_fp32, + int how_v3_bf16_cvt, + hipStream_t stream){ + + bool has_dropout = (dropout_probability > 0.f); + bool has_dbias = dbias_ptr!=nullptr; + bool is_mqa_gqa = (h > hg); + bias_enum bias_type; + BiasShape bias_shape; + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + + bool ck_log_config = false; + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { + if (env_p != nullptr && std::string(env_p) == "1") + ck_log_config = true; + } + + hipError_t impl_status = _ck_attn_bwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, + bias_b, bias_h, + s_q, s_kv, + q_ptr, + stride_b_q, stride_h_q, stride_s_q, + k_ptr, + stride_b_k, stride_h_k, stride_s_k, + v_ptr, + stride_b_v, stride_h_v, stride_s_v, + bias_ptr, + alibi_slope_ptr, + nullptr, nullptr, + nullptr, nullptr, + o_ptr, + stride_b_o, stride_h_o, stride_s_o, + lse_ptr, + do_ptr, + stride_b_do, stride_h_do, stride_s_do, + scaling_factor, dropout_probability, + philox_seed_ptr, philox_offset_ptr, + attn_bias_type, + attn_mask_type, + window_size_left, window_size_right, + dq_ptr, + stride_b_dq, stride_h_dq, stride_s_dq, + dq_acc_ptr, + dk_expanded_ptr, + dv_expanded_ptr, + stride_b_dk_expanded, stride_h_dk_expanded, stride_s_dk_expanded, + stride_b_dv_expanded, stride_h_dv_expanded, stride_s_dv_expanded, + dk_ptr, + stride_b_dk, stride_h_dk, stride_s_dk, + dv_ptr, + stride_b_dv, stride_h_dv, stride_s_dv, + dbias_expanded_ptr, + dbias_ptr, + lse_workspace_ptr, + deterministic, + uses_bwd_v3, + is_v3_atomic_fp32, + how_v3_bf16_cvt, + false, + __FUNCTION__, + ck_log_config, + stream); + if (impl_status != hipSuccess) { + return impl_status; + } if(is_mqa_gqa){ dim3 grid(b, s_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ + if (ck_log_config){ std::cout<, grid, block, 0, stream, b, h, s_q, s_kv, static_cast(dbias_expanded_ptr), - static_cast(dbias_ptr));); + static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (ck_fused_attn_log_config){ + if (ck_log_config){ std::cout<, grid, block, 0, stream, b, h, s_q, s_kv, static_cast(dbias_expanded_ptr), - static_cast(dbias_ptr));); + static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (ck_fused_attn_log_config){ + if (ck_log_config){ std::cout<, grid, block, 0, stream, b, h, s_q, s_kv, static_cast(dbias_expanded_ptr), - static_cast(dbias_ptr));); + static_cast(dbias_ptr));); } } return hipSuccess; } -hipError_t ck_attn_varlen_bwd( +hipError_t ck_attn_varlen_bwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t max_tokens_q, uint64_t max_tokens_kv, - const void* q_ptr, + const void* q_ptr, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, - const void* o_ptr, + const void* o_ptr, uint64_t stride_h_o, uint64_t stride_s_o, - const void* lse_thd_ptr, - const void* do_ptr, + const void* lse_thd_ptr, + const void* do_ptr, uint64_t stride_h_do, uint64_t stride_s_do, float scaling_factor, float dropout_probability, void* philox_seed_ptr, void* philox_offset_ptr, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* dq_ptr, + void* dq_ptr, uint64_t stride_h_dq, uint64_t stride_s_dq, void* dq_acc_ptr, void* dk_expanded_ptr, void* dv_expanded_ptr, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, - void* dk_ptr, + void* dk_ptr, uint64_t stride_h_dk, uint64_t stride_s_dk, - void* dv_ptr, + void* dv_ptr, uint64_t stride_h_dv, uint64_t stride_s_dv, void* lse_workspace_ptr, bool deterministic, @@ -859,216 +951,69 @@ hipError_t ck_attn_varlen_bwd( bool is_v3_atomic_fp32, int how_v3_bf16_cvt, hipStream_t stream){ - - bool has_dropout = (dropout_probability > 0.f); - bool has_dbias = false; bool is_mqa_gqa = (h > hg); - /* CK input parameters */ - ck_tile::index_t batch = b; - ck_tile::index_t nhead = h; - ck_tile::index_t hdim_q = d_qk; - ck_tile::index_t nhead_k = hg; - ck_tile::index_t hdim_v = d_v; - ck_tile::index_t max_seqlen_q = s_q; - ck_tile::index_t max_seqlen_k = s_kv; - float scale_s = scaling_factor; - float p_drop = dropout_probability; - float p_undrop = 1.0 - p_drop; - bool is_group_mode = true; - bool s_randval = false; - - // THD does not work with bias - - ck_tile::index_t left, right; - left = window_size_left; - right = window_size_right; - mask_enum mask_type = static_cast(attn_mask_type); - - bool ck_fused_attn_log_config = false; + bool ck_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } - const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); - // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; - - std::string data_type_str = get_data_type_str(dtype); - - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias not used in THD qkv layout - const ck_tile::index_t stride_bias = 0; - const ck_tile::index_t stride_o = stride_s_o; - const ck_tile::index_t stride_randval = max_seqlen_k; - const ck_tile::index_t stride_do = stride_s_do; - const ck_tile::index_t stride_dq = stride_s_dq; - const ck_tile::index_t stride_dk = stride_s_dk; - const ck_tile::index_t stride_dv = stride_s_dv; - const ck_tile::index_t stride_dk_expanded = stride_s_dk_expanded; - const ck_tile::index_t stride_dv_expanded = stride_s_dv_expanded; - const ck_tile::index_t stride_dq_acc = d_qk; //dq_acc of shape (nsplits, H, max_tokens_q, D_qk) - // bias not used in THD qkv layout - const ck_tile::index_t stride_dbias = 0; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - // bias not used in THD qkv layout - const ck_tile::index_t nhead_stride_bias = 0; - const ck_tile::index_t nhead_stride_o = stride_h_o; - const ck_tile::index_t nhead_stride_randval = 0; - const ck_tile::index_t nhead_stride_do = stride_h_do; - // use packed lse - const ck_tile::index_t nhead_stride_lsed = max_tokens_q; - const ck_tile::index_t nhead_stride_dq = stride_h_dq; - const ck_tile::index_t nhead_stride_dk = stride_h_dk; - const ck_tile::index_t nhead_stride_dv = stride_h_dv; - const ck_tile::index_t nhead_stride_dk_expanded = stride_h_dk_expanded; - const ck_tile::index_t nhead_stride_dv_expanded = stride_h_dv_expanded; - // bias not used in THD qkv layout - const ck_tile::index_t nhead_stride_dbias = 0; - const ck_tile::index_t nhead_stride_dq_acc = max_tokens_q*d_qk; //dq_acc of shape (nsplits, H, max_tokens_q, D_qk) - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = 0; - const ck_tile::index_t batch_stride_k = 0; - const ck_tile::index_t batch_stride_v = 0; - // bias not used in THD qkv layout - const ck_tile::index_t batch_stride_bias = 0; - const ck_tile::index_t batch_stride_o = 0; - const ck_tile::index_t batch_stride_randval = 0; - const ck_tile::index_t batch_stride_do = 0; - const ck_tile::index_t batch_stride_lsed = 0; - const ck_tile::index_t batch_stride_dq = 0; - const ck_tile::index_t batch_stride_dk = 0; - const ck_tile::index_t batch_stride_dv = 0; - const ck_tile::index_t batch_stride_dk_expanded = 0; - const ck_tile::index_t batch_stride_dv_expanded = 0; - // bias not used in THD qkv layout - const ck_tile::index_t batch_stride_dbias = 0; - const ck_tile::index_t batch_stride_dq_acc = 0; //dq_acc of shape (nsplits, T, H, D) - const ck_tile::index_t split_stride_dq_acc = max_tokens_q*h*d_qk; - - return fmha_bwd_args{q_ptr, - k_ptr, - v_ptr, - nullptr, - o_ptr, - lse_thd_ptr, - do_ptr, - lse_workspace_ptr, - nullptr, - dq_ptr, - is_mqa_gqa? dk_expanded_ptr:dk_ptr, - is_mqa_gqa? dv_expanded_ptr:dv_ptr, - nullptr, //dbias_ptr - dq_acc_ptr, //dq_acc_buf - cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr - cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr - nullptr, /* seqlen_q_ptr */ - nullptr, /* seqlen_k_ptr */ - cu_seqlen_q_ptr, //cu_seqlen_q_ptr - cu_seqlen_kv_ptr, //cu_seqlen_k_ptr - max_seqlen_q, //seqlen_q, unused in group mode - max_seqlen_k, //seqlen_kv, unused in group mode - batch, - max_seqlen_q, - max_seqlen_k, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - stride_randval, - stride_do, - stride_dq_acc,//stride_dq_acc - stride_dq,//stride_dq - is_mqa_gqa? stride_dk_expanded:stride_dk, - is_mqa_gqa? stride_dv_expanded:stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, //nhead_stride_dq_acc - nhead_stride_dq, - is_mqa_gqa? nhead_stride_dk_expanded:nhead_stride_dk, - is_mqa_gqa? nhead_stride_dv_expanded:nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_o, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, //batch_stride_dq_acc - batch_stride_dq, - is_mqa_gqa? batch_stride_dk_expanded:batch_stride_dk, - is_mqa_gqa? batch_stride_dv_expanded:batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - left, - right, - static_cast(mask_type), - p_drop, - p_undrop, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); - - // modify the max_seqlen_q for better performance in 0-length cases - // lse_thd_ptr used as buffer - if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { - if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; - } - fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); - fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); - } + ck_log_config = true; } - // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); - - float average_runtime = aiter::mha_bwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_enum::no_bias, - has_dbias, - s_randval, + hipError_t impl_status = _ck_attn_bwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, + 0, 0, + max_tokens_q, max_tokens_kv, + q_ptr, + 0, stride_h_q, stride_s_q, + k_ptr, + 0, stride_h_k, stride_s_k, + v_ptr, + 0, stride_h_v, stride_s_v, + nullptr, + nullptr, + cu_seqlen_q_ptr, cu_seqlen_kv_ptr, + cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, + o_ptr, + 0, stride_h_o, stride_s_o, + lse_thd_ptr, + do_ptr, + 0, stride_h_do, stride_s_do, + scaling_factor, dropout_probability, + philox_seed_ptr, philox_offset_ptr, + BiasType::no_bias, + attn_mask_type, + window_size_left, window_size_right, + dq_ptr, + 0, stride_h_dq, stride_s_dq, + dq_acc_ptr, + dk_expanded_ptr, + dv_expanded_ptr, + 0, stride_h_dk_expanded, stride_s_dk_expanded, + 0, stride_h_dv_expanded, stride_s_dv_expanded, + dk_ptr, + 0, stride_h_dk, stride_s_dk, + dv_ptr, + 0, stride_h_dv, stride_s_dv, + nullptr, + nullptr, + lse_workspace_ptr, deterministic, uses_bwd_v3, is_v3_atomic_fp32, - how_v3_bf16_cvt); - if(dump_path){ - dump_bwd_timings(dump_path, average_runtime); - } - if(average_runtime < 0){ - //TODO: better error out system - throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); + how_v3_bf16_cvt, + true, + __FUNCTION__, + ck_log_config, + stream); + if (impl_status != hipSuccess) { + return impl_status; } if(is_mqa_gqa){ dim3 grid(max_tokens_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ + if (ck_log_config){ std::cout<::type>(mask_type)<::type>(bias_type)<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))); + log_value("dropout_offset_ptr", std::get<1>(std::get>(fmha_args.drop_seed_offset))); } void dump_fwd_timings(const char* dump_path, float average_runtime){ @@ -126,17 +109,20 @@ void dump_fwd_timings(const char* dump_path, float average_runtime){ file << average_runtime << "\n"; } -hipError_t ck_attn_fwd( +hipError_t _ck_attn_fwd_impl( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - const void* q_ptr, + uint64_t max_tokens_q, + const void* q_ptr, uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, const void* bias_ptr, const void* alibi_slope_ptr, + const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, bool is_training, float scaling_factor, float dropout_probability, @@ -144,11 +130,13 @@ hipError_t ck_attn_fwd( BiasType attn_bias_type, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* o_ptr, + void* o_ptr, uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, void* lse_ptr, bool uses_fwd_v3, int how_v3_bf16_cvt, + bool is_group_mode, + const char* func_name, hipStream_t stream){ bool has_dropout = (is_training && dropout_probability > 0.f); @@ -162,134 +150,137 @@ hipError_t ck_attn_fwd( ck_tile::index_t hdim_v = d_v; ck_tile::index_t max_seqlen_q = s_q; ck_tile::index_t max_seqlen_k = s_kv; + float scale_s = scaling_factor; float logits_soft_cap = 0.f; float p_drop = dropout_probability; - bool is_group_mode = false; - bool is_v_rowmajor = true; - bool has_logits_soft_cap = 0.f < logits_soft_cap; - bool do_fp8_static_quant = false; - - bias_enum bias_type; - BiasShape bias_shape; - std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); - + ck_tile::index_t left, right; left = window_size_left; right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - - bool ck_fused_attn_log_config = false; + + bool ck_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; + ck_log_config = true; } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_log_config}; std::string data_type_str = get_data_type_str(dtype); - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias is of shape [b, h , s_q, s_kv] - const ck_tile::index_t stride_bias = max_seqlen_k; - const ck_tile::index_t stride_randval = max_seqlen_k; - const ck_tile::index_t stride_o = stride_s_o; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - const ck_tile::index_t nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; - //TODO: randval never used, can we remove it - const ck_tile::index_t nhead_stride_randval = 0; - // softmax_lse is of shape [b, h, s_q] - const ck_tile::index_t nhead_stride_lse = max_seqlen_q; - const ck_tile::index_t nhead_stride_o = stride_h_o; - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = stride_b_q; - const ck_tile::index_t batch_stride_k = stride_b_k; - const ck_tile::index_t batch_stride_v = stride_b_v; - const ck_tile::index_t batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: (bias_shape==BiasShape::kBHSS? bias_h* max_seqlen_q * max_seqlen_k: max_seqlen_q*max_seqlen_k); - //TODO: randval never used, can we remove it - const ck_tile::index_t batch_stride_randval = 0; - // softmax_lse is of shape [b, h, s_q] - const ck_tile::index_t batch_stride_lse = nhead * max_seqlen_q; - const ck_tile::index_t batch_stride_o = stride_b_o; - - return fmha_fwd_args{q_ptr, - k_ptr, - v_ptr, - bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr, - nullptr, //q_descale_ptr - nullptr, //k_descale_ptr - nullptr, //v_descale_ptr - nullptr,//rand_val_ptr - lse_ptr, - o_ptr, - nullptr, //seqstart_q_ptr - nullptr, //seqstart_k_ptr - nullptr, //seqlen_q_ptr - nullptr, //seqlen_k_ptr - nullptr, //cu_padded_q_ptr - nullptr, //cu_padded_k_ptr - max_seqlen_q, - max_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - bias_type==bias_enum::alibi? 0: stride_bias, // upstream TE only requires standard (vanilla) alibi slopes - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - left, - right, - 0, // sink_size - static_cast(mask_type), - 0, // min_seqlen_q - p_drop, - false, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); - - // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); - - float average_runtime = aiter::mha_fwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_lse, - quant_scale_enum::no_scale, - uses_fwd_v3, - false,//has_sink - how_v3_bf16_cvt); + // Handle bias + ck_tile::index_t _nhead_stride_bias = 0; + ck_tile::index_t _batch_stride_bias = 0; + bias_enum bias_type = bias_enum::no_bias; + const void* seqstart_q_ptr = nullptr; + const void* seqstart_k_ptr = nullptr; + if(is_group_mode){ + seqstart_q_ptr = cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr; + seqstart_k_ptr = cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr; + }else{ + BiasShape bias_shape; + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + _nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; + _batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: (bias_shape==BiasShape::kBHSS? bias_h* max_seqlen_q * max_seqlen_k: max_seqlen_q*max_seqlen_k); + } + const ck_tile::index_t nhead_stride_bias = _nhead_stride_bias; + const ck_tile::index_t batch_stride_bias = _batch_stride_bias; + + aiter::mha_fwd_args fmha_args{}; + fmha_args.q_ptr = q_ptr; + fmha_args.k_ptr = k_ptr; + fmha_args.v_ptr = v_ptr; + + fmha_args.batch = batch; + fmha_args.seqlen_q = max_seqlen_q; // unused in group mode + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + + fmha_args.stride_q = stride_s_q; + fmha_args.stride_k = stride_s_k; + fmha_args.stride_v = stride_s_v; + fmha_args.nhead_stride_q = stride_h_q; + fmha_args.nhead_stride_k = stride_h_k; + fmha_args.nhead_stride_v = stride_h_v; + fmha_args.batch_stride_q = stride_b_q; + fmha_args.batch_stride_k = stride_b_k; + fmha_args.batch_stride_v = stride_b_v; + + fmha_args.bias_ptr = bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr; + fmha_args.lse_ptr = lse_ptr; + fmha_args.o_ptr = o_ptr; + + fmha_args.seqstart_q_ptr = seqstart_q_ptr; + fmha_args.seqstart_k_ptr = seqstart_k_ptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.seqlen_q_ptr = nullptr; + fmha_args.cu_seqlen_q_ptr = is_group_mode ? cu_seqlen_q_ptr : nullptr; + fmha_args.cu_seqlen_k_ptr = is_group_mode ? cu_seqlen_kv_ptr : nullptr; + fmha_args.block_scale_seqstart_q_ptr = nullptr; + fmha_args.block_scale_seqstart_k_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + fmha_args.seqlen_k = max_seqlen_k; // unused in group mode (or kvcache enabled) + fmha_args.max_seqlen_q = max_seqlen_q; + + fmha_args.scale_s = scale_s; + + fmha_args.logits_soft_cap = logits_soft_cap; + + // bias is of shape [b, h , s_q, s_kv] + fmha_args.stride_bias = is_group_mode? 0 : (bias_type==bias_enum::alibi? 0: max_seqlen_k); + fmha_args.stride_o = stride_s_o; + fmha_args.nhead_stride_bias = nhead_stride_bias; + fmha_args.batch_stride_bias = batch_stride_bias; + // softmax_lse is of shape [b, h, s_q] + fmha_args.nhead_stride_lse = is_group_mode? max_tokens_q : max_seqlen_q; + fmha_args.batch_stride_lse = is_group_mode? 0 : nhead * max_seqlen_q; + fmha_args.nhead_stride_o = stride_h_o; + fmha_args.batch_stride_o = stride_b_o; + + fmha_args.window_size_left = left; + fmha_args.window_size_right = right; + fmha_args.mask_type = static_cast(mask_type); + + fmha_args.rand_val_ptr = nullptr; + + fmha_args.stride_randval = max_seqlen_k; + fmha_args.nhead_stride_randval = 0; // Unused + fmha_args.batch_stride_randval = 0; + fmha_args.nhead_stride_q_descale = 0; + fmha_args.nhead_stride_k_descale = 0; + fmha_args.nhead_stride_v_descale = 0; + fmha_args.batch_stride_q_descale = 0; + fmha_args.batch_stride_k_descale = 0; + fmha_args.batch_stride_v_descale = 0; + + fmha_args.p_drop = p_drop; + fmha_args.s_randval = 0; + fmha_args.drop_seed_offset = std::pair{philox_seed_ptr, philox_offset_ptr}; + fmha_args.use_asm_v3 = uses_fwd_v3; + fmha_args.how_v3_bf16_cvt = how_v3_bf16_cvt; + fmha_args.v3_api_check = false; + fmha_args.data_type = data_type_str; + fmha_args.is_group_mode = is_group_mode; + fmha_args.bias_type = static_cast(bias_type); + fmha_args.has_lse = lse_ptr!=nullptr; + fmha_args.qscale_type = static_cast(quant_scale_enum::no_scale); + fmha_args.has_sink = false; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.sink_size = 0; + fmha_args.min_seqlen_q = 0; + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; + + // print ck traits and fmha_args when needed + log_fwd_config(func_name, has_dropout, fmha_args, ck_log_config); + float average_runtime = aiter::mha_fwd(fmha_args, stream_config); if(dump_path){ dump_fwd_timings(dump_path, average_runtime); } @@ -300,15 +291,69 @@ hipError_t ck_attn_fwd( return hipSuccess; } +hipError_t ck_attn_fwd( + DType dtype, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, + const void* q_ptr, + uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, + const void* k_ptr, + uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, + const void* v_ptr, + uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, + const void* bias_ptr, + const void* alibi_slope_ptr, + bool is_training, + float scaling_factor, + float dropout_probability, + void* philox_seed_ptr, void* philox_offset_ptr, + BiasType attn_bias_type, + MaskType attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + void* o_ptr, + uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, + void* lse_ptr, + bool uses_fwd_v3, + int how_v3_bf16_cvt, + hipStream_t stream){ + + return _ck_attn_fwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, + 0, + q_ptr, stride_b_q, stride_h_q, stride_s_q, + k_ptr, stride_b_k, stride_h_k, stride_s_k, + v_ptr, stride_b_v, stride_h_v, stride_s_v, + bias_ptr, + alibi_slope_ptr, + nullptr, nullptr, + nullptr, nullptr, + is_training, + scaling_factor, + dropout_probability, + philox_seed_ptr, philox_offset_ptr, + attn_bias_type, + attn_mask_type, + window_size_left, window_size_right, + o_ptr, + stride_b_o, stride_h_o, stride_s_o, + lse_ptr, + uses_fwd_v3, + how_v3_bf16_cvt, + false, + __FUNCTION__, // func_name + stream + ); +} + hipError_t ck_attn_varlen_fwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t max_tokens_q, - const void* q_ptr, + const void* q_ptr, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, @@ -318,174 +363,40 @@ hipError_t ck_attn_varlen_fwd( void* philox_seed_ptr, void* philox_offset_ptr, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* o_ptr, + void* o_ptr, uint64_t stride_h_o, uint64_t stride_s_o, void* lse_thd_ptr, bool uses_fwd_v3, int how_v3_bf16_cvt, hipStream_t stream){ - bool has_dropout = (is_training && dropout_probability > 0.f); - bool has_lse = (lse_thd_ptr != nullptr); - - /* CK input parameters */ - ck_tile::index_t batch = b; - ck_tile::index_t nhead = h; - ck_tile::index_t hdim_q = d_qk; - ck_tile::index_t nhead_k = hg; - ck_tile::index_t hdim_v = d_v; - ck_tile::index_t max_seqlen_q = s_q; - ck_tile::index_t max_seqlen_kv = s_kv; - - float scale_s = scaling_factor; - float logits_soft_cap = 0.f; - float p_drop = dropout_probability; - bool is_group_mode = true; - bool is_v_rowmajor = true; - bool has_logits_soft_cap = 0.f < logits_soft_cap; - bool do_fp8_static_quant = false; - - // THD does not work with bias - - ck_tile::index_t left, right; - left = window_size_left; - right = window_size_right; - mask_enum mask_type = static_cast(attn_mask_type); - - bias_enum bias_type = bias_enum::no_bias; - - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } - const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); - // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; - - - std::string data_type_str = get_data_type_str(dtype); - - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias not used in THD qkv layout - const ck_tile::index_t stride_bias = 0; - // randval not used - const ck_tile::index_t stride_randval = 0; - const ck_tile::index_t stride_o = stride_s_o; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - // bias not used in THD qkv layout - const ck_tile::index_t nhead_stride_bias = 0; - //TODO: randval never used, can we remove it - const ck_tile::index_t nhead_stride_randval = 0; - // use packed lse of shape [h, max_tokens_q] - const ck_tile::index_t nhead_stride_lse = max_tokens_q; - const ck_tile::index_t nhead_stride_o = stride_h_o; - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = 0; - const ck_tile::index_t batch_stride_k = 0; - const ck_tile::index_t batch_stride_v = 0; - // bias not used in THD qkv layout - const ck_tile::index_t batch_stride_bias = 0; - //TODO: randval never used, can we remove it - const ck_tile::index_t batch_stride_randval = 0; - const ck_tile::index_t batch_stride_lse = 0; - const ck_tile::index_t batch_stride_o = 0; - - return fmha_fwd_args{q_ptr, - k_ptr, - v_ptr, - nullptr,//bias_ptr - nullptr, //q_descale_ptr - nullptr, //k_descale_ptr - nullptr, //v_descale_ptr - nullptr,//rand_val_ptr - lse_thd_ptr, - o_ptr, - cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr - cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr - nullptr, //seqlen_q_ptr - nullptr, //seqlen_k_ptr - cu_seqlen_q_ptr, //cu_seqlen_q_ptr - cu_seqlen_kv_ptr, //cu_seqlen_k_ptr - max_seqlen_q, //seqlen_q, unused in group mode - max_seqlen_kv, //seqlen_kv, unused in group mode - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - left, - right, - 0, // sink_size - static_cast(mask_type), - 0, // min_seqlen_q - p_drop, - false, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); - // modify the max_seqlen_q for better performance in 0-length cases - // lse_thd_ptr used as buffer - if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ - if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; - } - fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); - } - } - // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); - - float average_runtime = aiter::mha_fwd( - fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_lse, - quant_scale_enum::no_scale, - uses_fwd_v3, - false,//has_sink - how_v3_bf16_cvt); - if(dump_path){ - dump_fwd_timings(dump_path, average_runtime); - } - if(average_runtime < 0){ - //TODO: better error out system - throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); - } - return hipSuccess; + return _ck_attn_fwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, 0, 0, + max_tokens_q, + q_ptr, 0, stride_h_q, stride_s_q, + k_ptr, 0, stride_h_k, stride_s_k, + v_ptr, 0, stride_h_v, stride_s_v, + nullptr, + nullptr, + cu_seqlen_q_ptr, cu_seqlen_kv_ptr, + cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, + is_training, + scaling_factor, + dropout_probability, + philox_seed_ptr, philox_offset_ptr, + BiasType::no_bias, + attn_mask_type, + window_size_left, window_size_right, + o_ptr, + 0, stride_h_o, stride_s_o, + lse_thd_ptr, + uses_fwd_v3, + how_v3_bf16_cvt, + true, + __FUNCTION__, // func_name + stream + ); } }//namespace ck_fused_attn diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 26c92ca2b..309d31382 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -5,6 +5,12 @@ ************************************************************************/ #include +#include +#include +#include //once_flag + +#include + #include "ck_fused_attn_utils.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" #include "mask.hpp" @@ -13,6 +19,32 @@ namespace ck_fused_attn{ +void set_aiter_asm_dir() { + static std::once_flag aiter_asm_dir_once; + std::call_once(aiter_asm_dir_once, []() { + Dl_info info; + dladdr((void*)set_aiter_asm_dir, &info); + auto install_lib_path = std::filesystem::path(info.dli_fname).parent_path() / "aiter"; + const char* log_ck_config = std::getenv("NVTE_LOG_CK_CONFIG"); + auto editable_install_path = std::filesystem::path(info.dli_fname).parent_path().parent_path().parent_path() / "3rdparty" / "aiter" / "hsa"; + for(const auto& path : {install_lib_path, editable_install_path}) { + if(std::filesystem::exists(path)) { + setenv("AITER_ASM_DIR", path.c_str(), 1); + if (log_ck_config && log_ck_config == std::string("1")) { + std::cout << "AITER_ASM_DIR set to: " << getenv("AITER_ASM_DIR") << std::endl; + } + return; + } + if(log_ck_config && log_ck_config == std::string("1")) { + std::cout << "Checked AITER_ASM_DIR path: " << path << " does not exist." << std::endl; + } + } + }); +} + + +const bool aiterAsmDirInitialized = (set_aiter_asm_dir(), true); + std::string get_data_type_str(DType dtype){ std::string data_type_str; if(dtype==DType::kFloat16){