From f02c48e8a2089a1f89ac6c6a6af72fd7e8faefd2 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 2 Feb 2026 15:46:34 -0600 Subject: [PATCH 01/11] Initial inclusion of new API in fwd as well as part 1 of refactor --- .../include/ck_fused_attn/ck_fused_attn.hpp | 29 ++ .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 476 +++++++++--------- 2 files changed, 269 insertions(+), 236 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 54ee94786..7a3ce1c99 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -37,6 +37,35 @@ enum class BiasType{ }; +hipError_t _ck_attn_fwd_impl( + DType dtype, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, + uint64_t max_tokens_q, + const void* q_ptr, + uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, + const void* k_ptr, + uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, + const void* v_ptr, + uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, + const void* bias_ptr, + const void* alibi_slope_ptr, + const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, + bool is_training, + float scaling_factor, + float dropout_probability, + void* philox_seed_ptr, void* philox_offset_ptr, + BiasType attn_bias_type, + MaskType attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + void* o_ptr, + uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, + void* lse_ptr, + bool uses_fwd_v3, + int how_v3_bf16_cvt, + bool is_group_mode, + hipStream_t stream); + hipError_t ck_attn_fwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 8d7c4b5c9..b13f53fc7 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -16,7 +16,7 @@ namespace ck_fused_attn{ -// print the fmha traits and args when calling ck apis +// print the fmha traits and fmha_args when calling ck apis void log_fwd_config(const char* func_name, const std::string data_type_str, const bool is_group_mode, @@ -29,7 +29,7 @@ void log_fwd_config(const char* func_name, const bool do_fp8_static_quant, const bool uses_fwd_v3, const bool how_v3_bf16_cvt, - const fmha_fwd_args& fmha_args){ + const aiter::mha_fwd_args& fmha_args){ bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") @@ -126,9 +126,10 @@ void dump_fwd_timings(const char* dump_path, float average_runtime){ file << average_runtime << "\n"; } -hipError_t ck_attn_fwd( +hipError_t _ck_attn_fwd_impl( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, + uint64_t max_tokens_q, const void* q_ptr, uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, const void* k_ptr, @@ -137,6 +138,8 @@ hipError_t ck_attn_fwd( uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, const void* bias_ptr, const void* alibi_slope_ptr, + const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, bool is_training, float scaling_factor, float dropout_probability, @@ -149,6 +152,7 @@ hipError_t ck_attn_fwd( void* lse_ptr, bool uses_fwd_v3, int how_v3_bf16_cvt, + bool is_group_mode, hipStream_t stream){ bool has_dropout = (is_training && dropout_probability > 0.f); @@ -162,17 +166,13 @@ hipError_t ck_attn_fwd( ck_tile::index_t hdim_v = d_v; ck_tile::index_t max_seqlen_q = s_q; ck_tile::index_t max_seqlen_k = s_kv; + float scale_s = scaling_factor; float logits_soft_cap = 0.f; float p_drop = dropout_probability; - bool is_group_mode = false; bool is_v_rowmajor = true; bool has_logits_soft_cap = 0.f < logits_soft_cap; bool do_fp8_static_quant = false; - - bias_enum bias_type; - BiasShape bias_shape; - std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); ck_tile::index_t left, right; left = window_size_left; @@ -190,106 +190,96 @@ hipError_t ck_attn_fwd( std::string data_type_str = get_data_type_str(dtype); - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias is of shape [b, h , s_q, s_kv] - const ck_tile::index_t stride_bias = max_seqlen_k; - const ck_tile::index_t stride_randval = max_seqlen_k; - const ck_tile::index_t stride_o = stride_s_o; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - const ck_tile::index_t nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; - //TODO: randval never used, can we remove it - const ck_tile::index_t nhead_stride_randval = 0; - // softmax_lse is of shape [b, h, s_q] - const ck_tile::index_t nhead_stride_lse = max_seqlen_q; - const ck_tile::index_t nhead_stride_o = stride_h_o; - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = stride_b_q; - const ck_tile::index_t batch_stride_k = stride_b_k; - const ck_tile::index_t batch_stride_v = stride_b_v; - const ck_tile::index_t batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: (bias_shape==BiasShape::kBHSS? bias_h* max_seqlen_q * max_seqlen_k: max_seqlen_q*max_seqlen_k); - //TODO: randval never used, can we remove it - const ck_tile::index_t batch_stride_randval = 0; - // softmax_lse is of shape [b, h, s_q] - const ck_tile::index_t batch_stride_lse = nhead * max_seqlen_q; - const ck_tile::index_t batch_stride_o = stride_b_o; - - return fmha_fwd_args{q_ptr, - k_ptr, - v_ptr, - bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr, - nullptr, //q_descale_ptr - nullptr, //k_descale_ptr - nullptr, //v_descale_ptr - nullptr,//rand_val_ptr - lse_ptr, - o_ptr, - nullptr, //seqstart_q_ptr - nullptr, //seqstart_k_ptr - nullptr, //seqlen_q_ptr - nullptr, //seqlen_k_ptr - nullptr, //cu_padded_q_ptr - nullptr, //cu_padded_k_ptr - max_seqlen_q, - max_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - bias_type==bias_enum::alibi? 0: stride_bias, // upstream TE only requires standard (vanilla) alibi slopes - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - left, - right, - 0, // sink_size - static_cast(mask_type), - 0, // min_seqlen_q - p_drop, - false, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); + // Handle bias + ck_tile::index_t _nhead_stride_bias = 0; + ck_tile::index_t _batch_stride_bias = 0; + bias_enum bias_type = bias_enum::no_bias; + const void* seqstart_q_ptr = nullptr; + const void* seqstart_k_ptr = nullptr; + if(is_group_mode){ + seqstart_q_ptr = cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr; + seqstart_k_ptr = cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr; + }else{ + BiasShape bias_shape; + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + _nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; + _batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: (bias_shape==BiasShape::kBHSS? bias_h* max_seqlen_q * max_seqlen_k: max_seqlen_q*max_seqlen_k); + } + const ck_tile::index_t nhead_stride_bias = _nhead_stride_bias; + const ck_tile::index_t batch_stride_bias = _batch_stride_bias; + + aiter::mha_fwd_args fmha_args; + fmha_args.q_ptr = q_ptr; + fmha_args.k_ptr = k_ptr; + fmha_args.v_ptr = v_ptr; + + fmha_args.batch = batch; + fmha_args.seqlen_q = max_seqlen_q; // unused in group mode + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + + fmha_args.stride_q = stride_s_q; + fmha_args.stride_k = stride_s_k; + fmha_args.stride_v = stride_s_v; + fmha_args.nhead_stride_q = stride_h_q; + fmha_args.nhead_stride_k = stride_h_k; + fmha_args.nhead_stride_v = stride_h_v; + fmha_args.batch_stride_q = stride_b_q; + fmha_args.batch_stride_k = stride_b_k; + fmha_args.batch_stride_v = stride_b_v; + + fmha_args.bias_ptr = bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr; + fmha_args.lse_ptr = lse_ptr; + fmha_args.o_ptr = o_ptr; + + fmha_args.seqstart_q_ptr = seqstart_q_ptr; + fmha_args.seqstart_k_ptr = seqstart_k_ptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.seqlen_k = max_seqlen_k; // unused in group mode (or kvcache enabled) + fmha_args.max_seqlen_q = max_seqlen_q; + + fmha_args.scale_s = scale_s; + + fmha_args.logits_soft_cap = logits_soft_cap; + + // bias is of shape [b, h , s_q, s_kv] + fmha_args.stride_bias = bias_type==bias_enum::alibi? 0: max_seqlen_k; + fmha_args.stride_o = stride_s_o; + fmha_args.nhead_stride_bias = nhead_stride_bias; + fmha_args.batch_stride_bias = batch_stride_bias; + // softmax_lse is of shape [b, h, s_q] + fmha_args.nhead_stride_lse = max_seqlen_q; + fmha_args.batch_stride_lse = nhead * max_seqlen_q; + fmha_args.nhead_stride_o = stride_h_o; + fmha_args.batch_stride_o = stride_b_o; + + fmha_args.window_size_left = left; + fmha_args.window_size_right = right; + fmha_args.mask_type = static_cast(mask_type); + + fmha_args.rand_val_ptr = nullptr; + + fmha_args.stride_randval = max_seqlen_k; + fmha_args.nhead_stride_randval = 0; // Unused + fmha_args.batch_stride_randval = 0; + + fmha_args.p_drop = p_drop; + fmha_args.s_randval = 0; + fmha_args.use_asm_v3 = uses_fwd_v3; + fmha_args.how_v3_bf16_cvt = how_v3_bf16_cvt; + fmha_args.v3_api_check = false; + fmha_args.data_type = data_type_str; + fmha_args.is_group_mode = is_group_mode; + fmha_args.bias_type = static_cast(bias_type); + fmha_args.has_lse = lse_ptr!=nullptr; + fmha_args.qscale_type = static_cast(quant_scale_enum::no_scale); + fmha_args.has_sink = false; - // print ck traits and args when needed + // print ck traits and fmha_args when needed log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); - - float average_runtime = aiter::mha_fwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_lse, - quant_scale_enum::no_scale, - uses_fwd_v3, - false,//has_sink - how_v3_bf16_cvt); + float average_runtime = aiter::mha_fwd(fmha_args, stream_config); if(dump_path){ dump_fwd_timings(dump_path, average_runtime); } @@ -300,33 +290,33 @@ hipError_t ck_attn_fwd( return hipSuccess; } -hipError_t ck_attn_varlen_fwd( +hipError_t ck_attn_fwd( DType dtype, - uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, - uint64_t max_tokens_q, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, const void* q_ptr, - uint64_t stride_h_q, uint64_t stride_s_q, + uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, const void* k_ptr, - uint64_t stride_h_k, uint64_t stride_s_k, + uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, const void* v_ptr, - uint64_t stride_h_v, uint64_t stride_s_v, - const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, - const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, + uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, + const void* bias_ptr, + const void* alibi_slope_ptr, bool is_training, float scaling_factor, float dropout_probability, void* philox_seed_ptr, void* philox_offset_ptr, + BiasType attn_bias_type, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, void* o_ptr, - uint64_t stride_h_o, uint64_t stride_s_o, - void* lse_thd_ptr, + uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, + void* lse_ptr, bool uses_fwd_v3, int how_v3_bf16_cvt, hipStream_t stream){ bool has_dropout = (is_training && dropout_probability > 0.f); - bool has_lse = (lse_thd_ptr != nullptr); + bool has_lse = (lse_ptr != nullptr); /* CK input parameters */ ck_tile::index_t batch = b; @@ -335,25 +325,23 @@ hipError_t ck_attn_varlen_fwd( ck_tile::index_t nhead_k = hg; ck_tile::index_t hdim_v = d_v; ck_tile::index_t max_seqlen_q = s_q; - ck_tile::index_t max_seqlen_kv = s_kv; - + ck_tile::index_t max_seqlen_k = s_kv; float scale_s = scaling_factor; float logits_soft_cap = 0.f; float p_drop = dropout_probability; - bool is_group_mode = true; bool is_v_rowmajor = true; bool has_logits_soft_cap = 0.f < logits_soft_cap; bool do_fp8_static_quant = false; - // THD does not work with bias + bias_enum bias_type; + BiasShape bias_shape; + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); ck_tile::index_t left, right; left = window_size_left; right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bias_enum bias_type = bias_enum::no_bias; - bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") @@ -363,121 +351,84 @@ hipError_t ck_attn_varlen_fwd( // print kernel name on verbose mode ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; - std::string data_type_str = get_data_type_str(dtype); - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias not used in THD qkv layout - const ck_tile::index_t stride_bias = 0; - // randval not used - const ck_tile::index_t stride_randval = 0; - const ck_tile::index_t stride_o = stride_s_o; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - // bias not used in THD qkv layout - const ck_tile::index_t nhead_stride_bias = 0; - //TODO: randval never used, can we remove it - const ck_tile::index_t nhead_stride_randval = 0; - // use packed lse of shape [h, max_tokens_q] - const ck_tile::index_t nhead_stride_lse = max_tokens_q; - const ck_tile::index_t nhead_stride_o = stride_h_o; - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = 0; - const ck_tile::index_t batch_stride_k = 0; - const ck_tile::index_t batch_stride_v = 0; - // bias not used in THD qkv layout - const ck_tile::index_t batch_stride_bias = 0; - //TODO: randval never used, can we remove it - const ck_tile::index_t batch_stride_randval = 0; - const ck_tile::index_t batch_stride_lse = 0; - const ck_tile::index_t batch_stride_o = 0; - - return fmha_fwd_args{q_ptr, - k_ptr, - v_ptr, - nullptr,//bias_ptr - nullptr, //q_descale_ptr - nullptr, //k_descale_ptr - nullptr, //v_descale_ptr - nullptr,//rand_val_ptr - lse_thd_ptr, - o_ptr, - cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr - cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr - nullptr, //seqlen_q_ptr - nullptr, //seqlen_k_ptr - cu_seqlen_q_ptr, //cu_seqlen_q_ptr - cu_seqlen_kv_ptr, //cu_seqlen_k_ptr - max_seqlen_q, //seqlen_q, unused in group mode - max_seqlen_kv, //seqlen_kv, unused in group mode - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - left, - right, - 0, // sink_size - static_cast(mask_type), - 0, // min_seqlen_q - p_drop, - false, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); - // modify the max_seqlen_q for better performance in 0-length cases - // lse_thd_ptr used as buffer - if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ - if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; - } - fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); - } - } - // print ck traits and args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); - - float average_runtime = aiter::mha_fwd( - fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_lse, - quant_scale_enum::no_scale, - uses_fwd_v3, - false,//has_sink - how_v3_bf16_cvt); + const ck_tile::index_t nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: (bias_shape==BiasShape::kBHSS? bias_h* max_seqlen_q * max_seqlen_k: max_seqlen_q*max_seqlen_k); + + aiter::mha_fwd_args fmha_args; + fmha_args.q_ptr = q_ptr; + fmha_args.k_ptr = k_ptr; + fmha_args.v_ptr = v_ptr; + + fmha_args.batch = batch; + fmha_args.seqlen_q = max_seqlen_q; // unused in group mode + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + + fmha_args.stride_q = stride_s_q; + fmha_args.stride_k = stride_s_k; + fmha_args.stride_v = stride_s_v; + fmha_args.nhead_stride_q = stride_h_q; + fmha_args.nhead_stride_k = stride_h_k; + fmha_args.nhead_stride_v = stride_h_v; + fmha_args.batch_stride_q = stride_b_q; + fmha_args.batch_stride_k = stride_b_k; + fmha_args.batch_stride_v = stride_b_v; + + fmha_args.bias_ptr = bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr; + fmha_args.lse_ptr = lse_ptr; + fmha_args.o_ptr = o_ptr; + + fmha_args.seqstart_q_ptr = nullptr; + fmha_args.seqstart_k_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.seqlen_k = max_seqlen_k; // unused in group mode (or kvcache enabled) + fmha_args.max_seqlen_q = max_seqlen_q; + + fmha_args.scale_s = scale_s; + + fmha_args.logits_soft_cap = logits_soft_cap; + + // bias is of shape [b, h , s_q, s_kv] + fmha_args.stride_bias = bias_type==bias_enum::alibi? 0: max_seqlen_k; + fmha_args.stride_o = stride_s_o; + fmha_args.nhead_stride_bias = nhead_stride_bias; + fmha_args.batch_stride_bias = batch_stride_bias; + // softmax_lse is of shape [b, h, s_q] + fmha_args.nhead_stride_lse = max_seqlen_q; + fmha_args.batch_stride_lse = nhead * max_seqlen_q; + fmha_args.nhead_stride_o = stride_h_o; + fmha_args.batch_stride_o = stride_b_o; + + fmha_args.window_size_left = left; + fmha_args.window_size_right = right; + fmha_args.mask_type = static_cast(mask_type); + + fmha_args.rand_val_ptr = nullptr; + + fmha_args.stride_randval = max_seqlen_k; + fmha_args.nhead_stride_randval = 0; // Unused + fmha_args.batch_stride_randval = 0; + + fmha_args.p_drop = p_drop; + fmha_args.s_randval = 0; + fmha_args.use_asm_v3 = uses_fwd_v3; + fmha_args.how_v3_bf16_cvt = how_v3_bf16_cvt; + fmha_args.v3_api_check = false; + fmha_args.data_type = data_type_str; + fmha_args.is_group_mode = false; + fmha_args.bias_type = static_cast(bias_type); + fmha_args.has_lse = lse_ptr!=nullptr; + fmha_args.qscale_type = static_cast(quant_scale_enum::no_scale); + fmha_args.has_sink = false; + + // print ck traits and fmha_args when needed + log_fwd_config(__FUNCTION__, data_type_str, false, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + float average_runtime = aiter::mha_fwd(fmha_args, stream_config); if(dump_path){ dump_fwd_timings(dump_path, average_runtime); } @@ -488,5 +439,58 @@ hipError_t ck_attn_varlen_fwd( return hipSuccess; } +hipError_t ck_attn_varlen_fwd( + DType dtype, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, + uint64_t max_tokens_q, + const void* q_ptr, + uint64_t stride_h_q, uint64_t stride_s_q, + const void* k_ptr, + uint64_t stride_h_k, uint64_t stride_s_k, + const void* v_ptr, + uint64_t stride_h_v, uint64_t stride_s_v, + const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, + bool is_training, + float scaling_factor, + float dropout_probability, + void* philox_seed_ptr, void* philox_offset_ptr, + MaskType attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + void* o_ptr, + uint64_t stride_h_o, uint64_t stride_s_o, + void* lse_thd_ptr, + bool uses_fwd_v3, + int how_v3_bf16_cvt, + hipStream_t stream){ + + return _ck_attn_fwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, 0, 0, + max_tokens_q, + q_ptr, 0, stride_h_q, stride_s_q, + k_ptr, 0, stride_h_k, stride_s_k, + v_ptr, 0, stride_h_v, stride_s_v, + nullptr, + nullptr, + cu_seqlen_q_ptr, cu_seqlen_kv_ptr, + cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, + is_training, + scaling_factor, + dropout_probability, + philox_seed_ptr, philox_offset_ptr, + BiasType::no_bias, + attn_mask_type, + window_size_left, window_size_right, + o_ptr, + 0, stride_h_o, stride_s_o, + lse_thd_ptr, + uses_fwd_v3, + how_v3_bf16_cvt, + true, + stream + ); +} + }//namespace ck_fused_attn From 0b0ad931d0a15aad9b5ddac407fe5539e3225388 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 3 Feb 2026 08:55:14 -0600 Subject: [PATCH 02/11] Initial implementation of refactor/API update across ALL CK funcs --- transformer_engine/__init__.py | 7 + .../include/ck_fused_attn/ck_fused_attn.hpp | 29 - .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 663 ++++++++---------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 185 ++--- 4 files changed, 364 insertions(+), 520 deletions(-) diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index 050abc8f7..3ff9f03eb 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -83,4 +83,11 @@ category=RuntimeWarning, ) +# Set AITER_ASM_DIR to point to the asm dir in the installed package +local_asm_dir = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", "3rdparty", "aiter", "hsa", +) +if "AITER_ASM_DIR" not in os.environ: + os.environ["AITER_ASM_DIR"] = local_asm_dir __version__ = str(metadata.version("transformer_engine")) diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 7a3ce1c99..54ee94786 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -37,35 +37,6 @@ enum class BiasType{ }; -hipError_t _ck_attn_fwd_impl( - DType dtype, - uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - uint64_t max_tokens_q, - const void* q_ptr, - uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, - uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, - uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, - const void* bias_ptr, - const void* alibi_slope_ptr, - const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, - const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, - bool is_training, - float scaling_factor, - float dropout_probability, - void* philox_seed_ptr, void* philox_offset_ptr, - BiasType attn_bias_type, - MaskType attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - void* o_ptr, - uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, - void* lse_ptr, - bool uses_fwd_v3, - int how_v3_bf16_cvt, - bool is_group_mode, - hipStream_t stream); - hipError_t ck_attn_fwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 3f51b96b6..3930cc06d 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -345,7 +345,7 @@ void log_bwd_config(const char* func_name, const bool uses_bwd_v3, const bool is_v3_atomic_fp32, const int how_v3_bf16_cvt, - const fmha_bwd_args& fmha_args){ + const aiter::mha_bwd_args& fmha_args){ bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { @@ -360,16 +360,16 @@ void log_bwd_config(const char* func_name, std::cout<<"hdim_q: "<::type>(mask_type)<::type>(bias_type)< 0.f); - bool has_dbias = dbias_ptr!=nullptr; + bool has_dbias = dbias_ptr != nullptr; bool is_mqa_gqa = (h > hg); /* CK input parameters */ @@ -518,17 +525,19 @@ hipError_t ck_attn_bwd( float scale_s = scaling_factor; float p_drop = dropout_probability; float p_undrop = 1.0 - p_drop; - bool is_group_mode = false; bool s_randval = false; - bias_enum bias_type; - BiasShape bias_shape; - std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + bias_enum bias_type = bias_enum::no_bias; + BiasShape bias_shape = BiasShape::k11SS; + if (!is_group_mode) { + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + } + ck_tile::index_t left, right; left = window_size_left; right = window_size_right; - mask_enum mask_type = static_cast(attn_mask_type); + bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") @@ -539,163 +548,140 @@ hipError_t ck_attn_bwd( // print kernel name on verbose mode ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; - ck_tile::index_t shape_seqlen_q = seqlen_q; - ck_tile::index_t shape_seqlen_k = seqlen_k; - std::string data_type_str = get_data_type_str(dtype); - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias of shape (bias_b, bias_h, s_q, s_kv) - const ck_tile::index_t stride_bias = max_seqlen_k; - const ck_tile::index_t stride_o = stride_s_o; - const ck_tile::index_t stride_randval = max_seqlen_k; - const ck_tile::index_t stride_do = stride_s_do; - const ck_tile::index_t stride_dq = stride_s_dq; - const ck_tile::index_t stride_dk = stride_s_dk; - const ck_tile::index_t stride_dv = stride_s_dv; - const ck_tile::index_t stride_dk_expanded = stride_s_dk_expanded; - const ck_tile::index_t stride_dv_expanded = stride_s_dv_expanded; - const ck_tile::index_t stride_dq_acc = d_qk; //dq_acc of shape (nsplits, B, H, S, D) - // dbias is of the same shape as bias - // but ck only take dbias with BHSS - const ck_tile::index_t stride_dbias = max_seqlen_k; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - // bias input can be of different shapes (11SS, 1HSS, B1SS, and BHSS), but dbias must be of BHSS - const ck_tile::index_t nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; - const ck_tile::index_t nhead_stride_o = stride_h_o; - const ck_tile::index_t nhead_stride_randval = - shape_seqlen_q * max_seqlen_k; - const ck_tile::index_t nhead_stride_do = stride_h_do; - const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; - const ck_tile::index_t nhead_stride_dq = stride_h_dq; - const ck_tile::index_t nhead_stride_dk = stride_h_dk; - const ck_tile::index_t nhead_stride_dv = stride_h_dv; - const ck_tile::index_t nhead_stride_dk_expanded = stride_h_dk_expanded; - const ck_tile::index_t nhead_stride_dv_expanded = stride_h_dv_expanded; - // dbias can only be of BHSS - const ck_tile::index_t nhead_stride_dbias = max_seqlen_q * max_seqlen_k; - const ck_tile::index_t nhead_stride_dq_acc = s_q*d_qk; //dq_acc of shape (nsplits, B, H, S, D) - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = stride_b_q; - const ck_tile::index_t batch_stride_k = stride_b_k; - const ck_tile::index_t batch_stride_v = stride_b_v; - // bias input can be of different shapes (11SS, 1HSS, B1SS, and BHSS), but dbias must be of BHSS - // for B1SS and BHSS, batch stride for bias are both bias_h x s_q x s_kv (bias_h==1 for B1SS and bias_h == h for BHSS) - const ck_tile::index_t batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: bias_h* max_seqlen_q * max_seqlen_k; - const ck_tile::index_t batch_stride_o = stride_b_o; - const ck_tile::index_t batch_stride_randval = - nhead * shape_seqlen_q * max_seqlen_k; - const ck_tile::index_t batch_stride_do = stride_b_do; - const ck_tile::index_t batch_stride_lsed = nhead * max_seqlen_q; - const ck_tile::index_t batch_stride_dq = stride_b_dq; - const ck_tile::index_t batch_stride_dk = stride_b_dk; - const ck_tile::index_t batch_stride_dv = stride_b_dv; - const ck_tile::index_t batch_stride_dk_expanded = stride_b_dk_expanded; - const ck_tile::index_t batch_stride_dv_expanded = stride_b_dv_expanded; - // for dbias, use h since h can be different from bias_h - const ck_tile::index_t batch_stride_dbias = h* max_seqlen_q * max_seqlen_k; - const ck_tile::index_t batch_stride_dq_acc = h*s_q*d_qk; //dq_acc of shape (nsplits, B, H, S, D) - const ck_tile::index_t split_stride_dq_acc = b * h * s_q * d_qk; - - return fmha_bwd_args{q_ptr, - k_ptr, - v_ptr, - bias_type==bias_enum::no_bias? nullptr : (bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr), - o_ptr, - lse_ptr, - do_ptr, - lse_workspace_ptr, - nullptr, - dq_ptr, - is_mqa_gqa? dk_expanded_ptr:dk_ptr, - is_mqa_gqa? dv_expanded_ptr:dv_ptr, - has_dbias? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr): nullptr, - dq_acc_ptr, //dq_acc_buf - nullptr,//seqstart_q_ptr - nullptr,//seqstart_k_ptr - nullptr, /* seqlen_q_ptr */ - nullptr, /* seqlen_k_ptr */ - nullptr, //cu_seqlen_q_ptr - nullptr, //cu_seqlen_k_ptr - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - max_seqlen_k, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - stride_q, - stride_k, - stride_v, - bias_type==bias_enum::alibi? 0: stride_bias, - stride_o, - stride_randval, - stride_do, - stride_dq_acc,//stride_dq_acc - stride_dq,//stride_dq - is_mqa_gqa? stride_dk_expanded:stride_dk, - is_mqa_gqa? stride_dv_expanded:stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, //nhead_stride_dq_acc - nhead_stride_dq, - is_mqa_gqa? nhead_stride_dk_expanded:nhead_stride_dk, - is_mqa_gqa? nhead_stride_dv_expanded:nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_o, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, //batch_stride_dq_acc - batch_stride_dq, - is_mqa_gqa? batch_stride_dk_expanded:batch_stride_dk, - is_mqa_gqa? batch_stride_dv_expanded:batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - left, - right, - static_cast(mask_type), - p_drop, - p_undrop, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); + aiter::mha_bwd_args fmha_args{}; + fmha_args.mask_type = static_cast(mask_type); + fmha_args.use_asm_v3 = uses_bwd_v3; + fmha_args.v3_atomic_fp32 = is_v3_atomic_fp32; + fmha_args.v3_bf16_cvt = how_v3_bf16_cvt; + fmha_args.v3_api_check = false; + + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.data_type = data_type_str; + fmha_args.is_group_mode = is_group_mode; + fmha_args.ck_mask_type = static_cast(mask_type); + fmha_args.bias_type = static_cast(bias_type); + fmha_args.has_dbias = (!is_group_mode) && has_dbias; + fmha_args.has_dropout = has_dropout; + fmha_args.is_store_randval = s_randval; + fmha_args.is_deterministic = deterministic; + + fmha_args.q_ptr = q_ptr; + fmha_args.k_ptr = k_ptr; + fmha_args.v_ptr = v_ptr; + fmha_args.bias_ptr = (bias_type==bias_enum::no_bias || is_group_mode) ? nullptr + : (bias_type==bias_enum::alibi? alibi_slope_ptr : bias_ptr); + fmha_args.o_ptr = o_ptr; + fmha_args.lse_ptr = is_group_mode ? lse_thd_ptr : lse_ptr; + fmha_args.do_ptr = do_ptr; + fmha_args.d_ptr = lse_workspace_ptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.dq_ptr = dq_ptr; + fmha_args.dk_ptr = is_mqa_gqa? dk_expanded_ptr:dk_ptr; + fmha_args.dv_ptr = is_mqa_gqa? dv_expanded_ptr:dv_ptr; + fmha_args.dbias_ptr = ((!is_group_mode) && has_dbias) + ? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr) + : nullptr; + fmha_args.dq_acc_ptr = dq_acc_ptr; + + if (is_group_mode) { + fmha_args.seqstart_q_ptr = cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr; + fmha_args.seqstart_k_ptr = cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr; + fmha_args.seqlen_q_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.cu_seqlen_q_ptr = cu_seqlen_q_ptr; + fmha_args.cu_seqlen_k_ptr = cu_seqlen_kv_ptr; + } else { + fmha_args.seqstart_q_ptr = nullptr; + fmha_args.seqstart_k_ptr = nullptr; + fmha_args.seqlen_q_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.cu_seqlen_q_ptr = nullptr; + fmha_args.cu_seqlen_k_ptr = nullptr; + } + + fmha_args.seqlen_q = is_group_mode ? max_seqlen_q : seqlen_q; + fmha_args.seqlen_k = is_group_mode ? max_seqlen_k : seqlen_k; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = max_seqlen_q; + fmha_args.max_seqlen_k = max_seqlen_k; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + fmha_args.scale = scale_s; + + // setup stride_* arguments + fmha_args.stride_q = stride_s_q; + fmha_args.stride_k = stride_s_k; + fmha_args.stride_v = stride_s_v; + fmha_args.stride_bias = (!is_group_mode && bias_type!=bias_enum::alibi) ? max_seqlen_k : 0; + fmha_args.stride_o = stride_s_o; + fmha_args.stride_randval = max_seqlen_k; + fmha_args.stride_do = stride_s_do; + fmha_args.stride_dq_acc = d_qk; + fmha_args.stride_dq = stride_s_dq; + fmha_args.stride_dk = is_mqa_gqa? stride_s_dk_expanded:stride_s_dk; + fmha_args.stride_dv = is_mqa_gqa? stride_s_dv_expanded:stride_s_dv; + fmha_args.stride_dbias = (!is_group_mode && bias_type!=bias_enum::alibi) ? max_seqlen_k : 0; + + // setup nhead_stride_* arguments + fmha_args.nhead_stride_q = stride_h_q; + fmha_args.nhead_stride_k = stride_h_k; + fmha_args.nhead_stride_v = stride_h_v; + fmha_args.nhead_stride_bias = (!is_group_mode && (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS)) + ? max_seqlen_q * max_seqlen_k + : 0; + fmha_args.nhead_stride_o = stride_h_o; + fmha_args.nhead_stride_randval = is_group_mode ? 0 : seqlen_q * max_seqlen_k; + fmha_args.nhead_stride_do = stride_h_do; + fmha_args.nhead_stride_lsed = is_group_mode ? max_tokens_q : max_seqlen_q; + fmha_args.nhead_stride_dq_acc = static_cast((is_group_mode ? max_tokens_q : s_q) * d_qk); + fmha_args.nhead_stride_dq = stride_h_dq; + fmha_args.nhead_stride_dk = is_mqa_gqa? stride_h_dk_expanded:stride_h_dk; + fmha_args.nhead_stride_dv = is_mqa_gqa? stride_h_dv_expanded:stride_h_dv; + fmha_args.nhead_stride_dbias = (!is_group_mode) ? max_seqlen_q * max_seqlen_k : 0; + + // setup batch_stride_* arguments + fmha_args.batch_stride_q = is_group_mode ? 0 : stride_b_q; + fmha_args.batch_stride_k = is_group_mode ? 0 : stride_b_k; + fmha_args.batch_stride_v = is_group_mode ? 0 : stride_b_v; + fmha_args.batch_stride_bias = (!is_group_mode && (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS)) + ? 0 + : (is_group_mode ? 0 : bias_h * max_seqlen_q * max_seqlen_k); + fmha_args.batch_stride_o = is_group_mode ? 0 : stride_b_o; + fmha_args.batch_stride_randval = is_group_mode ? 0 : nhead * seqlen_q * max_seqlen_k; + fmha_args.batch_stride_do = is_group_mode ? 0 : stride_b_do; + fmha_args.batch_stride_lsed = is_group_mode ? 0 : nhead * max_seqlen_q; + fmha_args.batch_stride_dq_acc = is_group_mode ? 0 : static_cast(h * s_q * d_qk); + fmha_args.batch_stride_dq = is_group_mode ? 0 : stride_b_dq; + fmha_args.batch_stride_dk = is_group_mode ? 0 : (is_mqa_gqa? stride_b_dk_expanded:stride_b_dk); + fmha_args.batch_stride_dv = is_group_mode ? 0 : (is_mqa_gqa? stride_b_dv_expanded:stride_b_dv); + fmha_args.batch_stride_dbias = is_group_mode ? 0 : h * max_seqlen_q * max_seqlen_k; + fmha_args.split_stride_dq_acc = static_cast(is_group_mode ? (max_tokens_q * h * d_qk) : (b * h * s_q * d_qk)); + + fmha_args.window_size_left = left; + fmha_args.window_size_right = right; + fmha_args.p_drop = p_drop; + fmha_args.p_undrop = p_undrop; + fmha_args.drop_seed_offset = std::pair{philox_seed_ptr, philox_offset_ptr}; + + // modify the max_seqlen_q for better performance in 0-length cases + // lse_thd_ptr used as buffer + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { + if(std::string(env_p) == "1"){ + if(ck_fused_attn_log_config){ + std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + } + fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); + fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); + } + } // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); - - float average_runtime = aiter::mha_bwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_dbias, - s_randval, - deterministic, - uses_bwd_v3, - is_v3_atomic_fp32, - how_v3_bf16_cvt); + log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, fmha_args.has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + + float average_runtime = aiter::mha_bwd(fmha_args, stream_config); if(dump_path){ dump_bwd_timings(dump_path, average_runtime); } @@ -703,6 +689,111 @@ hipError_t ck_attn_bwd( //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); } + return hipSuccess; +} +hipError_t ck_attn_bwd( + DType dtype, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, + const void* q_ptr, + uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, + const void* k_ptr, + uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, + const void* v_ptr, + uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, + const void* bias_ptr, + const void* alibi_slope_ptr, + const void* o_ptr, + uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, + const void* lse_ptr, + const void* do_ptr, + uint64_t stride_b_do, uint64_t stride_h_do, uint64_t stride_s_do, + float scaling_factor, float dropout_probability, + void* philox_seed_ptr, void* philox_offset_ptr, + BiasType attn_bias_type, + MaskType attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + void* dq_ptr, + uint64_t stride_b_dq, uint64_t stride_h_dq, uint64_t stride_s_dq, + void* dq_acc_ptr, + void* dk_expanded_ptr, + void* dv_expanded_ptr, + uint64_t stride_b_dk_expanded, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, + uint64_t stride_b_dv_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, + void* dk_ptr, + uint64_t stride_b_dk, uint64_t stride_h_dk, uint64_t stride_s_dk, + void* dv_ptr, + uint64_t stride_b_dv, uint64_t stride_h_dv, uint64_t stride_s_dv, + void* dbias_expanded_ptr, + void* dbias_ptr, + void* lse_workspace_ptr, + bool deterministic, + bool uses_bwd_v3, + bool is_v3_atomic_fp32, + int how_v3_bf16_cvt, + hipStream_t stream){ + + bool has_dropout = (dropout_probability > 0.f); + bool has_dbias = dbias_ptr!=nullptr; + bool is_mqa_gqa = (h > hg); + bias_enum bias_type; + BiasShape bias_shape; + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + + bool ck_fused_attn_log_config = false; + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { + if (env_p != nullptr && std::string(env_p) == "1") + ck_fused_attn_log_config = true; + } + + hipError_t impl_status = _ck_attn_bwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, + bias_b, bias_h, + s_q, s_kv, + q_ptr, + stride_b_q, stride_h_q, stride_s_q, + k_ptr, + stride_b_k, stride_h_k, stride_s_k, + v_ptr, + stride_b_v, stride_h_v, stride_s_v, + bias_ptr, + alibi_slope_ptr, + nullptr, nullptr, + nullptr, nullptr, + o_ptr, + stride_b_o, stride_h_o, stride_s_o, + lse_ptr, + nullptr, + do_ptr, + stride_b_do, stride_h_do, stride_s_do, + scaling_factor, dropout_probability, + philox_seed_ptr, philox_offset_ptr, + attn_bias_type, + attn_mask_type, + window_size_left, window_size_right, + dq_ptr, + stride_b_dq, stride_h_dq, stride_s_dq, + dq_acc_ptr, + dk_expanded_ptr, + dv_expanded_ptr, + stride_b_dk_expanded, stride_h_dk_expanded, stride_s_dk_expanded, + stride_b_dv_expanded, stride_h_dv_expanded, stride_s_dv_expanded, + dk_ptr, + stride_b_dk, stride_h_dk, stride_s_dk, + dv_ptr, + stride_b_dv, stride_h_dv, stride_s_dv, + dbias_expanded_ptr, + dbias_ptr, + lse_workspace_ptr, + deterministic, + uses_bwd_v3, + is_v3_atomic_fp32, + how_v3_bf16_cvt, + false, + stream); + if (impl_status != hipSuccess) { + return impl_status; + } if(is_mqa_gqa){ dim3 grid(b, s_kv, hg); if (d_qk == d_v) { @@ -859,210 +950,62 @@ hipError_t ck_attn_varlen_bwd( bool is_v3_atomic_fp32, int how_v3_bf16_cvt, hipStream_t stream){ - - bool has_dropout = (dropout_probability > 0.f); - bool has_dbias = false; bool is_mqa_gqa = (h > hg); - /* CK input parameters */ - ck_tile::index_t batch = b; - ck_tile::index_t nhead = h; - ck_tile::index_t hdim_q = d_qk; - ck_tile::index_t nhead_k = hg; - ck_tile::index_t hdim_v = d_v; - ck_tile::index_t max_seqlen_q = s_q; - ck_tile::index_t max_seqlen_k = s_kv; - float scale_s = scaling_factor; - float p_drop = dropout_probability; - float p_undrop = 1.0 - p_drop; - bool is_group_mode = true; - bool s_randval = false; - - // THD does not work with bias - - ck_tile::index_t left, right; - left = window_size_left; - right = window_size_right; - mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") ck_fused_attn_log_config = true; - } - const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); - // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; - - std::string data_type_str = get_data_type_str(dtype); - - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias not used in THD qkv layout - const ck_tile::index_t stride_bias = 0; - const ck_tile::index_t stride_o = stride_s_o; - const ck_tile::index_t stride_randval = max_seqlen_k; - const ck_tile::index_t stride_do = stride_s_do; - const ck_tile::index_t stride_dq = stride_s_dq; - const ck_tile::index_t stride_dk = stride_s_dk; - const ck_tile::index_t stride_dv = stride_s_dv; - const ck_tile::index_t stride_dk_expanded = stride_s_dk_expanded; - const ck_tile::index_t stride_dv_expanded = stride_s_dv_expanded; - const ck_tile::index_t stride_dq_acc = d_qk; //dq_acc of shape (nsplits, H, max_tokens_q, D_qk) - // bias not used in THD qkv layout - const ck_tile::index_t stride_dbias = 0; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - // bias not used in THD qkv layout - const ck_tile::index_t nhead_stride_bias = 0; - const ck_tile::index_t nhead_stride_o = stride_h_o; - const ck_tile::index_t nhead_stride_randval = 0; - const ck_tile::index_t nhead_stride_do = stride_h_do; - // use packed lse - const ck_tile::index_t nhead_stride_lsed = max_tokens_q; - const ck_tile::index_t nhead_stride_dq = stride_h_dq; - const ck_tile::index_t nhead_stride_dk = stride_h_dk; - const ck_tile::index_t nhead_stride_dv = stride_h_dv; - const ck_tile::index_t nhead_stride_dk_expanded = stride_h_dk_expanded; - const ck_tile::index_t nhead_stride_dv_expanded = stride_h_dv_expanded; - // bias not used in THD qkv layout - const ck_tile::index_t nhead_stride_dbias = 0; - const ck_tile::index_t nhead_stride_dq_acc = max_tokens_q*d_qk; //dq_acc of shape (nsplits, H, max_tokens_q, D_qk) - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = 0; - const ck_tile::index_t batch_stride_k = 0; - const ck_tile::index_t batch_stride_v = 0; - // bias not used in THD qkv layout - const ck_tile::index_t batch_stride_bias = 0; - const ck_tile::index_t batch_stride_o = 0; - const ck_tile::index_t batch_stride_randval = 0; - const ck_tile::index_t batch_stride_do = 0; - const ck_tile::index_t batch_stride_lsed = 0; - const ck_tile::index_t batch_stride_dq = 0; - const ck_tile::index_t batch_stride_dk = 0; - const ck_tile::index_t batch_stride_dv = 0; - const ck_tile::index_t batch_stride_dk_expanded = 0; - const ck_tile::index_t batch_stride_dv_expanded = 0; - // bias not used in THD qkv layout - const ck_tile::index_t batch_stride_dbias = 0; - const ck_tile::index_t batch_stride_dq_acc = 0; //dq_acc of shape (nsplits, T, H, D) - const ck_tile::index_t split_stride_dq_acc = max_tokens_q*h*d_qk; - - return fmha_bwd_args{q_ptr, - k_ptr, - v_ptr, - nullptr, - o_ptr, - lse_thd_ptr, - do_ptr, - lse_workspace_ptr, - nullptr, - dq_ptr, - is_mqa_gqa? dk_expanded_ptr:dk_ptr, - is_mqa_gqa? dv_expanded_ptr:dv_ptr, - nullptr, //dbias_ptr - dq_acc_ptr, //dq_acc_buf - cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr - cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr - nullptr, /* seqlen_q_ptr */ - nullptr, /* seqlen_k_ptr */ - cu_seqlen_q_ptr, //cu_seqlen_q_ptr - cu_seqlen_kv_ptr, //cu_seqlen_k_ptr - max_seqlen_q, //seqlen_q, unused in group mode - max_seqlen_k, //seqlen_kv, unused in group mode - batch, - max_seqlen_q, - max_seqlen_k, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - stride_randval, - stride_do, - stride_dq_acc,//stride_dq_acc - stride_dq,//stride_dq - is_mqa_gqa? stride_dk_expanded:stride_dk, - is_mqa_gqa? stride_dv_expanded:stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, //nhead_stride_dq_acc - nhead_stride_dq, - is_mqa_gqa? nhead_stride_dk_expanded:nhead_stride_dk, - is_mqa_gqa? nhead_stride_dv_expanded:nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_o, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, //batch_stride_dq_acc - batch_stride_dq, - is_mqa_gqa? batch_stride_dk_expanded:batch_stride_dk, - is_mqa_gqa? batch_stride_dv_expanded:batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - left, - right, - static_cast(mask_type), - p_drop, - p_undrop, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); - - // modify the max_seqlen_q for better performance in 0-length cases - // lse_thd_ptr used as buffer - if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { - if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ - std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; - } - fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); - fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); - } } - // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); - - float average_runtime = aiter::mha_bwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_enum::no_bias, - has_dbias, - s_randval, + hipError_t impl_status = _ck_attn_bwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, + 0, 0, + max_tokens_q, max_tokens_kv, + q_ptr, + 0, stride_h_q, stride_s_q, + k_ptr, + 0, stride_h_k, stride_s_k, + v_ptr, + 0, stride_h_v, stride_s_v, + nullptr, + nullptr, + cu_seqlen_q_ptr, cu_seqlen_kv_ptr, + cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, + o_ptr, + 0, stride_h_o, stride_s_o, + nullptr, + lse_thd_ptr, + do_ptr, + 0, stride_h_do, stride_s_do, + scaling_factor, dropout_probability, + philox_seed_ptr, philox_offset_ptr, + BiasType::no_bias, + attn_mask_type, + window_size_left, window_size_right, + dq_ptr, + 0, stride_h_dq, stride_s_dq, + dq_acc_ptr, + dk_expanded_ptr, + dv_expanded_ptr, + 0, stride_h_dk_expanded, stride_s_dk_expanded, + 0, stride_h_dv_expanded, stride_s_dv_expanded, + dk_ptr, + 0, stride_h_dk, stride_s_dk, + dv_ptr, + 0, stride_h_dv, stride_s_dv, + nullptr, + nullptr, + lse_workspace_ptr, deterministic, uses_bwd_v3, is_v3_atomic_fp32, - how_v3_bf16_cvt); - if(dump_path){ - dump_bwd_timings(dump_path, average_runtime); - } - if(average_runtime < 0){ - //TODO: better error out system - throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); + how_v3_bf16_cvt, + true, + stream); + if (impl_status != hipSuccess) { + return impl_status; } if(is_mqa_gqa){ dim3 grid(max_tokens_kv, hg); diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index b13f53fc7..571717540 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -28,7 +28,7 @@ void log_fwd_config(const char* func_name, const bool is_v_rowmajor, const bool do_fp8_static_quant, const bool uses_fwd_v3, - const bool how_v3_bf16_cvt, + const int how_v3_bf16_cvt, const aiter::mha_fwd_args& fmha_args){ bool ck_fused_attn_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { @@ -42,8 +42,8 @@ void log_fwd_config(const char* func_name, std::cout<::type>(mask_type)<{philox_seed_ptr, philox_offset_ptr}; fmha_args.use_asm_v3 = uses_fwd_v3; fmha_args.how_v3_bf16_cvt = how_v3_bf16_cvt; fmha_args.v3_api_check = false; @@ -276,6 +288,13 @@ hipError_t _ck_attn_fwd_impl( fmha_args.has_lse = lse_ptr!=nullptr; fmha_args.qscale_type = static_cast(quant_scale_enum::no_scale); fmha_args.has_sink = false; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.sink_size = 0; + fmha_args.min_seqlen_q = 0; + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; // print ck traits and fmha_args when needed log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); @@ -315,128 +334,32 @@ hipError_t ck_attn_fwd( int how_v3_bf16_cvt, hipStream_t stream){ - bool has_dropout = (is_training && dropout_probability > 0.f); - bool has_lse = (lse_ptr != nullptr); - - /* CK input parameters */ - ck_tile::index_t batch = b; - ck_tile::index_t nhead = h; - ck_tile::index_t hdim_q = d_qk; - ck_tile::index_t nhead_k = hg; - ck_tile::index_t hdim_v = d_v; - ck_tile::index_t max_seqlen_q = s_q; - ck_tile::index_t max_seqlen_k = s_kv; - float scale_s = scaling_factor; - float logits_soft_cap = 0.f; - float p_drop = dropout_probability; - bool is_v_rowmajor = true; - bool has_logits_soft_cap = 0.f < logits_soft_cap; - bool do_fp8_static_quant = false; - - bias_enum bias_type; - BiasShape bias_shape; - std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); - - ck_tile::index_t left, right; - left = window_size_left; - right = window_size_right; - mask_enum mask_type = static_cast(attn_mask_type); - - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } - const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); - // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; - - std::string data_type_str = get_data_type_str(dtype); - - const ck_tile::index_t nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: (bias_shape==BiasShape::kBHSS? bias_h* max_seqlen_q * max_seqlen_k: max_seqlen_q*max_seqlen_k); - - aiter::mha_fwd_args fmha_args; - fmha_args.q_ptr = q_ptr; - fmha_args.k_ptr = k_ptr; - fmha_args.v_ptr = v_ptr; - - fmha_args.batch = batch; - fmha_args.seqlen_q = max_seqlen_q; // unused in group mode - fmha_args.hdim_q = hdim_q; - fmha_args.hdim_v = hdim_v; - fmha_args.nhead_q = nhead; - fmha_args.nhead_k = nhead_k; - - fmha_args.stride_q = stride_s_q; - fmha_args.stride_k = stride_s_k; - fmha_args.stride_v = stride_s_v; - fmha_args.nhead_stride_q = stride_h_q; - fmha_args.nhead_stride_k = stride_h_k; - fmha_args.nhead_stride_v = stride_h_v; - fmha_args.batch_stride_q = stride_b_q; - fmha_args.batch_stride_k = stride_b_k; - fmha_args.batch_stride_v = stride_b_v; - - fmha_args.bias_ptr = bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr; - fmha_args.lse_ptr = lse_ptr; - fmha_args.o_ptr = o_ptr; - - fmha_args.seqstart_q_ptr = nullptr; - fmha_args.seqstart_k_ptr = nullptr; - fmha_args.seqlen_k_ptr = nullptr; - fmha_args.seqlen_k = max_seqlen_k; // unused in group mode (or kvcache enabled) - fmha_args.max_seqlen_q = max_seqlen_q; - - fmha_args.scale_s = scale_s; - - fmha_args.logits_soft_cap = logits_soft_cap; - - // bias is of shape [b, h , s_q, s_kv] - fmha_args.stride_bias = bias_type==bias_enum::alibi? 0: max_seqlen_k; - fmha_args.stride_o = stride_s_o; - fmha_args.nhead_stride_bias = nhead_stride_bias; - fmha_args.batch_stride_bias = batch_stride_bias; - // softmax_lse is of shape [b, h, s_q] - fmha_args.nhead_stride_lse = max_seqlen_q; - fmha_args.batch_stride_lse = nhead * max_seqlen_q; - fmha_args.nhead_stride_o = stride_h_o; - fmha_args.batch_stride_o = stride_b_o; - - fmha_args.window_size_left = left; - fmha_args.window_size_right = right; - fmha_args.mask_type = static_cast(mask_type); - - fmha_args.rand_val_ptr = nullptr; - - fmha_args.stride_randval = max_seqlen_k; - fmha_args.nhead_stride_randval = 0; // Unused - fmha_args.batch_stride_randval = 0; - - fmha_args.p_drop = p_drop; - fmha_args.s_randval = 0; - fmha_args.use_asm_v3 = uses_fwd_v3; - fmha_args.how_v3_bf16_cvt = how_v3_bf16_cvt; - fmha_args.v3_api_check = false; - fmha_args.data_type = data_type_str; - fmha_args.is_group_mode = false; - fmha_args.bias_type = static_cast(bias_type); - fmha_args.has_lse = lse_ptr!=nullptr; - fmha_args.qscale_type = static_cast(quant_scale_enum::no_scale); - fmha_args.has_sink = false; - - // print ck traits and fmha_args when needed - log_fwd_config(__FUNCTION__, data_type_str, false, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); - float average_runtime = aiter::mha_fwd(fmha_args, stream_config); - if(dump_path){ - dump_fwd_timings(dump_path, average_runtime); - } - if(average_runtime < 0){ - //TODO: better error out system - throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); - } - return hipSuccess; + return _ck_attn_fwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, + 0, + q_ptr, stride_b_q, stride_h_q, stride_s_q, + k_ptr, stride_b_k, stride_h_k, stride_s_k, + v_ptr, stride_b_v, stride_h_v, stride_s_v, + bias_ptr, + alibi_slope_ptr, + nullptr, nullptr, + nullptr, nullptr, + is_training, + scaling_factor, + dropout_probability, + philox_seed_ptr, philox_offset_ptr, + attn_bias_type, + attn_mask_type, + window_size_left, window_size_right, + o_ptr, + stride_b_o, stride_h_o, stride_s_o, + lse_ptr, + uses_fwd_v3, + how_v3_bf16_cvt, + false, + stream + ); } hipError_t ck_attn_varlen_fwd( From c198cbd6b8deb49b9551e2fcd8071b49aeda59b3 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 6 Feb 2026 09:53:06 -0600 Subject: [PATCH 03/11] Updated logging --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 274 +++++++++--------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 203 ++++++------- 2 files changed, 225 insertions(+), 252 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 3930cc06d..dc7337409 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -333,125 +333,117 @@ __global__ void dbias_reduce_b1ss( } // print the fmha_traits and args passed into ck apis -void log_bwd_config(const char* func_name, - const std::string data_type_str, - const bool is_group_mode, - const mask_enum mask_type, - const bias_enum bias_type, - const bool has_dbias, - const bool has_dropout, - const bool is_store_randval, - const bool is_deterministic, - const bool uses_bwd_v3, - const bool is_v3_atomic_fp32, - const int how_v3_bf16_cvt, - const aiter::mha_bwd_args& fmha_args){ - - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } - if (ck_fused_attn_log_config) { - std::cout<::type>(mask_type)<::type>(bias_type)<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset)) + ); + log_value( + "dropout_offset_ptr", + std::get<1>(std::get>(fmha_args.drop_seed_offset)) + ); } void dump_bwd_timings(const char* dump_path, float average_runtime){ @@ -479,7 +471,6 @@ hipError_t _ck_attn_bwd_impl( const void* o_ptr, uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, const void* lse_ptr, - const void* lse_thd_ptr, const void* do_ptr, uint64_t stride_b_do, uint64_t stride_h_do, uint64_t stride_s_do, float scaling_factor, float dropout_probability, @@ -506,6 +497,8 @@ hipError_t _ck_attn_bwd_impl( bool is_v3_atomic_fp32, int how_v3_bf16_cvt, bool is_group_mode, + const char* func_name, + bool ck_log_config, hipStream_t stream){ bool has_dropout = (dropout_probability > 0.f); @@ -538,15 +531,10 @@ hipError_t _ck_attn_bwd_impl( right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; - } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_log_config}; std::string data_type_str = get_data_type_str(dtype); @@ -574,7 +562,7 @@ hipError_t _ck_attn_bwd_impl( fmha_args.bias_ptr = (bias_type==bias_enum::no_bias || is_group_mode) ? nullptr : (bias_type==bias_enum::alibi? alibi_slope_ptr : bias_ptr); fmha_args.o_ptr = o_ptr; - fmha_args.lse_ptr = is_group_mode ? lse_thd_ptr : lse_ptr; + fmha_args.lse_ptr = lse_ptr; fmha_args.do_ptr = do_ptr; fmha_args.d_ptr = lse_workspace_ptr; fmha_args.rand_val_ptr = nullptr; @@ -670,7 +658,7 @@ hipError_t _ck_attn_bwd_impl( // lse_thd_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { if(std::string(env_p) == "1"){ - if(ck_fused_attn_log_config){ + if(ck_log_config){ std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); @@ -679,7 +667,7 @@ hipError_t _ck_attn_bwd_impl( } // print ck traits and args when needed - log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, fmha_args.has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + log_bwd_config(func_name, fmha_args, ck_log_config); float average_runtime = aiter::mha_bwd(fmha_args, stream_config); if(dump_path){ @@ -739,10 +727,10 @@ hipError_t ck_attn_bwd( BiasShape bias_shape; std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); - bool ck_fused_attn_log_config = false; + bool ck_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; + ck_log_config = true; } hipError_t impl_status = _ck_attn_bwd_impl( @@ -763,7 +751,6 @@ hipError_t ck_attn_bwd( o_ptr, stride_b_o, stride_h_o, stride_s_o, lse_ptr, - nullptr, do_ptr, stride_b_do, stride_h_do, stride_s_do, scaling_factor, dropout_probability, @@ -790,6 +777,8 @@ hipError_t ck_attn_bwd( is_v3_atomic_fp32, how_v3_bf16_cvt, false, + __FUNCTION__, + ck_log_config, stream); if (impl_status != hipSuccess) { return impl_status; @@ -798,7 +787,7 @@ hipError_t ck_attn_bwd( dim3 grid(b, s_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ + if (ck_log_config){ std::cout<(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (ck_fused_attn_log_config){ + if (ck_log_config){ std::cout<(dbias_expanded_ptr), static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (ck_fused_attn_log_config){ + if (ck_log_config){ std::cout< hg); - bool ck_fused_attn_log_config = false; + bool ck_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; + ck_log_config = true; } hipError_t impl_status = _ck_attn_bwd_impl( @@ -975,7 +964,6 @@ hipError_t ck_attn_varlen_bwd( cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, o_ptr, 0, stride_h_o, stride_s_o, - nullptr, lse_thd_ptr, do_ptr, 0, stride_h_do, stride_s_do, @@ -1003,6 +991,8 @@ hipError_t ck_attn_varlen_bwd( is_v3_atomic_fp32, how_v3_bf16_cvt, true, + __FUNCTION__, + ck_log_config, stream); if (impl_status != hipSuccess) { return impl_status; @@ -1011,7 +1001,7 @@ hipError_t ck_attn_varlen_bwd( dim3 grid(max_tokens_kv, hg); if (d_qk == d_v) { dim3 block(d_qk); - if (ck_fused_attn_log_config){ + if (ck_log_config){ std::cout<::type>(mask_type)<::type>(bias_type)<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))<(std::get>(fmha_args.drop_seed_offset))); + log_value("dropout_offset_ptr", std::get<1>(std::get>(fmha_args.drop_seed_offset))); } void dump_fwd_timings(const char* dump_path, float average_runtime){ @@ -153,6 +136,7 @@ hipError_t _ck_attn_fwd_impl( bool uses_fwd_v3, int how_v3_bf16_cvt, bool is_group_mode, + const char* func_name, hipStream_t stream){ bool has_dropout = (is_training && dropout_probability > 0.f); @@ -170,23 +154,20 @@ hipError_t _ck_attn_fwd_impl( float scale_s = scaling_factor; float logits_soft_cap = 0.f; float p_drop = dropout_probability; - bool is_v_rowmajor = true; - bool has_logits_soft_cap = 0.f < logits_soft_cap; - bool do_fp8_static_quant = false; ck_tile::index_t left, right; left = window_size_left; right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - bool ck_fused_attn_log_config = false; + bool ck_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") - ck_fused_attn_log_config = true; + ck_log_config = true; } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_log_config}; std::string data_type_str = get_data_type_str(dtype); @@ -250,13 +231,13 @@ hipError_t _ck_attn_fwd_impl( fmha_args.logits_soft_cap = logits_soft_cap; // bias is of shape [b, h , s_q, s_kv] - fmha_args.stride_bias = bias_type==bias_enum::alibi? 0: max_seqlen_k; + fmha_args.stride_bias = is_group_mode? 0 : (bias_type==bias_enum::alibi? 0: max_seqlen_k); fmha_args.stride_o = stride_s_o; fmha_args.nhead_stride_bias = nhead_stride_bias; fmha_args.batch_stride_bias = batch_stride_bias; // softmax_lse is of shape [b, h, s_q] - fmha_args.nhead_stride_lse = max_seqlen_q; - fmha_args.batch_stride_lse = nhead * max_seqlen_q; + fmha_args.nhead_stride_lse = is_group_mode? max_tokens_q : max_seqlen_q; + fmha_args.batch_stride_lse = is_group_mode? 0 : nhead * max_seqlen_q; fmha_args.nhead_stride_o = stride_h_o; fmha_args.batch_stride_o = stride_b_o; @@ -297,7 +278,7 @@ hipError_t _ck_attn_fwd_impl( fmha_args.block_scale_size_kv = 0; // print ck traits and fmha_args when needed - log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + log_fwd_config(func_name, has_dropout, fmha_args, ck_log_config); float average_runtime = aiter::mha_fwd(fmha_args, stream_config); if(dump_path){ dump_fwd_timings(dump_path, average_runtime); @@ -358,6 +339,7 @@ hipError_t ck_attn_fwd( uses_fwd_v3, how_v3_bf16_cvt, false, + __FUNCTION__, // func_name stream ); } @@ -411,6 +393,7 @@ hipError_t ck_attn_varlen_fwd( uses_fwd_v3, how_v3_bf16_cvt, true, + __FUNCTION__, // func_name stream ); } From a52bb322a58c01f4137d3a2dcaae82eaf6ec00cf Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 6 Feb 2026 11:42:16 -0600 Subject: [PATCH 04/11] Add script for comparing AITER/TE API --- tools/check_mha_fwd_args_usage.py | 81 +++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 tools/check_mha_fwd_args_usage.py diff --git a/tools/check_mha_fwd_args_usage.py b/tools/check_mha_fwd_args_usage.py new file mode 100644 index 000000000..65a6cc526 --- /dev/null +++ b/tools/check_mha_fwd_args_usage.py @@ -0,0 +1,81 @@ +import argparse +import re +from pathlib import Path +from typing import List, Set + +STRUCT_FIELD_RE = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") +STRUCT_END_RE = re.compile(r"^\s*};\s*$") + + +def extract_fields_from_header(text: str, struct_name: str) -> List[str]: + struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") + lines = text.splitlines() + in_struct = False + fields: List[str] = [] + buffer = "" + for line in lines: + if not in_struct: + if struct_start_re.search(line): + in_struct = True + continue + if STRUCT_END_RE.search(line): + break + # skip comments + stripped = line.strip() + if not stripped or stripped.startswith("//"): + continue + line_no_comment = re.sub(r"//.*", "", line) + buffer += " " + line_no_comment.strip() + if ";" not in line_no_comment: + continue + m = STRUCT_FIELD_RE.search(buffer) + if m: + fields.append(m.group(1)) + buffer = "" + return fields + + +def extract_usage_from_source(text: str, var_name: str) -> Set[str]: + assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b") + return set(assign_re.findall(text)) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") + parser.add_argument("--mode", choices=["fwd", "bwd"], required=True, help="Mode: fwd or bwd") + args = parser.parse_args() + + header_path = Path(f"3rdparty/aiter/csrc/include/mha_{args.mode}.h") + source_path = Path(f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{args.mode}.cpp") + header_text = header_path.read_text(encoding="utf-8") + source_text = source_path.read_text(encoding="utf-8") + + header_fields = extract_fields_from_header(header_text, f"mha_{args.mode}_args") + header_set = set(header_fields) + used_fields = extract_usage_from_source(source_text, f"fmha_args") + + missing_in_usage = sorted(header_set - used_fields) + unknown_in_header = sorted(used_fields - header_set) + + print(f"mha_{args.mode}_args fields in header:", len(header_set)) + print(f"mha_{args.mode}_args fields referenced in source:", len(used_fields)) + + if missing_in_usage: + print("\nFields present in header but not referenced in source:") + for name in missing_in_usage: + print(f" - {name}") + else: + print("\nAll header fields are referenced in source.") + + if unknown_in_header: + print("\nFields referenced in source but not in header:") + for name in unknown_in_header: + print(f" - {name}") + else: + print("\nNo unknown fields referenced in source.") + + return 0 + + +if __name__ == "__main__": + main() From 77f0a05a5f73f2a9fd46280d55fb6307cb1f571c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 9 Feb 2026 12:12:02 -0600 Subject: [PATCH 05/11] Reconcile new AITER mask type --- .../common/ck_fused_attn/src/ck_fused_attn_bwd.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index dc7337409..85ce6b489 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -16,6 +16,17 @@ namespace ck_fused_attn{ +int ck_to_aiter_mask_type(mask_enum mask_type, ck_tile::index_t left, ck_tile::index_t right){ + if( + mask_type == mask_enum::no_mask || + mask_type == mask_enum::window_generic + ) return 0; + if(left == -1 && right == 0){ + return mask_type == mask_enum::mask_top_left ? 1 : 2; + } + return 3; +} + // TODO: unify with binary search in TE/common/fused_attn(rocm)/util // no device std::upper_bound // in an increasing array with given size len, search for the index that: @@ -433,6 +444,7 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, log_value("window_size_left", fmha_args.window_size_left); log_value("window_size_right", fmha_args.window_size_right); log_value("mask_type", fmha_args.mask_type); + log_value("ck_mask_type", fmha_args.ck_mask_type); log_value("bias_type", fmha_args.bias_type); log_value("p_drop", fmha_args.p_drop); log_value("p_undrop", fmha_args.p_undrop); @@ -539,7 +551,7 @@ hipError_t _ck_attn_bwd_impl( std::string data_type_str = get_data_type_str(dtype); aiter::mha_bwd_args fmha_args{}; - fmha_args.mask_type = static_cast(mask_type); + fmha_args.mask_type = ck_to_aiter_mask_type(mask_type, left, right); fmha_args.use_asm_v3 = uses_bwd_v3; fmha_args.v3_atomic_fp32 = is_v3_atomic_fp32; fmha_args.v3_bf16_cvt = how_v3_bf16_cvt; From 16372665c4aca27e0b0e5dcbae27d946422baaaf Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 9 Feb 2026 13:37:42 -0600 Subject: [PATCH 06/11] Updated API helper tool --- setup.py | 2 + tools/check_aiter_mha_args_usage.py | 96 +++++++++++++++++++ tools/check_mha_fwd_args_usage.py | 81 ---------------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 1 + 4 files changed, 99 insertions(+), 81 deletions(-) create mode 100644 tools/check_aiter_mha_args_usage.py delete mode 100644 tools/check_mha_fwd_args_usage.py diff --git a/setup.py b/setup.py index 1ae476311..13ae68793 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import List, Tuple import subprocess +import sys import setuptools from setuptools.command.egg_info import egg_info @@ -88,6 +89,7 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF") elif os.getenv("NVTE_FUSED_ATTN_CK") or os.getenv("NVTE_FUSED_ATTN"): cmake_flags.append("-DUSE_FUSED_ATTN_CK=ON") + subprocess.run(sys.executable + " tools/check_aiter_mha_args_usage.py --mode both", shell=True, check=True) if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None: os.environ["NVTE_ENABLE_ROCSHMEM"] = '1' diff --git a/tools/check_aiter_mha_args_usage.py b/tools/check_aiter_mha_args_usage.py new file mode 100644 index 000000000..d483de5bd --- /dev/null +++ b/tools/check_aiter_mha_args_usage.py @@ -0,0 +1,96 @@ +import argparse +import re +from pathlib import Path +from typing import List, Set + +def parse_with_skip_comments(buffer, line, regex, outputs): + # skip comments + stripped = line.strip() + if not stripped or stripped.startswith("//"): + return + line_no_comment = re.sub(r"//.*", "", line) + buffer[0] += " " + line_no_comment.strip() + if ";" not in line_no_comment: + return + match = regex.search(buffer[0]) + if match: + outputs.append(match.group(1)) + buffer[0] = "" + + +def extract_fields_from_header(text: str, struct_name: str) -> List[str]: + struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") + struct_end_re = re.compile(r"^\s*};\s*$") + + struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") + lines = text.splitlines() + in_struct = False + fields: List[str] = [] + buffer = [""] + for line in lines: + if not in_struct: + if struct_start_re.search(line): + in_struct = True + continue + if struct_end_re.search(line): + break + parse_with_skip_comments(buffer, line, struct_field_re, fields) + return fields + + +def extract_usage_from_source(text: str, var_name: str) -> Set[str]: + assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=") + assignments = [] + lines = text.splitlines() + buffer = [""] + for line in lines: + parse_with_skip_comments(buffer, line, assign_re, assignments) + return set(assignments) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") + parser.add_argument("--mode", choices=["fwd", "bwd", "both"], required=True, help="Mode: fwd, bwd, or both") + args = parser.parse_args() + modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode] + mismatch = 0 + for mode in modes: + header_path = Path(f"3rdparty/aiter/csrc/include/mha_{mode}.h") + source_path = Path(f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp") + header_text = header_path.read_text(encoding="utf-8") + source_text = source_path.read_text(encoding="utf-8") + + header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args") + header_set = set(header_fields) + used_fields = extract_usage_from_source(source_text, f"fmha_args") + + missing_in_usage = sorted(header_set - used_fields) + unknown_in_header = sorted(used_fields - header_set) + mismatch += len(missing_in_usage) + len(unknown_in_header) + + print(f"\nAnalyzing mha_{mode}_args\n") + print(f"mha_{mode}_args fields in header:", len(header_set)) + print(f"mha_{mode}_args fields referenced in source:", len(used_fields)) + + if missing_in_usage: + print("\nFields present in header but not referenced in source:") + for name in missing_in_usage: + print(f" - {name}") + else: + print("\nAll header fields are referenced in source.") + + if unknown_in_header: + print("\nFields referenced in source but not in header:") + for name in unknown_in_header: + print(f" - {name}") + else: + print("\nNo unknown fields referenced in source.") + + if mismatch: + print(f"\nTotal mismatched fields: {mismatch}") + return 1 + return 0 + + +if __name__ == "__main__": + main() diff --git a/tools/check_mha_fwd_args_usage.py b/tools/check_mha_fwd_args_usage.py deleted file mode 100644 index 65a6cc526..000000000 --- a/tools/check_mha_fwd_args_usage.py +++ /dev/null @@ -1,81 +0,0 @@ -import argparse -import re -from pathlib import Path -from typing import List, Set - -STRUCT_FIELD_RE = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") -STRUCT_END_RE = re.compile(r"^\s*};\s*$") - - -def extract_fields_from_header(text: str, struct_name: str) -> List[str]: - struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") - lines = text.splitlines() - in_struct = False - fields: List[str] = [] - buffer = "" - for line in lines: - if not in_struct: - if struct_start_re.search(line): - in_struct = True - continue - if STRUCT_END_RE.search(line): - break - # skip comments - stripped = line.strip() - if not stripped or stripped.startswith("//"): - continue - line_no_comment = re.sub(r"//.*", "", line) - buffer += " " + line_no_comment.strip() - if ";" not in line_no_comment: - continue - m = STRUCT_FIELD_RE.search(buffer) - if m: - fields.append(m.group(1)) - buffer = "" - return fields - - -def extract_usage_from_source(text: str, var_name: str) -> Set[str]: - assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b") - return set(assign_re.findall(text)) - - -def main() -> int: - parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") - parser.add_argument("--mode", choices=["fwd", "bwd"], required=True, help="Mode: fwd or bwd") - args = parser.parse_args() - - header_path = Path(f"3rdparty/aiter/csrc/include/mha_{args.mode}.h") - source_path = Path(f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{args.mode}.cpp") - header_text = header_path.read_text(encoding="utf-8") - source_text = source_path.read_text(encoding="utf-8") - - header_fields = extract_fields_from_header(header_text, f"mha_{args.mode}_args") - header_set = set(header_fields) - used_fields = extract_usage_from_source(source_text, f"fmha_args") - - missing_in_usage = sorted(header_set - used_fields) - unknown_in_header = sorted(used_fields - header_set) - - print(f"mha_{args.mode}_args fields in header:", len(header_set)) - print(f"mha_{args.mode}_args fields referenced in source:", len(used_fields)) - - if missing_in_usage: - print("\nFields present in header but not referenced in source:") - for name in missing_in_usage: - print(f" - {name}") - else: - print("\nAll header fields are referenced in source.") - - if unknown_in_header: - print("\nFields referenced in source but not in header:") - for name in unknown_in_header: - print(f" - {name}") - else: - print("\nNo unknown fields referenced in source.") - - return 0 - - -if __name__ == "__main__": - main() diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 6dd2ed352..78293411b 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -218,6 +218,7 @@ hipError_t _ck_attn_fwd_impl( fmha_args.seqstart_q_ptr = seqstart_q_ptr; fmha_args.seqstart_k_ptr = seqstart_k_ptr; fmha_args.seqlen_k_ptr = nullptr; + fmha_args.seqlen_q_ptr = nullptr; fmha_args.cu_seqlen_q_ptr = is_group_mode ? cu_seqlen_q_ptr : nullptr; fmha_args.cu_seqlen_k_ptr = is_group_mode ? cu_seqlen_kv_ptr : nullptr; fmha_args.block_scale_seqstart_q_ptr = nullptr; From 2cb6d82a77b7417d351c04b7fd22f20631ddaa5e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 9 Feb 2026 13:54:30 -0600 Subject: [PATCH 07/11] Formatting --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 94 +++++++++---------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 24 ++--- 2 files changed, 59 insertions(+), 59 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 85ce6b489..e1453aba5 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -56,12 +56,12 @@ __global__ void dk_dv_reduce( DataType *dv, //k,v, dk, dv guaranteed to have the same stride uint64_t stride_b_dkv, uint64_t stride_h_dkv, uint64_t stride_s_dkv){ - + uint64_t batch_idx = blockIdx.x; uint64_t seqlen_idx = blockIdx.y; uint64_t head_k_idx = blockIdx.z; uint64_t hdim_idx = threadIdx.x; - + // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -71,7 +71,7 @@ __global__ void dk_dv_reduce( assert(hdim_idx){ @@ -103,12 +103,12 @@ __global__ void dk_or_dv_reduce( DataType *dk_or_dv, //k,v, dk, dv guaranteed to have the same stride uint64_t stride_b_dk_or_dv, uint64_t stride_h_dk_or_dv, uint64_t stride_s_dk_or_dv){ - + uint64_t batch_idx = blockIdx.x; uint64_t seqlen_idx = blockIdx.y; uint64_t head_k_or_v_idx = blockIdx.z; uint64_t hdim_idx = threadIdx.x; - + // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -117,7 +117,7 @@ __global__ void dk_or_dv_reduce( assert(hdim_idx){ @@ -153,9 +153,9 @@ __global__ void dk_dv_reduce_thd( uint64_t seqlen_idx = blockIdx.x; uint64_t head_k_idx = blockIdx.y; uint64_t hdim_idx = threadIdx.x; - + assert(hdim_idx= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ return; } @@ -175,7 +175,7 @@ __global__ void dk_dv_reduce_thd( uint64_t read_idx = head_k_idx*head_idx_offset*stride_h_dkv_expanded + seqlen_idx*stride_s_dkv_expanded + hdim_idx; uint64_t write_idx = head_k_idx*stride_h_dkv + seqlen_idx* stride_s_dkv + hdim_idx; - + for(uint64_t ii = 0; ii < head_idx_offset; ii++){ // bf16 requires special casting in CK if constexpr (std::is_same_v){ @@ -213,7 +213,7 @@ __global__ void dk_or_dv_reduce_thd( uint64_t seqlen_idx = blockIdx.x; uint64_t head_k_or_v_idx = blockIdx.y; uint64_t hdim_idx = threadIdx.x; - + assert(hdim_idx= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ @@ -233,7 +233,7 @@ __global__ void dk_or_dv_reduce_thd( uint64_t read_idx = head_k_or_v_idx*head_idx_offset*stride_h_dk_or_dv_expanded + seqlen_idx*stride_s_dk_or_dv_expanded + hdim_idx; uint64_t write_idx = head_k_or_v_idx*stride_h_dk_or_dv + seqlen_idx* stride_s_dk_or_dv + hdim_idx; - + for(uint64_t ii = 0; ii < head_idx_offset; ii++){ // bf16 requires special casting in CK if constexpr (std::is_same_v){ @@ -259,7 +259,7 @@ __global__ void dbias_reduce_11ss( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ - + const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ @@ -289,7 +289,7 @@ __global__ void dbias_reduce_1hss( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ - + const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ @@ -319,7 +319,7 @@ __global__ void dbias_reduce_b1ss( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ - + const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ @@ -464,42 +464,42 @@ void dump_bwd_timings(const char* dump_path, float average_runtime){ file << average_runtime << "\n"; } -hipError_t _ck_attn_bwd_impl( +hipError_t _ck_attn_bwd_impl( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, uint64_t max_tokens_q, uint64_t max_tokens_kv, - const void* q_ptr, + const void* q_ptr, uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, const void* bias_ptr, const void* alibi_slope_ptr, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, - const void* o_ptr, + const void* o_ptr, uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, const void* lse_ptr, - const void* do_ptr, + const void* do_ptr, uint64_t stride_b_do, uint64_t stride_h_do, uint64_t stride_s_do, float scaling_factor, float dropout_probability, void* philox_seed_ptr, void* philox_offset_ptr, BiasType attn_bias_type, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* dq_ptr, + void* dq_ptr, uint64_t stride_b_dq, uint64_t stride_h_dq, uint64_t stride_s_dq, void* dq_acc_ptr, void* dk_expanded_ptr, void* dv_expanded_ptr, uint64_t stride_b_dk_expanded, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, uint64_t stride_b_dv_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, - void* dk_ptr, + void* dk_ptr, uint64_t stride_b_dk, uint64_t stride_h_dk, uint64_t stride_s_dk, - void* dv_ptr, + void* dv_ptr, uint64_t stride_b_dv, uint64_t stride_h_dv, uint64_t stride_s_dv, void* dbias_expanded_ptr, void* dbias_ptr, @@ -691,37 +691,37 @@ hipError_t _ck_attn_bwd_impl( } return hipSuccess; } -hipError_t ck_attn_bwd( +hipError_t ck_attn_bwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - const void* q_ptr, + const void* q_ptr, uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, const void* bias_ptr, const void* alibi_slope_ptr, - const void* o_ptr, + const void* o_ptr, uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, - const void* lse_ptr, - const void* do_ptr, + const void* lse_ptr, + const void* do_ptr, uint64_t stride_b_do, uint64_t stride_h_do, uint64_t stride_s_do, float scaling_factor, float dropout_probability, void* philox_seed_ptr, void* philox_offset_ptr, BiasType attn_bias_type, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* dq_ptr, + void* dq_ptr, uint64_t stride_b_dq, uint64_t stride_h_dq, uint64_t stride_s_dq, void* dq_acc_ptr, void* dk_expanded_ptr, void* dv_expanded_ptr, uint64_t stride_b_dk_expanded, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, uint64_t stride_b_dv_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, - void* dk_ptr, + void* dk_ptr, uint64_t stride_b_dk, uint64_t stride_h_dk, uint64_t stride_s_dk, - void* dv_ptr, + void* dv_ptr, uint64_t stride_b_dv, uint64_t stride_h_dv, uint64_t stride_s_dv, void* dbias_expanded_ptr, void* dbias_ptr, @@ -736,7 +736,7 @@ hipError_t ck_attn_bwd( bool has_dbias = dbias_ptr!=nullptr; bool is_mqa_gqa = (h > hg); bias_enum bias_type; - BiasShape bias_shape; + BiasShape bias_shape; std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); bool ck_log_config = false; @@ -883,7 +883,7 @@ hipError_t ck_attn_bwd( dbias_reduce_11ss, grid, block, 0, stream, b, h, s_q, s_kv, static_cast(dbias_expanded_ptr), - static_cast(dbias_ptr));); + static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ if (ck_log_config){ std::cout<, grid, block, 0, stream, b, h, s_q, s_kv, static_cast(dbias_expanded_ptr), - static_cast(dbias_ptr));); + static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ if (ck_log_config){ std::cout<, grid, block, 0, stream, b, h, s_q, s_kv, static_cast(dbias_expanded_ptr), - static_cast(dbias_ptr));); + static_cast(dbias_ptr));); } } return hipSuccess; } -hipError_t ck_attn_varlen_bwd( +hipError_t ck_attn_varlen_bwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t max_tokens_q, uint64_t max_tokens_kv, - const void* q_ptr, + const void* q_ptr, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, - const void* o_ptr, + const void* o_ptr, uint64_t stride_h_o, uint64_t stride_s_o, - const void* lse_thd_ptr, - const void* do_ptr, + const void* lse_thd_ptr, + const void* do_ptr, uint64_t stride_h_do, uint64_t stride_s_do, float scaling_factor, float dropout_probability, void* philox_seed_ptr, void* philox_offset_ptr, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* dq_ptr, + void* dq_ptr, uint64_t stride_h_dq, uint64_t stride_s_dq, void* dq_acc_ptr, void* dk_expanded_ptr, void* dv_expanded_ptr, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, - void* dk_ptr, + void* dk_ptr, uint64_t stride_h_dk, uint64_t stride_s_dk, - void* dv_ptr, + void* dv_ptr, uint64_t stride_h_dv, uint64_t stride_s_dv, void* lse_workspace_ptr, bool deterministic, diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 78293411b..8936ca34f 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -130,7 +130,7 @@ hipError_t _ck_attn_fwd_impl( BiasType attn_bias_type, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* o_ptr, + void* o_ptr, uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, void* lse_ptr, bool uses_fwd_v3, @@ -154,12 +154,12 @@ hipError_t _ck_attn_fwd_impl( float scale_s = scaling_factor; float logits_soft_cap = 0.f; float p_drop = dropout_probability; - + ck_tile::index_t left, right; left = window_size_left; right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); - + bool ck_log_config = false; if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") @@ -277,7 +277,7 @@ hipError_t _ck_attn_fwd_impl( fmha_args.min_seqlen_q = 0; fmha_args.block_scale_size_q = 0; fmha_args.block_scale_size_kv = 0; - + // print ck traits and fmha_args when needed log_fwd_config(func_name, has_dropout, fmha_args, ck_log_config); float average_runtime = aiter::mha_fwd(fmha_args, stream_config); @@ -294,11 +294,11 @@ hipError_t _ck_attn_fwd_impl( hipError_t ck_attn_fwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - const void* q_ptr, + const void* q_ptr, uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, const void* bias_ptr, const void* alibi_slope_ptr, @@ -309,7 +309,7 @@ hipError_t ck_attn_fwd( BiasType attn_bias_type, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* o_ptr, + void* o_ptr, uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, void* lse_ptr, bool uses_fwd_v3, @@ -349,11 +349,11 @@ hipError_t ck_attn_varlen_fwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t max_tokens_q, - const void* q_ptr, + const void* q_ptr, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, @@ -363,7 +363,7 @@ hipError_t ck_attn_varlen_fwd( void* philox_seed_ptr, void* philox_offset_ptr, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* o_ptr, + void* o_ptr, uint64_t stride_h_o, uint64_t stride_s_o, void* lse_thd_ptr, bool uses_fwd_v3, From cf4aa9ee7525c9d91a0959870d772e46ca2e88ef Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 9 Feb 2026 15:33:49 -0600 Subject: [PATCH 08/11] Added sys exit to script --- tools/check_aiter_mha_args_usage.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/check_aiter_mha_args_usage.py b/tools/check_aiter_mha_args_usage.py index d483de5bd..760c017dc 100644 --- a/tools/check_aiter_mha_args_usage.py +++ b/tools/check_aiter_mha_args_usage.py @@ -2,6 +2,7 @@ import re from pathlib import Path from typing import List, Set +import sys def parse_with_skip_comments(buffer, line, regex, outputs): # skip comments @@ -93,4 +94,4 @@ def main() -> int: if __name__ == "__main__": - main() + sys.exit(main()) From e25cea8e42799f0e8d936658fe1d0beaf4c2be96 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 9 Feb 2026 15:55:23 -0600 Subject: [PATCH 09/11] Slightly better error message --- setup.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0f4562772..5fe39220d 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,14 @@ def setup_common_extension() -> CMakeExtension: cmake_flags.append("-DUSE_FUSED_ATTN_CK=OFF") elif os.getenv("NVTE_FUSED_ATTN_CK") or os.getenv("NVTE_FUSED_ATTN"): cmake_flags.append("-DUSE_FUSED_ATTN_CK=ON") - subprocess.run(sys.executable + " tools/check_aiter_mha_args_usage.py --mode both", shell=True, check=True) + try: + subprocess.run( + sys.executable + " tools/check_aiter_mha_args_usage.py --mode both", + shell=True, check=True + ) + except subprocess.CalledProcessError: + print("Error checking AITER mha_args usage.") + sys.exit(1) if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))) and os.getenv("NVTE_ENABLE_ROCSHMEM") is None: os.environ["NVTE_ENABLE_ROCSHMEM"] = '1' From 212247950488bbedb4276c71cd79e6facbd6735e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 11 Feb 2026 10:00:32 -0600 Subject: [PATCH 10/11] Updated AITER_ASM_DIR implementation --- .gitignore | 1 + setup.py | 27 ++++++++++++-- transformer_engine/__init__.py | 7 ---- .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 35 +++++++++++++++++++ 4 files changed, 60 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index d3b18b358..ab0fa3618 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ artifacts/ **/times.csv transformer_engine/build_info.txt transformer_engine/common/util/hip_nvml.* +transformer_engine/lib/aiter diff --git a/setup.py b/setup.py index 5fe39220d..bee669a1e 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,8 @@ from setuptools.command.build_ext import build_ext as BuildExtension +from setuptools.command.editable_wheel import editable_wheel +from setuptools.command.build import SubCommand os.environ["NVTE_PROJECT_BUILDING"] = "1" @@ -57,6 +59,25 @@ def run(self): if not rocm_build(): archs = cuda_archs() +# A custom develop command only used for ROCm builds +class EditableWheel(editable_wheel, SubCommand): + def run(self): + super().run() + if ( + int(os.getenv("NVTE_FUSED_ATTN_CK", "1")) and + int(os.getenv("NVTE_FUSED_ATTN", "1")) + ): + # Ensure that the AITER ASM kernels are properly available at runtime + # by creating a symlink to them. + project_dir = Path(__file__).parent + asm_src_dir = project_dir / '3rdparty' / 'aiter' / 'hsa' + # Must be synced with + # TransformerEngine/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp + asm_target_dir = project_dir / 'transformer_engine' / 'lib' / 'aiter' + if asm_src_dir.is_dir() and not asm_target_dir.is_dir(): + print(f"Setting up symlink for AITER ASM kernels: {asm_target_dir} -> {asm_src_dir}") + asm_target_dir.symlink_to(asm_src_dir) + class TimedBdist(bdist_wheel): """Helper class to measure build time""" @@ -186,6 +207,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: with open("README.rst", encoding="utf-8") as f: long_description = f.read() + cmdclass = {"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} # Settings for building top level empty package for dependency management. if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))): assert bool( @@ -193,7 +215,6 @@ def setup_requirements() -> Tuple[List[str], List[str]]: ), "NVTE_RELEASE_BUILD env must be set for metapackage build." te_cuda_vers = "rocm" if rocm_build() else "cu12" ext_modules = [] - cmdclass = {} package_data = {} include_package_data = False install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],) @@ -204,7 +225,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: else: install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] - cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} + cmdclass["editable_wheel"] = EditableWheel package_data = { "": ["VERSION.txt"], "transformer_engine.pytorch.triton_kernels.gmm": ["configs/*.json"], @@ -250,7 +271,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, - cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, + cmdclass=cmdclass, python_requires=">=3.8", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index 3ff9f03eb..050abc8f7 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -83,11 +83,4 @@ category=RuntimeWarning, ) -# Set AITER_ASM_DIR to point to the asm dir in the installed package -local_asm_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", "3rdparty", "aiter", "hsa", -) -if "AITER_ASM_DIR" not in os.environ: - os.environ["AITER_ASM_DIR"] = local_asm_dir __version__ = str(metadata.version("transformer_engine")) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 26c92ca2b..74dd52d9d 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -5,6 +5,9 @@ ************************************************************************/ #include +#include +#include +#include //once_flag #include "ck_fused_attn_utils.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" #include "mask.hpp" @@ -13,6 +16,38 @@ namespace ck_fused_attn{ +void set_aiter_asm_dir() { + static std::once_flag aiter_asm_dir_once; + std::call_once(aiter_asm_dir_once, []() { + hipDeviceProp_t prop; + hipError_t res= hipGetDeviceProperties(&prop, 0); + if (res != hipSuccess) { + throw std::runtime_error(std::string( + "hipGetDeviceProperties failed with error: ") + hipGetErrorString(res)); + } + switch (prop.major*10 + prop.minor) { + case 94: // Gfx942 + case 95: // Gfx950 + break; + default: + // Unsupported V3 architecture + return; + } + Dl_info info; + dladdr((void*)set_aiter_asm_dir, &info); + setenv("AITER_ASM_DIR", + (std::filesystem::path(info.dli_fname).parent_path() / "aiter").c_str(), 1); + if (const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG")) { + if (std::string(env_p) == "1"){ + // Print the set environment variable for debugging purposes + std::cout << "AITER_ASM_DIR set to: " << getenv("AITER_ASM_DIR") << std::endl; + } + } + }); +} + +bool aiter_asm_dir_loaded = (set_aiter_asm_dir(), true); + std::string get_data_type_str(DType dtype){ std::string data_type_str; if(dtype==DType::kFloat16){ From 4817e7294a6b9c28aea7c336fbf0f922cf0aac9d Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 11 Feb 2026 10:46:40 -0600 Subject: [PATCH 11/11] Update AITER --- 3rdparty/aiter | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/aiter b/3rdparty/aiter index a64fa18e6..dd3bd2c93 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit a64fa18e60235994e4cbfd7059cc2f60d06e743f +Subproject commit dd3bd2c93ed2806f2f6ef88c80b025a510a7cf8a