From e54cb5a713036a3d019f66188ca7c62265fdeb63 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 6 Oct 2025 13:02:38 +0000 Subject: [PATCH 01/88] intial commit --- .../block/block_attention_bias_enum.hpp | 37 + .../unified_attention/block/block_dropout.hpp | 654 +++++++++ .../unified_attention/block/block_masking.hpp | 642 +++++++++ .../block/block_position_encoding.hpp | 205 +++ .../block/block_rotary_embedding.hpp | 108 ++ .../block/page_block_navigator.hpp | 358 +++++ .../ops/unified_attention/block/variants.hpp | 302 ++++ .../kernel/fmha_fwd_v3_kernel.hpp | 450 ++++++ .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 1258 +++++++++++++++++ ...ck_fmha_fwd_v3_pipeline_default_policy.hpp | 603 ++++++++ .../pipeline/block_fmha_pipeline_enum.hpp | 42 + .../pipeline/block_fmha_pipeline_problem.hpp | 60 + 12 files changed, 4719 insertions(+) create mode 100644 include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/block_dropout.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/block_masking.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp create mode 100644 include/ck_tile/ops/unified_attention/block/variants.hpp create mode 100644 include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp diff --git a/include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp b/include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp new file mode 100644 index 00000000000..e5be21e0489 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockAttentionBiasEnum +{ + NO_BIAS = 0, + ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale) + ALIBI = 2, // bias computed with position encoding, applied after scale +}; + +template +struct BlockAttentionBiasEnumToStr; + +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "bias"; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "alibi"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_dropout.hpp b/include/ck_tile/ops/unified_attention/block/block_dropout.hpp new file mode 100644 index 00000000000..8abdd54cd97 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_dropout.hpp @@ -0,0 +1,654 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +// BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and +// 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random +// numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host +// (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp). +// +// The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of +// random numbers (ph_subsequence). +// The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and +// ph_offset). +// This means that subsequences are non-overlapping, reproducible and independent of mask or window. +// +// There are 3 modes (all produce the same results): +// * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates +// the entire 32x32 tile (64 * 16 = 32 * 32). +// * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4 +// warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock > +// MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions +// are needed for generating a 32x32 tile. +// * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2 +// warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp * +// WG::kM one warp can generate two 16x16 tiles. + +namespace detail { +// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values +constexpr index_t philox_per_tile = 64; +} // namespace detail + +struct NullBlockDropout +{ + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + (void)randval_dram_block_window_tmp; + (void)seqlen_qk_start; + + return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); + } +}; + +struct BlockDropout +{ + CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, + index_t i_head, + index_t nheads, + unsigned long long seed, + unsigned long long offset, + float rp_undrop_, + uint8_t p_undrop_in_uint8_t_, + bool is_store_randval_) + : ph_seed(amd_wave_read_first_lane(seed)), + ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * + detail::philox_per_tile)), + rp_undrop(rp_undrop_), + p_undrop_in_uint8_t(p_undrop_in_uint8_t_), + is_store_randval(is_store_randval_) + { + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); + + return randval_dram_window; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + constexpr index_t kN1 = 8; + constexpr index_t kN0 = kNPerStep / kN1; + + constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( + ck_tile::make_tuple(number{}, number{}, number{}), + ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto randval_lds_block_desc = transform_tensor_descriptor( + randval_lds_block_desc_0, + ck_tile::make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(ck_tile::make_tuple(number{}, number{}))), + ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), + ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); + + return randval_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t NIterPerWarp = 1; + + // The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution, + // because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + + // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. + constexpr auto randval_block_inner_part_dstr_encoding = + typename WarpGemmDispatcher::CWarpDstrEncoding{}; + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + randval_block_inner_part_dstr_encoding); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + typename WG::CWarpDstrEncoding{}); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE void Run(void* randval_ptr, + const index_t start_n0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // randval tile in LDS + auto randval_lds = make_tensor_view( + reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); + + auto randval_lds_window = make_tile_window( + randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + + const auto randval_lds_read_window = + make_tile_window(randval_lds_window.get_bottom_tensor_view(), + randval_lds_window.get_window_lengths(), + randval_lds_window.get_window_origin(), + MakeRandValLdsShuffleTileDistribution()); + + const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + auto generate_randval = [&](auto i_m0, auto i_n0) { + // Generate random numbers + uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; + const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; + const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); + if constexpr(IsWG32) + { + // Generate the whole 32x32 tile at once (each tile consists of random numbers taken + // from a separate subsequence of Philox) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0, wg_n0)); + const index_t ph_offset = get_lane_id(); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + else + { + // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether + // MIterPerWarp is equal to 1 or 2) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); + const index_t subtile_m0 = wg_m0 % 2; + if constexpr(get_warp_size() == 32) + { + const index_t ph_offset = (get_lane_id() & 15) + + (((get_lane_id() >> 4) & 1) << 5) + + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + } + else + { + const index_t subtile_n0 = (get_lane_id() >> 4) & 1; + const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); + ph.get_random_4x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); + } + } + } + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // Transpose randval using LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + const auto randval = load_tile(randval_lds_read_window); + block_sync_lds(); + return randval; + }; + + if(is_store_randval) + { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + const auto randval = generate_randval(i_m0, i_n0); + // save to Global + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {0, kNPerStep}); + }); + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); + }); + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); + } + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + const auto randval = generate_randval(i_m0, i_n0); + // Drop values of P based on the generated probabilities + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto p_idx0 = + tile_distributed_index()>{}; + constexpr auto p_idx1 = + tile_distributed_index(), + idx1.impl_.template at<2>()>{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] * rp_undrop + : PComputeDataType(0); + }); + }); + }); + }); + } + + const unsigned long long ph_seed; + const unsigned long long ph_head_offset; + const float rp_undrop; + const uint8_t p_undrop_in_uint8_t; + const bool is_store_randval; +}; + +// TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be +// replaced with NullBlockDropout. This requires changes in xformers and other libs. +template +struct BlockDropoutBwd; + +template +struct BlockDropoutBwd +{ + static constexpr bool IsDropout = false; + static constexpr bool IsStoreRandval = IsStoreRandval_; + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + (void)randval_dram_block_window_tmp; + (void)seqlen_qk_start; + + return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); + } +}; + +template +struct BlockDropoutBwd +{ + static constexpr bool IsDropout = true; + static constexpr bool IsStoreRandval = IsStoreRandval_; + + CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, + index_t i_head, + index_t nheads, + unsigned long long seed, + unsigned long long offset, + float rp_undrop_, + uint8_t p_undrop_in_uint8_t_) + : ph_seed(amd_wave_read_first_lane(seed)), + ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * + detail::philox_per_tile)), + rp_undrop(rp_undrop_), + p_undrop_in_uint8_t(p_undrop_in_uint8_t_) + { + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); + + return randval_dram_window; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + + constexpr auto randval_block_inner_part_dstr_encoding = + typename WarpGemmDispatcher::CWarpDstrEncoding{}; + static_assert( + std::is_same_v, + typename WG::CWarpDstrEncoding>); + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + randval_block_inner_part_dstr_encoding); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, + const index_t start_n0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr bool IsWG32 = WG::kM == 32; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; + constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + auto generate_randval = [&](auto i_m0, auto i_n0) { + // Generate random numbers + uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; + const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; + const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); + if constexpr(IsWG32) + { + // Generate the whole 32x32 tile at once (each tile consists of random numbers + // taken from a separate subsequence of Philox) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0, wg_n0)); + const index_t ph_offset = get_lane_id(); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + else + { + // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether + // MIterPerWarp is equal to 1 or 2) + const unsigned long long ph_subsequence = + bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); + const index_t subtile_m0 = wg_m0 % 2; + if constexpr(get_warp_size() == 32) + { + const index_t ph_offset = (get_lane_id() & 15) + + (((get_lane_id() >> 4) & 1) << 5) + + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + ph.get_random_16x8(random_uint8_t, ph_subsequence); + } + } + else + { + const index_t subtile_n0 = (get_lane_id() >> 4) & 1; + const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); + const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); + if constexpr(MIterPerWarp == 1) + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); + ph.get_random_4x8( + random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); + } + else + { + static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); + ph.get_random_8x8( + random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); + } + } + } + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + return randval_dist_generated; + }; + + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + const auto randval = generate_randval(i_m0, i_n0); + // Drop values of P based on the generated probabilities, negative sign is used to + // distinguish such values ​​later in bwd pipeline. + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + constexpr auto p_idx0 = + tile_distributed_index(), + idx0.impl_.template at<1>(), + idx0.impl_.template at<2>()>{}; + constexpr auto p_idx1 = tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] + : -p_compute[p_idx]; + }); + }); + // save to Global + if constexpr(IsStoreRandval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {kMPerStep, 0}); + } + }); + if constexpr(IsStoreRandval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); + } + }); + if constexpr(IsStoreRandval) + { + move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); + } + } + + const unsigned long long ph_seed; + const unsigned long long ph_head_offset; + const float rp_undrop; + const uint8_t p_undrop_in_uint8_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp new file mode 100644 index 00000000000..2c45945fac0 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -0,0 +1,642 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +enum struct GenericAttentionMaskEnum +{ + NO_MASK = 0, + + // below enum could be causal, or sliding window + MASK_FROM_TOP_LEFT = 1, + MASK_FROM_BOTTOM_RIGHT = 2, + + // this enum maybe not used by xformer/FA, since it's hard to + // specify left/right window for varlen case. put it here for + // debug purpose + MASK_GENERIC, +}; + +// clang-format off +/* generic Attention Mask Coordinate + use x(horizontal axis), y(vertical axis) to describe mask. + top-left corner is origin + + x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask) + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + 1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + l=7,-1/r=0(tl) l=7,-1/r=0(br) + + x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2 + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 + * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1 + l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl) + l=4/r=0(br) l=4/r=2(br) l=4/r=4(br) + + x=4/y=-1 x=6/y=-1 x=8/y=-1 + * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 + * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 + * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1 + * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1 + + x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r) + * * * * * * * * 1 * * * * * * * + * * * * * * * * 1 1 * * 1 * * * + * * * * * * * * 1 1 1 * 1 1 * * + 1 * * * * * * * 1 1 1 1 1 1 1 * + 1 1 * * * * * * 1 1 1 1 1 1 1 1 + + Validations: + x + y > 1 (x + y >= 2) + + Note: + y = seq_q, x = 1 -> top-left + y = seq_q, x = seq_k - seq_q + 1 -> bottom-right + y < seq_q, x < seq_k -> local-attn + y = seq_q, x = seq_k -> no mask + +*/ +namespace impl { + template struct MaskName; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mn"; }; + template<> struct MaskName { static constexpr const char * name = "mc"; }; + template<> struct MaskName { static constexpr const char * name = "mg"; }; +} +// clang-format on + +template +struct GenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask, + // else only upper-right could have mask + + static constexpr const char* name = impl::MaskName::name; + + CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_) + : GenericAttentionMask(0, 0, y_total_, x_total_) + { + } + + CK_TILE_HOST_DEVICE + GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + { + } + template + CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord) + : y(mask_coord.at(number<0>{})), + x(mask_coord.at(number<1>{})), + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + if constexpr(IsLocal) + { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + } + else + { + return 0; + } + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + return i_x >= x_total; + } + else + { + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + i_y + 1; + index_t x_end = min(i_y + x, x_total); + + if constexpr(IsLocal) + { + return i_x < x_start || i_x >= x_end; + } + else + { + return i_x >= x_end || i_y >= y_total; + } + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number, number) const + { + if constexpr(!IsMasking) + { + // TODO: no need to check begin + return (i_tile_left + TileWidth) > x_total; + } + else + { + if constexpr(IsLocal) + { + // check top-right corner > x or left-borrom corner < x + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = i_tile_top + TileHeight; + index_t x_end = min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > (i_tile_top + x); + bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); + bool is_partial_out_of_bound = + i_tile_right > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + } + else + { + // only need to check top-right corner > x + index_t i_tile_right = i_tile_left + TileWidth; + index_t x_end = min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > x_end; + return top_right_edge; + } + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + +// clang-format off +namespace impl { + template struct SimplifiedMaskName; + template<> struct SimplifiedMaskName { static constexpr const char * name = "nomask"; }; + template<> struct SimplifiedMaskName { static constexpr const char * name = "mask"; }; +} +// clang-format on + +// this version only have 2 variation: masking and non-masking +// This is more friendly to codegen (e.g. need generate less kernel) +// ... with the trade-off that may have more instruction in causal mode +template +struct SimplifiedGenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + + static constexpr const char* name = impl::SimplifiedMaskName::name; + + CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_) + : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) + { + } + + CK_TILE_HOST_DEVICE + SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + { + } + template + CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord) + : y(mask_coord.at(number<0>{})), + x(mask_coord.at(number<1>{})), + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + template + CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, + number height, + number width, + index_t num_splits, + index_t i_split) const + { + auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); + + const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); + const index_t split_start = x_per_split * i_split; + const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); + + return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), + ck_tile::min(origin_end, split_end)); + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + return i_x >= x_total; + } + else + { + index_t x_start = -y + i_y + 1; // this could be negative, but it's fine + index_t x_end = min(i_y + x, x_total); // need min in case x is padded + + return i_x < x_start || i_x >= x_end || i_y >= y_total; + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_y, index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + // return (i_x < x_total) && ((i_x + TileWidth) > x_total); + + // TODO: no need to check begin + return (i_x + TileWidth) > x_total; + } + else + { + // check top-right corner > x or left-borrom corner < x + index_t i_x_end = i_x + TileWidth; + index_t i_y_end = i_y + TileHeight; + // index_t x_end = min(i_y + x, x_total); + + bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad + bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad + // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + +// clang-format off +namespace impl { + template struct SimplifiedRatioMaskName; + template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "nomask"; }; + template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "mask"; }; +} +// clang-format on + +// this version is used for cases that the step length of y-direction changes greater than one. It +// means that the mask is not a regular triangular matrix. + +// clang-format off +/* y_ratio is used to describe the step length of y-direction changes + in certain performance optimization scenarios like merging seqlen + and qk_head_ratio, for example: + + x=1/y=6/y_ratio=2(top-left) + 1 * * * * * * * + 1 * * * * * * * + 1 1 * * * * * * + 1 1 * * * * * * + 1 1 1 * * * * * + 1 1 1 * * * * * + +*/ +// clang-format on +template +struct SimplifiedRatioAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + + static constexpr const char* name = impl::SimplifiedRatioMaskName::name; + + CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_) + : SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{}) + { + } + + CK_TILE_HOST_DEVICE + SimplifiedRatioAttentionMask( + index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_) + : SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast(y_ratio_mdiv_.get()), + /*x_=*/x_, + /*y_total_=*/y_total_, + /*x_total_=*/x_total_, + /*y_real_=*/y_real_, + /*y_ratio_=*/static_cast(y_ratio_mdiv_.get()), + /*y_ratio_mdiv_=*/y_ratio_mdiv_) + + { + } + CK_TILE_HOST_DEVICE + SimplifiedRatioAttentionMask(index_t y_, + index_t x_, + index_t y_total_, + index_t x_total_, + index_t y_real_, + index_t y_ratio_, + mdiv y_ratio_mdiv_) + : y(y_), + x(x_), + y_total(y_total_), + x_total(x_total_), + y_real(y_real_), + y_ratio(y_ratio_), + y_ratio_mdiv(y_ratio_mdiv_) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = -y_real + + static_cast(y_ratio_mdiv.div(static_cast(i_y))) + + 1; + + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + uint32_t y_offset = i_y + YTile - 1; + index_t tmp = min(static_cast(y_ratio_mdiv.div(y_offset)) + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck_tile::make_tuple(x_start, x_end); + } + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max((-x + i_x + 1) * y_ratio, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + return i_x >= x_total; + } + else + { + index_t x_tmp = static_cast(y_ratio_mdiv.div(static_cast(i_y))); + index_t x_start = -y_real + x_tmp + 1; + index_t x_end = min(x_tmp + x, + x_total); // need min in case x is padded + return i_x < x_start || i_x >= x_end || i_y >= y_total; + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() + // can be used as a fast-path to decide if do per-pixel check or not + template + CK_TILE_HOST_DEVICE constexpr auto + IsEdgeTile(index_t i_y, index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + // the only case that need do following compare is under kPadSeqLenK + // ... for non-masking kernel. + // return (i_x < x_total) && ((i_x + TileWidth) > x_total); + + return (i_x + TileWidth) > x_total; + } + else + { + // check top-right corner > x or left-borrom corner < x + index_t i_x_end = i_x + TileWidth; + index_t i_y_end = i_y + TileHeight; + // index_t x_end = min(i_y + x, x_total); + uint32_t y_tmp = static_cast(i_y); + bool top_right_edge = i_x_end > min(static_cast(y_ratio_mdiv.div(y_tmp)) + x, + x_total); // consider right pad + bool bottom_left_edge = + i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad + return top_right_edge || bottom_left_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; + // y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y + index_t y_real; + index_t y_ratio; + mdiv y_ratio_mdiv; +}; + +// TODO: prefer use this function in host code +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + // TODO: below should all use sgpr arithmetic + index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1; + index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1; + + left_size = left_size < 0 ? left_size_tmp : left_size; + right_size = right_size < 0 ? right_size_tmp : right_size; + + index_t x_tmp = is_top_left ? 0 : x_total - y_total; + index_t y_tmp = is_top_left ? 0 : y_total - x_total; + + index_t x = 1 + right_size + x_tmp; + index_t y = 1 + left_size + y_tmp; + + return ck_tile::make_tuple(y, x, y_total, x_total); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + auto r = make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, is_top_left); + return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp b/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp new file mode 100644 index 00000000000..703ec0967ab --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include +#include + +namespace ck_tile { + +enum struct PositionEncodingEnum +{ + NO = 0, + ALIBI = 1, +}; + +/* +VERTICAL: + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + +TOP_LEFT(but negative): + [0] 1 2 3 4 5 + 1 [0] 1 2 3 4 + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + +FROM_BOTTOM_RIGHT(but negative): + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + 4 3 2 1 [0] 1 + 5 4 3 2 1 [0] +*/ + +enum struct AlibiMode +{ + VERTICAL = 0, + FROM_TOP_LEFT = 1, // keep sync with mask enum + FROM_BOTTOM_RIGHT = 2, +}; + +template +struct Alibi +{ + static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32, + "for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t"); + + // RowMajor here means if pixel within the same thread are along the row, or col + // this may impact the performance of update(), while the result are the same. + // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false + CK_TILE_HOST_DEVICE Alibi(DataType slope_, + index_t y_total_, + index_t x_total_, + AlibiMode mode_ = AlibiMode::VERTICAL) + { + slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_; + + shift_left_up = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + }(); + shift_right_down = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + }(); + mode = mode_; + } + + CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); } + + CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) + { + if constexpr(LogMaxSadOprndSize <= 16) + { + return sad_u16( + static_cast(x), static_cast(y), static_cast(acc)); + } + + return sad_u32(x, y, acc); + } + + CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) + { + if constexpr(RowMajor) + { + // at least 3 instructions per row + index_t current_zero_point = + mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down; + + // for every threads, most of the pixels are along the row, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(col_idx + shift_left_up), + 0)); + pixel += slope * position; + } + else + { + // at least 3 instructions per col; + index_t current_zero_point = mode == AlibiMode::VERTICAL + ? row_idx + col_idx + shift_right_down + : col_idx + shift_right_down; + + // for every threads, most of the pixels are along the col, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(row_idx + shift_left_up), + 0)); + pixel += slope * position; + } + } + + DataType slope; // float? + index_t shift_left_up; // always possitive + index_t shift_right_down; // always possitive + AlibiMode mode; +}; + +template +struct EmptyPositionEncoding +{ + CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/) + { + } +}; + +// +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +template +CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, + index_t window_left_size, + index_t window_right_size, + index_t y_total, + index_t x_total, + GenericAttentionMaskEnum mask_enum) +{ + // assume mask_enum will never be NO_MASK, since if we do not have mask, it's + // totally OK to use constexpr + bool is_causal = window_left_size < 0 && window_right_size == 0; + AlibiMode alibi_mode = + is_causal ? AlibiMode::VERTICAL + : static_cast(mask_enum) /*either top-left or bottom-right*/; + return Alibi{slope, y_total, x_total, alibi_mode}; +} + +// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 +// Do we need a device version? +template +CK_TILE_HOST std::vector get_alibi_slopes(ck_tile::index_t nheads) +{ + auto get_slopes_power_of_2 = [](ck_tile::index_t n) { + float start = std::powf( + static_cast(2), + -std::powf(static_cast(2), -static_cast((integer_log2_floor(n) - 3)))); + + std::vector rtn; + for(auto i = 0; i < n; i++) + { + rtn.push_back(static_cast(start * std::powf(start, i))); + } + return rtn; + }; + if(is_power_of_two_integer(nheads)) + { + // power of 2 calculation + return get_slopes_power_of_2(nheads); + } + else + { + ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads); + auto v0 = get_slopes_power_of_2(closest_power_of_2); + auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2); + auto v1_sliced = [&](auto vec, ck_tile::index_t rem) { + std::vector sliced; + for(ck_tile::index_t i = 0; i < static_cast(vec.size()); i++) + { + if(i % 2 == 0) + sliced.push_back(vec[i]); + } + std::vector sliced_2(sliced.begin(), sliced.begin() + rem); + return sliced_2; + }(v1, nheads - closest_power_of_2); + v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end()); + return v0; + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp b/include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp new file mode 100644 index 00000000000..51732792990 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class RotaryEmbeddingEnum +{ + NONE = 0, + INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc + HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc +}; + +template +struct RotaryEmbeddingEnumToStr; + +template <> +struct RotaryEmbeddingEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct RotaryEmbeddingEnumToStr +{ + static constexpr const char* name = "inter"; +}; +template <> +struct RotaryEmbeddingEnumToStr +{ + static constexpr const char* name = "half"; +}; + +template +struct BlockRotaryEmbedding +{ + template + CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile, + OtherDramBlockWindow other_window, + RotaryCosDramBlockWindow rotary_cos_window, + RotarySinDramBlockWindow rotary_sin_window, + index_t rotary_dim, + index_t thread_end) + { + using DataType = typename remove_cvref_t::DataType; + + if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) + { + auto rotary_cos_tile = load_tile(rotary_cos_window); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + if(thread_end <= rotary_dim) + { + constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); + static_for<0, thread_buffer_size, 2>{}([&](auto idx) { + const auto left = type_convert(tile.thread_buf_[idx]); + const auto right = type_convert(tile.thread_buf_[idx + 1]); + + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx / 2]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx / 2]); + + tile.thread_buf_[idx] = type_convert(left * cos - right * sin); + tile.thread_buf_[idx + 1] = type_convert(right * cos + left * sin); + }); + } + } + else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) + { + if(thread_end <= rotary_dim) + { + const bool is_left = (thread_end <= (rotary_dim / 2)); + + move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); + auto other_tile = load_tile(other_window); + + move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_cos_tile = load_tile(rotary_cos_window); + + move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); + auto rotary_sin_tile = load_tile(rotary_sin_window); + + constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); + static_for<0, thread_buffer_size, 1>{}([&](auto idx) { + const auto curr = type_convert(tile.thread_buf_[idx]); + const auto other = type_convert(other_tile.thread_buf_[idx]); + + const auto cos = + type_convert(rotary_cos_tile.thread_buf_[idx]); + const auto sin = + type_convert(rotary_sin_tile.thread_buf_[idx]); + + tile.thread_buf_[idx] = + type_convert(curr * cos + other * (is_left ? -sin : sin)); + }); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp b/include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp new file mode 100644 index 00000000000..f1e6101d1d4 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp @@ -0,0 +1,358 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" + +namespace ck_tile { + +// assume that we have only 1 page-block/tensor view +template +struct TrivialPageBlockNavigator +{ + using DataType = typename TensorView::DataType; + using WindowOrigin = multi_index<2>; + + CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_) + : tensor_view(tensor_view_) + { + } + + template + CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin) const + { + return make_tuple(/*block_index=*/0, + ck_tile::make_tile_window(tensor_view, window_lengths, window_origin)); + } + + template + CK_TILE_HOST_DEVICE constexpr auto + make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin, + const TileDistribution& tile_distribution) const + { + return make_tuple( + /*block_index=*/0, + ck_tile::make_tile_window( + tensor_view, window_lengths, window_origin, tile_distribution)); + } + + template + CK_TILE_HOST_DEVICE static index_t + move_tile_window(index_t /*block_index*/, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) + { + ck_tile::move_tile_window(tile_window, step); + + return /*block_index=*/0; + } + + template + CK_TILE_HOST_DEVICE index_t + move_tile_window(index_t /*block_index*/, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step, + index_t /*id*/) const + { + + ck_tile::move_tile_window(tile_window, step); + return 0; + } + + template + CK_TILE_HOST_DEVICE index_t + prefetch_table_id(index_t /*block_index*/, + TileWindow /*tile_window*/, + const typename remove_cvref_t::BottomTensorIndex& /*step*/) const + { + return -1; + } + + CK_TILE_HOST_DEVICE static constexpr WindowOrigin + to_local_window_origin(const WindowOrigin& global_window_origin) + { + return global_window_origin; + } + + CK_TILE_HOST_DEVICE static constexpr WindowOrigin + to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin) + { + return local_window_origin; + } + + private: + TensorView tensor_view; +}; + +// default page-block navigator, assume that tensor view size is same as page-block size or smaller +// if tile window on last page-block +template +struct PageBlockNavigator +{ + using DataType = DataType_; + static_assert(std::is_same_v); + static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window"); + using WindowOrigin = multi_index<2>; + + CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t* physical_blocks_, + long_index_t block_stride_, + long_index_t fixed_offset_, + const int32_t* physical_block_indices_, + index_t num_blocks_, + index_t page_block_size_, + const TensorView& complete_view_, + const TensorView& last_view_) + : physical_blocks(reinterpret_cast(physical_blocks_)), + block_stride(block_stride_), + fixed_offset(fixed_offset_), + physical_block_indices(physical_block_indices_), + num_blocks(num_blocks_), + page_block_size(page_block_size_), + complete_view(complete_view_), + last_view(last_view_) + { + } + + template + CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin) const + { + const index_t block_index = get_block_index(window_origin); + const WindowOrigin local_window_origin = to_local_window_origin(window_origin); + + auto new_tile_window = + ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view, + window_lengths, + local_window_origin); + new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index)); + + return make_tuple(block_index, new_tile_window); + } + + template + CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths, + const WindowOrigin& window_origin, + const TileDistribution& tile_distribution) const + { + const index_t block_index = get_block_index(window_origin); + const WindowOrigin local_window_origin = to_local_window_origin(window_origin); + + auto new_tile_window = + ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view, + window_lengths, + local_window_origin, + tile_distribution); + new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index)); + + return make_tuple(block_index, new_tile_window); + } + + template + CK_TILE_HOST_DEVICE index_t + move_tile_window(index_t block_index, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) const + { + + ck_tile::move_tile_window(tile_window, step); + + const WindowOrigin global_window_origin = + to_global_window_origin(block_index, tile_window.get_window_origin()); + const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin); + + const index_t new_block_index = get_block_index(global_window_origin); + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); + tile_window.set_window_origin(local_window_origin); + tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); + + return new_block_index; + } + + template + CK_TILE_HOST_DEVICE index_t + move_tile_window(index_t block_index, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step, + index_t id) const + { + ck_tile::move_tile_window(tile_window, step); + + const WindowOrigin global_window_origin = + to_global_window_origin(block_index, tile_window.get_window_origin()); + const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin); + + const index_t new_block_index = get_block_index(global_window_origin); + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); + tile_window.set_window_origin(local_window_origin); + if(id >= 0) + tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride + + fixed_offset); + else + tile_window.set_bottom_tensor_view_data_ptr(nullptr); + + return new_block_index; + } + + template + CK_TILE_HOST_DEVICE index_t + prefetch_table_id(index_t block_index, + TileWindow& tile_window, + const typename remove_cvref_t::BottomTensorIndex& step) const + { + auto local_tile_window = tile_window; // not affect origin window + ck_tile::move_tile_window(local_tile_window, step); + + const WindowOrigin global_window_origin = + to_global_window_origin(block_index, local_tile_window.get_window_origin()); + const index_t new_block_index = get_block_index(global_window_origin); + + if(new_block_index < num_blocks) + { + return physical_block_indices[new_block_index]; + } + else + { + return -1; + } + } + + CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const + { + return block_index == num_blocks - 1; + } + + template + CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index, + const TileWindow& tile_window) const + { + const index_t origin = tile_window.get_window_origin().at(number{}); + const index_t length = tile_window.get_window_lengths().at(number{}); + return (block_index < num_blocks - 1) && (page_block_size < origin + length); + } + + template + CK_TILE_HOST_DEVICE void + move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const + { + const multi_index<2> step = [&]() { + const index_t origin_diff = (block_index - new_block_index) * page_block_size; + if constexpr(VirtualDim == 0) + { + return make_multi_index(origin_diff, 0); + } + else + { + return make_multi_index(0, origin_diff); + } + }(); + + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); + tile_window.set_window_origin(tile_window.get_window_origin() + step); + tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); + } + + CK_TILE_HOST_DEVICE WindowOrigin + to_local_window_origin(const WindowOrigin& global_window_origin) const + { + if constexpr(VirtualDim == 0) + { + const index_t length = global_window_origin.at(number<0>{}); + const index_t num_complete_blocks = integer_divide_floor(length, page_block_size); + return make_multi_index(length - page_block_size * num_complete_blocks, + global_window_origin.at(number<1>{})); + } + else + { + const index_t length = global_window_origin.at(number<1>{}); + const index_t num_complete_blocks = integer_divide_floor(length, page_block_size); + return make_multi_index(global_window_origin.at(number<0>{}), + length - page_block_size * num_complete_blocks); + } + } + + CK_TILE_HOST_DEVICE WindowOrigin + to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const + { + if constexpr(VirtualDim == 0) + { + return make_multi_index(block_index * page_block_size + + local_window_origin.at(number<0>{}), + local_window_origin.at(number<1>{})); + } + else + { + return make_multi_index(local_window_origin.at(number<0>{}), + block_index * page_block_size + + local_window_origin.at(number<1>{})); + } + } + + private: + CK_TILE_HOST_DEVICE + DataType* get_block_ptr(index_t block_index) const + { + if(block_index < num_blocks) + { + return physical_blocks + physical_block_indices[block_index] * block_stride + + fixed_offset; + } + else + { + return nullptr; + } + } + + CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const + { + return integer_divide_floor(global_window_origin.at(number{}), page_block_size); + } + + DataType* physical_blocks; + long_index_t block_stride; + long_index_t fixed_offset; + + const int32_t* physical_block_indices; + index_t num_blocks; + index_t page_block_size; + + TensorView complete_view; + TensorView last_view; +}; + +template +CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view) +{ + return TrivialPageBlockNavigator(tensor_view); +} + +template +CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t* physical_blocks, + long_index_t block_stride, + long_index_t fixed_offset, + const int32_t* physical_block_indices, + index_t num_blocks, + index_t page_block_size, + const TensorView& complete_view, + const TensorView& last_view) +{ + return PageBlockNavigator(physical_blocks, + block_stride, + fixed_offset, + physical_block_indices, + num_blocks, + page_block_size, + complete_view, + last_view); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/variants.hpp b/include/ck_tile/ops/unified_attention/block/variants.hpp new file mode 100644 index 00000000000..d8b0cdbb86b --- /dev/null +++ b/include/ck_tile/ops/unified_attention/block/variants.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include +#include + +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0 +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1 + +#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH +#endif + +#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM +#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0 +#endif + +namespace ck_tile { +namespace internal { +__device__ inline float +exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp) +{ +#if(defined(__gfx90a__) || defined(__gfx94__)) && \ + (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ + CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) + /// NOTICE: Make sure softmax_scale is stored in SGPR + float result, numerator, denominator; + asm volatile( + "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n" + "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n" + "v_rcp_f32_e32 %[denominator], %[denominator]\n" + "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n" + "v_mul_f32_e32 %[result], %[numerator], %[denominator]" + : [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result) + : [softmax_scale] "s"(softmax_scale), + [logits] "v"(logits), + [logits_soft_cap_rcp] "v"(logits_soft_cap_rcp)); + return result; +#else + return softmax_scale * logits * rcp(1.f + abs(logits * logits_soft_cap_rcp)); +#endif +} +} // namespace internal + +template +struct StandardAttentionParams +{ + __device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_) + : impl_mask(impl_mask_), sm_scale(sm_scale_) + { + } + + const ImplMask& impl_mask; + float sm_scale; +}; + +template +struct LogitsSoftCapParams +{ + __device__ + LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) + : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) + { + if(0.f < logits_soft_cap) + { + logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap); + } + else + { + logits_soft_cap_rcp = 0.f; + } + + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + __host__ + LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) + : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) + { + if(0.f < logits_soft_cap) + { + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap_rcp = 0.f; + } + + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + __device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_, + float sm_scale_, + float logits_soft_cap_, + float logits_soft_cap_rcp_) + : impl_mask(impl_mask_), + sm_scale(sm_scale_), + logits_soft_cap(logits_soft_cap_), + logits_soft_cap_rcp(logits_soft_cap_rcp_) + { + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + const ImplMask& impl_mask; + float sm_scale; + float logits_soft_cap; + float logits_soft_cap_rcp; +}; + +struct StandardAttention +{ + __device__ __host__ StandardAttention() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + return type_convert(q) * params.sm_scale; + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return logits; + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +template +struct LogitsSoftCap +{ + __device__ __host__ LogitsSoftCap() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + if constexpr(UseExp2) + { + return q; + } + else + { + return type_convert(q) * params.sm_scale; + } + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform(const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + if constexpr(UseExp2) + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); +#endif + } + else + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanhf(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return type_convert(logits) * + rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +constexpr uint32_t CUSTOM_MASK = 1U; +constexpr uint32_t SLIDING_WINDOW = 2U; +constexpr uint32_t LOGITS_SOFT_CAP = 4U; +constexpr uint32_t ALIBI = 8U; + +template +struct ComposedAttention +{ + static constexpr bool use_exp2 = UseExp2; + + static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0; + + __device__ __host__ ComposedAttention() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + if constexpr(use_logits_soft_cap && UseExp2) + { + return q; + } + return type_convert(q) * params.sm_scale; + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform(const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + if constexpr(use_logits_soft_cap) + { + if constexpr(UseExp2) + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); +#endif + } + else + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanhf(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return type_convert(logits) * + rcp(1.f + + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + } + return logits; + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp new file mode 100644 index 00000000000..9d164b639ed --- /dev/null +++ b/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp @@ -0,0 +1,450 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" + +#include +#include + +namespace ck_tile { + +template +struct FmhaFwdV3Kernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + // The attention is default causal + struct UnifiedAttentionCommonKargs + { + const void* q_ptr; + const void* k_ptr; // [num_blks, blk_size, num_kv_heads, head_size] + const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size] + void* o_ptr; + + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t num_queries_per_kv; + // scales + float scale_s; + float scale; + float scale_k; + float scale_v; + float scale_out; + + ck_tile::index_t total_num_q_blocks; + ck_tile::index_t query_stride_0; + ck_tile::index_t query_stride_1; + ck_tile::index_t stride_k_cache_0; + ck_tile::index_t stride_k_cache_1; + ck_tile::index_t stride_k_cache_2; + ck_tile::index_t stride_k_cache_3; + ck_tile::index_t stride_v_cache_0; + ck_tile::index_t stride_v_cache_1; + ck_tile::index_t stride_v_cache_2; + ck_tile::index_t stride_v_cache_3; + ck_tile::index_t output_stride_0; + ck_tile::index_t output_stride_1; + ck_tile::index_t HEAD_SIZE_PADDED; + }; + + + struct UnifiedAttentionVarlenKargs + { + const int32_t* block_tables_ptr; + const int32_t* seq_lens_ptr; // seq len in each batch + const int32_t* query_start_len_ptr; // [num_seqs+1] + + ck_tile::index_t num_seqs; // number of batches for q + ck_tile::index_t BLOCK_SIZE; // Block size for kv cache. to 2's exponent???? + ck_tile::index_t BLOCK_Q; // Block size for kv cache. to 2's exponent???? + }; + + struct Kargs { + UnifiedAttentionCommonKargs unifiedAttentionCommonKargs; + UnifiedAttentionVarlenKargs unifiedAttentionVarlenKargs; + }; + + // using Kargs = FmhaFwdGroupModeKargs; + + CK_TILE_HOST static constexpr Kargs MakeKargs( + const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* o_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t num_queries_per_kv, + float scale_s, + float scale, + float scale_k, + float scale_v, + float scale_out, + ck_tile::index_t total_num_q_blocks, + ck_tile::index_t query_stride_0, + ck_tile::index_t query_stride_1, + ck_tile::index_t stride_k_cache_0, + ck_tile::index_t stride_k_cache_1, + ck_tile::index_t stride_k_cache_2, + ck_tile::index_t stride_k_cache_3, + ck_tile::index_t stride_v_cache_0, + ck_tile::index_t stride_v_cache_1, + ck_tile::index_t stride_v_cache_2, + ck_tile::index_t stride_v_cache_3, + ck_tile::index_t output_stride_0, + ck_tile::index_t output_stride_1, + const int32_t* block_tables_ptr, + const int32_t* seq_lens_ptr, + const int32_t* query_start_len_ptr, + ck_tile::index_t num_seqs, + ck_tile::index_t BLOCK_SIZE, + ck_tile::index_t BLOCK_Q + ) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + hdim_q, + hdim_v, + num_head_q, + num_queries_per_kv, + static_cast(scale_s * ck_tile::log2e_v<>), + scale, + scale_k, + scale_v, + scale_out, + total_num_q_blocks, + query_stride_0, + query_stride_1, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + output_stride_0, + output_stride_1}, + { + block_tables_ptr, + seq_lens_ptr, + query_start_len_ptr, + num_seqs, + BLOCK_SIZE, + BLOCK_Q, + }}; + + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads, + ck_tile::index_t total_num_q_blocks) + { + return dim3(num_kv_heads * total_num_q_blocks, 0, 0); + } + + // CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads, + // ck_tile::index_t total_num_q_blocks) + // { + // // TODO: fix 3D grid + // return dim2(num_kv_heads, total_num_q_blocks); + // } + + // Binary search to find the sequence index for a given target index + CK_TILE_DEVICE static constexpr ck_tile::index_t + find_seq_idx(const int32_t* query_start_len_ptr, + ck_tile::index_t target_idx, + ck_tile::index_t num_seqs, + ck_tile::index_t BLOCK_Q, + bool use_q_block_mode) + { + ck_tile::index_t left = 0; + ck_tile::index_t right = num_seqs; + + while (left < right) + { + ck_tile::index_t mid = (left + right) / 2; + ck_tile::index_t val = query_start_len_ptr[mid]; + ck_tile::index_t mid_val = use_q_block_mode ? (val / BLOCK_Q + mid) : val; + + if (mid_val <= target_idx) + { + left = mid + 1; + } + else + { + right = mid; + } + } + + return left - 1; + } + + CK_TILE_DEVICE static constexpr auto + RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs) + { + using namespace ck_tile; + + constexpr index_t NUM_XCDS = 8; + const index_t GRID_MN = kargs.unifiedAttentionCommonKargs.total_num_q_blocks * + (kargs.unifiedAttentionCommonKargs.num_head_q); + + // Number of pids per XCD in the new arrangement + const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; + + // When GRID_MN cannot divide NUM_XCDS, some xcds will have + // pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. + // We calculate the number of xcds that have pids_per_xcd pids as tall_xcds + index_t tall_xcds = GRID_MN % NUM_XCDS; + tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; + + // Compute current XCD and local pid within the XCD + const index_t xcd = pid % NUM_XCDS; + const index_t local_pid = pid / NUM_XCDS; + + // Calculate new pid based on the new grouping + index_t remapped_pid = 0; // Initialize to avoid constexpr error + if(xcd < tall_xcds) + { + remapped_pid = xcd * pids_per_xcd + local_pid; + } + else + { + remapped_pid = tall_xcds * pids_per_xcd + + (xcd - tall_xcds) * (pids_per_xcd - 1) + + local_pid; + } + + return remapped_pid; + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) + { + using namespace ck_tile; + + ck_tile::index_t total_num_q_blocks = kargs.unifiedAttentionCommonKargs.total_num_q_blocks; + // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, + // FmhaPipeline::kN1); + + const index_t i_tile_m = pid % total_num_q_blocks; // Query block index + const index_t i_tile_n = pid / total_num_q_blocks; // Head index + + return ck_tile::make_tuple(i_tile_m, i_tile_n); + } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + using namespace ck_tile; + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + ck_tile::index_t pid = blockIdx.x; + + pid = RemapTileIndices(pid, kargs); + + // divide problem + const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs); + + const index_t seq_idx = find_seq_idx( + kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true + ); // which batch + + const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); + + const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + + const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); + const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx + 1]); + + const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index; + + // TODO check if we get the block size info from pipeline + if (q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q >= cur_batch_query_len) { + return; + } + + const index_t query_pos = q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q; + + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + + static_cast(kv_head_idx) * kargs.unifiedAttentionCommonKargs.num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1 + + static_cast(cur_batch_in_all_start_index) * kargs.unifiedAttentionCommonKargs.query_stride_0; + // const KDataType* k_ptr = + // reinterpret_cast(kargs.k_ptr) + + // static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + // batch_offset_k; + // const VDataType* v_ptr = + // reinterpret_cast(kargs.v_ptr) + + // static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + // batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.unifiedAttentionVarlenKargs.), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto q_dram_window = make_tile_window( + q_dram, + make_tuple(number{}, number{}), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {0, i_n1}); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto o_acc_tile = [&]() { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lse_dram_window, + mask, + kargs.scale_s, + smem_ptr); + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp new file mode 100644 index 00000000000..b151b61028d --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -0,0 +1,1258 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" +#include "ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +#define ENABLE_ASM_MARKER 1 +#if ENABLE_ASM_MARKER +#define ASM_MARKER(marker) \ + __builtin_amdgcn_sched_barrier(0); \ + asm volatile("; [POYENC] " #marker); \ + __builtin_amdgcn_sched_barrier(0); +#else +#define ASM_MARKER(marker) +#endif + +#define ADD_SBARRIER_FOR_PHASE0 1 +#if !defined(CK_TILE_DISABLE_PACKED_FP32) +#define CK_TILE_DISABLE_PACKED_FP32 0 +#endif + +#define WARP_ID 0 +#define LANE_ID 0 + +#define ENABLE_DEBUG_STMTS 1 +#if ENABLE_DEBUG_STMTS +#define DEBUG_STMTS \ + if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID) +#else +#define DEBUG_STMTS if constexpr(false) +#endif + +namespace ck_tile { + +template +struct CoreLoopScheduler; + +template +struct CoreLoopScheduler +{ + template + CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, + ck_tile::number) + { + using namespace ck_tile; + + if constexpr(WaveGroup == 0) + { + if constexpr(Phase == 0) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 2) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + } + else + { + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 1) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 3) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + } + } +}; + +template +struct CoreLoopScheduler +{ + template + CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, + ck_tile::number) + { + using namespace ck_tile; + + if constexpr(WaveGroup == 0) + { + if constexpr(Phase == 0) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 2) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + } + else + { + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 1) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } + else if constexpr(Phase == 3) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + } + } +}; + +namespace detail { +CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c) +{ +#if CK_TILE_DISABLE_PACKED_FP32 + return a * b + c; +#else + float result; + asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]" + : [result] "=v"(result) + : [a] "v"(a), [b] "s"(b), [c] "v"(c)); + return result; +#endif +} + +CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs) +{ + float result; + asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + +CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs) +{ + float result; + asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + +CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b) +{ + fp16x2_t result; + asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]" + : [result] "=v"(result) + : [a] "v"(a), [b] "v"(b)); + return result; +} + +CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b) +{ + bf16x2_t result; + asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]" + : [result] "=v"(result) + : [a] "v"(a), [b] "v"(b)); + return result; +} + +CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) +{ + fp32x2_t result; + asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} +} // namespace detail + +template +struct UnifiedAttentionPipeline +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + + static_assert(std::is_same_v, + "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); + + using UnifiedAttentionShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; + + static constexpr ck_tile::index_t kM0 = UnifiedAttentionShape::kM0; + static constexpr ck_tile::index_t kN0 = UnifiedAttentionShape::kN0; + static constexpr ck_tile::index_t kK0 = UnifiedAttentionShape::kK0; + static constexpr ck_tile::index_t kN1 = UnifiedAttentionShape::kN1; + static constexpr ck_tile::index_t kK1 = UnifiedAttentionShape::kK1; + static constexpr ck_tile::index_t kQKHeaddim = UnifiedAttentionShape::kQKHeaddim; + static constexpr ck_tile::index_t kSubQKHeaddim = UnifiedAttentionShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr ck_tile::index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr ck_tile::index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr ck_tile::index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr ck_tile::index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr ck_tile::index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + return 2; + } + }(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // create another LDS buffer for p + return ck_tile::max(kM0 * kN1 * sizeof(PDataType), + Policy::template GetSmemSize() + + kM0 * kN0 * sizeof(PDataType)); + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc() + { + using namespace ck_tile; + constexpr auto lds_block_desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + return lds_block_desc; + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D() + { + using namespace ck_tile; + constexpr auto lds_block_desc = make_naive_tensor_descriptor( + make_tuple(number{}), make_tuple(number<1>{}), number<1>{}, number<1>{}); + + return lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc) + { + using namespace ck_tile; + + auto tensor_view = + make_tensor_view(reinterpret_cast(base), desc); + return make_tile_window(tensor_view, desc.get_lengths(), {0, 0}); + } + + // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7 + template + CK_TILE_DEVICE static constexpr void s_waitcnt() + { + // vmcnt use bits {[15:14],[3:0]} + // expcnt use bits [6:4] + // lgkmcnt use bits [11:8] + __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) | + ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8)); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt() + { + s_waitcnt(); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt() + { + s_waitcnt<63, Lgkmcnt>(); + } + + template + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + [[maybe_unused]] const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + [[maybe_unused]] const VElementFunction& v_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + [[maybe_unused]] const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); + auto s_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto s_lds_window = + make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); + + auto p_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto p_lds_window = + make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); + + auto o_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto o_lds_window = + make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); + + auto m_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc1D()); + [[maybe_unused]] auto m_lds_window = + make_tile_window(m_lds, make_tuple(number{}), {0}); + + const index_t warp_group_id = get_warp_id() / 4; + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); + + auto q_dram_window = make_tile_window_linear( + q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution()); + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + auto k_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + auto v_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + statically_indexed_array( + nullptr, + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution())), + 2> + k_lds_window_load; + + statically_indexed_array( + nullptr, + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution())), + 2> + v_lds_window_load; + + decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())) q_tile; + + union kv_tile_type + { + CK_TILE_DEVICE kv_tile_type() {} + + decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile; + + decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile; + } kv_tile; + + union sp_compute_type + { + CK_TILE_DEVICE sp_compute_type() {} + + decltype(gemm_0.MakeCBlockTile()) sp_compute; + decltype(make_static_distributed_tensor( + Policy::template MakePRegTileDistribution())) p; + }; + statically_indexed_array sp; + + decltype(gemm_1.MakeCBlockTile()) o_acc; + constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd() + // instructions should we move to fmha_alu1() + static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); + + decltype(block_tile_reduce( + sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m; + decltype(m) l; + + // initialize k_lds_window and v_lds_window + static_for<0, 2, 1>{}([&](auto idx) { + k_lds_window_load(idx) = make_tile_window( + make_lds_tile_window( + static_cast(smem_ptr) + (idx)*Policy::template GetSmemSizeKV(), + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution()); + }); + + static_for<0, 2, 1>{}([&](auto idx) { + v_lds_window_load(idx) = + make_tile_window(make_lds_tile_window( + static_cast(smem_ptr) + + (idx + 2) * Policy::template GetSmemSizeKV(), + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution()); + }); + + { + auto origin_q = load_tile(q_dram_window); + auto transformed_q = tile_elementwise_in(q_element_func, origin_q); + + q_tile = transformed_q; + } + + clear_tile(o_acc); + set_tile(m, bit_cast(0xff7fffff)); // a bit larger than -infinity + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + index_t kv_token_start = seqlen_k_start; + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + k_dram_window.init_raw(); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + v_dram_window.init_raw(); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + static_assert(1 == k0_loops); + static_assert(1 == k1_loops); + static_assert(kN0 == kK1); + + constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup; + static_assert(NumWarpGroups == 2); + + [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) { + printf("[POYENC] %s (size=%d): %5.2f", + name, + decltype(dist_tensor.thread_buf_)::size(), + ck_tile::type_convert(dist_tensor.thread_buf_[0])); + static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) { + printf(", %5.2f", ck_tile::type_convert(dist_tensor.thread_buf_[i])); + }); + printf("\n"); + }; + + [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) { + const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{}); + const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + if constexpr(true || num_rows < num_cols) + { + for(int row = 0; row < num_rows; ++row) + { + int offset = desc.calculate_offset(make_tuple(row, 0)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + row, + ck_tile::type_convert(data[offset])); + for(int col = 1; col < num_cols; ++col) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + else + { + for(int col = 0; col < num_cols; ++col) + { + int offset = desc.calculate_offset(make_tuple(0, col)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + col, + ck_tile::type_convert(data[offset])); + for(int row = 1; row < num_rows; ++row) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + }; + + [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) { + const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + int offset = desc.calculate_offset(make_tuple(0)); + printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert(data[offset])); + for(int e = 1; e < num_elems; ++e) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(e)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + }; + + // K_mem_su_ld_insts = 1 for 32 x 128 + // V_mem_su_ld_insts = 1 for 128 x 32 + constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access(); + constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); + + auto K_mem_load = [&](auto k_lds_write_idx) { + async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); + + /// FIXME: use the future-predicting method to move the window + // move K tile windows + move_tile_window(k_dram_window, {kN0, 0}); + }; + + auto K_lds_load = [&](auto k_lds_read_idx) { + kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx)); + }; + + auto V_mem_load = [&](auto v_lds_write_idx) { + async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); + + /// FIXME: use the future-predicting method to move the window + move_tile_window(v_dram_window, {kK1, 0}); + }; + + auto V_lds_load = [&](auto v_lds_read_idx) { + kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + }; + + decltype(m) m_old; + SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd() + /// TODO: remove the sp_delta and use sp_compute directly + statically_indexed_array{}).sp_compute), 2> sp_delta; + + auto fmha_alu0 = [&](auto sp_reg_idx) { + m_old = m; // m{j-1} + static_assert(m.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowmax value"); + auto m_latest = block_tile_reduce( + sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(m_latest.thread_buf_[0]), + bit_cast(m_latest.thread_buf_[0]), + false, + false); + /// TODO: eliminate 2 redudant v_max_f32 instructions generated by the compiler + m_latest.thread_buf_[0] = f_max(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(m_latest, f_max, bool_constant{}); +#endif + m = m_latest; + + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( + sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx)); + }); + }); + /// TODO: move some fmha_alu1() code here if necessary + }; + + auto fmha_alu1 = [&](auto sp_reg_idx) { + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp(sp_reg_idx).sp_compute(i_j_idx) = + ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx)); + }); + }); + + auto rowsum_p = block_tile_reduce( + sp(sp_reg_idx).sp_compute, + sequence<1>{}, + f_sum, + SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + static_assert(rowsum_p.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowsum value"); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(rowsum_p.thread_buf_[0]), + bit_cast(rowsum_p.thread_buf_[0]), + false, + false); + rowsum_p.thread_buf_[0] = f_sum(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#endif + + // l{j} + /// Note: The compiler keeps moving the following instructions elsewhere because 'l' + /// is first consumed later. To anchor them here, we rewrite the final addition in + /// inline assembly to create a dependency, forcing the dependent instructions to + /// be emitted at this point. + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx])); + + l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); + }); + + // update partial o_acc [0, fmha_alu_D_reg_cnt) + static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) { + o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale); + }); + + /// Note: The compiler keeps sinking the conversion instructions because the + /// result 'p' is only consumed later. To anchor them here, we rewrite + /// the cast_tile() call as inline assembly, forcing the conversions to be + /// emitted at this point. + static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0); + static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) { + float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]); + float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]); + if constexpr(std::is_same_v) + { + auto casted = detail::cvt_pk_fp16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + else + { + auto casted = detail::cvt_pk_bf16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + }); + + /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly + /// can interfere with the behavior of sched_group_barrier(), so ending the phase here + /// avoids unintended reordering. + }; + + auto gemm = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{})); + } + }; + + auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{})); + fmha_alu0(number<1>{} - sp_reg_idx); + } + }; + + auto fmha_alu_D_upd = [&] { + o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0])); + + fp32x2_t pk_o_acc_scale; + pk_o_acc_scale.x = o_acc_scale; + pk_o_acc_scale.y = o_acc_scale; + + static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0); +#if CK_TILE_DISABLE_PACKED_FP32 + static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size()); + static_for{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); +#endif + + constexpr auto issued_D_reg_cnt = +#if CK_TILE_DISABLE_PACKED_FP32 + fmha_alu_D_reg_cnt + 2 +#else + fmha_alu_D_reg_cnt +#endif + ; + /// NOTICE: Use inline asm v_pk_mul_f32 to reduce latency. The fmha_alu_D_upd() call + /// should be placed at the end of a phase. + // update partial o_acc after [issued_D_reg_cnt] + static_for{}([&](auto idx) { + fp32x2_t input; + input.x = o_acc.thread_buf_[idx]; + input.y = o_acc.thread_buf_[idx + 1]; + + auto output = detail::pk_mul_f32(input, pk_o_acc_scale); + + o_acc.thread_buf_[idx] = output.x; + o_acc.thread_buf_[idx + 1] = output.y; + }); + }; + + auto fmha_mask = [&](auto sp_reg_idx) { + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), kv_token_start, number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(sp(sp_reg_idx).sp_compute, + -numeric::infinity(), + [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = kv_token_start + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + }; + + auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) { + if constexpr(load_type == 0) + { + V_mem_load(mem_wr_idx); + K_lds_load(lds_rd_idx); + } + else + { + K_mem_load(mem_wr_idx); + V_lds_load(lds_rd_idx); + } + }; + + auto core_loop = [&](auto cl_p) { + auto gemm0 = number<0>{}; + auto gemm1 = number<1>{}; + + auto memV = number<0>{}; + auto memK = number<1>{}; + + using Scheduler = CoreLoopScheduler; + + auto iteration = [&](auto pi) { + auto xdl_SP_p01_reg_idx = number<1>{} - pi; + auto xdl_SP_p23_reg_idx = pi; + + auto K_w0_lds_wr_idx = number<1>{} - pi; + auto V_w0_lds_wr_idx = pi; + auto K_w0_lds_rd_idx = pi; + auto V_w0_lds_rd_idx = pi; + + auto K_w4_lds_wr_idx = number<1>{} - pi; + auto V_w4_lds_wr_idx = number<1>{} - pi; + auto K_w4_lds_rd_idx = number<1>{} - pi; + auto V_w4_lds_rd_idx = pi; + + bool result = true; + + if constexpr(cl_p == 0) + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave0-3 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave0-3 (pi=1)"); + } + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); + Scheduler::schedule(cl_p, number<1>{}); + fmha_mask(xdl_SP_p01_reg_idx); + + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave0-3"); + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 0"); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<2>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); + + Scheduler::schedule(cl_p, number<3>{}); + kv_token_start += kN0; + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + } + else + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave4-7 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave4-7 (pi=1)"); + } + cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<1>{}); + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave4-7"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx); + Scheduler::schedule(cl_p, number<2>{}); + fmha_mask(xdl_SP_p01_reg_idx); + + kv_token_start += kN0; + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<3>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + } + return result; + }; + return iteration(number<0>{}) && iteration(number<1>{}); + }; + + auto fmha_post_process = [&](auto d) { + auto ps_pi = number<1>{} - d; + auto V_lds_rd_idx = ps_pi; + + if(1 < num_total_loop) + { + s_waitcnt_vmcnt(); + } + else + { + s_waitcnt_vmcnt<0>(); + } + __builtin_amdgcn_s_barrier(); + + V_lds_load(V_lds_rd_idx); + fmha_alu1(ps_pi); + + s_waitcnt_lgkmcnt<0>(); + + auto xdl_SP_p23_reg_idx = ps_pi; + gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{}); + }; + + // pre-stage + { + ASM_MARKER("before pre-stage"); + // (1) load K0 to LDS & VGPR + K_mem_load(number<0>{}); // mem_K0 + + s_waitcnt_vmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + K_lds_load(number<0>{}); // lds_K0 + + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + // (2) prefetch K1 and V0 to LDS in parallel with GEMM0 + if(1 < num_total_loop) + { + K_mem_load(number<1>{}); // mem_K1 + } + V_mem_load(number<0>{}); // mem_V0 + + // (3) mfma (Q*K0) + softmax + gemm(number<0>{}, /*gemm_idx=*/number<0>{}); + + fmha_mask(number<0>{}); + /// TODO: find better way to map fmha_alu(0,96) call + fmha_alu0(number<0>{}); + fmha_alu_D_upd(); + + kv_token_start += kN0; + ++i_total_loops; + if(num_total_loop <= i_total_loops) + { + goto label_main_loops_exit; + } + + if(2 < num_total_loop) + { + K_mem_load(number<0>{}); // mem_K2 + + s_waitcnt_vmcnt(); + __builtin_amdgcn_s_barrier(); + } + + ASM_MARKER("end pre-stage"); + } + + if(1 < num_total_loop) + { + if(warp_group_id == 0) + { + V_mem_load(number<1>{}); // V1 + K_lds_load(number<1>{}); // K1 + + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<0>{})) + ; + } + if(warp_group_id != 0) + { + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<1>{})) + ; + } + } + label_main_loops_exit: + if(num_total_loop % 2) + { + fmha_post_process(number<1>{}); + } + if(!(num_total_loop % 2)) + { + fmha_post_process(number<0>{}); + } + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]); + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + scale_s, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp new file mode 100644 index 00000000000..bfbb1a93f04 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -0,0 +1,603 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" + +namespace ck_tile { + +struct UnifiedAttentionPipelineDefaultPolicy +{ + static constexpr ck_tile::index_t NumWarpPerGroup = 4; + static constexpr ck_tile::index_t NumThreadPerWarpGroup = + NumWarpPerGroup * ck_tile::get_warp_size(); + + // TODO: GetAlignment*() currently didn't consider if need padding or not + // so in pipeline still need check padding requirement + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentK() + { + using namespace ck_tile; + using KDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(KDataType); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentV() + { + using namespace ck_tile; + using VDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentK(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentV(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + // compute the endcoding before transpose + constexpr auto v_block_dstr = + make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(v_block_dstr_encode), + typename Problem::VDataType>::TransposedDstrEncode{}); + + return v_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr auto GetQKBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, + typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here + return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + /// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here + return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV2CustomPolicy; + + return BlockGemmARegBRegCRegV2{}; + } + + template + CK_TILE_DEVICE static constexpr auto GetPVBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::UnifiedAttentionShape::Gemm1BlockWarps, + typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>; + /// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass + /// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single + using WarpGemm = WarpGemmDispatcher{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV2CustomPolicy; + return BlockGemmARegBRegCRegV2{}; + } + + static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords + static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords + + template + CK_TILE_DEVICE static constexpr auto + MakeKLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = WarpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (WarpSize * KVector + kPad); + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return max(SingleKSize, SingleVSize); + } + + template + CK_TILE_DEVICE static constexpr auto + MakeVLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentV(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number<(IBuf + 2) * GetSingleSmemElementSpaceSize()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return v_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() + { + using namespace ck_tile; + + static_assert(MakeKLdsLoadBlockDescriptor().get_element_space_size() == + MakeKLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t k_element_space_size = + MakeKLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(MakeVLdsLoadBlockDescriptor().get_element_space_size() == + MakeVLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t v_element_space_size = + MakeVLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <= + GetSingleSmemElementSpaceSize()); + + /// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() & + /// MakeVLdsBlockDescriptor() + static_assert(std::is_same_v); + constexpr index_t kv_element_space_size_in_bytes = + GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + + return kv_element_space_size_in_bytes; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return 4 * GetSmemSizeKV(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp new file mode 100644 index 00000000000..45a1c8f4b87 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockFmhaPipelineEnum +{ + QRKSVS = 0, + QRKSVS_ASYNC, + QSKSVS, + QRKSVS_ASYNC_TRLOAD, +}; + +template +struct BlockFmhaPipelineEnumToStr; + +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_async"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qs"; +}; + +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_async_trload"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp new file mode 100644 index 00000000000..8c8ccc3bd27 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" + +namespace ck_tile { + +template +struct UnifiedAttentionPipelineProblem +{ + // TODO kM0 and KN1?? + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + // first gemm accumulation dtype + using SaccDataType = remove_cvref_t; + // Softmax dtype + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + // data type for A matrix of second gemm + using PDataType = remove_cvref_t; + // data type for second gemm accumulation + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using UnifiedAttentionShape = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps; + static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size(); + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; + static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kHasDropout = Traits::kHasDropout; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; +} \ No newline at end of file From 191f17903866bdec19ddda69869576f965604f6e Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 9 Oct 2025 08:47:19 +0000 Subject: [PATCH 02/88] unified attention rename --- ...ernel.hpp => unified_attention_kernel.hpp} | 23 +++++++++++++++---- ...ine.hpp => unified_attention_pipeline.hpp} | 0 ...ied_attention_pipeline_default_policy.hpp} | 0 ...pp => unified_attention_pipeline_enum.hpp} | 0 ...=> unified_attention_pipeline_problem.hpp} | 0 5 files changed, 19 insertions(+), 4 deletions(-) rename include/ck_tile/ops/unified_attention/kernel/{fmha_fwd_v3_kernel.hpp => unified_attention_kernel.hpp} (94%) rename include/ck_tile/ops/unified_attention/pipeline/{block_fmha_fwd_v3_pipeline.hpp => unified_attention_pipeline.hpp} (100%) rename include/ck_tile/ops/unified_attention/pipeline/{block_fmha_fwd_v3_pipeline_default_policy.hpp => unified_attention_pipeline_default_policy.hpp} (100%) rename include/ck_tile/ops/unified_attention/pipeline/{block_fmha_pipeline_enum.hpp => unified_attention_pipeline_enum.hpp} (100%) rename include/ck_tile/ops/unified_attention/pipeline/{block_fmha_pipeline_problem.hpp => unified_attention_pipeline_problem.hpp} (100%) diff --git a/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp similarity index 94% rename from include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp rename to include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 9d164b639ed..f49e560a963 100644 --- a/include/ck_tile/ops/unified_attention/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -84,6 +84,7 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_seqs; // number of batches for q ck_tile::index_t BLOCK_SIZE; // Block size for kv cache. to 2's exponent???? ck_tile::index_t BLOCK_Q; // Block size for kv cache. to 2's exponent???? + ck_tile::index_t BLOCK_M; // Block size for kv cache. to 2's exponent???? }; struct Kargs { @@ -125,7 +126,8 @@ struct FmhaFwdV3Kernel const int32_t* query_start_len_ptr, ck_tile::index_t num_seqs, ck_tile::index_t BLOCK_SIZE, - ck_tile::index_t BLOCK_Q + ck_tile::index_t BLOCK_Q, + ck_tile::index_t BLOCK_M ) { Kargs kargs{{q_ptr, @@ -161,6 +163,7 @@ struct FmhaFwdV3Kernel num_seqs, BLOCK_SIZE, BLOCK_Q, + BLOCK_M }}; return kargs; @@ -301,7 +304,17 @@ struct FmhaFwdV3Kernel } const index_t query_pos = q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q; + const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; + const index_t context_len = seq_len - cur_batch_query_len; + + + const index_t max_seq_prefix_len = ( + context_len + + q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q + + (kargs.unifiedAttentionVarlenKargs.BLOCK_M - 1) // num_queries_per_kv + + 1 + ); // for simplicity, batch stride we just modify the pointer const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + @@ -323,14 +336,16 @@ struct FmhaFwdV3Kernel const auto q_dram = [&]() { const auto q_dram_naive = make_naive_tensor_view( q_ptr, - make_tuple(kargs.seqlen_q, kargs.unifiedAttentionVarlenKargs.), - make_tuple(kargs.stride_q, 1), + make_tuple(seq_len, kargs.unifiedAttentionCommonKargs.HEAD_SIZE_PADDED), + make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, 1), number{}, number<1>{}); return pad_tensor_view( q_dram_naive, - make_tuple(number{}, number{}), + // block sizes + make_tuple(number{}, number{}), + // bool defining should we pad sequence{}); }(); const auto k_dram = [&]() { diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp similarity index 100% rename from include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline.hpp rename to include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp similarity index 100% rename from include/ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp rename to include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp similarity index 100% rename from include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_enum.hpp rename to include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp diff --git a/include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp similarity index 100% rename from include/ck_tile/ops/unified_attention/pipeline/block_fmha_pipeline_problem.hpp rename to include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp From 436eb3a4f8b90bbb8baf1fcf48afbf2616a95ba9 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Fri, 10 Oct 2025 12:08:16 +0000 Subject: [PATCH 03/88] transform q tensor view --- .../kernel/unified_attention_kernel.hpp | 110 ++++++++++++------ 1 file changed, 76 insertions(+), 34 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index f49e560a963..6973e36d551 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -285,6 +285,11 @@ struct FmhaFwdV3Kernel // divide problem const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs); + // grid size is (num_kv_heads, total_num_q_blocks) + // total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + // q.shape[0] is total number of query tokens across all batches + // one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. One query token group shares one kv token + const index_t seq_idx = find_seq_idx( kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true ); // which batch @@ -316,38 +321,77 @@ struct FmhaFwdV3Kernel + 1 ); - // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + - static_cast(kv_head_idx) * kargs.unifiedAttentionCommonKargs.num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1 + - static_cast(cur_batch_in_all_start_index) * kargs.unifiedAttentionCommonKargs.query_stride_0; - // const KDataType* k_ptr = - // reinterpret_cast(kargs.k_ptr) + - // static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - // batch_offset_k; - // const VDataType* v_ptr = - // reinterpret_cast(kargs.v_ptr) + - // static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + - // batch_offset_v; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_o + - batch_offset_o; - + const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr) + const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr) + ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr) + // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { - const auto q_dram_naive = make_naive_tensor_view( + const index_t qheads = kargs.unifiedAttentionCommonKargs.num_head_q; + const index_t kheads = kargs.unifiedAttentionCommonKargs.num_head_k; + const index_t qheads_per_kv = qheads / kheads; + const index_t BLOCK_Q = kargs.unifiedAttentionVarlenKargs.BLOCK_Q; // = BLOCK_M / qheads_per_kv + const index_t BLOCK_D = kargs.unifiedAttentionVarlenKargs.BLOCK_D; // BLOCK_SIZE along head dim + const index_t D = kargs.unifiedAttentionCommonKargs.HEAD_SIZE; // head dim + const index_t cu_seqlens = kPadSeqLenQ; + + + const auto q_dram_base = make_naive_tensor_view( q_ptr, - make_tuple(seq_len, kargs.unifiedAttentionCommonKargs.HEAD_SIZE_PADDED), - make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, 1), + make_tuple(cu_seqlens, qheads, D), + make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1), number{}, number<1>{}); - return pad_tensor_view( - q_dram_naive, + const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with BLOCK_D + q_dram_base, // block sizes - make_tuple(number{}, number{}), - // bool defining should we pad - sequence{}); + make_tuple(BLOCK_Q, 1, BLOCK_D), + sequence{} + ); + + const auto q_dram_unmerged = transform_tensor_view( + q_dram_pad, + make_tuple( + make_pass_through_transform(kPadSeqLenQ), + make_unmerge_transform(make_tuple(qheads / kheads, kheads)), + make_pass_through_transform(BLOCK_D) + ), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}) + ); + + const auto q_dram_permuted = transform_tensor_view( + q_dram_unmerged, + make_tuple( + make_pass_through_transform(qheads / kheads), + make_pass_through_transform(kPadSeqLenQ), + make_pass_through_transform(kheads), + make_pass_through_transform(BLOCK_D) + ), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}, sequence<3>{}) + ); + const auto q_dram_merged = transform_tensor_view( + q_dram_permuted, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, kPadSeqLenQ, kheads) + ), + make_pass_through_transform(BLOCK_D) + ), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}) + ); + return q_dram_merged; }(); + auto q_dram_window = make_tile_window( + q_dram, + make_tuple(number{}, number{}), + {kv_head_idx * kPadSeqLenQ + q_block_global_idx*BLOCK_Q, 0} + ); + const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, @@ -361,6 +405,8 @@ struct FmhaFwdV3Kernel make_tuple(number{}, number{}), sequence{}); }(); + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, @@ -374,19 +420,15 @@ struct FmhaFwdV3Kernel make_tuple(number{}, number{}), sequence{}); }(); + auto v_dram_window = make_tile_window( + v_dram, make_tuple(number{}, number{}), + {0, i_n1}); + - auto q_dram_window = make_tile_window( - q_dram, - make_tuple(number{}, number{}), - {i_m0, 0}); - auto k_dram_window = make_tile_window( - k_dram, make_tuple(number{}, number{}), {0, 0}); + - auto v_dram_window = - make_tile_window(v_dram, - make_tuple(number{}, number{}), - {0, i_n1}); + // lse auto lse_dram_window = [&, i_nhead_ = i_nhead]() { From df604932194f8f680e7e3092c15883e93fc4a9a4 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 10 Oct 2025 13:25:19 +0000 Subject: [PATCH 04/88] refactor --- .../kernel/unified_attention_kernel.hpp | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 6973e36d551..f3170f69ccb 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -7,8 +7,10 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" +#include #include #include +#include namespace ck_tile { @@ -33,6 +35,18 @@ struct FmhaFwdV3Kernel static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + // TODO add yjese + static constexpr index_t HEAD_SIZE = FmhaPipeline::HEAD_SIZE; + static constexpr index_t HEAD_SIZE_PADDED = FmhaPipeline::HEAD_SIZE_PADDED; + + // BLOCK_Q = BLOCK_M // num_queries_per_kv + // BLOCK_Q is the block size for q seqlen + static constexpr index_t BLOCK_Q = FmhaPipeline::BLOCK_Q; + // static constexpr index_t BLOCK_M = FmhaPipeline::BLOCK_M; + // BLOCK size for K seqlen + static constexpr index_t BLOCK_SIZE = FmhaPipeline::BLOCK_SIZE; + + // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. @@ -50,7 +64,7 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_head_q; // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case - ck_tile::index_t num_queries_per_kv; + const ck_tile::index_t num_queries_per_kv; // scales float scale_s; float scale; @@ -71,7 +85,6 @@ struct FmhaFwdV3Kernel ck_tile::index_t stride_v_cache_3; ck_tile::index_t output_stride_0; ck_tile::index_t output_stride_1; - ck_tile::index_t HEAD_SIZE_PADDED; }; @@ -82,9 +95,6 @@ struct FmhaFwdV3Kernel const int32_t* query_start_len_ptr; // [num_seqs+1] ck_tile::index_t num_seqs; // number of batches for q - ck_tile::index_t BLOCK_SIZE; // Block size for kv cache. to 2's exponent???? - ck_tile::index_t BLOCK_Q; // Block size for kv cache. to 2's exponent???? - ck_tile::index_t BLOCK_M; // Block size for kv cache. to 2's exponent???? }; struct Kargs { @@ -102,7 +112,7 @@ struct FmhaFwdV3Kernel ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, - ck_tile::index_t num_queries_per_kv, + const ck_tile::index_t num_queries_per_kv, float scale_s, float scale, float scale_k, @@ -125,9 +135,6 @@ struct FmhaFwdV3Kernel const int32_t* seq_lens_ptr, const int32_t* query_start_len_ptr, ck_tile::index_t num_seqs, - ck_tile::index_t BLOCK_SIZE, - ck_tile::index_t BLOCK_Q, - ck_tile::index_t BLOCK_M ) { Kargs kargs{{q_ptr, @@ -160,10 +167,7 @@ struct FmhaFwdV3Kernel block_tables_ptr, seq_lens_ptr, query_start_len_ptr, - num_seqs, - BLOCK_SIZE, - BLOCK_Q, - BLOCK_M + num_seqs }}; return kargs; @@ -279,6 +283,9 @@ struct FmhaFwdV3Kernel __shared__ char smem_ptr[GetSmemSize()]; ck_tile::index_t pid = blockIdx.x; + index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; + + const index_t BLOCK_M = BLOCK_Q * kargs.unifiedAttentionCommonKargs.num_queries_per_kv; pid = RemapTileIndices(pid, kargs); @@ -304,11 +311,11 @@ struct FmhaFwdV3Kernel const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index; // TODO check if we get the block size info from pipeline - if (q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q >= cur_batch_query_len) { + if (q_block_local_idx * BLOCK_Q >= cur_batch_query_len) { return; } - const index_t query_pos = q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q; + const index_t query_pos = q_block_local_idx * BLOCK_Q; const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; const index_t context_len = seq_len - cur_batch_query_len; @@ -316,11 +323,15 @@ struct FmhaFwdV3Kernel const index_t max_seq_prefix_len = ( context_len - + q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q - + (kargs.unifiedAttentionVarlenKargs.BLOCK_M - 1) // num_queries_per_kv + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + 1 ); + // for simplicity, batch stride we just modify the pointer + index_t num_head_q = kargs.unifiedAttentionCommonKargs.num_head_q; + + // Q/K/V DRAM and DRAM window const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr) const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr) @@ -331,8 +342,6 @@ struct FmhaFwdV3Kernel const index_t qheads = kargs.unifiedAttentionCommonKargs.num_head_q; const index_t kheads = kargs.unifiedAttentionCommonKargs.num_head_k; const index_t qheads_per_kv = qheads / kheads; - const index_t BLOCK_Q = kargs.unifiedAttentionVarlenKargs.BLOCK_Q; // = BLOCK_M / qheads_per_kv - const index_t BLOCK_D = kargs.unifiedAttentionVarlenKargs.BLOCK_D; // BLOCK_SIZE along head dim const index_t D = kargs.unifiedAttentionCommonKargs.HEAD_SIZE; // head dim const index_t cu_seqlens = kPadSeqLenQ; From 1f4648dab5567922184c80c07004b254598cf055 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 10 Oct 2025 15:27:36 +0000 Subject: [PATCH 05/88] refactor. and fixed q transformation --- .../kernel/unified_attention_kernel.hpp | 111 +++++------------- 1 file changed, 30 insertions(+), 81 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index f3170f69ccb..67d6372c31f 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -273,8 +273,6 @@ struct FmhaFwdV3Kernel return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - - CK_TILE_DEVICE void operator()(Kargs kargs) const { using namespace ck_tile; @@ -320,7 +318,6 @@ struct FmhaFwdV3Kernel const index_t context_len = seq_len - cur_batch_query_len; - const index_t max_seq_prefix_len = ( context_len + q_block_local_idx * BLOCK_Q @@ -330,42 +327,29 @@ struct FmhaFwdV3Kernel // for simplicity, batch stride we just modify the pointer index_t num_head_q = kargs.unifiedAttentionCommonKargs.num_head_q; + index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; // Q/K/V DRAM and DRAM window - const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) - const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr) - const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr) - ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr) + const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr); + const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr); + const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr); + ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { - const index_t qheads = kargs.unifiedAttentionCommonKargs.num_head_q; - const index_t kheads = kargs.unifiedAttentionCommonKargs.num_head_k; - const index_t qheads_per_kv = qheads / kheads; - const index_t D = kargs.unifiedAttentionCommonKargs.HEAD_SIZE; // head dim - const index_t cu_seqlens = kPadSeqLenQ; - - const auto q_dram_base = make_naive_tensor_view( q_ptr, - make_tuple(cu_seqlens, qheads, D), + make_tuple(seq_len, num_head_q, HEAD_SIZE), make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1), number{}, number<1>{}); - const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with BLOCK_D - q_dram_base, - // block sizes - make_tuple(BLOCK_Q, 1, BLOCK_D), - sequence{} - ); - const auto q_dram_unmerged = transform_tensor_view( - q_dram_pad, + q_dram_base, make_tuple( - make_pass_through_transform(kPadSeqLenQ), - make_unmerge_transform(make_tuple(qheads / kheads, kheads)), - make_pass_through_transform(BLOCK_D) + make_pass_through_transform(seq_len), + make_unmerge_transform(make_tuple(num_head_q / num_queries_per_kv, num_queries_per_kv)), + make_pass_through_transform(HEAD_SIZE_PADDED) ), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}) @@ -374,10 +358,10 @@ struct FmhaFwdV3Kernel const auto q_dram_permuted = transform_tensor_view( q_dram_unmerged, make_tuple( - make_pass_through_transform(qheads / kheads), - make_pass_through_transform(kPadSeqLenQ), - make_pass_through_transform(kheads), - make_pass_through_transform(BLOCK_D) + make_pass_through_transform(num_head_q / num_queries_per_kv), + make_pass_through_transform(seq_len), + make_pass_through_transform(num_queries_per_kv), + make_pass_through_transform(HEAD_SIZE_PADDED) ), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}, sequence<3>{}) @@ -386,19 +370,30 @@ struct FmhaFwdV3Kernel q_dram_permuted, make_tuple( make_merge_transform_v3_division_mod( - make_tuple(number{}, kPadSeqLenQ, kheads) + make_tuple(num_head_q / num_queries_per_kv, seq_len, num_queries_per_kv) ), - make_pass_through_transform(BLOCK_D) + make_pass_through_transform(HEAD_SIZE_PADDED) ), make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), make_tuple(sequence<0>{}, sequence<1>{}) ); - return q_dram_merged; + + const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + q_dram_merged, + // block sizes + make_tuple(BLOCK_Q, HEAD_SIZE_PADDED), + sequence{} + ); + + return q_dram_pad; }(); + + // Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim) + // stride for dim 0 (num_queries_per_kv * seq_len, num_queries_per_kv, 1) auto q_dram_window = make_tile_window( q_dram, - make_tuple(number{}, number{}), - {kv_head_idx * kPadSeqLenQ + q_block_global_idx*BLOCK_Q, 0} + make_tuple(BLOCK_Q, HEAD_SIZE_PADDED), + {kv_head_idx * seq_len * num_queries_per_kv + q_block_global_idx * num_queries_per_kv, 0} ); const auto k_dram = [&]() { @@ -434,52 +429,6 @@ struct FmhaFwdV3Kernel {0, i_n1}); - - - - - - // lse - auto lse_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - if constexpr(kStoreLSE) - { - LSEDataType* lse_ptr = - reinterpret_cast(kargs.lse_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; - - const auto lse_dram = [&]() { - const auto lse_dram_naive = make_naive_tensor_view( - lse_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - - return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); - } - else - { - return make_null_tile_window(lse_dram_window_lengths); - } - }(); - - FmhaMask mask = [&]() { - if constexpr(kHasMask) - return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); - else - return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; - }(); - auto o_acc_tile = [&]() { return FmhaPipeline{}(q_dram_window, k_dram_window, From bc6385f389abbd5b744cf4457df5b1f262263aac Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 13 Oct 2025 10:01:38 +0000 Subject: [PATCH 06/88] Some refactor --- .../kernel/unified_attention_kernel.hpp | 62 +++++++++---------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 67d6372c31f..bb38df6b264 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -58,9 +58,7 @@ struct FmhaFwdV3Kernel const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size] void* o_ptr; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - + ck_tile::index_t num_blks; ck_tile::index_t num_head_q; // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case @@ -88,7 +86,7 @@ struct FmhaFwdV3Kernel }; - struct UnifiedAttentionVarlenKargs + struct UnifiedAttentionVarlenKargs: UnifiedAttentionCommonKargs { const int32_t* block_tables_ptr; const int32_t* seq_lens_ptr; // seq len in each batch @@ -97,20 +95,15 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_seqs; // number of batches for q }; - struct Kargs { - UnifiedAttentionCommonKargs unifiedAttentionCommonKargs; - UnifiedAttentionVarlenKargs unifiedAttentionVarlenKargs; - }; - // using Kargs = FmhaFwdGroupModeKargs; + using Kargs = UnifiedAttentionVarlenKargs; CK_TILE_HOST static constexpr Kargs MakeKargs( const void* q_ptr, const void* k_ptr, const void* v_ptr, void* o_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, + ck_tile::index_t num_blks, ck_tile::index_t num_head_q, const ck_tile::index_t num_queries_per_kv, float scale_s, @@ -134,15 +127,14 @@ struct FmhaFwdV3Kernel const int32_t* block_tables_ptr, const int32_t* seq_lens_ptr, const int32_t* query_start_len_ptr, - ck_tile::index_t num_seqs, + ck_tile::index_t num_seqs ) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, - hdim_q, - hdim_v, + num_blks, num_head_q, num_queries_per_kv, static_cast(scale_s * ck_tile::log2e_v<>), @@ -221,9 +213,13 @@ struct FmhaFwdV3Kernel { using namespace ck_tile; + const index_t num_head_q = kargs.num_head_q; + const index_t num_queries_per_kv = kargs.num_queries_per_kv; + const index_t num_head_k = kargs.num_queries_per_kv; + constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.unifiedAttentionCommonKargs.total_num_q_blocks * - (kargs.unifiedAttentionCommonKargs.num_head_q); + const index_t GRID_MN = kargs.total_num_q_blocks * + (kargs.num_head_q); // Number of pids per XCD in the new arrangement const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; @@ -258,7 +254,7 @@ struct FmhaFwdV3Kernel { using namespace ck_tile; - ck_tile::index_t total_num_q_blocks = kargs.unifiedAttentionCommonKargs.total_num_q_blocks; + ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks; // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, // FmhaPipeline::kN1); @@ -281,9 +277,9 @@ struct FmhaFwdV3Kernel __shared__ char smem_ptr[GetSmemSize()]; ck_tile::index_t pid = blockIdx.x; - index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; + index_t num_queries_per_kv = kargs.num_queries_per_kv; - const index_t BLOCK_M = BLOCK_Q * kargs.unifiedAttentionCommonKargs.num_queries_per_kv; + const index_t BLOCK_M = BLOCK_Q * kargs.num_queries_per_kv; pid = RemapTileIndices(pid, kargs); @@ -296,15 +292,15 @@ struct FmhaFwdV3Kernel // one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. One query token group shares one kv token const index_t seq_idx = find_seq_idx( - kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true + kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true ); // which batch - const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); + const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); - const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]); - const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx + 1]); + const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index; @@ -314,7 +310,7 @@ struct FmhaFwdV3Kernel } const index_t query_pos = q_block_local_idx * BLOCK_Q; - const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; + const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; const index_t context_len = seq_len - cur_batch_query_len; @@ -326,21 +322,22 @@ struct FmhaFwdV3Kernel ); // for simplicity, batch stride we just modify the pointer - index_t num_head_q = kargs.unifiedAttentionCommonKargs.num_head_q; - index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; + const index_t num_head_q = kargs.num_head_q; + const index_t num_queries_per_kv = kargs.num_queries_per_kv; + const index_t num_head_k = num_head_q / num_queries_per_kv; // Q/K/V DRAM and DRAM window - const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr); - const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr); - const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr); - ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr); + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr); + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr); + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr); + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { const auto q_dram_base = make_naive_tensor_view( q_ptr, make_tuple(seq_len, num_head_q, HEAD_SIZE), - make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1), + make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, number<1>{}); @@ -378,6 +375,7 @@ struct FmhaFwdV3Kernel make_tuple(sequence<0>{}, sequence<1>{}) ); + // TODO are we padding the tensor view or the block here? const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED q_dram_merged, // block sizes @@ -399,7 +397,7 @@ struct FmhaFwdV3Kernel const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(num_b, BLOCK_SIZE, num_head_k, HEAD_SIZE), make_tuple(kargs.stride_k, 1), number{}, number<1>{}); From 36a65b19687444c3b55d1b4cdcf2eb61d7f37198 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 13 Oct 2025 10:05:23 +0000 Subject: [PATCH 07/88] refactor --- .../kernel/unified_attention_kernel.hpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index bb38df6b264..eef8ca4f790 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -212,10 +212,6 @@ struct FmhaFwdV3Kernel RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs) { using namespace ck_tile; - - const index_t num_head_q = kargs.num_head_q; - const index_t num_queries_per_kv = kargs.num_queries_per_kv; - const index_t num_head_k = kargs.num_queries_per_kv; constexpr index_t NUM_XCDS = 8; const index_t GRID_MN = kargs.total_num_q_blocks * @@ -277,9 +273,12 @@ struct FmhaFwdV3Kernel __shared__ char smem_ptr[GetSmemSize()]; ck_tile::index_t pid = blockIdx.x; - index_t num_queries_per_kv = kargs.num_queries_per_kv; const index_t BLOCK_M = BLOCK_Q * kargs.num_queries_per_kv; + // for simplicity, batch stride we just modify the pointer + const index_t num_head_q = kargs.num_head_q; + const index_t num_queries_per_kv = kargs.num_queries_per_kv; + const index_t num_head_k = num_head_q / num_queries_per_kv; pid = RemapTileIndices(pid, kargs); @@ -321,11 +320,6 @@ struct FmhaFwdV3Kernel + 1 ); - // for simplicity, batch stride we just modify the pointer - const index_t num_head_q = kargs.num_head_q; - const index_t num_queries_per_kv = kargs.num_queries_per_kv; - const index_t num_head_k = num_head_q / num_queries_per_kv; - // Q/K/V DRAM and DRAM window const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr); const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr); From 2d6dab29ebae0ab5564efe734ac3279a7f908514 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 13 Oct 2025 10:18:23 +0000 Subject: [PATCH 08/88] refactor the q tensor view transformation --- .../kernel/unified_attention_kernel.hpp | 58 +++++++------------ 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 67d6372c31f..31bf24fa31c 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/core/numeric/math.hpp" #include #include @@ -314,7 +315,7 @@ struct FmhaFwdV3Kernel } const index_t query_pos = q_block_local_idx * BLOCK_Q; - const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; + const index_t seq_len = kargs.unifiedAttentionVarlenKargs.seq_lens_ptr[seq_idx]; // should be cu_seqlens_q rather const index_t context_len = seq_len - cur_batch_query_len; @@ -330,62 +331,47 @@ struct FmhaFwdV3Kernel index_t num_queries_per_kv = kargs.unifiedAttentionCommonKargs.num_queries_per_kv; // Q/K/V DRAM and DRAM window - const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr); + index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start + index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start + index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; + const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + q_ptr_offset; const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr); const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr); ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr); + + index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q; + bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0); + // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { const auto q_dram_base = make_naive_tensor_view( q_ptr, - make_tuple(seq_len, num_head_q, HEAD_SIZE), + make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1), number{}, number<1>{}); - const auto q_dram_unmerged = transform_tensor_view( - q_dram_base, - make_tuple( - make_pass_through_transform(seq_len), - make_unmerge_transform(make_tuple(num_head_q / num_queries_per_kv, num_queries_per_kv)), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}) - ); + const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + q_dram_base, + // block sizes + make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), + sequence{} + ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) - const auto q_dram_permuted = transform_tensor_view( - q_dram_unmerged, - make_tuple( - make_pass_through_transform(num_head_q / num_queries_per_kv), - make_pass_through_transform(seq_len), - make_pass_through_transform(num_queries_per_kv), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}, sequence<3>{}) - ); const auto q_dram_merged = transform_tensor_view( - q_dram_permuted, + q_dram_pad, make_tuple( - make_merge_transform_v3_division_mod( - make_tuple(num_head_q / num_queries_per_kv, seq_len, num_queries_per_kv) + make_merge_transform( + make_tuple(seq_len, num_queries_per_kv) ), make_pass_through_transform(HEAD_SIZE_PADDED) ), - make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}) ); - const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED - q_dram_merged, - // block sizes - make_tuple(BLOCK_Q, HEAD_SIZE_PADDED), - sequence{} - ); - - return q_dram_pad; + return q_dram_merged; }(); // Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim) From af94aaf1cbb1b4bbd2145a02112ee4c695ff5e45 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 13 Oct 2025 10:22:52 +0000 Subject: [PATCH 09/88] refactor the q tensor view transformation --- .../kernel/unified_attention_kernel.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ffceec8aa2a..ea1c4f3bf0b 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -322,13 +322,13 @@ struct FmhaFwdV3Kernel ); // Q/K/V DRAM and DRAM window - index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.unifiedAttentionCommonKargs.query_stride_0; // move the pointer to the batch start - index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1; // move the pointer to the correct head group start + index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start + index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; - const QDataType* q_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.q_ptr) + q_ptr_offset; - const KDataType* k_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.k_ptr); - const VDataType* v_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.v_ptr); - ODataType* o_ptr = reinterpret_cast(kargs.unifiedAttentionCommonKargs.o_ptr); + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr); + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr); + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr); index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q; @@ -339,7 +339,7 @@ struct FmhaFwdV3Kernel const auto q_dram_base = make_naive_tensor_view( q_ptr, make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), - make_tuple(kargs.unifiedAttentionCommonKargs.query_stride_0, kargs.unifiedAttentionCommonKargs.query_stride_1, 1), + make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, number<1>{}); From 55fc6d71518b9ca73d0dee800555e6c967715cdd Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 13 Oct 2025 10:28:02 +0000 Subject: [PATCH 10/88] kv tensor view --- .../kernel/unified_attention_kernel.hpp | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index eef8ca4f790..5f00e2cc921 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -320,10 +320,13 @@ struct FmhaFwdV3Kernel + 1 ); + + index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2; + // Q/K/V DRAM and DRAM window const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr); - const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr); - const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr); + const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + kv_head_offset; + const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr); // Q/K/V DRAM and DRAM window @@ -391,30 +394,32 @@ struct FmhaFwdV3Kernel const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(num_b, BLOCK_SIZE, num_head_k, HEAD_SIZE), - make_tuple(kargs.stride_k, 1), + make_tuple(kargs.num_blks, BLOCK_SIZE, num_head_k, HEAD_SIZE), + make_tuple(kargs.stride_k_cache_3, kargs.stride_k_cache_2, kargs.stride_k_cache_1, kargs.stride_k_cache_0), number{}, number<1>{}); return pad_tensor_view( k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + // TODO can the BLOCK_SIZE_RAW needs padding? + make_tuple(1, BLOCK_SIZE, 1, HEAD_SIZE_PADDED), + sequence{}); }(); - auto k_dram_window = make_tile_window( - k_dram, make_tuple(number{}, number{}), {0, 0}); + + // auto k_dram_window = make_tile_window( + // k_dram, make_tuple(number{}, number{}), {0, 0}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), + make_tuple(kargs.num_blks, BLOCK_SIZE, num_head_k, HEAD_SIZE), + make_tuple(kargs.stride_v_cache_3, kargs.stride_v_cache_2, kargs.stride_v_cache_1, kargs.stride_v_cache_0), number{}, number<1>{}); return pad_tensor_view( v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + make_tuple(1, BLOCK_SIZE, 1, HEAD_SIZE_PADDED), + sequence{}); }(); auto v_dram_window = make_tile_window( v_dram, make_tuple(number{}, number{}), From 16129a794aafc989319ed7561ad169115865c097 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 13 Oct 2025 10:30:08 +0000 Subject: [PATCH 11/88] stride fix --- .../ops/unified_attention/kernel/unified_attention_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 3313f3a9ec7..5d194e21677 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -381,7 +381,7 @@ struct FmhaFwdV3Kernel const auto k_dram_naive = make_naive_tensor_view( k_ptr, make_tuple(kargs.num_blks, BLOCK_SIZE, num_head_k, HEAD_SIZE), - make_tuple(kargs.stride_k_cache_3, kargs.stride_k_cache_2, kargs.stride_k_cache_1, kargs.stride_k_cache_0), + make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_2, kargs.stride_k_cache_3), number{}, number<1>{}); @@ -398,7 +398,7 @@ struct FmhaFwdV3Kernel const auto v_dram_naive = make_naive_tensor_view( v_ptr, make_tuple(kargs.num_blks, BLOCK_SIZE, num_head_k, HEAD_SIZE), - make_tuple(kargs.stride_v_cache_3, kargs.stride_v_cache_2, kargs.stride_v_cache_1, kargs.stride_v_cache_0), + make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_2, kargs.stride_v_cache_3), number{}, number<1>{}); From b721f79f994d620d7adc25ebcf69f4e0a10dc631 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 13 Oct 2025 10:30:11 +0000 Subject: [PATCH 12/88] fix --- .../unified_attention/kernel/unified_attention_kernel.hpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ea1c4f3bf0b..5765fd858c6 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -354,15 +354,13 @@ struct FmhaFwdV3Kernel q_dram_pad, make_tuple( make_merge_transform( - make_tuple(seq_len, num_queries_per_kv) + make_tuple(seq_len_padded, num_queries_per_kv) ), make_pass_through_transform(HEAD_SIZE_PADDED) ), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}) - ); - - // TODO are we padding the tensor view or the block here? + ); // flattens the first two dims, head dim is the fastest changing dim in the merged dim return q_dram_merged; }(); From 6ba25b7e8492fd5fa88e91d29a11135490af8f3d Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 13 Oct 2025 10:34:55 +0000 Subject: [PATCH 13/88] add commenting --- .../ops/unified_attention/kernel/unified_attention_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ca71e36c3b8..9955ca24b38 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -346,7 +346,7 @@ struct FmhaFwdV3Kernel number{}, number<1>{}); - const auto q_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED q_dram_base, // block sizes make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), @@ -363,7 +363,7 @@ struct FmhaFwdV3Kernel ), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}) - ); // flattens the first two dims, head dim is the fastest changing dim in the merged dim + ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim return q_dram_merged; }(); From be58d51d36ad28bee005b22fd7d65f7e159fe424 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 13 Oct 2025 11:32:28 +0000 Subject: [PATCH 14/88] o ptr and window --- .../kernel/unified_attention_kernel.hpp | 47 ++++++++++++++----- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 5d194e21677..423aea9054a 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -328,14 +328,19 @@ struct FmhaFwdV3Kernel index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; + + index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start + index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start + index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1; + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + kv_head_offset; const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr); + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q; - bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0); + const bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { @@ -370,11 +375,11 @@ struct FmhaFwdV3Kernel }(); // Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim) - // stride for dim 0 (num_queries_per_kv * seq_len, num_queries_per_kv, 1) + // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) auto q_dram_window = make_tile_window( q_dram, make_tuple(BLOCK_Q, HEAD_SIZE_PADDED), - {kv_head_idx * seq_len * num_queries_per_kv + q_block_global_idx * num_queries_per_kv, 0} + {q_block_global_idx * num_queries_per_kv * HEAD_SIZE_PADDED, 0} ); const auto k_dram = [&]() { @@ -424,23 +429,39 @@ struct FmhaFwdV3Kernel // O DRAM and O DRAM window auto o_dram = [&]() { - const auto o_dram_naive = make_naive_tensor_view( + const auto o_dram_base = make_naive_tensor_view( o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), + make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), + make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1), number{}, number<1>{}); - return pad_tensor_view( - o_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + const auto o_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + o_dram_base, + // block sizes + make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), + sequence{} + ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) + + const auto o_dram_merged = transform_tensor_view( + o_dram_pad, + make_tuple( + make_merge_transform( + make_tuple(seq_len, num_queries_per_kv) + ), + make_pass_through_transform(HEAD_SIZE_PADDED) + ), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}) + ); + + return o_dram_merged; }(); auto o_dram_window = make_tile_window(o_dram, - make_tuple(number{}, number{}), - {i_m0, i_n1}); + make_tuple(BLOCK_M, HEAD_SIZE_PADDED), + {q_block_global_idx * num_queries_per_kv * HEAD_SIZE_PADDED, 0}); EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } From 6a7fa959b74ed085bfe057f6251723081a8a24c5 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 13 Oct 2025 12:53:43 +0000 Subject: [PATCH 15/88] kv tensor view and initial window --- .../kernel/unified_attention_kernel.hpp | 62 ++++++++++++++----- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 18bdfe184b8..e8cd551417c 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -381,40 +381,70 @@ struct FmhaFwdV3Kernel ); const auto k_dram = [&]() { + // HEAD dim is skipped as defined in the ptrs const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.num_blks, BLOCK_SIZE, num_head_k, HEAD_SIZE), - make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_2, kargs.stride_k_cache_3), + make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_3), number{}, number<1>{}); - return pad_tensor_view( + const auto k_dram_pad = pad_tensor_view( k_dram_naive, // TODO can the BLOCK_SIZE_RAW needs padding? - make_tuple(1, BLOCK_SIZE, 1, HEAD_SIZE_PADDED), - sequence{}); + make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); + + + const auto k_dram_merged = transform_tensor_view( + k_dram_pad, + make_tuple( + make_merge_transform( + make_tuple(kargs.num_blks, BLOCK_SIZE) + ), + make_pass_through_transform(HEAD_SIZE_PADDED) + ), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}) + ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim + + return k_dram_merged; }(); - // auto k_dram_window = make_tile_window( - // k_dram, make_tuple(number{}, number{}), {0, 0}); - const auto v_dram = [&]() { + auto k_dram_window = make_tile_window( + k_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0}); + + const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.num_blks, BLOCK_SIZE, num_head_k, HEAD_SIZE), - make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_2, kargs.stride_v_cache_3), + make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_3), number{}, number<1>{}); - return pad_tensor_view( + const auto v_dram_pad = pad_tensor_view( v_dram_naive, - make_tuple(1, BLOCK_SIZE, 1, HEAD_SIZE_PADDED), - sequence{}); + make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); + + const auto v_dram_merged = transform_tensor_view( + v_dram_pad, + make_tuple( + make_merge_transform( + make_tuple(kargs.num_blks, BLOCK_SIZE) + ), + make_pass_through_transform(HEAD_SIZE_PADDED) + ), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}) + ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim + + return v_dram_merged; }(); + auto v_dram_window = make_tile_window( - v_dram, make_tuple(number{}, number{}), - {0, i_n1}); + v_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0}); - auto o_acc_tile = [&]() { return FmhaPipeline{}(q_dram_window, k_dram_window, From b37c35609020d128027a045e025a8de105859260 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 14 Oct 2025 09:36:28 +0000 Subject: [PATCH 16/88] fix q window origin --- .../ops/unified_attention/kernel/unified_attention_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index e8cd551417c..152af051f83 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -377,7 +377,7 @@ struct FmhaFwdV3Kernel auto q_dram_window = make_tile_window( q_dram, make_tuple(BLOCK_Q, HEAD_SIZE_PADDED), - {q_block_global_idx * num_queries_per_kv * HEAD_SIZE_PADDED, 0} + {0, 0} ); const auto k_dram = [&]() { @@ -414,7 +414,7 @@ struct FmhaFwdV3Kernel auto k_dram_window = make_tile_window( k_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0}); - const auto v_dram = [&]() { + const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), From c3d27abfb8dc5ff4c3ad235b882fc1b08536ec74 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 14 Oct 2025 09:49:54 +0000 Subject: [PATCH 17/88] fix q window --- .../ops/unified_attention/kernel/unified_attention_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 152af051f83..ba2da985f48 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -376,8 +376,8 @@ struct FmhaFwdV3Kernel // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) auto q_dram_window = make_tile_window( q_dram, - make_tuple(BLOCK_Q, HEAD_SIZE_PADDED), - {0, 0} + make_tuple(BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED), + {query_pos * num_queries_per_kv, 0} ); const auto k_dram = [&]() { From e1120fffb0c4d635bcbd7c859ffa354944cb6526 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 14 Oct 2025 09:58:27 +0000 Subject: [PATCH 18/88] pipeline api --- .../kernel/unified_attention_kernel.hpp | 13 ++++++++++++- .../pipeline/unified_attention_pipeline.hpp | 8 ++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index e8cd551417c..386319f28b9 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -96,7 +96,6 @@ struct FmhaFwdV3Kernel ck_tile::index_t num_seqs; // number of batches for q }; - using Kargs = UnifiedAttentionVarlenKargs; CK_TILE_HOST static constexpr Kargs MakeKargs( @@ -332,6 +331,8 @@ struct FmhaFwdV3Kernel index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1; + index_t block_table_offset = seq_idx * kargs.block_table_stride; + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + kv_head_offset; @@ -445,10 +446,20 @@ struct FmhaFwdV3Kernel auto v_dram_window = make_tile_window( v_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0}); + // Create mask for causal attention + auto mask = [&]() { + return make_casual_mask(query_pos, BLOCK_Q, max_seq_prefix_len, BLOCK_SIZE); + }(); + + // Define LSE dram window (or use a dummy if not needed by pipeline) + auto lse_dram_window = make_dummy_tile_window(); + auto o_acc_tile = [&]() { return FmhaPipeline{}(q_dram_window, k_dram_window, v_dram_window, + block_tables_ptr, + block_table_offset, lse_dram_window, mask, kargs.scale_s, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index b151b61028d..7bc3dc1d7dc 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -391,6 +391,8 @@ struct UnifiedAttentionPipeline [[maybe_unused]] const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile [[maybe_unused]] const VElementFunction& v_element_func, + const void* block_tables_ptr, + index_t block_table_offset, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, [[maybe_unused]] const SAccElementFunction& s_acc_element_func, @@ -402,6 +404,8 @@ struct UnifiedAttentionPipeline { using namespace ck_tile; + index_t block_idx_prev = 0; + static_assert( std::is_same_v> && std::is_same_v> && @@ -1231,6 +1235,8 @@ struct UnifiedAttentionPipeline CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const void* block_tables_ptr, + index_t block_table_offset, LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale_s, @@ -1244,6 +1250,8 @@ struct UnifiedAttentionPipeline identity{}, v_dram_block_window_tmp, identity{}, + block_tables_ptr, + block_table_offset, lse_dram_block_window_tmp, identity{}, identity{}, From c87f2e3ca97a03902e0b237003f7e7a8c83b96d0 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 14 Oct 2025 09:59:47 +0000 Subject: [PATCH 19/88] o window change --- .../ops/unified_attention/kernel/unified_attention_kernel.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 17f56c08e8d..f752aeabd81 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -500,7 +500,7 @@ struct FmhaFwdV3Kernel auto o_dram_window = make_tile_window(o_dram, make_tuple(BLOCK_M, HEAD_SIZE_PADDED), - {q_block_global_idx * num_queries_per_kv * HEAD_SIZE_PADDED, 0}); + {query_pos * num_queries_per_kv, 0}); EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } From ec29289bb18df4ea918700ff28cc891832747395 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 14 Oct 2025 12:04:11 +0000 Subject: [PATCH 20/88] kv paging --- .../pipeline/unified_attention_pipeline.hpp | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 7bc3dc1d7dc..15a5e339ea9 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -404,7 +404,6 @@ struct UnifiedAttentionPipeline { using namespace ck_tile; - index_t block_idx_prev = 0; static_assert( std::is_same_v> && @@ -577,22 +576,26 @@ struct UnifiedAttentionPipeline } } + index_t i_total_loops = 0; + index_t kv_blk_idx = block_tables_ptr[block_table_offset + i_total_loops]; + index_t kv_blk_idx_prev = 0; + + auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}, + {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}, // TODO: hdim split? + {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); // prefetch K tile - index_t i_total_loops = 0; constexpr index_t k0_loops = kQKHeaddim / kK0; constexpr index_t k1_loops = kN0 / kK1; static_assert(1 == k0_loops); @@ -685,7 +688,10 @@ struct UnifiedAttentionPipeline /// FIXME: use the future-predicting method to move the window // move K tile windows - move_tile_window(k_dram_window, {kN0, 0}); + auto k_dram_window = make_tile_window(k_dram_window.get_bottom_tensor_view(), + k_dram_window.get_window_lengths(), + {(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0}, + Policy::template MakeVDramTileDistribution()); }; auto K_lds_load = [&](auto k_lds_read_idx) { @@ -696,7 +702,10 @@ struct UnifiedAttentionPipeline async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); /// FIXME: use the future-predicting method to move the window - move_tile_window(v_dram_window, {kK1, 0}); + auto v_dram_window = make_tile_window(v_dram_window.get_bottom_tensor_view(), + v_dram_window.get_window_lengths(), + {(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0}, + Policy::template MakeVDramTileDistribution()); }; auto V_lds_load = [&](auto v_lds_read_idx) { From b940a75328a08f5bc8f1525fd2de074511203e37 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 14 Oct 2025 12:19:20 +0000 Subject: [PATCH 21/88] Comments --- .../unified_attention/pipeline/unified_attention_pipeline.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 15a5e339ea9..8d37d422916 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -685,7 +685,7 @@ struct UnifiedAttentionPipeline auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); - + // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx as the index /// FIXME: use the future-predicting method to move the window // move K tile windows auto k_dram_window = make_tile_window(k_dram_window.get_bottom_tensor_view(), From 4d232d59cc50891d88157a72adef3c43dc68b9fd Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 14 Oct 2025 12:34:33 +0000 Subject: [PATCH 22/88] fix seq_len -> cur_batch_query_len --- .../kernel/unified_attention_kernel.hpp | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ba2da985f48..f4f4ebe726f 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -29,6 +29,7 @@ struct FmhaFwdV3Kernel using VDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; @@ -339,14 +340,14 @@ struct FmhaFwdV3Kernel ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - index_t seq_len_padded = integer_divide_ceil(seq_len, BLOCK_Q) * BLOCK_Q; - const bool is_seq_len_aligned = (seq_len % BLOCK_Q == 0); + index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q; + const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { const auto q_dram_base = make_naive_tensor_view( q_ptr, - make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), + make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, number<1>{}); @@ -355,14 +356,14 @@ struct FmhaFwdV3Kernel q_dram_base, // block sizes make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), - sequence{} + sequence{} ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) const auto q_dram_merged = transform_tensor_view( q_dram_pad, make_tuple( make_merge_transform( - make_tuple(seq_len_padded, num_queries_per_kv) + make_tuple(query_len_padded, num_queries_per_kv) ), make_pass_through_transform(HEAD_SIZE_PADDED) ), @@ -445,6 +446,18 @@ struct FmhaFwdV3Kernel auto v_dram_window = make_tile_window( v_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0}); + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.BLOCK_M, + kargs.BLOCK_SIZE, + cur_batch_query_len, + seq_len, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{cur_batch_query_len, seq_len}; + }(); + auto o_acc_tile = [&]() { return FmhaPipeline{}(q_dram_window, k_dram_window, From 853fa21566f6a1fe4237289c61db772b1bbfeb3f Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 15 Oct 2025 11:58:44 +0000 Subject: [PATCH 23/88] Example boostrap --- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 27 +- .../01_unified_attention/CMakeLists.txt | 238 +++ .../ck_tile/01_unified_attention/README.md | 159 ++ example/ck_tile/01_unified_attention/bias.hpp | 114 ++ .../01_unified_attention/codegen/__init__.py | 0 .../codegen/cmake_config.py | 5 + .../codegen/cpp_symbol_map.py | 138 ++ .../codegen/ops/__init__.py | 0 .../codegen/ops/fmha_batch_prefill.py | 633 ++++++ .../codegen/ops/fmha_bwd.py | 929 +++++++++ .../codegen/ops/fmha_fwd.py | 783 ++++++++ .../codegen/ops/fmha_fwd_appendkv.py | 376 ++++ .../codegen/ops/fmha_fwd_splitkv.py | 885 ++++++++ .../codegen/ops/fmha_pagedkv_prefill.py | 591 ++++++ .../01_unified_attention/codegen/utils.py | 21 + .../example_fmha_fwd_v3.cpp | 616 ++++++ .../ck_tile/01_unified_attention/generate.py | 132 ++ .../instances/fmha_fwd_v3_d128_bf16_mask.cpp | 14 + .../instances/fmha_fwd_v3_d128_bf16_nmask.cpp | 14 + .../instances/fmha_fwd_v3_d128_fp16_mask.cpp | 14 + .../instances/fmha_fwd_v3_d128_fp16_nmask.cpp | 14 + example/ck_tile/01_unified_attention/mask.hpp | 167 ++ .../01_unified_attention/misc/gamc.png | Bin 0 -> 30073 bytes .../ck_tile/01_unified_attention/rotary.hpp | 84 + .../script/benchmark_bwd.sh | 20 + .../script/benchmark_fwd.sh | 53 + .../script/benchmark_fwd_v3.sh | 42 + .../script/fmha_bwd_known_fails_gfx90a.txt | 0 .../script/fmha_bwd_known_fails_gfx942.txt | 0 .../script/fmha_bwd_known_fails_gfx950.txt | 0 .../script/fmha_fwd_known_fails_gfx90a.txt | 0 .../script/fmha_fwd_known_fails_gfx942.txt | 0 .../script/fmha_fwd_known_fails_gfx950.txt | 0 .../script/run_full_test.sh | 48 + .../script/smoke_test_bwd.sh | 90 + .../script/smoke_test_fwd.sh | 281 +++ .../unified_attention.cpp | 60 + .../unified_attention.hpp | 74 + .../unified_attention_impl.hpp | 158 ++ .../unified_attention_runner.hpp | 1789 +++++++++++++++++ .../ck_tile/01_unified_attention/utils.hpp | 244 +++ .../kernel/unified_attention_kernel.hpp | 55 +- .../pipeline/tile_unified_attention_shape.hpp | 68 + .../tile_unified_attention_traits.hpp | 24 + .../pipeline/unified_attention_pipeline.hpp | 106 +- .../unified_attention_pipeline_problem.hpp | 6 +- 46 files changed, 8967 insertions(+), 105 deletions(-) create mode 100644 example/ck_tile/01_unified_attention/CMakeLists.txt create mode 100644 example/ck_tile/01_unified_attention/README.md create mode 100644 example/ck_tile/01_unified_attention/bias.hpp create mode 100644 example/ck_tile/01_unified_attention/codegen/__init__.py create mode 100644 example/ck_tile/01_unified_attention/codegen/cmake_config.py create mode 100644 example/ck_tile/01_unified_attention/codegen/cpp_symbol_map.py create mode 100644 example/ck_tile/01_unified_attention/codegen/ops/__init__.py create mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_batch_prefill.py create mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_bwd.py create mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd.py create mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_appendkv.py create mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_splitkv.py create mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_pagedkv_prefill.py create mode 100644 example/ck_tile/01_unified_attention/codegen/utils.py create mode 100644 example/ck_tile/01_unified_attention/example_fmha_fwd_v3.cpp create mode 100644 example/ck_tile/01_unified_attention/generate.py create mode 100644 example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_mask.cpp create mode 100644 example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_nmask.cpp create mode 100644 example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_mask.cpp create mode 100644 example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_nmask.cpp create mode 100644 example/ck_tile/01_unified_attention/mask.hpp create mode 100644 example/ck_tile/01_unified_attention/misc/gamc.png create mode 100644 example/ck_tile/01_unified_attention/rotary.hpp create mode 100755 example/ck_tile/01_unified_attention/script/benchmark_bwd.sh create mode 100755 example/ck_tile/01_unified_attention/script/benchmark_fwd.sh create mode 100755 example/ck_tile/01_unified_attention/script/benchmark_fwd_v3.sh create mode 100644 example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx90a.txt create mode 100644 example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx942.txt create mode 100644 example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx950.txt create mode 100644 example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx90a.txt create mode 100644 example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx942.txt create mode 100644 example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx950.txt create mode 100755 example/ck_tile/01_unified_attention/script/run_full_test.sh create mode 100755 example/ck_tile/01_unified_attention/script/smoke_test_bwd.sh create mode 100755 example/ck_tile/01_unified_attention/script/smoke_test_fwd.sh create mode 100644 example/ck_tile/01_unified_attention/unified_attention.cpp create mode 100644 example/ck_tile/01_unified_attention/unified_attention.hpp create mode 100644 example/ck_tile/01_unified_attention/unified_attention_impl.hpp create mode 100644 example/ck_tile/01_unified_attention/unified_attention_runner.hpp create mode 100644 example/ck_tile/01_unified_attention/utils.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp create mode 100644 include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index 194675f9627..b067a9acae7 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -20,31 +20,22 @@ #include "fmha_fwd_v3.hpp" #include "mask.hpp" -#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ +#define INST_UNIFIED_ATTENTION_V3_DISPATCH(kernel_traits) \ template <> \ - std::pair fmha_fwd_v3_kernel_dispatch( \ - const fmha_fwd_v3_args& args, const stream_config& config) \ + std::pair unified_attention_kernel_dispatch( \ + const unified_attention_args& args, const stream_config& config) \ { \ return std::make_pair(true, \ - fmha_fwd_v3_kernel_launch(args, config)); \ + unified_attention_kernel_launch(args, config)); \ } namespace ck_tile { -template -struct fmha_fwd_v3_problem_traits; +template +struct unified_attention_problem_traits; template <> -struct fmha_fwd_v3_problem_traits -{ - using qkvp_dtype = ck_tile::half_t; - using acc_dtype = float; - using o_dtype = ck_tile::half_t; - using lse_dtype = float; -}; - -template <> -struct fmha_fwd_v3_problem_traits +struct unified_attention_problem_traits { using qkvp_dtype = ck_tile::bf16_t; using acc_dtype = float; @@ -52,8 +43,8 @@ struct fmha_fwd_v3_problem_traits using lse_dtype = float; }; -template -struct fmha_fwd_v3_kernel_traits +template +struct unified_attention_kernel_traits { static constexpr auto date_type = DataType; static constexpr bool is_variable_seqlen = IsVariableSeqlen; diff --git a/example/ck_tile/01_unified_attention/CMakeLists.txt b/example/ck_tile/01_unified_attention/CMakeLists.txt new file mode 100644 index 00000000000..b8ca26193d6 --- /dev/null +++ b/example/ck_tile/01_unified_attention/CMakeLists.txt @@ -0,0 +1,238 @@ +set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) +# Currently only gfx9 archs are supported by FMHA +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") +if(NOT INST_TARGETS) + message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() +endif() + +# validate user-specified fmha_fwd API list +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") +set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") +if(BUILD_TESTING) + # Build instances of all APIs for tests + set(FMHA_FWD_ENABLE_APIS "all") +endif() +if(FMHA_FWD_ENABLE_APIS STREQUAL "all") + set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) +endif() + +foreach(api ${FMHA_FWD_ENABLE_APIS}) + if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) + message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.") + endif() +endforeach() + +# "fwd" is a must-have api for the fmha_fwd example, add it if not specified +if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) + list(PREPEND FMHA_FWD_ENABLE_APIS "fwd") +endif() + +file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py +) +# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change +set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") + +string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +set(FMHA_FWD_CODE_GEN_COMMON_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FMHA_FWD_APIS} + --optdim 32,64,128,256 + # --filter fmha_fwd... +) +set(FMHA_BWD_CODE_GEN_COMMON_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api bwd + --receipt 3 + --optdim 32,64,96,128,256 + # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... +) + +# Reduce building time by disabling instances that are not currently used in the gtests +# TODO: Consider to use a special receipt for testing only, or even two receipts: a small subset of +# instances for quick CI runs and a larger subset for scheduled runs (the tests skip tests when +# there is no corresponding instance for parameters). +if(BUILD_TESTING) + # Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill + list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) +endif() + +# generate a list of kernels, but not actually emit files at config sta +execute_process( + COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") +endif() + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.") +endif() + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory +# as current cmake list, otherwise will not figure out the dependency properly +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${FMHA_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} +) + +add_custom_command( + OUTPUT ${FMHA_BWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} +) + +set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") +set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") + +message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}") +add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) +target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}") +add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) +target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS}) +set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS) +set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS) +set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS) +set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +# ... because they are auto-generated +list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) + +# Allow comparing floating points directly in order to check sentinel values +list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) + +# NOTE: this is dangerous since will change the whole kernel to flush denormals +# WIP with compiler team for an exp2 intrinsic..., then remove this +if(NOT DEFINED FMHA_FWD_FAST_EXP2) + set(FMHA_FWD_FAST_EXP2 ON) +endif() + +if(FMHA_FWD_FAST_EXP2) + list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) +else() + list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0) +endif() +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero) + +# conditionally enable call to the fwd_splitkv API in fmha_fwd example and tests +if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) +else() + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) +endif() + +# conditionally enable call to the fwd_appendkv API in fmha_fwd example and tests +if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) +else() + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) +endif() + +# conditionally enable call to the pagedkv_prefill API in fmha_fwd example and tests +if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) +else() + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) +endif() + +# conditionally specify the use of OCP_FP8 +if(CK_USE_OCP_FP8) + list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) + +target_compile_options(${FMHA_FWD_INSTANCES} + PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS} + INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS}) +target_compile_options(${FMHA_BWD_INSTANCES} + PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS} + INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS}) + +set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") +set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") + +message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") +# not using add_example_executable() to add this target, since we don't want this to be included in +# "make all/install/check" +add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp) +target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES}) +target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + +message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") +# not using add_example_executable() to add this target, since we don't want this to be included in +# "make all/install/check" +add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) +target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) +target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + +# add fmha_fwd_v3 example +set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") +message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") + +add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) +target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" +) +target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE + fmha_fwd_v3.cpp + ${FMHA_FWD_V3_INSTANCES} +) + +set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) +list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + -fgpu-flush-denormals-to-zero + -Wno-undefined-func-template + --save-temps +) +set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) + +check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) +if(HAS_DISABLE_PACKED_FP32) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + -mllvm --amdgpu-disable-packed-fp32=1 + ) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS + -DCK_TILE_DISABLE_PACKED_FP32=1 + ) +endif() + +target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) +target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/01_unified_attention/README.md b/example/ck_tile/01_unified_attention/README.md new file mode 100644 index 00000000000..2b872cb9b51 --- /dev/null +++ b/example/ck_tile/01_unified_attention/README.md @@ -0,0 +1,159 @@ +# fused multi-head attention + +This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +../script/cmake-ck-dev.sh ../ +make tile_example_fmha_fwd -j +``` +This will result in an executable `build/bin/tile_example_fmha_fwd` + +## kernel +The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. + +There are 2 template parameters for this kernel template. +* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). +* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. + +## codegen +To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. + +## executable +`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all the arguments. Below is an example of the output (may subject to change) +``` +args: + -v weather do CPU validation or not (default:1) + -mode kernel mode. 0:batch, 1:group (default:0) + -b batch size (default:2) + -h num of head, for q (default:8) + -h_k num of head, for k/v, -1 means equal to h (default:-1) + if not equal to h, then this is GQA/MQA case + -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) + total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary + also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) + -s_k seqlen_k (including new key/value), -1 means equal to s (default:-1) + also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode) + -s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1) + Provide positive strides per-batch to simulate physical padding on Q + -s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1) + for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride + along seqlen, instead of packed, same as xformer kv_padding, + must be greater than or equal to s_k + -d head dim for q, k (default:128) + -d_v head dim for v, -1 means equal to d (default:-1) + -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) + note when squant=1, this value will be modified by range_q/k + -range_q per-tensor quantization range of q. used if squant=1. (default:16) + -range_k per-tensor quantization range of k. used if squant=1. (default:16) + -range_v per-tensor quantization range of v. used if squant=1. (default:16) + -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) + -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) + -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto) + 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. + calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o + -iperm permute input (default:1) + if true, will be b*h*s*d, else b*s*h*d + -operm permute output (default:1) + -bias n or 0, no bias (default:n) + e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s + a(libi) or 2, alibi with 1*h. a:1, b*h + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) + 't', top-left causal mask, 'b', bottom-r causal mask + 't:l,r', top-left sliding window attn(swa) with FA style left right size + 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size + 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa + 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa + 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) + -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) + -lse 0 not store lse, 1 store lse (default:0) + -kname if set to 1 will print kernel name (default:0) + -init init method. ui, uniform random int, ni, normalized random int (default:uf) + uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization + -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) + -drop_seed seed for random number generator (default:1) +-drop_offset offset for random number generator (default:0) + -drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0) + -num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1) + -warmup number of iterations before benchmark the kernel (default:5) + -repeat number of iterations to benchmark the kernel (default:20) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:fmha_fwd.json) + -q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override +-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override +``` +Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with + batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case + +## Padding Examples +Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively. + +Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively. + +## support features +Currently we are still in rapid development stage, so more features/optimizations will be coming soon. + +### hdim +Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default) + +### group/batch mode +Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below) + +### MQA/GQA +By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers. + +### input/output permute, and `b*s*3*h*d` +If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`. + +### attention bias +Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. + +### alibi +alibi is supported + +### lse +For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` + +### vlayout +We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts. + +### attention mask +we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right. +Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out. +![](misc/gamc.png) + +Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline. + +| mask case| cmdline | FA style | xformer style | +|----------|:-------------:|:-------------:|:-------------:| +| no mask | `-mask=0`(default) | | | +| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` | +| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` | +| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` | +| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` | + +Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right. + +### dropout +TBD + +### sequence padding and variable length support +We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths. + +**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment. + +**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste. + +Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. + +## FP8 experimental support +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. + +Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. diff --git a/example/ck_tile/01_unified_attention/bias.hpp b/example/ck_tile/01_unified_attention/bias.hpp new file mode 100644 index 00000000000..c07232a13a9 --- /dev/null +++ b/example/ck_tile/01_unified_attention/bias.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "e" || t == "elementwise") + { + info.type = bias_enum::elementwise_bias; + info.rank_info = std::stoi(v); + if(info.rank_info < 0 || info.rank_info > 2) + throw std::invalid_argument("invalid bias rank: " + str); + } + else if(t == "a" || t == "alibi") + { + info.type = bias_enum::alibi; + info.rank_info = std::stoi(v); + if(info.rank_info < 0 || info.rank_info > 1) + throw std::invalid_argument("invalid bias rank: " + str); + } + else + { + throw std::invalid_argument("invalid bias value: " + str); + } + } + else if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str == "1" || str == "e" || str == "elementwise") + { + info.type = bias_enum::elementwise_bias; + } + else if(str == "2" || str == "a" || str == "alibi") + { + info.type = bias_enum::alibi; + } + else + { + throw std::invalid_argument("invalid bias value: " + str); + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/01_unified_attention/codegen/__init__.py b/example/ck_tile/01_unified_attention/codegen/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_unified_attention/codegen/cmake_config.py b/example/ck_tile/01_unified_attention/codegen/cmake_config.py new file mode 100644 index 00000000000..03ebfd67021 --- /dev/null +++ b/example/ck_tile/01_unified_attention/codegen/cmake_config.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +GEN_DIR = "" # in Cmake, have to generate files in same folder \ No newline at end of file diff --git a/example/ck_tile/01_unified_attention/codegen/cpp_symbol_map.py b/example/ck_tile/01_unified_attention/codegen/cpp_symbol_map.py new file mode 100644 index 00000000000..81d34484a54 --- /dev/null +++ b/example/ck_tile/01_unified_attention/codegen/cpp_symbol_map.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +FWD_DTYPE_MAP = { + "fp32" : "FmhaFwdFp32", + "fp16" : "FmhaFwdFp16", + "bf16" : "FmhaFwdBf16", + "fp8" : "FmhaFwdFp8", + "fp8fp16": "FmhaFwdFp8Fp16", + "fp8bf16": "FmhaFwdFp8Bf16", + "fp8fp32": "FmhaFwdFp8Fp32" +} + +BWD_DTYPE_MAP = { + "fp32": "FmhaBwdFp32", + "fp16": "FmhaBwdFp16", + "bf16": "FmhaBwdBf16" +} + +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +_MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + +_MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +def get_mask_map(mask : str): + if mask == "generic": + return _MASK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_MAP + else: + assert False + return None + +_MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +_MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + +def get_mask_check_map(mask : str): + if mask == "generic": + return _MASK_CHECK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + +BIAS_MAP = { + "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" +} + +# TODO: this is ugly +BIAS_CHECK_MAP = { + "no" : "bias_enum::no_bias", + "bias" : "bias_enum::elementwise_bias", + "alibi" : "bias_enum::alibi" +} + +DROPOUT_MAP = { + "no" : "ck_tile::BlockDropoutBwd", + "dropout_wg32" : "ck_tile::BlockDropoutBwd", + "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", + "dropout_wg16" : "ck_tile::BlockDropoutBwd", + "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" +} + +DROPOUT_CHECK_MAP = { + "no" : "t.has_dropout == false", + "dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true", + "dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", +} + +ROPE_MAP = { + "no" : "ck_tile::RotaryEmbeddingEnum::NONE", + "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", + "half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED" +} + +ROPE_CHECK_MAP = { + "no" : "rope_enum::none", + "inter" : "rope_enum::interleaved", + "half" : "rope_enum::half_rotated" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", + "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", +} + +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", + "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false", + True : "true", + False : "false", +} diff --git a/example/ck_tile/01_unified_attention/codegen/ops/__init__.py b/example/ck_tile/01_unified_attention/codegen/ops/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_batch_prefill.py new file mode 100644 index 00000000000..e2f69fa49ab --- /dev/null +++ b/example/ck_tile/01_unified_attention/codegen/ops/fmha_batch_prefill.py @@ -0,0 +1,633 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_BATCH_PREFILL_PIPELINE_MAP = { + "qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + false, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + +#include + +template<> +float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_batch_prefill_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp" +FMHA_FWD_API=""" +#include + +namespace {{ +bool get_num_cus(unsigned& num_cu) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cu = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + return fmha_batch_prefill_(s, a); + }} +""" + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return 'true' + else: + return f'{self.bool_expr}' + + def __and__(self, other): + return CppConstraint(f'({str(self)}) && ({str(other)})') + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + constraint : CppConstraint + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + +class KernelComponentFactory: + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + @staticmethod + def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + else: + assert False + return pipelines + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == 'fp16' or dtype == 'bf16': + if 128 in result.keys(): + result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + return result + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = CustomFactory.get_hdim_tile_size_dict(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_batch_prefill) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_batch_prefill C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_bwd.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_bwd.py new file mode 100644 index 00000000000..7319ef7ea1a --- /dev/null +++ b/example/ck_tile/01_unified_attention/codegen/ops/fmha_bwd.py @@ -0,0 +1,929 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Tuple, Dict, Literal, Any +from collections import defaultdict + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * +from codegen.utils import update_file + + +FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_bwd.hpp" +""" + +FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile:: + sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; +using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; +using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; +using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; +using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>; +using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>; +using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaBwdTraits<{F_dpad}, + {F_dvpad}, + {F_bias}, + {F_dbias}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; +using fmha_dropout_{F_idx} = {F_dropout}; + +using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_{F_idx}, + {F_mode}, + {F_deterministic}, + fmha_mask_{F_idx}, + fmha_dropout_{F_idx}, + {F_trload}, + fmha_bwd_trait_{F_idx}>; + +using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline; + +using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, + false, + ({F_dpad} > 0)>>; + +using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, + false, + ({F_dvpad} > 0)>>; + +using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType, + false, + ({F_dpad} > 0)>>; + +using fmha_bwd_dq_dk_dv_kernel_{F_idx} = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, + {F_dtype}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_dropout_{F_idx}, + {F_bias}, + {F_dbias}, + {F_dpad}, + {F_dvpad}, + {F_deterministic}, + {F_trload}, + {F_maxq}, + {F_bn0}>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template <> +int fmha_bwd_dq_dk_dv_maxq_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::kMaxSeqLenQ; +}} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" +FMHA_BWD_API=""" +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + if constexpr (!std::is_same_v) + {{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} + ); + }} + else + {{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + ); + }} +}} + +template <> +float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); + float r = -1; +{F_dispatch} + return r; +}} +""" + +def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str: + lines = [ + f"{'if' if if_ == 0 else 'else if'}({F_cond})", + "{", + *[' ' + line for line in F_body.split('\n') if line.strip() != ''], + "}", + ] + return '\n'.join(' ' * indent + line for line in lines) + '\n' + + +FMHA_BWD_API_INNER_DISPATCH=""" +{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; + r = fmha_bwd_>(s, a); + return r; +}} +""" + +# M0 size for 1d kernels (dot/convert) +M0_1D = 64 + +# GEMM0: Q@K=S^T +# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) +# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) +# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) +# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) +# Is it necessary to distinguish between K0~K4? +@dataclass(frozen=True) +class FmhaBwdDQDKDVTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along gemm0 unroll(F_bhdq) + F_bk1 : int # tile size along gemm1 unroll(F_bm0) + F_bk2 : int # tile size along gemm2 unroll(F_bhdv) + F_bk3 : int # tile size along gemm3 unroll(F_bm0) + F_bk4 : int # tile size along gemm4 unroll(F_bn0) + F_bhdq : int # q head_dim + F_bhdv : int # v head_dim + F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 + F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 + F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2 + F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 + F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3 + F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3 + F_rm2 : int # number of warps along q seqlen (block warps) in gemm4 + F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4 + F_rk2 : int # number of warps along k seqlen (not used) in gemm4 + F_wm0 : int # warp size along m in gemm0/gemm2/gemm4 + F_wn0 : int # warp size along n in gemm0/gemm2/gemm4 + F_wk0 : int # warp size along k in gemm0/gemm2/gemm4 + F_wm1 : int # warp size along m in gemm1/gemm3 + F_wn1 : int # warp size along n in gemm1/gemm3 + F_wk1 : int # warp size along k in gemm1/gemm3 + F_occupancy : int # occupancy + max_seq_q : int = 0 + + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}" + +@dataclass(frozen=True) +class FmhaBwdDQDKDVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaBwdDQDKDVTileSize + F_dpad : Literal[0, 8 ,1] + F_dvpad : Literal[0, 8 ,1] + F_bias : str # + F_dbias : str # + F_dropout : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_deterministic : str # + mask_impl : str # + F_trload : str # + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = BWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bk1 = self.F_tile.F_bk1, + F_bk2 = self.F_tile.F_bk2, + F_bk3 = self.F_tile.F_bk3, + F_bk4 = self.F_tile.F_bk4, + F_bhdq = self.F_tile.F_bhdq, + F_bhdv = self.F_tile.F_bhdv, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_rm2 = self.F_tile.F_rm2, + F_rn2 = self.F_tile.F_rn2, + F_rk2 = self.F_tile.F_rk2, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_dpad = self.F_dpad, + F_dvpad = self.F_dvpad, + F_bias = BIAS_MAP[self.F_bias], + F_dbias = BOOL_MAP[self.F_dbias], + F_dropout = DROPOUT_MAP[self.F_dropout], + F_occupancy = self.F_tile.F_occupancy, + F_mask = get_mask_map(self.mask_impl)[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_deterministic = BOOL_MAP[self.F_deterministic], + F_trload = BOOL_MAP[self.F_trload], + F_maxq = self.F_tile.max_seq_q + ) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_dpad : n += f'd{self.F_dpad}' + if self.F_dvpad : n += f'dv{self.F_dvpad}' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_dbias == 't' : n += '_dbias' + else: n += '_ndbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_dropout != 'no' : n += f'_{self.F_dropout}' + else: n += '_ndropout' + + if self.F_deterministic == 't' : n += '_deterministic' + else: n += '_ndeterministic' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +# TODO: design a more practical way to do it +# this is current supported tile size. +def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: + if dtype == 'fp32' and tr_load == 'f': + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + ] + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': + return [ + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + ] + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': + return [ + FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), + FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), + ] + else: + return [] + +FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_dot_do_o_trait_{F_idx} = + ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>; + +using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = M0 = */ 64, + {F_hdim}, + {F_mode}, + fmha_bwd_dot_do_o_trait_{F_idx}>; + +using fmha_bwd_dot_do_o_{F_idx} = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_{F_idx} = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_{F_idx} = + fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass(frozen=True) +class FmhaBwdOGradDotOKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_spad : str # true/false + F_dvpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = BWD_DTYPE_MAP[self.F_dtype], + F_spad = BOOL_MAP[self.F_spad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + else: n += '_npad' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_convert_dq_trait_{F_idx} = + ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>; + +using fmha_bwd_convert_dq_pipeline_problem_{F_idx} = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + {F_bm0}, + {F_bn0}, + {F_hdim}, + {F_mode}, + {F_deterministic}, + fmha_bwd_convert_dq_trait_{F_idx}>; + +using fmha_bwd_convert_dq_{F_idx} = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_{F_idx} = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, + {F_dtype}, + {F_mode}, + {F_spad}, + {F_dpad}, + {F_deterministic}, + {F_bn0}>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{{ + using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{{ + using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass(frozen=True) +class FmhaBwdConvertQGradKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_spad : str # true/false + F_dpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int # + F_deterministic : str # + disabled : bool # sometimes this kernel is not used + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = BWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_bm0, + F_bn0 = self.F_bn0, + F_spad = BOOL_MAP[self.F_spad], + F_dpad = BOOL_MAP[self.F_dpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy, + F_deterministic = BOOL_MAP[self.F_deterministic]) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dpad == 't' : n += 'd' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + else: n += '_npad' + if self.F_deterministic == 't' : n += '_deterministic' + else: n += '_ndeterministic' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +@dataclass(frozen=True) +class FmhaBwdApiTrait: + idx : int # this is not a tunable, but a counter to differentiate symbol + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : int + dtype : str # data type + mode : str # value from MODE_MAP + tile : FmhaBwdDQDKDVTileSize + mask : str + bias : str + dbias : str + dropout : str + spad1d : str # spad for 1d kernels (dot/convert) + dpad : Literal[0, 1, 8] + dvpad : Literal[0, 1, 8] + deterministic : str + mask_impl : str + tr_load : str + + @property + def bm0(self) -> int: + return self.tile.F_bm0 + @property + def bn0(self) -> int: + return self.tile.F_bn0 + @property + def bhdq(self) -> int: + return self.tile.F_bhdq + @property + def bhdv(self) -> int: + return self.tile.F_bhdv + + @property + def scheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad1d == 't': + return f'a.seqlen_q % {M0_1D} != 0' + else: # self.spad1d == 'f' + return f'a.seqlen_q % {M0_1D} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0' + else: return f'a.hdim_q % {self.dpad} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0' + else: return f'a.hdim_v % {self.dvpad} == 0' + + @property + def extra_cond(self) -> str: + if self.tr_load == 't' and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128: + return "&& (a.seqlen_k <= 256)" + else: + return "" + + @property + def convert_dq_bn0(self) -> int: + return self.tile.F_bn0 if self.deterministic == 't' else 0 + + @property + def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + F_dvpad = 't' if self.dvpad else 'f' + return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, + F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + + @property + def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: + return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, + F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, + F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load) + + @property + def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + F_dpad = 't' if self.dpad else 'f' + return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, + F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad, + F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), + F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))) + + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: + # TODO: do we need to check duplication? + self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait)) + + @staticmethod + def if_(i: int) -> str: + return 'if' if i == 0 else 'else if' + + def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str: + inners = "" + i = 0 + for trait in traits: + inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], + F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], + F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad, + F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, + F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond, + F_convert_dq_bn0=trait.convert_dq_bn0) + i += 1 + return inners + + @staticmethod + def trload_sort_key(tf): + return 0 if tf == 't' else 1 # sort 't' before 'f' + + @staticmethod + def max_seq_q_sort_key(max_seq_q): + return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end + + @staticmethod + def max_seq_q_cond(max_seq_q: int) -> str: + if max_seq_q == 0: + return 'true /* no seqlen_q limit */' + else: + return f'a.seqlen_q <= {max_seq_q}' + + @staticmethod + def dtype_cond(dtype: str) -> str: + return f't.data_type.compare("{dtype}") == 0' + + @staticmethod + def hdim_cond(hdim: int) -> str: + return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}' + + @property + def api(self) -> str: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true /* no trload requirement */" + } + per_tr_load = '' + for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key): + per_max_seq_q = '' + for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key): + per_dtypes = '' + for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]): + per_hdim_case = '' + for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]): + traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim] + inners = self._api_innders(traits) + per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners) + per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case) + per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes) + per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += ' (void)t ; (void)s ; (void)a; (void)has_load_tr;' + result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load) + return result.replace('\n\n', '\n') + +def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: + if filter_list == '': + filter_list = '*@*@*' + filters = filter_list.split('@') + filters.extend(['*'] * (3 - len(filters))) + filter_dot_do_o = filters[0] + filter_convert_dq = filters[1] + filter_dq_dk_dv = filters[2] + + # use dict as ordered set + gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {} + gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {} + gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {} + api_pool = FmhaBwdApiPool(mask_impl) + + for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): + tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) + dpad_options = itertools.product(*([[0, 8, 1]] * 2)) + tf = ["t", "f"] + for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product( + tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf): + assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" + hdim = tile.F_bhdq + if (mode == "group") and (spad1d == "f"): + continue + if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0: + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + if ("wg32" in dropout): + continue + if tr_load == "t": + continue # tr_load cannot work with dpad or dvpad + else: # tr_load == "f" + # do not generate instance with only 1 of dpad/dvpad being 8 + if dpad != dvpad and dpad == 8: + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load) + + if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): + continue + if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): + continue + if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): + continue + + # Flash attention integration + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue + elif receipt == 3: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue + # Aiter (mha_bwd) integration + elif receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue + # Aiter (mha_varlen_bwd) integration + elif receipt == 400: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue + + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= dpad == dvpad + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [64, 128] + cond &= dpad == dvpad + cond &= mode == 'batch' + cond &= bias == 'no' + cond &= dropout == 'no' + cond &= mask == 's_no' + cond &= deterministic == "f" + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + + gen_dot_do_o[t.dot_do_o_kernel] = True + gen_dq_dk_dv[t.dq_dk_dv_kernel] = True + if not t.convert_dq_kernel.disabled: + gen_convert_dq[t.convert_dq_kernel] = True + api_pool.register_dq_dk_dv_traits(t) + + return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys()) + +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) + update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api) + for k in kernels_dot_do_o: + update_file(output_dir / k.filename, k.template) + for k in kernels_convert_dq: + update_file(output_dir / k.filename, k.template) + for k in kernels_dq_dk_dv: + update_file(output_dir / k.filename, k.template) + + +def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: + _, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( + filter_list, receipt, mask_impl, optdim_list + ) + with file_path.open("a") as f: + for k in kernels_dot_do_o: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_dq_dk_dv: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_convert_dq: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd.py new file mode 100644 index 00000000000..f898d5f7b26 --- /dev/null +++ b/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd.py @@ -0,0 +1,783 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * +from codegen.utils import update_file + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 48 : 48, + 64 : 64, + 96 : 128, + 128: 128, + 192: 192, + 256: 256 +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdKernel; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + return fmha_fwd_(s, a); + }} +""" + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return 'true' + else: + return f'{self.bool_expr}' + + def __and__(self, other): + return CppConstraint(f'({str(self)}) && ({str(other)})') + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + tr_load : str + constraint : CppConstraint + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag in ['qr_async', 'qr_async_trload']: + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr', 'qs']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + def seqtune(self, max_bm0 : int) -> str: + if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/' + else: + return f'a.seqlen_q <= {self.bm0}' + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' + elif self.pipeline_tag in ['qr', 'qs']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' + elif self.pipeline_tag == 'qr_async_trload': + if self.skpad == 't' : return 'true' + else: return 'true' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + F_trload : str # true/false + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' + + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true" + } + + per_tr_load =str() + for tr_load in ["t", "f"]: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + max_bm0 = max((t.bm0 for t in traits), default=0) + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_trload = BOOL_MAP[self.F_pipeline.F_trload]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp32': + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } + elif dtype == 'fp16' or dtype == 'bf16': + return { + (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + elif dtype == 'fp8' or dtype == 'fp8bf16': + return { + (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } + elif dtype == 'fp8fp32': + return { + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + pipelines = [] + if dtype in ['fp32']: + squant = 'f' + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + elif dtype in ['fp16', 'bf16']: + squant = 'f' + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + if hdim == 256 and hdim_v == 256: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + # the below two is used for hdim vectorize load + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + else: + if bias == "bias": + # TODO: rocm 6.2 compiler problem if using qr_async for bias case + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) + if receipt == 1 and bias != "bias": + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: + # no need lse/dropout kernels + for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + elif dtype in ['fp8fp16', 'bf8']: + # TODO + None + else: + assert False + return pipelines + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == 'fp16' or dtype == 'bf16': + if (128, 128) in result.keys(): + result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + return result + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory + + for dtype in FWD_DTYPE_MAP.keys(): + d = factory.get_hdim_tile_size_dict(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, next_tile in zip(tiles, tiles[1:]): + assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0' + for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if (hdim, hdim_v) == (192, 128): + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': + continue + if dtype != 'fp32': + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + if dtype == 'fp8bf16': + cond &= hdim == 128 + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + if dtype == 'fp8bf16': + cond &= hdim == 128 + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] + cond &= pipeline.F_vlayout == 'row' + if dtype == 'fp8bf16': + cond &= hdim == 128 + if not cond: + continue + elif receipt == 888: + cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32'] + cond &= pipeline.F_vlayout == 'row' + cond &= hdim == 128 + if not cond: + continue + + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [48, 128] + cond &= mode == 'batch' + cond &= pipeline.F_bias == 'no' + cond &= pipeline.F_lse == 'f' + cond &= pipeline.F_dropout == 'f' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + cond &= pipeline.F_mask == 's_no' + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_appendkv.py new file mode 100644 index 00000000000..38491b56c40 --- /dev/null +++ b/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_appendkv.py @@ -0,0 +1,376 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + +from codegen.ops.fmha_fwd import ( + FmhaFwdApiTrait, + DTYPE_BITS, + FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_DTYPE, + FMHA_FWD_API_PER_HDIM_CASE, +) + + +FMHA_FWD_APPENDKV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_occupancy}>; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + {F_bs}, + {F_bsk}, + {F_bd}, + {F_bdv}, + {F_vlayout}, + {F_rope}, + {F_pagedkv}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< + fmha_pipeline_problem_{F_idx}>; + +using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel; + +using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, + {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; + +#include + +template<> +float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp" +FMHA_FWD_APPENDKV_API=""" +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv})) {{ + using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; + return fmha_fwd_appendkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdAppendKVApiTrait: + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + bs : int # tile size along q seqlen + bsk : int # tile size along k seqlen + bd : int # tile size along qk gemm unroll + bdv : int # tile size along kv gemm unroll + vlayout : str + spad : str + skpad : str + dpad : str + dvpad : str + rope : str # key from ROPE_MAP + pagedkv : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\ + f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}' + + @property + def scheck(self) -> str: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/' + else : return f'a.seqlen_q % {self.bs} == 0' + + @property + def skcheck(self) -> str: + # we do not check all the values in a.seqlen_k_ptr + return 'true' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bd} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bdv} == 0' + +@dataclass +class FmhaFwdAppendKVPipeline: + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_rope : str # key from ROPE_MAP + F_pagedkv : str # t/f + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_rope != 'no': n += f'_{self.F_rope}' + if self.F_pagedkv == 't': n += '_pagedkv' + return n + +class FmhaFwdAppendKVApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], + F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdAppendKVTileSize: + F_bs : int # tile size along q seqlen + F_bsk : int # tile size along k seqlen + F_bd : int # tile size along qk gemm unroll + F_bdv : int # tile size along kv gemm unroll + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdAppendKVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaFwdAppendKVTileSize + F_pipeline : FmhaFwdAppendKVPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_APPENDKV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bs = self.F_tile.F_bs, + F_bsk = self.F_tile.F_bsk, + F_bd = self.F_tile.F_bd, + F_bdv = self.F_tile.F_bdv, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_rope = ROPE_MAP[self.F_pipeline.F_rope], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_occupancy = self.F_tile.F_occupancy) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdAppendKVApiTrait: + return FmhaFwdAppendKVApiTrait( + hdim=str(self.F_hdim), + dtype=self.F_dtype, + bs=self.F_tile.F_bs, + bsk=self.F_tile.F_bsk, + bd=self.F_tile.F_bd, + bdv=self.F_tile.F_bdv, + vlayout=self.F_pipeline.F_vlayout, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + rope=self.F_pipeline.F_rope, + pagedkv=self.F_pipeline.F_pagedkv) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1), + '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1) + } + else: + return None + +def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + # NOTICE: it will be very complicated if we consider all the hdim_q padding cases while + # applying rotary embedding, so I just use 't' in inter/half pipelines + for vlayout in ['row', 'col']: + for pagedkv in ["t", "f"]: + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv)) + + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv)) + + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv)) + elif dtype in ['fp8', 'bf8']: + # rope/paged-kv is not supported + pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdAppendKVApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str in d.keys(): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + k = FmhaFwdAppendKVKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + if not cond: + continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_fwd_appendkv_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_splitkv.py new file mode 100644 index 00000000000..281357ef1ec --- /dev/null +++ b/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_splitkv.py @@ -0,0 +1,885 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple, Union + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + +from codegen.ops.fmha_fwd import ( + FmhaFwdTileSize, + FmhaFwdApiTrait, + FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_DTYPE, + FMHA_FWD_API_PER_HDIM_CASE, +) + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + # 160: 160, + 256: 256 +} + +FMHA_FWD_SPLITKV_PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", +} + +FMHA_FWD_SPLITKV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_mask_{F_idx} = {F_mask}; + +namespace {{ +template +struct instance {{ +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + /*kHasBiasGrad=*/false, + {F_lse}, + {F_squant}, + {F_pagedkv}, + kHasUnevenSplits, + kMergeNumHeadGroupsSeqLenQ, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::OaccDataType, + fmha_shape, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + fmha_trait>; + +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; + +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, + false, false>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVKernel; + +static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_dvpad}>; + +#include + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wtautological-compare" + +namespace {{ +template +void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ + if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> + || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ + if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ + instance::run(s, a); + }} else {{ + instance::run(s, a); + }} + }} else {{ + instance::run(s, a); + }} +}} +}} // anonymous namespace + +#pragma clang diagnostic pop + +template<> +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + if constexpr({F_mode} == false) {{ // batch mode + // we don't check every seqlen_k values for kvcache + if (a.seqlen_k_ptr != nullptr) {{ + run_instance(s, a); + // make sure F_bn0 is divisible by F_bk1 + }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + run_instance(s, a); + }} else {{ + run_instance(s, a); + }} + }} else {{ + run_instance(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_get_name_() +{{ + using k_ = instance::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +namespace {{ +template +struct instance {{ +using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, + {F_dvpad}, + {F_lse}, + {F_squant}, + kLogMaxSplits, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {F_hdim}, + {F_mode}, + {F_bn1}, + fmha_trait>; + +using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + fmha_pipeline_problem>; + +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + false, false>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVCombineKernel; + +static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1}, + {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + +#include + +template<> +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + if (a.num_splits <= 8) {{ + instance<3>::run(s, a); + }} else if (a.num_splits <= 16) {{ + instance<4>::run(s, a); + }} else if (a.num_splits <= 32) {{ + instance<5>::run(s, a); + }} else if (a.num_splits <= 64) {{ + instance<6>::run(s, a); + }} else if (a.num_splits <= 128) {{ + instance<7>::run(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_combine_get_name_() +{{ + using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp" +FMHA_FWD_SPLITKV_API=""" +#include + +template +float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + if(s.log_level_ > 0) + std::cout + << ", " << fmha_fwd_splitkv_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() + << std::flush; + + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + ); +}} + +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + + // get combine kernel tile sizes + using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; + constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; + + // make sure we can reuse the padding flags in combine kernels + static_assert({F_bm0} % kM0 == 0); + static_assert({F_bn1} % 32 == 0); + + if (t.has_lse) {{ + if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ + return -1; + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} + }} +""" + +@dataclass +class FmhaFwdSplitKVApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + mask : str + logits : str + bias : str # + lse : str # + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + pagedkv : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ + f'{self.dvpad}-{self.pagedkv}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: + if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdSplitKVPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_squant : str # + F_pagedkv : str # t/f + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + + if self.F_pagedkv == 't' : n += '_pagedkv' + else: n += '_npagedkv' + return n + +@dataclass +class FmhaFwdSplitKVCombinePipeline: + tag : str + + F_spad : str # true/false + F_dvpad : str # + F_lse : str # + F_squant : str # + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + return n + +class FmhaFwdSplitKVApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdSplitKVCombineTileSize: + F_bn1 : int # tile size along v head_dim + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bn1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdSplitKVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdSplitKVPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdSplitKVApiTrait: + return FmhaFwdSplitKVApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + logits=self.F_pipeline.F_logits, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + squant=self.F_pipeline.F_squant, + pagedkv=self.F_pipeline.F_pagedkv, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +@dataclass +class FmhaFwdSplitKVCombineKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdSplitKVCombineTileSize + F_pipeline : FmhaFwdSplitKVCombinePipeline + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bn1 = self.F_tile.F_bn1, + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_mode = MODE_MAP[self.F_mode]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } + else: + return None + +def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdSplitKVCombineTileSize(32, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + '96' : FmhaFwdSplitKVCombineTileSize(32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + # '160' : FmhaFwdSplitKVCombineTileSize(32, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, -1), + } + else: + return None + +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: + Pipeline = FmhaFwdSplitKVPipeline + Kernel = FmhaFwdSplitKVKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + elif dtype in ['fp8', 'bf8']: + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask)) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdSplitKVApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # Flash attention integration + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16, bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]: + Pipeline = FmhaFwdSplitKVCombinePipeline + Kernel = FmhaFwdSplitKVCombineKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(Pipeline('unused', spad, dvpad, lse, squant)) + elif dtype in ['fp8', 'bf8']: + # no need lse kernels + pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant)) + else: + assert False + return pipelines + + gen = list() + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # Aiter(mha_varlen_fwd) integration + if receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + + gen.append(k) + + return gen + +def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None: + file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME + file_path.write_text(api_pool.api) + +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (2 - len(filter_list))) + + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_fwd_splitkv_api(api_pool, output_dir) + +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (2 - len(filter_list))) + + with file_path.open('a') as f: + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_pagedkv_prefill.py new file mode 100644 index 00000000000..3624b7b387e --- /dev/null +++ b/example/ck_tile/01_unified_attention/codegen/ops/fmha_pagedkv_prefill.py @@ -0,0 +1,591 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_FWD_PAGEDKV_PIPELINE_MAP = { + "qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, //lse + {F_pagedkv}, //pagedkv + {F_squant}, + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdPagedKVKernel; + +using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + +#include + +template<> +float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + return fmha_fwd_pagedkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + pagedkv : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_pagedkv : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + + if self.F_pagedkv == 't' : n += '_pagedkv' + else: n += '_npagedkv' + + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + pagedkv=self.F_pipeline.F_pagedkv, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + # if pipeline.F_pagedkv == 'f': + # continue + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' : + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/utils.py b/example/ck_tile/01_unified_attention/codegen/utils.py new file mode 100644 index 00000000000..e3bbb18c427 --- /dev/null +++ b/example/ck_tile/01_unified_attention/codegen/utils.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import os.path as path + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) diff --git a/example/ck_tile/01_unified_attention/example_fmha_fwd_v3.cpp b/example/ck_tile/01_unified_attention/example_fmha_fwd_v3.cpp new file mode 100644 index 00000000000..7ddb65a2dbc --- /dev/null +++ b/example/ck_tile/01_unified_attention/example_fmha_fwd_v3.cpp @@ -0,0 +1,616 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fmha_fwd.hpp" +#include "fmha_fwd_v3.hpp" +#include "mask.hpp" + +auto parse_cmd_args(int argc, char* argv[]) -> std::pair +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("prec", "fp16", "data type. fp16/bf16") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", "3328", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q & k") + .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") + .insert("iperm", + "0", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "0", "permute output") + .insert("causal", "0", "0: no mask, 1: causal mask") + .insert("v", "1", "0:no verify, 1:verify") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "30", "number of iterations to benchmark the kernel") + // Optional effective seqlen override (exclude PAD) for batch mode + .insert("q_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override.") + .insert("kv_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override."); + + bool result = arg_parser.parse(argc, argv); + return std::make_pair(result, arg_parser); +} + +enum class TensorLayout +{ + bhsd, + bshd, +}; + +std::ostream& operator<<(std::ostream& stream, TensorLayout layout) +{ + switch(layout) + { + case TensorLayout::bhsd: return stream << "bhsd"; + case TensorLayout::bshd: return stream << "bshd"; + default: return stream << "unknown"; + } +} + +struct Problem +{ + explicit Problem(const ck_tile::ArgParser& args) + { + data_type = args.get_str("prec") == "fp16" + ? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16 + : ck_tile::fmha_fwd_v3_args::data_type_enum::bf16; + batch = args.get_int("b"); + seqlen_q = args.get_int("s"); + seqlen_k = args.get_int("s_k"); + if(seqlen_k < 0) + { + seqlen_k = seqlen_q; + } + nhead_q = args.get_int("h"); + nhead_kv = args.get_int("h_k"); + if(nhead_kv < 0) + { + nhead_kv = nhead_q; + } + hdim = args.get_int("d"); + softmax_scale = args.get_float("scale_s"); + if(softmax_scale == .0f) + softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); + + const auto is_causal = args.get_bool("causal"); + if(is_causal) + { + mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); + } + else + { + mask = mask_info::decode("0", seqlen_q, seqlen_k); + } + + input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; + output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; + q_eff_lens = args.get_int_vec("q_eff_lens"); + kv_eff_lens = args.get_int_vec("kv_eff_lens"); + } + + std::vector get_query_shape() const + { + if(input_layout == TensorLayout::bhsd) + { + return {batch, nhead_q, seqlen_q, hdim}; + } + else + { + return {batch, seqlen_q, nhead_q, hdim}; + } + } + + std::vector get_key_shape() const + { + if(input_layout == TensorLayout::bhsd) + { + return {batch, nhead_kv, seqlen_k, hdim}; + } + else + { + return {batch, seqlen_k, nhead_kv, hdim}; + } + } + + std::vector get_value_shape() const + { + if(input_layout == TensorLayout::bhsd) + { + return {batch, nhead_kv, seqlen_k, hdim}; + } + else + { + return {batch, seqlen_k, nhead_kv, hdim}; + } + } + + std::vector get_output_shape() const + { + if(output_layout == TensorLayout::bhsd) + { + return {batch, nhead_q, seqlen_q, hdim}; + } + else + { + return {batch, seqlen_q, nhead_q, hdim}; + } + } + + ck_tile::fmha_fwd_v3_args::data_type_enum data_type; + ck_tile::index_t batch; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_kv; + ck_tile::index_t hdim; + float softmax_scale; + mask_info mask; + TensorLayout input_layout; + TensorLayout output_layout; + std::vector q_eff_lens; + std::vector kv_eff_lens; +}; + +struct RunConfig +{ + explicit RunConfig(const ck_tile::ArgParser& args) + { + seed = args.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + kernel_warmup = args.get_int("warmup"); + kernel_repeat = args.get_int("repeat"); + verify = args.get_bool("v"); + } + + std::optional seed; + int kernel_warmup; + int kernel_repeat; + bool verify; +}; + +template +auto generate_qkv(const Problem& problem, + [[maybe_unused]] std::optional seed = std::nullopt) + -> std::tuple, + ck_tile::HostTensor, + ck_tile::HostTensor> +{ + ck_tile::HostTensor q(problem.get_query_shape()); + ck_tile::HostTensor k(problem.get_key_shape()); + ck_tile::HostTensor v(problem.get_value_shape()); + + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v); + + return std::make_tuple(q, k, v); +} + +namespace host { +template +CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, + const ck_tile::HostTensor& k_bshd, + const ck_tile::HostTensor& v_bshd, + const mask_info& mask, + ck_tile::HostTensor& o_bshd, + const QElementOp& q_element_op = {}, + const KElementOp& k_element_op = {}, + const VElementOp& v_element_op = {}, + const SAccElementOp& s_acc_element_op = {}) +{ + const int batch_size = q_bshd.mDesc.get_lengths()[0]; + const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; + const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; + const int nhead_q = q_bshd.mDesc.get_lengths()[2]; + const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; + const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; + const int hdim_v = v_bshd.mDesc.get_lengths()[3]; + + const int nr = nhead_q / nhead_kv; + + ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); + ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); + ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); + ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); + ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); + + // do computation for each batch + for(int b = 0; b < batch_size; ++b) + { + // copy per-batch data from input tensors + // clang-format off + q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); + k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); + v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + // clang-format on + ck_tile::reference_batched_gemm( + q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, seqlen_q, seqlen_kv)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + } + + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, ck_tile::identity{}); + + ck_tile::reference_batched_gemm( + p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); + + // copy resulting per-batch data to the output tensor + o_host_ref.ForEach( + [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); + } +} +} // namespace host + +template +bool run_impl(const Problem& problem, const RunConfig& run_config) +{ + auto [q, k, v] = generate_qkv(problem, run_config.seed); + + ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes()); + /// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v + ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q.data()); + k_buf.ToDevice(k.data()); + v_buf.ToDevice(v.data()); + // Ensure output buffer is zero-initialized so padded regions compare cleanly + o_buf.SetZero(); + + ck_tile::fmha_fwd_v3_args args{}; + + args.data_type = problem.data_type; + args.batch = problem.batch; + args.seqlen_q = problem.seqlen_q; + args.seqlen_k = problem.seqlen_k; + args.nhead_q = problem.nhead_q; + args.nhead_kv = problem.nhead_kv; + args.hdim_qk = problem.hdim; + args.hdim_v = problem.hdim; + args.softmax_scale = problem.softmax_scale; + + args.window_size_left = problem.mask.left; + args.window_size_right = problem.mask.right; + args.mask_type = static_cast(problem.mask.type); + + // bshd: (batch, seqlen_q, nhead_q, hdim) + // bhsd: (batch, nhead_q, seqlen_q, hdim) + args.q_ptr = q_buf.GetDeviceBuffer(); + args.stride_q = + problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; + args.nhead_stride_q = + problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim; + args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim; + + // bshd: (batch, seqlen_k, nhead_kv, hdim) + // bhsd: (batch, nhead_kv, seqlen_k, hdim) + args.k_ptr = k_buf.GetDeviceBuffer(); + args.stride_k = + problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; + args.nhead_stride_k = + problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; + args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim; + + // bshd: (batch, seqlen_k, nhead_kv, hdim) + // bhsd: (batch, nhead_kv, seqlen_k, hdim) + args.v_ptr = v_buf.GetDeviceBuffer(); + args.stride_v = + problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; + args.nhead_stride_v = + problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; + args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim; + + // bshd: (batch, seqlen_q, nhead_q, hdim) + // bhsd: (batch, nhead_q, seqlen_q, hdim) + args.o_ptr = o_buf.GetDeviceBuffer(); + args.stride_o = + problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; + args.nhead_stride_o = problem.output_layout == TensorLayout::bshd + ? problem.hdim + : problem.seqlen_q * problem.hdim; + args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; + + // Optional cumulative seqlen overrides (exclude PAD) + const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; + const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; + + auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { + std::vector eff; + if(!opt_vec.empty() && opt_vec[0] != -1) + { + eff.assign(opt_vec.begin(), opt_vec.end()); + if(eff.size() < static_cast(problem.batch)) + { + eff.resize(problem.batch, eff.back()); + } + } + else + { + eff.assign(problem.batch, fallback); + } + return eff; + }; + + const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); + const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); + + // Calculate cumulative sums for kernel arguments if varlen is used + std::vector cuq_cum, cukv_cum; + auto calculate_cumulative = [&](const std::vector& per_batch_vec, + std::vector& cum_vec) { + cum_vec.resize(per_batch_vec.size() + 1); + cum_vec[0] = 0; + for(std::size_t i = 0; i < per_batch_vec.size(); ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + }; + + if(has_varlen_q) + { + calculate_cumulative(eff_q_vec, cuq_cum); + } + if(has_varlen_k) + { + calculate_cumulative(eff_kv_vec, cukv_cum); + } + + ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); + ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); + cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); + cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); + args.cu_seqlen_q_ptr = + !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) + : nullptr; + args.cu_seqlen_kv_ptr = + !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) + : nullptr; + + ck_tile::stream_config stream_config{nullptr, + true, + /*log_level=*/0, + run_config.kernel_warmup, + run_config.kernel_repeat}; + + auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config); + if(!result) + { + std::cerr << "faild to run fmha_fwd_v3()" << std::endl; + return false; + } + + std::size_t flop = [&] { + if(problem.mask.type == mask_enum::no_mask) + { + return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * + problem.hdim; + } + else + { + /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. + return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * + problem.hdim; + } + }(); + float tflops = static_cast(flop) / 1.e9 / time; + + std::cout << "[" << problem.data_type << "|"; + if(problem.input_layout == problem.output_layout) + { + std::cout << problem.input_layout; + } + else + { + std::cout << problem.input_layout << "-" << problem.output_layout; + } + std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv + << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim + << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed + << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops + << " TFlops" << std::endl; + + if(!run_config.verify) + { + return true; + } + + // transpose tensor descriptors from bhsd to bshd if necessary + if(problem.input_layout != TensorLayout::bshd) + { + q = q.transpose({0, 2, 1, 3}); + k = k.transpose({0, 2, 1, 3}); + v = v.transpose({0, 2, 1, 3}); + } + + ck_tile::HostTensor o_ref(problem.get_output_shape()); + if(problem.output_layout != TensorLayout::bshd) + { + o_ref = o_ref.transpose({0, 2, 1, 3}); + } + + // If variable lengths are provided, compute per-batch references + // with the effective lengths; else compute a single full reference. + if(has_varlen_q || has_varlen_k) + { + // Variable-length aware verification: zero-fill padded region and only compute valid part. + o_ref.SetZero(); + + for(int b = 0; b < problem.batch; ++b) + { + const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; + + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); + v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(b, s, h, d) = o_b(0, s, h, d); + } + } + } + } + } + else + { + // No varlen override: compute the full reference once + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + } + + ck_tile::HostTensor o(problem.get_output_shape()); + o_buf.FromDevice(o.data()); + + const auto [rtol, atol] = [&] { + if constexpr(std::is_same_v) + return std::make_tuple(1e-3, 1e-3); + else + return std::make_tuple(1e-2, 1e-2); + }(); + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); +} + +int main(int argc, char* argv[]) +{ + auto [parse_result, args] = parse_cmd_args(argc, argv); + if(!parse_result) + { + std::cerr << "failed to parse command line arguments" << std::endl; + } + + Problem problem(args); + RunConfig run_config(args); + + const auto run = [&] { + if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16) + { + return run_impl(problem, run_config); + } + else + { + return run_impl(problem, run_config); + } + }; + + return !run(); +} diff --git a/example/ck_tile/01_unified_attention/generate.py b/example/ck_tile/01_unified_attention/generate.py new file mode 100644 index 00000000000..03173305118 --- /dev/null +++ b/example/ck_tile/01_unified_attention/generate.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import pkgutil +import sys +from typing import List, Optional + +import codegen.ops +from codegen.cmake_config import * + + +class HandlerId(IntEnum): + LIST_BLOBS = 0 + WRITE_BLOBS = 1 + +# inspect all modules under 'codegen.ops' and register API handlers +ops = [] +for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): + full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) +unwanted_prefix = 'fmha_' +handlers = dict( + [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, + (op.list_blobs, op.write_blobs)) for op in ops] +) +assert 0 < len(handlers) + +def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) / GEN_DIR + + output_dir.mkdir(parents=True, exist_ok=True) + + for api, kernel_filter in zip(api_list, filters_list): + handler = handlers[api][HandlerId.WRITE_BLOBS] + handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) + +# list all the files that will be generated +def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: + assert output_file is not None + file_path = Path(output_file) + + # create an empty file / drop its contents if it exists + open(file_path, "w").close() + + for api, kernel_filter in zip(api_list, filters_list): + handler = handlers[api][HandlerId.LIST_BLOBS] + handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK fmha kernel", + ) + parser.add_argument( + "-d", + "--direction", # we keep 'direction' option for backward compatibility + "-a", + "--api", + default='fwd', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="write all the blobs into a directory" + ) + parser.add_argument( + "-l", + "--list_blobs", + required=False, + help="list all the kernels to a file" + ) + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + default='', + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-m", + "--mask", + default="simplified", + required=False, + help="mask implementation, simplified/generic" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ + " 1: generate more instance to cover all hdim\n" + \ + " 2: Only generate instance for Flash attention integration\n" + \ + " 4: Only generate instance for PyTorch integration\n" + \ + " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" + ) + + parser.add_argument( + "--optdim", + default='-1', + required=False, + help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ + "eg. --optdim=32,64,128,256" + ) + + args = parser.parse_args() + api_list = args.direction.split(',') + filter_list = args.filter.split(',') + filter_list.extend([''] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(',')] + + if args.list_blobs is not None: + list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) + else: + write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_mask.cpp new file mode 100644 index 00000000000..d99838d17c0 --- /dev/null +++ b/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_mask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_nmask.cpp new file mode 100644 index 00000000000..a6806b95d7a --- /dev/null +++ b/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_nmask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_mask.cpp new file mode 100644 index 00000000000..a710efd2cbd --- /dev/null +++ b/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_mask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_nmask.cpp new file mode 100644 index 00000000000..d8fcd7d97db --- /dev/null +++ b/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_nmask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + unified_attention_kernel_traits; + +INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_unified_attention/mask.hpp b/example/ck_tile/01_unified_attention/mask.hpp new file mode 100644 index 00000000000..2dfe0e7c529 --- /dev/null +++ b/example/ck_tile/01_unified_attention/mask.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + tmp.seqlen_q = seqlen_q; + tmp.seqlen_k = seqlen_k; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = std::stoi(v); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else if(t == "t" || t == "b" || t == "g") + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + throw std::invalid_argument("invalid mask value: " + str); + } + ck_tile::index_t v0 = std::stoi(v.substr(0, found_1)); + ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1)); + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.type = mask_enum::window_generic; + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + } + else + { + throw std::invalid_argument("invalid mask value: " + str); + } + } + else if(str == "0") + { + tmp.type = mask_enum::no_mask; + } + else if(str == "1" || str == "t") + { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + } + else if(str == "2" || str == "b") + { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + } + else + { + throw std::invalid_argument("invalid mask value: " + str); + } + return tmp; + } + + ck_tile::index_t get_unmaskarea() const + { + if(type == mask_enum::no_mask) + return seqlen_q * seqlen_k; + ck_tile::index_t area = 0; + for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) + { + ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0)); + ck_tile::index_t x_end = std::min(i_y + x, seqlen_k); + if(x_end > x_start) + { + area += (x_end - x_start); + } + } + return area; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/01_unified_attention/misc/gamc.png b/example/ck_tile/01_unified_attention/misc/gamc.png new file mode 100644 index 0000000000000000000000000000000000000000..2c96951f30f99ff0345706c4f3031ab2a623ce1a GIT binary patch literal 30073 zcmeFZ2UOI_w=QY|6&g@UO;$vb0u2ZV2q;L--6TPB5+o=R8w3?;6c8jQ0hQQ9$r4%y zn;;-4&}3;uk~B#|gGhSSfHTfH^Z(y-*E;XMyVg6i)~s1W*RRsvwX5p;zFkjrw3MkR zm?;h%I6$SUat(gqz~Srz2M!4xI|Tl6lpw`@;DE;!)oWMvya#7fv~P+I-P?{}aWh#u zbqu4KOqcZeq0M8_Wbp~S5~_y3tU|(t!i=nzuWR1k{-t70JnBMOWE{i58yi}M`sBXB zBNPw%9L(U=oni%jHgR9nPd}647PuY#tFjrs zD)@qX2C~m_7US8WxArRkrU)Tt?4_ezPI7vFO#RDh1c6|{7Ubwg>2Dyx*y37#VsesK z$l+;jH-d25?Bk2{)EENZW^QQuikQ$?BOT&p-8LFIfbhZ_2sxy0Xx=3D6F%!TR6=c(YwEcNzo zes6J;@!t-;_lErmx-x-6fta(dB*dl4lY3-H&}(AIY2?!1hW{-KuXru_!%;-G<6r{K zhGx$n9MtYO{XP?mT&XF-7lXMvy*{t`sY;1thPQA6eruBu>lm)*%5bk*y&$rVJ+q?_ z*Qj=yf@ew53?l1!`HwoPrLDv#kgYf8gS3bBwLzwQTaxQJiVh}FNG>(|BtUvwv zXV6q=qu||*>#rwh#0z`B6qRPOKR!f^%z4~*;Svpb znhUUxF^-Wq>oo6Msm0HWYmC#_A)ys9V>zRr+S7T2rqGp4E7)=U6N_6EeR^GzD-@~f z4Hd3&PuXt~C-KMfHj>41Wn1As@Q*eFZxfp%sDAgin^|UZ@JB^1+RB89QmLwAhYT)V zuT}`8mH0!T@VbzoOoX`Ltfkn`{6x*-TG_XCo@xp!bCYWC+Ia3@ClpV;j+bUhkI@(S zBJ|0a`tDELA8FUJWWsRkcepxGzGo9gKN*gx3yu4Xl|<&k`Hz1*|M-J&g0)Aw&@F8T zBA2>rL&c!t&sx0ZJ?MO*Qruj!$D6(vJ2rcZgRB+V+KcXE%USMfQkTmk#@=Owu)eKU zZ9gCQl+yQ@a+=t{`5Vb;4x6Lc_n#cp2*Wp2D{w65U9Vqh!VMYPdlBI1mRLtYyezpj zCgrZ{1D0i~7bVG838W%T@Q^mzg-L`Xf?!(udgtmYjz#SOC6|4BgI-L*Y`!%H zEXo4ore4L{H)@TteB_4KpRU@K|13S^_mdNE%P}*YfDC-f-ldNhU}JuYGswGuc^S8r zCPbt@<;&f+TCfrT*=zIx=Yv-PG{KGI#?5y|sDrPv8;6Ola^T*=*ZK%6`fiiSpKXl# z6;9myId7Vi^@NW8x|^>4o7bnw%W;7L(xu;q&{_k<`ax}!MFtqX=J?1L&tK;%2BR2` z{SgA|Tbr(0>7D3yGb2C*Gtpup=8%=FAhIMVm^oGGI ztc^Exr19garuB)Xk_|!zncQUo^3`6US*w$&wHc>gUW-nwoBSTr{{4!@cgk#^&taGz ze(86n2s2;*vaZ2~^ql{*Wj%M4f*l(Q{rZScMbXurmfiSFMiE1#8`b6_kG1%uOTeWZxP7m-cIPmR zAjhmMqxLc&&|sA#E*UV5Eh z^zcs#-jk+MtT2M;)oI(sO!3} zy{A@z_+ZnsNrlc#8;LNv-XVAdzO=$*F4*u<0yErw4K^TBB-|eSO1;GeDu-F+kzl0S zc@qOgPN{onb7Pig0*AgM$J&~#gwU1lt!vx!BLO?zJ3c?k=k)wOmJK<+VO{VFo|gZv zJgb7qLgb2Tz|J>m!nQPNeQj5b+WC5z?J7R6cMs^VWuS566YGl;gLls*${5^LoA8Ck ziY=?n6(=G1#Xnq_5oe4jf){KDO3++pFvUFY_CL0ja~e96@?5rgO|AHZ20v}ZgpV@E9_gr{dsIUe2Bgr#(|6q%bK(BA;gHIVVgs z&b2pl?T#=L^<)mtz@yPV5SpDAYufJShHp_1$gkY+!{1<=q0+sTf(%$1Z7!1i@olbh zggW3ysrja3`;dg(PMw4+zSet@)A>OUO~B7ZA?p$U<%j@#&%uoDIK`+Gr?*AewVU)(QD>AGXXoR9NX)Z=0D_rU_aIp(PFOVn?H;0{TK~i zs9n|?Je0@bKeN$Rnpotw*!V%v+tI1{PNoxq5oHNBec~|?WZCA&4^?^*Vq;>JFqk`DdJc8fmL)a>zA&=66|mXZ=U>ym zs*aBs$|#98)nGzJ)^aq%OyD2HAQcU;a(!W5q80R(9={r9WOKfIZF}j(f1$Jrh72o9 z?ED;daaMoAB^g5{p z>itiXv6w~g{LCbpk@Ze|WnJpv<>otf0UqalmM5R~F0D3(I_l-#QGwV(l^FX)(kM_T zb2ta$pjj3C>rPKBbq16gzZB#B&?8yi^(q?ZJ*R16riUOLEpy)M=|Y92A^>Nkw(Es`D|NN8;utYQ)=dOQG;h4L(F@tWu z%WA(`IjiPJpZrx(jK2qj4awh)Lj`QAdIVLhl<8@Wspc;-uj@6&zs0pZDgU+RTI|&! zfg!y6)^)i*?tS$!UpcIWKg#)urDL8vq1+D9dtK2Ing>=O%V|lo6IN;h8YTNt<{7eJ zAjZRB&utdSRl*q`CBP}8^Y>jY!|ay0ZHhh}s@OOPS936GjOdeKhWEY5utQ4;TnF zI0Hq0LL)5arYCCWz zw&|&9V@45~4bd6H6#-|tnkmmKgiV84CM_zZ3^Oo&3}jfNEO=wlz$PNI?Dt|KpZB`k z_!yT5gYYNOL9Cf*bBz zavlDyfN0o@VzXo=T;iw>f1@n6-)7IfKqoj6icXd}{o>gAEHiG&a+hTU)mdQ=jPlV*0GU#pQ&emXES z{h}L@Ey1WvURlRQ4`0}1L!HccXj#EaYMg0)?4#D_`wHjizO z_b6gNuf}J{6R7YOP*r{?cK~TP^}s#R772rSpkumch~%!YFfMvXbRDT*7Erx1J($KE zSW=le&yax|SccxAmg6}V3C zidht>tBUGmSlkWUfDS z>i5Sr9ir7__Z(*YzGk~J8TQ4g^=G*ASA%eQWG%^E+`EU8k`8`J3&b+$K+ZK;E~1*w z8KUFWrK~d5-BU?^sTy@Z&QA{*Ad*HEa|+C>@v~Y!XK)??$C8=+sKer79AJ!U;|lPi zBcMVCE?2c=P^#CdRVwel+Ux~xaK8js(_te?d&A+-^ro++8#Rlw3ILy5M8968)J2z@{hHIKA$EimYAH)qB4jy8OMR zUHALsJft~4aR0mLVp3=W2Hkg&%*Vy}rmx2X3Oq+@q`vsv^mlw@Esh-owQCqeiu8g; zR1tqR@+}+;GhOw^&}64uVVgTnQjR@05wYrfNlCqN?~{Q>BcZ*jU%bkEXh|r9Hxvzg zj?UgH^mVqd3~yPw<}jH)DSX$e7mLRwAAwju4v8@@wcK4fdrL%%KPba&Ve?z3?{Jcd zwSWDUa$^Bv&LwT*(4| zR(P$XHCe*7EUrj|sCHWg5+#wTxp(yXEhxpx>!~r#i189x3ar|VipI5+6q7>w-|0l6 z|B72xqO1_dx=zrVo@35mf8LLUCSLS#A%3ylpedsNXGD`H6x2|l8Dp7uL*H&$#Vk_I zXNbL~LAX#gs=vza#p7qlnL@A|WTHqhcCYMSFH|rdhG⪼0mN|Vz2CejK)`kn5K(< z`5$`>cO3=+EeEIbS3KYC8U?y?{{OyVRAic+rnT}CzQPTId^&_WX#5B43gseBnmxkH zG|G+DzVF<9Jr!U;sFKKU1%$DWp)&j@*vhy1I%i10xU{SrE!)OVo;CO z@ZdR*o@(`+2C@p?p0GbFOE-!}dLT5!ZH#VRbQ!;ZeWq3M_*LagEA@EmigGe1=0q)Z zj7h9xQ`VzBUZzC%f-0t4cWTr7N55lvN@;hY+-os5DrBFR^~41^H-#(N+10~a_;uoJ za_LZVXt}k;;)gGV%iAuq@S@KiDEE}W`Suo@a>z>=1Fw)e+YOqd+NsKgF0PZ**UL3O zn6^L38ZfIi_pw84TZyiHzL6jC@}MEH!#>p9G(~nqZo%uUb^C*php);JwY1WFM_0MF z%-`b$_aI6&2Re1Q>uqntt>8Rb*J~dr#d5c94>@Zg+b4P~7vRN?DY_r+)06hxf>Pjj)A;#-DiuEVek6nbYfpKyS>njiygMqZ?`vzhzd6 z{xq|S8to%n(~zl}UaR%-60aoOSjZjnn->#mE`z3}qDCIJb(n+^k+&SmhwQVndD_#J&Z@@a zZc&^YGjg?HX)nefDUC$#kAMwvK^?Qz((Y)*-PWE|;tQ zK_OIxYt0D5H%e9h8Rby%#M$6pDSZMmTBcT(=*ZHQ@i%81hTZnQ?58)^__FSQ?j=|#qG#G*L)5#=@MOgM)3--Y z)R^wER57e{;7r#&N_gjkYF@hP1OOVixEr&pREjwB_MCSzOX@%c{ofYujCI^u14fsN zxy&aShOEW1V!JSucIV$4w_l1+|IRkywc8FI&Gpo47FSGO=y_-?rFjk4N$K~s@n|IU zG4RG_;=df9*809}F~?Lvs_H##p7%rdX1n*B>3T5|(%l6~|4CEHqbW9>!a>8RDt@NE z$tN$;n}}b^Fr(#7zdB`>g7@ZmKmY0TmcN9p_c`y=4CqRH%q(ir#d30_*lA0!F@uCN z--IeHJdn1#$Ppr_h=r*?Vt*-!24i_)BcXAz7KJZp=v%jKwYYpq{`s`b!qKU*#j`Bv zzGG^)S4D@uI2PNUo%-Y`Ir-|Yo=64>{gy&4#&LLbC`7O`^f2(!>y)rGam+ZQo7V$h zo1e6DROuYX+O*X+ky<-fW4XC9nD+U`#mzCe)wdL(9+||Ijv^Wt?95XkmQYLvzKd^U zqb6WyOM7Q(w66lODAVndH|ka^x0>rRTV`wE4qrt}2RzM`kX`gtaF5WRgH;dX}bu)v3)=gPr>1qwJxDW*rW4kNovc zgpuOlm~W=hF>ry>TA%b9E=`oXi{I`p+Mb+}+nD~$zsOV*Ml@DK55krkl>-Ml-i`iD zzZ1j8)P&7)5Z>(ofvR-XP=$O2G!x6L)SgYO#xEXz-JIaQT1fM*iaViWI=uIS{iP2@ zR}=%Rq(g42V(2BCnm0DmHkRdn_%V!drSJ?(7Br789aWycC-pAoODfy`K-DBoW>`sB zX=?{0T?L~A&lUf$^mTR8dT)GiY=b9XC~el{Y-4v%YoVQ_pskGyV=)~iRl4XW6HO9m{0Y__ zQ;aua^=y8!W!})PeCr`@rkyeTlA!E!#0+kuXopxAutBr486@O)W^>dY{}Y!~DdfM^ zC%gF_>xK8>AtC8?#I1>e4~B1lSlPu3)ormDJ3rBQGU$y5)WTD%bj!~}4rUzzi4C=Q zD?V}NCYjIA&8cOYohrgU%yGtn-%C2a734=dzv{o4@SdG}GjKqqY_ZupaQNAp<)s<& zHN%yJ^L^=#DagE+WjjC9?6NgQi~PoD0&cBV*8XRFnHD0dueNM|MD_gjluBj5ZFKF_ zE1M%{0^TnAPyc7UJByhV84Ov{&M+Ilz3Uyo?2F_g@iDVp)8^beMprE*PnUvi5kTEA zNX!x`6p5i)9vEliF>1DS2+a66xCSNDTmi`mwH90eOx$%E>!$#( zC@0|+yb$YhlHr~q1wIaNBh7GS9L3t%>~*h64B6@e6E@qf}2U>&rIQ(tJrIm*(ywYpG7?PYgF8^A@es@ zu~R$#0XugO?KN9)MhPH?zmV=<^(ke(!kyo%$%xgeSE0lbL``%-$W?S@*krm@@E&~8 zIvf!V5MGu2h5#~Q_2A$CQY=6`lz5^h^q|Oty-kXA$$y-UU-Yh>iHaG!R3msWL;j(H z)gQ2^-wQA)eKIQ2JWKd+n>QA4;!Y)iM+NTnCpzPlJZgauBK75e6Kw_{JfKqq{;^UO zj2j6H28?@{CMn82hLh0nee}ez5JWV+kb9`&mN;Mo-M^5U0aXGUmDJkCaB_OnEYbbb zG`c%E^laeJE;PGMjXefXKa<~~*_vG#kvciP1Ir^(sIHWP8M_m(M0jQCvY`%e&i`JJ z8tw|PGW-`RR3yOBml%J8d@&)bMFR|iiyxNz-wCUNLhI6yvnpg*C*9u=@Ol10FF_V6 z>IfkYdm-bw3Q<J=(1g)$1c>$keN&R5)g*@`<~e@1xUQTT6C^3&6yED7kS z8W_a-QLFeel2T#*L5|nRJwrpqNv01oeMRMQ5)S#5M=YQuN<32>0M2J$M?1q?4nnNK ziDWlI!HWbtnX#m^3Saqpjs^5Vic_>2wFQ(!j*}R=XGo7Hy-+0*ZH^{E;m2lelJ(pU zV4>jDvKzy2dFo|NFcQJ1<}3>|0~S0(8BvF!O{ZgDAxOY|K}S>oL)FU?De0`Y07E`z z5Sf4k%n4M<7!-VH>14=vk%*-Efp}6`*QRqlg$>aSfL5El$Lb`=8gjM(^tT=+ro#s# zR!2xORs5te)L#2&T0nb_710JtX5<~)EeRNvJ;EU9H*h<9IP~mCmQ?dWnzZNm?8{$W zt!+I#!~Pdg_NWxL=EC<&%Wdttguns`0|WbJqhxR#Y~hz?nj!Io=zYeeT|WdmqmS21 zSj9S5qH!604Z?uHvknZ2A%Wd$Nx#5uGexg_FDYitcO&8ve%Q+>Ij2!5-1Pv&`gN;d z6KO60MYB%{#lBnUYwI}(!l>|ATEN*yTS@IlQCG>ZDl=f#VPO;dyN(K`TA*Za(p#N= zO*)9WhNA;cB%>KcB9{;Rzhl&0y?`+}YqPGt?VFOJnMy^&iS}d($w4dSh=~Izw9>gW z;3{O!S2G{JgK5aTSR zHy%hNA>R8EqXxu+_hr0?YU`cnrtt^hXF${fA?Wu4kh50#jQi?O01n5%`7+G*wXfi> zs|Xd$zH}6*hi9{&?E{J=DRIB{8KV3?O#b3Zahg0KXkYbz<$;wP$OPC|96vbJ=nu$? zKide|qJsOH6x3{q{7h8$`Uf~8n=EB*r!9$tv07>o9QH0rGyF}s`8@FwL@djM$ zJnL<`_6tUea$LI|4WaZ(gE{>Pm8X;_oE#y6duAiT&OF`$x0$8*6T1OU#$+e;Z6$Tm zwFx_#Zd?<6l7uhLS=O%P{sY9(qtOButJOb;ns`KA48{9q9a_5R_%ycdZ2R=SIn$K)wn9!6&G`Pc^YW zr0O)7!j39G7R5|$eet{OZ*vw5!HVPDCDYWoGdL^56HM;9omZvSLb?`poFSs9Ul?97 zJ`U;UE_bV(!|Z<-;R7CSqx*KvYiD~SZTlm($T|D#g3X|_*bTGOZp|JFH!sR>L?}zt zuSjUl#ts0CSnpQCfZLQdnDyu^cy#P@8OT(#WRSDo49p-J2?74{kHhsgm$KjxcapDiN6*c?O_LTYHpl zhS^yFZrb_W@m;=g`@bYmRRsaUrcqmV^GDB74$Ri-&?rs7YN&aE&&K4`&e~*qGNX`L zhQRBequ?%NJ3HjMy}0n9qts)teKnr$xXL!l6IwgAR`M^w1<+bZHf~u5UE~Ug^hB2P;%k;&(c&9 z;S#C1t4-}FwU>@}-5hmo-;OghqJBRMq{V;-fE>=e+#3FOU@z#9BE!GYrk=D(==}2C z%-Iih^X7WT=$6CmbR;QA+vwWa_w0zQH&amjUCbKkv=^|V=z%oVc;=|S)tKNpatPus z(4_9leMCP1Z3fsD4t%TM-xzMkKm1eh17xd!#%@2DXdR+s1K*^b;n?}nY|JlOyD?u= z&0jr6v8y-R#;oDn`WTA&t(2{7n7PeONLMQK3jq;*62zzXi6InReRR0s+mj1BV?u9N zhaB3}@r~f_%Mc8`OLbP#%mO;haNT z=;Fh2xX~q~ss>8B>ZcEJ-U?UUJ14)j0qvsDW|{ps79Y4N`so8JiN?>QrmsFGHQCIy zcVIU!0b3YWujo}GQaP{#BZaB@&i?f>^9oQIwPia$Vy~MOl-O|u!PKt<=xGQ!o&6bb zt&Nw^YyYm3Kuv*cF{r{;IVRq@_lO*sJ-Wb&~ zOLD93-3_9^#6q)>^jl2S{`1b2J-4A9wUIB9pCehiJHOa2e|=bN#+&@_zKp%RjKpYc zK-tQkc6`%R#gK`UhYZu*^%5p2`IIUZds(ANvfh|MzZO97Ll79Xx5aqrY~m$9Cm|Fv zyLf<_4=r(d$g8dCN^!E)yr2>ZocUJqKFKh-eX`n!&%bgQ5^ITwd2}5nEy6!M=4pxpMr=_a94lR$ukFNMlDY(`v^$I_dGD zw{#${5^6#tNw5z%fKMiHk;JgLNp&D)wMro9^LvZ4EF)juM0^+~x!fa?tKHRzA_}*Eb!fpI=V0gU7IgIiM2+m<&Vz$BCcR^H@l(pz7shkjbM0B# z&>UHSYk<`i?%Qm%3xz;_c|!sJPohFGPxLM6vK8^W1z9)1H%B&&tRK8XrGp z^jt}J2lFJ2bOfmJ1zo~pwtZ&9ym_~EW|p28V;0X~mC z93@kH+yU}(lu}`_4*DjKBvoKI6@uF){^iMlC&Sm8TjVSHu4e3GzgC;ZYn6E!gC z`*B-p^XV5aKY7>SBxa|E`1$HuF}?sysp8a1!)l?sn+;`$cUZKP zSBe z-d^4!PQX5MS4wo?q5}k?P36W%%)}R4t5+rvItjgNR`&>| zvWEuPD%^$M^6+PA=RWQ>jB`IKCp3H#-poWpOKz;+giFT%p(6m|+)GpgF7MP3`}H00 zHyl3JY)PWDrlXgWsxON}MnRQzz)1xoJ6U?1yJ*$-cu2CpQALLG(N|N}NfMoxZhSIw^$S;UBWM<$ zsjX-6r`lXWpeo0i6lT|{(Wo$f<(gB8v#>oVLI&U<8p>(ASlpccl+p&2Azzg3+8dZTGX)k*p7*bQ`WX#(y=fex)-)n@kog%ybus`zw$< zLtd@7U_2{s*57GOf$g{5-6{%bgoE~31V%c@zyA0iL;Mk*MsW7EF*ouyam$&tiHA#I zV}@JwSD~oRb0K13Z2V9DfHhhL^H*DA4mjPn<>{5w!{1vQo~4|Up18O5(?iU;@{3{Jtp!Z5Rsj6(qSMk#{fFb)X{AAbbcUm(=2F z+!Xr1!c5Bf?b|b7HsE2e#6b`>6JS6UPbOR=w0CYwu_$GM7+T@f+67-gO{F*ozE8Pz zx7d*VVDY&uk~FQDTOwljbQ@q4L~v+D~)HP2HB(mQjJwV`T{i$EXZEYkN_t zx!V$1^SlyN^2zR+#Wd^VUFtww4lWbXS9uIC;}Qh3YH< z>>wuMwpRg!%e+#!(4;15Ur$^2^}6HRM`OWp-C7;xhoml zqBjd-x2{mi%tKCzKF!PPp4tFcM|H2)ME`-XKo_0+O6B#mpN$Jvtp?gLFFDhKP=2O_ z-+_mSFtm-p`;UspNv6n$R;@Rxcla#z+Tp0!bho!nGbYHqb}51OIs)s3iiJr zV#Bw?)G!weqzqlR1ce?SjSz7mUD37sn(MpW6&~feELZi_hEYHMc_AX>##MTMPnXTG3FF+_~#{#&WBqP~W0^LZff3sf7M&4V;~v^#fai zd$Q2SsfgR*<+gQagj1ygreR30Xfxow6g7K zMZ4f+HpkAA$w=Cxx}SN`LID$@i#3I>|K$i%M!$qFvmHGt_B`{I;S%Dj*|&BzD>J34 zbqc*p>>HtW2vwsID*;;9Pezr4E2=dun3D!Knx@}}4EKuBO2fqYIIHnJ2a?2HEuEdw z>bZ@HiY8Q8Rc}FpzJ#>to9l@SuO}ulNPO7Gf>s%x0{e!umY!FZm)>gY{0eg zL()gAJrLk{16G6v!Q&{o(C`_RW~QTecw;i-&OQq5Po&dByfW>@CF|OH+*7@l;%LZo zH-GM_BJHVYVu?rV#`jUrohnA=j~&mEh)IZjj=;T^%yB;>zf3vh6zte&y4-0q9ONjy zOr_T4JcC(M+e+iYEsyt1?Xqub70l>^E~8_bAYs#viY^f;n%mG`2Q^YOUu^ z|Mw8a3C+#p6#IKYdZgZG6pp2?Tt05FS3olUwsAyf3-5XBsrrnFiR_^r_dnz+wWR;n zVje8aKcYu%Z_V66dr3BWh^Ys0_B%yR=na&yGKI&?FXb&GY2mKou{{n4wjpL>i5hy{ zq4Cu{9rmU820MoQnWrkDap!ej_I6|w@jv9AExX1%OF$l;zIlL{OXna=9 zu)u<1VNxLWttJO<(d~z~%14Vlj&~+*MYFOBjA2_Qx!@FPRn}$x8fB-idObh zRDi^1yU_^v;n=``2)28aPsn}TjP{N1W)(jnE6g5$*4JGdaRKG+h!jcXZS;?1|JbHjk4Nd1M9OWBk4JYGg;E3KEPRrZNC|#%$J`3Q3 z#gBmc#Vz8FM(&lyuh#IcENshu%Hl707niiQRVe}#T&t!lo+t^dMm zQ5e0)F3h%BLxU{Goz>t%F$6v!6!fq2y*n=JYXff0xas!B&I+Kd;acnWbgqXb>RpiK zcF46z(BS8GYIOujHQe;H;b_OXxtz; zcaW)uIG{ehlVfcV#Xk}!0W$CK`s^t}!M^!^Itc|X)6~Z7L^eG2g%(o1x ze5t22>IToHKub)sK%#R&IR{EmXRg&?R=yG$HX}Eof=>+aEjq248rf%iZZAB&#O9II z;tM6YDI`7nUsoKX!9)TT=ZRcBj%%*2241ke_=f;q6IB6b?S)m-ZW}f_!q$jMI4Oyd zY>nx_IqV%)+de_=;8wtymI&RjjX3@}(*hdX2D5bqwsd3U_S_wUJ0)!mm{#FNSQt@$ zKt_>qHN|L3$7gNWt5oRuLc>quOrPW+e0ACvkf8ccEN#Nq#t)F3_$X8<(HaUoVHn{K zuDg*)kNcJ}lFb?K^Z^^n3rRl=N&g@WK{K|_^Hp0g0?DG(AF$TvBuAIL57tTNL6q(! z6osR-U86I79jcpQ#%2Mv?I5`+>mOhZd;^u-g#eJdgr#*?ed`nbA*D2LzERiutnEQh z;aqNUa?<*Q?RmE!{nYGXsmEun z^x8qN958E)v4)@FhuZl9lGL)yBE%9HCO5kw?eW-K4#!!W{>B_6NOD2k58-Rh-~=vp z5D=2soAtXj$EYUmrwRc!HNB&5x-_kv127r-#*^uaLcNojmQ3HGbc`{pbMsSXE$1;a^k>A%lk>_9dU zt7BL*>!&%7lZj6~0}z>waMh*&|Ll zADz;~s5YAhG9M-jLDMyp7?54^H4_f3{o>e*Uqo$8)Gobx1f((T=hw9Cvag4Df6=^| zHoN5P!9Ph^xCn{64LR^9XUmsZBEcT9Wmq~ke{{L;z2_&U>^hT7O4fHy%@fC6wz`f#AiN_35(gxOwQ9=Q8)`m<6@6fZL!mT9&#Gk#{=;>4St`80G*8<~Vrx?3>H#O00t-C~;p z9ho-EiAOCogKpBF<8&Zdz^~pMYcjNxW-7nrc}-~=r1AcHj2+Bqx+v1A*Br`oOIo5@tLyd_?O2a%j_xP zpCfO}QRKzolH;DAO+GjAQ(c~mYU0!SY>;<`X-`N#4__3|BO=62dU&rd?b&_{3)@?3 za$9M^=0jSxviD!0+Rle`PhZqrL|uPQJGheP!jK6{`g zPC1GpTxat>8An^+6xBR$LgW9!WVXBQk+RAl)!iUUoUiB!8C}el^4F@V}s%U2>4s#sfV+4yQTFQvTtS z;GXs|InNq%CRY{hbtf=L@j}W}OYjiwgxM|MmH9WH{Hmfvwj?=2b$oqArmGabsb(MX ze@*S5KHI@y@RD9JlVSZLvx2E`>?^V3J)&vKtaQR^Klt-*SjP{Y3aO`qyC%kv>~Ri7Gh|0r|TMtSTv$cjOCi@5HEiytlmr4>W)@Exli#yiXFlvfEU5x1pYDyGW? zit-b=!#;H!tzxWWzFYTs!G}ZesPa@@Re}d@J+tnO#M!U9=Z0&X-i9q145f{nJG0p+ zA{cC)LlK&JQ}Y$^hcLcQ+Biecw2I9|otGRabZp_<#JcS@*PYR{y6-_kkK{H#n52Or zG#2d)26OoT=#XC6W#B8{*0Bn$0~PWIYL#D)Gr3ne*a%kJ6N2YCOOr{jK>aO;B^DyS zQ|~msDjeEU;3!MY+)-V>MK3eiO5`ntyJj96Df@8j{buDO&Lnw)SX$TiOu$c&-6JN= zeD6Qyv?N$0^c<^jkk{{0GjQosCzKx(*|B`_7%b!V?3IA;o$>geX|A>A$iI?Q_TOXb zH8&3t9FeS*dnzLng}X&#vyHuSQu7QkNqZf%MHGq-$ba#g^%+Sb0pvWrQ0}8DnAIzj zauB-CUE5*?I;d0{d89fF!Rreo<&!X$gx2o#%)Z&E-xNFWLhbvAA^-l~Z|&33yr&zs zLII^*{HUerd)707?~IDr=mNMVNiQ?Sg$Uw7nu{O*!rP5P7Zdm38j$*8stF+8$86Pk zrIt%XBQXFiS6{mYK|H2F8y>N-BLkZ)f;V7!ZD*0ud}FR`f`rxze~6PYKmj>4Blg>B zf)RLl;24xn=tT@zxGu^d2Q2OFTKrLmHWFHDl_21o4*^eqkl+^KVtx-540tiGK=$2T zM*WKs%q_NmeK?v>#{LWSklzD&v+(G?d(^2?b7oFzrN-Kp7s_gvrM<|@?|RbHRn0eb zkT+8DW`gsGTCaw!buNItzB2u`hA|E4s#DV33ewYefnIBsT|F~lW%8Zx{PN9D zH6sgx?4OF0REIOCD1Tt9Bp7o+y3z?BiuXBpzsauqSQDO3eh}eT?KJH`KfYWdB^Jo? z=gj=W0GpoSSi=0O>$QXISTwQ`ncSGk3uIc=?E_X%mClNDSe}I1A20K|C+W#9Pg53} z(L&d7)DWalw_~->v;_8Ng!Uf^E$hFfQ+ppy0hGgRLDpSJfC%lH2i&TamuH-y2yhDz z2{54w_CK(~!+TBDc`bQ7e|oBoltO*UWLE!Oe2YvjNyZm7TvD-FyjU3DxmZ0t9tGh9pzjTTC+r{t zz(>0__VeJ)l;rev5AiTbbRWy8J9u@j*>i>9-_M)RsIKOF0 zB}(+J$L3qu7T!LucXi1b2MJ-+QG5BXd`F2ed}!dIGb+}m%F=Z`HowZWNOkd$8Xw5} z^`r||ow*A19ww_sOhAyk%S3?W>h5V3vCrsY@S;i8J)9msVGh%JJ9;?kJb*J6v=u7N zhANt07%7t0**|$x2TnTRP>`EFgHx)!2{(U!!wc_&nRGy;Elts!K|K9yHNsuPC;J9q zjTs>GV74xa%}Xw!j{Jb$@$HR@hJB)c{z}S#YS0&JICX4smI=Pr$6C{b;vkBT!9POZ z8t_XGfPj$h6&%Q##9%j1Q&7ZE(KUe;1eG}zt&o`{d`fzdLn1J2mrET$A{tFuSs0~+QG zI8k4^93W+a|MbVbrFr6yS};!f>y}eSk}}oPQHwxo@Sdp?xa0g8Z91)(NWkqL58yXZ>Zr#6xuf5(+kq9(nVS2*9`<1m&`iid@_Y8-@Q-( zZ(y^4gU`PjBu|l+*-wRvg5I(8WMs`p{Km~UR8Pm0zJ9+zTxAQmY$P=8((TW?d+qs0 z(_IiJnq8DHWix0(N@SUN#Vt6j*~mtEz5Osg`Iw1X$5G`sZ%%I%;oXxF03Pr=uIBj! zOJefIZQoaBsT@M&{(&EdXewR)pW?1O9?G_DSC(X}>_W1nkTtuaQi?{N#=b<1E&IMq zmO_?fuZ)mHG?o~$8?x_P3{A{f%FfuDknOt$PtWsy@B6&ppWpn&?>F~7=Y3u0bzbLn z9>;my8Fc3PmBW~7cy*0#R=y}=kmlXw-y)38W`cz zcJD$ui%L1F5r_`=&3sIAqd6M?UuQ&c0c+of{H$J1!9+>%^pAh?U4TNC{2xyCU)>vW z2dG$m36NGK_=^Y79eU5OwdN3h(2a_8Sv=XkSG$1w%x(m>~*f zEu~3dEQi@s67!8&@e-}OS&vOBc5gZMJMe>2X60galHZswp9CPkb@j4@5G=B_@3SS6 z1%oCs2T+XEKN{p3jXfAaj;phD0fA;RbxX<=ORCt^Ik%u9=3mfE=k*V+3Fy$JCGw~ zw~mEiskHVdk7U}S3{_nE0N(_|Kmkwc%R=2X7d~6mb)nif!!Vfx&sU9s*Qee`M2eq1 zpr3>w&kzB7W^?<)Zn8{(&5uc5kw}f-`bMkET5IW>l`h|lma@9AO=-G$PtfowX?Tav za?Wi}P%@6@;Lf;YG`}* zX;VC7j?>-gjzEqjQ!z2kA|2!&cSHWU>rS@hQZxSz5cXgFWv`$zQI!FqSSR=TNke=> z3SKMBWi&)&LIj)Yt|?ID?&UnEsf@R_DC0cO4YY_dnz5w8v8NaFc+%{n8@&!IDiRxE4ibU+Ic z5Kd_32u%PYBx%P<+9N3g9AfEZdB@N4fjv56r7t>QBX3c&&U72onozUSxK8ITCm={l z;kjTeWNxboItN zt&4f0-4bdb>wG2vId;HlFyoIT%1mleR$bLtBMVp#9onA46DR&S%l}E^sxBB0pqg70 zO@y!A&=)PMQmOO3BK@`|u1!qmmfJE;{o+9P*LxZjC?)L-6El!#o2Yw>3%IbUx4S4Bm9VOy zTPdtxjTJaPtq(!r??=_Mz|eTZalc!$FH}Kneoq`SUyx15pnD-a5P^(9b|#yj7reHq)e@ov{F z=W=w^YT}7(j4kqJvH!Y^nOU(_)Ti^D9xylhxP-1HvT|vvHY1^sua#^ZQ4kiBpY^X^ z5PaOi7zbwD^eDYxmQ1B!X}l%vd0R^xH6tUt#$Q)BdJx$h zhTbpX-53WKo##Cgb;ZD0#Jh6UHLT`vHFTCa2}&| z(fwO6g*jh>;YkCgI4!3io+lsV;i)B6o;@L*tU&Vh6zt|wX6hh?01%-$DT-#YWyYcK zjhsOjSf_{uoYj$d{f?XIy|;|N!~sHzVsV466Gb9V(>2ig2m$ruzg95-Qc)x{OpCUj zF}zlxg;s(Eq%oGzEkEvQh#z$5qStI@p8&{A2Q{!VrzmQ{vyFP zH}NYS!Yryk_q(`m7D5InGbCmiOSvl9q@ttThae;2z+A$w9&bI611 zqttCO?UTfKxV%8CHc!L!?;G+EzlP1cYHm6fv`PNJ<14x)ans#Mfkb?Kr$6s7#^&J5 zSzH;cZhWOk1}X*wpM`?i{;HT~m~ik&9h|Owb1d_Fnss-qqo#|cO57=MA6?EXh!AZD zRsYSbCQz!{S2r~AT?#$wB>H*f(;iF}`NGXpnp?$PxuS;SrV>rSY_bEVW53~-950jf z=44h=wc6MgDd7npWrs2KP200$6q={8&*9b)IBtGrM1#|V-LcVDyGG1HRRSxWk*+um zZt7yWZq9<=K=O4wWkhbl5mMk5!{sLz8Su23+7Db67Nup;MXzGl^z^d2*J8U&XV_~r z1+u!)&Sd8Atk{>%5@!jbasmdvkk3FiF%TdlC;bwsX~AhQvZ-v89d=Gk5uJ77+o#T> zZP8cb39X=_>E}Gxgd9~^CGi^ZRoebdE4s+9$4iSeA9l7AyX6K9Qo4>y3Su3Zqnr1svS3Q^e+SN8c3*b zE*a^KU@-37;3QMNosBI#M_96(&r>&}=(gXZ;>`Z?&SZfDNf$HJQX<^7Z_(%d+SCy} z+~jJNwLsd7N;<@vcW#OOu2_Z%cX(w4bEpfXa33S)2e7j_{2~Qg)AxQ33*^~h{s})U zeRI0F*HxV{W=D9-lD75O`P;=qwsOYC$PXo>tuw0|%a_V2eeCJId_8LPG)!?_g(DL% z_j1qK)3D;MGZ%kdE!bWXdVW#w`><@*+L%WO(?FMVm@WgHb%%GMm|Cx9OMKX%EJWZwU@95C(|9(T@7hM zPox&Vk{d8mt*xbc?>bmcRSXgnNdwVuoiceP`Eip49|#h0vp;I6nw0ry&XEs>7FAjf zn6j?}F|qdA4Y`~*{Yk?YGw=)`ujfaeFn>DJ$+QnE!CYt2e}Ka$N`5PKuN6gYPKY`$&Tg+I!z1_v&Cw){7}x~HoD02gO2j@l>I8Q zRW^2mX6((}Kmp7oN#4+XrAW~JH1vAbZO3O(Zf!@Z~(}gQzmRpEkE%0Fx`_e95_BAa3ObqysgF0 z5H{E)d|poJve%cVT_Jf@`2!>6!u8Q)eX;sZKl*y8T>cnx)phyqASo;|BS0XSFD#R@ z3@N2q*CA!aH+b!D+76LJ^9ctA0oe;a&=btmn}GZz-jsnqQ9`foVgtd@iCnsko|-Kw z6KWIk4*jf{t0nhzk^T2|86u$LI!Z?zvU9ovw826EFw-2KLh_H)C?LtfnUq z(=<>H@$<{)1$vzqNz3OqQ%He(*+)R$oTth#!Qm*H`2n8+=tiqX0kw{|Vp@;Qo~fbR zbLz)vu#NZSeL2u5MD#|A}RV zX#L~Z6C5qK&_}#jIpc-2cp5B1m-Y!E^kbyW2rBkouz^w%n*RwrW=5Mfl?0Z@cb|T- zcMl*idUi4z5mVVEDbyeI<%;Soa~XIV?L=xG0mWbv*+CrQNi_mC*ciA&@J~huVbduY zyQC<}_{m0I@n*bL2J7CH$w)eJzgIItu?vOu2L zgxJvImx+mm=X|-heq6);paK`2J%9x2u?Cg&lUOQo%SnnFQ5$99S zUfc8tR%6fi%~Gk%u}}+N;b7auA9b3iR)w7QS|N9SwI1&wSia5*%IVvlBFO?4=0+lA zW2YNUFac!>E?(zRrPbA5b4xjWS#t(w?gUW2BFzuZ)vI^r@{{UIUwjRdYM+rj1*y<* zpRFFOoBi2kOM+_-->pXVnMMV}M?TiOh(L@pESWW(<|N1d4Et=?rInL!WZ!$!PzcbR zspi$5B)v}yzR+c2LncB(wdfd$;2&>rF9%o`&@6gZ)Ve!v9Hk8X%%dtv3VJKiuUEJqrZlM?Qs zW*sD+)ZQvFx1aU6xed*vy}&!<-(5M3-`@hHuc_2cWSTf#%_6duSHD*Bdb3`*zX{Nu zOGHUR+l-Y5=DJ<H26-_*CX=& z?m{)er#GZCl)XkflE`{yHflJHg6q`8l?P~?R}~vnUEQ+E{o{Pb>)O!Vt>73CkzV|~ z)t1s&m*bR*NVR>QJA2NyA(1iQNQo=qq*SM{6~1G%>gU+cx+J$bUA%8OZ7yEv%0*&R zScFfH?AYVUMk8o`Ro!(cO9E`aG#W`c$bI8I)wT=6-(4HXgaiIMZ1JfdfOtddzwcVkhH&^{_#2CD%7usUhEoZr4q`2io@qB{+ z%^{8tS6w05kQ-KEwc_T#ofZ)bS7~vkIYDQG8@rDJg{arfi%PnJQcfL*;?9tQ3oRa*25yv(+&I*!q;LXO5w9CYXOUq7bxP6s z-{JP%Ba)8mzYq~jkdb^`t@^9WuT+Vc;#@Yq~v zZ4Vcax!IEzDI78_IvjH-@h%d2a1v<+xj9L_o1_7eVKRmo8S@(S(8hB9<5#dfzT!{Z zE$pH?-fIg2F$4j?uCZXg_FF>&nr|NiFRGRK5BFCDQq6DN7L`8^20Y3KQTbN;mA>9L zUVgH=;N3)Pbu||V{_X*?QNAoY{qly?}<$g_%UY+s@@T&!_ zXxmrYaHCat0zMg1btylfF@(xQ9WC8(|dQ?Tfc7;iI{cVuhik6_t80IVr`*t%_~Im4@#6(2At?^zl-=9xv#uG z9bM#n@=7b4$I6I~AsgBo!%3y4EtoYer8^LvS!Oe|{K=_|77>2c+t7$(>1UCak1DZW zv-|b0EPQ`(xwyhDSGT^ID2jd2?QnYg>r%lxlSr}Sju@Pev&-(-1tv6L6(P3n+Z%^oOR^`HQdZ(Q~ExLXPOo^6e+Nd221r;Yf8f zZ?5dQ@noj*Ck@~sCConZ7t=0%FT`#SVYu$zkWVy*WIasD|D1)HSXy|AnN|4p>(oXQ zL1wTwqVPgQ==jpgLR3%$8RdZCrtwEFvLE83gs#yCSZr=mjCDi@F_km#>S+^x$7;u_ zn|EvK`!qe2?ukT9 zH!P)xf*yg7&HD2{UdJ;4pPSo^wA>zOdr3S!=GNn%>(L6k3Y!qx2R1&qK2w-M}FF zry1F9aSagt@K}U|8NcAYxs>iHOwNSQ+SFum*bnv^KLa<^R{FKKR;QgNlf+|Tb+@1j zi~?tx|J#y?(!T;BP4s*=0X?1X(6Tq6*aH(9`FD++Ce?Pu* uCMMxplu>>d__WlV9uVdQpWT3+L#|9QWlatdsU#6!wHtS?f4Fu(@P7cv2xV;m literal 0 HcmV?d00001 diff --git a/example/ck_tile/01_unified_attention/rotary.hpp b/example/ck_tile/01_unified_attention/rotary.hpp new file mode 100644 index 00000000000..346f2a5e7ee --- /dev/null +++ b/example/ck_tile/01_unified_attention/rotary.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include + +// keep sync with RotaryEmbeddingEnum +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + +template +std::tuple, ck_tile::HostTensor> +generate_rotary_cos_sin(ck_tile::index_t seqlen, + ck_tile::index_t rotary_dim, + std::optional seed = std::nullopt) +{ + // return dummy tensors if we won't apply RoPE at all + if(rotary_dim <= 0) + { + ck_tile::HostTensor dummy({1, 1}); + return std::make_tuple(dummy, dummy); + } + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_real_distribution generator(0.0f, 1.0f); + + const ck_tile::index_t num_rows = seqlen * 2; + const ck_tile::index_t num_cols = rotary_dim / 2; + + using std::begin, std::end; + + ck_tile::HostTensor angle({num_rows, num_cols}); + std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; }); + + ck_tile::HostTensor cos({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) { + return ck_tile::type_convert(std::cos(origin_value)); + }); + + ck_tile::HostTensor sin({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) { + return ck_tile::type_convert(std::sin(origin_value)); + }); + + return std::make_tuple(cos, sin); +} + +template +std::tuple, ck_tile::HostTensor> +slice_rotary_cos_sin(const ck_tile::HostTensor& cos, + const ck_tile::HostTensor& sin, + ck_tile::index_t seqlen_offset, + ck_tile::index_t seqlen) +{ + assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2); + assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1)); + + assert(static_cast(seqlen_offset + seqlen) <= cos.get_length(0)); + + const ck_tile::index_t num_rows = seqlen; + const ck_tile::index_t num_cols = cos.get_length(1); + + ck_tile::HostTensor cos_pt({num_rows, num_cols}); + cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); }); + + ck_tile::HostTensor sin_pt({num_rows, num_cols}); + sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); }); + + return std::make_tuple(cos_pt, sin_pt); +} diff --git a/example/ck_tile/01_unified_attention/script/benchmark_bwd.sh b/example/ck_tile/01_unified_attention/script/benchmark_bwd.sh new file mode 100755 index 00000000000..cfd792906ce --- /dev/null +++ b/example/ck_tile/01_unified_attention/script/benchmark_bwd.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done diff --git a/example/ck_tile/01_unified_attention/script/benchmark_fwd.sh b/example/ck_tile/01_unified_attention/script/benchmark_fwd.sh new file mode 100755 index 00000000000..31ad8000392 --- /dev/null +++ b/example/ck_tile/01_unified_attention/script/benchmark_fwd.sh @@ -0,0 +1,53 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 64 128 256 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done + +#Padding Benchmarks: batch mode (baseline vs low/med/high pad) +prec="fp16" +base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no pad) +$EXE $base_batch_args + +# low pad (≈90–95% effective) +$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + +# Padding Benchmarks: group mode (baseline vs low/med/high physical pad) +seqlens_q="1024,768,512,256" +seqlens_k="1024,768,512,256" +base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no physical pad) +$EXE $base_group_args + +# low physical pad +$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + +# medium physical pad +$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + +# high physical pad +$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 diff --git a/example/ck_tile/01_unified_attention/script/benchmark_fwd_v3.sh b/example/ck_tile/01_unified_attention/script/benchmark_fwd_v3.sh new file mode 100755 index 00000000000..a3f7d68eb3a --- /dev/null +++ b/example/ck_tile/01_unified_attention/script/benchmark_fwd_v3.sh @@ -0,0 +1,42 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)" +VALID=0 + +for causal in 0 1 ; do +for prec in "fp16" "bf16" ; do +for hdim in 128 ; do +for perm in 0 ; do + +$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID + +$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID + +done +done +done +done + +# Padding benchmark comparisons for v3 (batch mode only) +# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ==== +prec="fp16" +base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID" + +# baseline (no pad) +$EXE $base_v3_args + +# low pad (≈90–95% effective) +$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 diff --git a/example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx90a.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx942.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx950.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx90a.txt b/example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx90a.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx942.txt b/example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx942.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx950.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/01_unified_attention/script/run_full_test.sh b/example/ck_tile/01_unified_attention/script/run_full_test.sh new file mode 100755 index 00000000000..5c2a5a4b3d0 --- /dev/null +++ b/example/ck_tile/01_unified_attention/script/run_full_test.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/ +# +# run the script as "./run_full_test.sh +# input arguments: +# environment tag : a string describing the specifics of your test environment +# branch name : name of the branch in git repo (git status | grep -e 'On branch') +# host name : $hostname +# gpu architecture: e.g., gfx90a, or gfx942, etc. + +set -euo pipefail + +#get the command line arguments: +export env_type=$1 +echo 'Environment type: ' $env_type +export branch=$2 +echo 'Branch name: ' $branch +export host_name=$3 +echo 'Host name: ' $host_name +export GPU_arch=$4 +echo 'GPU_arch: ' $GPU_arch + +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + #get GPU_arch and number of compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + +#run verification tests +time example/ck_tile/01_fmha/script/smoke_test_fwd.sh +time example/ck_tile/01_fmha/script/smoke_test_bwd.sh + +#run performance benchmarks +export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" +print_log_header $fmha_fwd_log $env_type $branch $host_name +time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log + +export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log" +print_log_header $fmha_bwd_log $env_type $branch $host_name +time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log + diff --git a/example/ck_tile/01_unified_attention/script/smoke_test_bwd.sh b/example/ck_tile/01_unified_attention/script/smoke_test_bwd.sh new file mode 100755 index 00000000000..cd51dde2d4e --- /dev/null +++ b/example/ck_tile/01_unified_attention/script/smoke_test_bwd.sh @@ -0,0 +1,90 @@ +#!/bin/bash +# TODO: run this script from CK root or build directory +set -euo pipefail + +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_bwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" +KNAME=1 +GPU_arch=${GPU_arch:-""} +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi + +export CK_WARMUP=0 +export CK_REPEAT=1 + +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"} + +COMMON_ARGS='-v=1' + +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + +test_h_s_mask() { + run_exe -b=1 -h=4 -h_k=2 -s=259 $@ + run_exe -b=2 -h=2 -s=516 -s_k=253 $@ + run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@ + run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@ + run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@ + run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@ +} + +set -x +# main tests +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 256 ; do +for mode in 0 1 ; do +for bias in "n" "a" ; do +for dbias in 0 ; do +for p_drop in 0.0 0.2 ; do +for deterministic in 0 ; do +test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +done +done +done +done +done +done +done +done + +# additional cases +for hdim in 40 48 72 96 ; do +test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +done +set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) diff --git a/example/ck_tile/01_unified_attention/script/smoke_test_fwd.sh b/example/ck_tile/01_unified_attention/script/smoke_test_fwd.sh new file mode 100755 index 00000000000..fca6b8d0cd3 --- /dev/null +++ b/example/ck_tile/01_unified_attention/script/smoke_test_fwd.sh @@ -0,0 +1,281 @@ +#!/bin/bash +# TODO: run this script from CK root or build directory +set -euo pipefail + +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_fwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" +KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi + +export CK_WARMUP=0 +export CK_REPEAT=1 + +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"} + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 + +TEST_SPLITKV=0 +TEST_APPENDKV=0 +# options: +# -s: run splitkv tests +# -a: run appendkv tests +while getopts ":sa" opt; do + case "${opt}" in + s) + TEST_SPLITKV=1 + ;; + a) + TEST_APPENDKV=1 + ;; + *) + ;; + esac +done + +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + +run_fp16_bf16_tests() { + local NUM_SPLITS="1" + local PAGE_BLOCK_SIZE="0" + local CACHE_BATCH_IDX="0" + + if [ $TEST_SPLITKV -eq 1 ] ; then + NUM_SPLITS="$NUM_SPLITS 2 3" + PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" + CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" + fi + + for prec in "fp16" "bf16" ; do + for mode in 1 0 ; do + for perm in 0 1 ; do + for hdim in 32 64 128 256 ; do + for lse in 0 1 ; do + for bias in "n" "e" "a" ; do + for p_drop in 0.0 0.2 ; do + for num_splits in $NUM_SPLITS ; do + for page_block_size in $PAGE_BLOCK_SIZE ; do + for cache_batch_idx in $CACHE_BATCH_IDX ; do + + # run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done ; done ; done +} + +run_fp8_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8bf16_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8fp32_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp16_appendkv_tests() { + for s in $(seq 63 1 65) ; do + for s_k in 65 129 ; do + for s_knew in 0 64 $s_k ; do + for hdim in 32 64 128 256 ; do + for ri in 0 1 ; do + for rdim in 0 16 32 $hdim ; do + for page_block_size in 0 128 ; do + for cache_batch_idx in 0 1 ; do + + run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done +} + +run_padding_smoke_tests() { + # Padding-only smoke tests for batch/group mode using COMMON_ARGS + local prec="fp16" + + # Batch mode: padding via effective lengths (exclude PAD) + # Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches + local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low pad (≈90–95% effective) + $EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + # medium pad (≈60–75% effective) + $EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + # high pad (≈30–40% effective) + $EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + + # Group mode: padding via physical stride along seqlen + local seqlens_q="1024,768,512,256" + local seqlens_k="1024,768,512,256" + local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low physical pad + $EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + # medium physical pad + $EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + # high physical pad + $EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 +} + +run_padding_basic_boundary_tests() { + # Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh) + local prec + local perm + + # Group mode: Q&K padded with per-batch different strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \ + -s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # slightly larger, uneven padding strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \ + -s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # only K padded; Q unpadded + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \ + -s=55 -s_k=256 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # use cu_seqlen overrides to skip tail PAD + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \ + -q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \ + -q_eff_lens=55,60 -kv_eff_lens=200,256 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # no padding (equal), mixed Q/KV, all len=1 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # highly variable logical lengths + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \ + -s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \ + -s=256,129 -s_k=256,129 -s_kpad=256 \ + -bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \ + -kname=$KNAME $COMMON_ARGS + done +} + +set -x + +run_fp16_bf16_tests +run_padding_smoke_tests +run_padding_basic_boundary_tests +run_fp8_tests +run_fp8bf16_tests +run_fp8fp32_tests + +if [ $TEST_APPENDKV -eq 1 ] ; then + run_fp16_appendkv_tests +fi + +set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) diff --git a/example/ck_tile/01_unified_attention/unified_attention.cpp b/example/ck_tile/01_unified_attention/unified_attention.cpp new file mode 100644 index 00000000000..8c2b22f0a29 --- /dev/null +++ b/example/ck_tile/01_unified_attention/unified_attention.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "unified_attention.hpp" +#include "unified_attention_impl.hpp" +#include "mask.hpp" + +namespace ck_tile { + +std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type) +{ + switch(data_type) + { + case unified_attention_args::data_type_enum::fp16: return stream << "fp16"; + case unified_attention_args::data_type_enum::bf16: return stream << "bf16"; + default: return stream << "unknown"; + } +} + +std::pair unified_attention(const unified_attention_args& args, const stream_config& config) +{ + if(args.data_type == unified_attention_args::data_type_enum::fp16) + { + if(args.mask_type == static_cast(mask_enum::no_mask)) + { + using kernel_traits = + unified_attention_kernel_traits; + + return unified_attention_kernel_dispatch(args, config); + } + else + { + using kernel_traits = + unified_attention_kernel_traits; + + return unified_attention_kernel_dispatch(args, config); + } + } + else if(args.data_type == unified_attention_args::data_type_enum::bf16) + { + if(args.mask_type == static_cast(mask_enum::no_mask)) + { + using kernel_traits = + unified_attention_kernel_traits; + + return unified_attention_kernel_dispatch(args, config); + } + else + { + using kernel_traits = + unified_attention_kernel_traits; + + return unified_attention_kernel_dispatch(args, config); + } + } + + return std::make_pair(false, -1.f); +} + +} // namespace ck_tile diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp new file mode 100644 index 00000000000..63a348a69c9 --- /dev/null +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/host/stream_config.hpp" + +namespace ck_tile { + +struct unified_attention_args +{ + enum class data_type_enum + { + fp16, + bf16 + }; + + data_type_enum data_type; + // bool is_varlen; + index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and + // window_size_right == 0). + + index_t num_blks; + index_t num_head_q; + index_t num_queries_per_kv; + + // TODO window + float scale_s; + float scale; + float scale_k; + float scale_v; + float scale_out; + + index_t total_num_q_blocks; + + const void* q_ptr; + index_t query_stride_0; + index_t query_stride_1; + + const void* k_ptr; // [num_blks, blk_size, num_kv_heads, head_size] + index_t stride_k_cache_0; + index_t stride_k_cache_1; + index_t stride_k_cache_2; + index_t stride_k_cache_3; + + const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size] + index_t stride_v_cache_0; + index_t stride_v_cache_1; + index_t stride_v_cache_2; + index_t stride_v_cache_3; + + void* o_ptr; + index_t output_stride_0; + index_t output_stride_1; + + const int32_t* block_tables_ptr; + const int32_t* seq_lens_ptr; // seq len in each batch + const int32_t* query_start_len_ptr; // [num_seqs+1] + + index_t num_seqs; // number of batches for q +}; + +std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type); + +// return value: +// first = whether the kernel was launched (true = launched, false = skipped) +// second = elapsed time (ms) of the kernel launch, valid only if first == true +std::pair unified_attention(const unified_attention_args& args, const stream_config& config); + +} // namespace ck_tile diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp new file mode 100644 index 00000000000..952ebfe0fa4 --- /dev/null +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" +#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp" + +#include "unified_attention.hpp" +#include "mask.hpp" + +#define INST_unified_attention_DISPATCH(kernel_traits) \ + template <> \ + std::pair unified_attention_kernel_dispatch( \ + const unified_attention_args& args, const stream_config& config) \ + { \ + return std::make_pair(true, \ + unified_attention_kernel_launch(args, config)); \ + } + +namespace ck_tile { + +template +struct unified_attention_problem_traits; + +template <> +struct unified_attention_problem_traits +{ + using qkvp_dtype = ck_tile::half_t; + using acc_dtype = float; + using o_dtype = ck_tile::half_t; + using lse_dtype = float; +}; + +template <> +struct unified_attention_problem_traits +{ + using qkvp_dtype = ck_tile::bf16_t; + using acc_dtype = float; + using o_dtype = ck_tile::bf16_t; + using lse_dtype = float; +}; + +template +struct unified_attention_kernel_traits +{ + static constexpr auto date_type = DataType; + static constexpr bool is_masking = IsMasking; + + // BLOCK_Q BLOCK_SIZE HEAD_SIZE N1 K1 + using unified_attention_block_tile = sequence<128, 128, 128>; + using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; + using unified_attention_block_warps = sequence<8, 1, 1>; + + using unified_attention_shape = TileUnifiedAttentionShape; + + using unified_attention_traits = TileUnifiedAttentionTraits; + + using funified_attention_mask = GenericAttentionMask; + + using funified_attention_pipeline_problem = + UnifiedAttentionPipelineProblem::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::lse_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::o_dtype, + funified_attention_shape, + funified_attention_mask, + funified_attention_traits>; + + using funified_attention_pipeline = BlockFunified_attentionFwdV3Pipeline; + + using epilogue = Default2DEpilogue< + Default2DEpilogueProblem::acc_dtype, + typename unified_attention_problem_traits::o_dtype, + true, // kPadM + true, // kPadM + true // UseRawStore + >>; + + using kernel = UnifiedAttentionKernel; +}; + +template +float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) +{ + + auto kargs = Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.o_ptr, + args.num_blks, + args.num_head_q, + args.num_queries_per_kv, + args.scale_s, + args.scale, + args.scale_k, + args.scale_v, + args.scale_out, + args.total_num_q_blocks, + args.query_stride_0, + args.query_stride_1, + args.stride_k_cache_0, + args.stride_k_cache_1, + args.stride_k_cache_2, + args.stride_k_cache_3, + args.stride_v_cache_0, + args.stride_v_cache_1, + args.stride_v_cache_2, + args.stride_v_cache_3, + args.output_stride_0, + args.output_stride_1, + args.block_tables_ptr, + args.seq_lens_ptr, + args.query_start_len_ptr, + args.num_seqs + ); + + dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, args.total_num_q_blocks); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; + + return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} + +// return value: +// first = whether the kernel was launched (true = launched, false = skipped) +// second = elapsed time (ms) of the kernel launch, valid only if first == true +template +std::pair unified_attention_kernel_dispatch(const unified_attention_args& args, + const stream_config& config); + +} // namespace ck_tile diff --git a/example/ck_tile/01_unified_attention/unified_attention_runner.hpp b/example/ck_tile/01_unified_attention/unified_attention_runner.hpp new file mode 100644 index 00000000000..0703af71e33 --- /dev/null +++ b/example/ck_tile/01_unified_attention/unified_attention_runner.hpp @@ -0,0 +1,1789 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/host.hpp" +#include "ck_tile/ref/naive_attention.hpp" +#include "fmha_fwd.hpp" +#include "utils.hpp" +#include "ck_tile/utility/json_dump.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()" +#endif + +enum class fwd_result +{ + success, + failure, + invalid_args, + no_instance, +}; + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + using TypeConfig = FmhaFwdTypeConfig; + using ODataType = typename TypeConfig::ODataType; + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + double rtol = 0; + double atol = 16 * (o_dtype_max > 240 ? 2 : 1); + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) +{ + // If we have enough to almost fill the SMs, then just use 1 split + if(batch_nhead_mblocks >= 0.8f * num_SMs) + { + return 1; + } + max_splits = std::min({max_splits, num_SMs}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) + { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) + { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +int override_num_splits_if_necessary( + int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +{ + (void)hdim_v; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return num_splits; + } + + hipDeviceProp_t props{}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return num_splits; + } + + // tile size should match the generate.py + const int kM0 = 64; + + const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); + + if(num_splits < 1 && p_drop == 0.0f) + { + return num_splits_heuristic( + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); + } + + return num_splits; +} + +template +fwd_result fmha_fwd_run(mode_enum mode, + ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t nhead_k, + std::vector seqlen_qs, + std::vector seqlen_ks, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t seqlen_knew, + std::vector seqlen_qpads, + std::vector seqlen_kpads, + std::vector q_eff_lens_per_batch, + std::vector kv_eff_lens_per_batch, + ck_tile::index_t rotary_dim, + bool i_perm, + bool o_perm, + float scale_s, + float logits_soft_cap, + bool is_v_rowmajor, + bool lse, + ck_tile::index_t page_block_size, + bool use_cache_batch_idx, + std::string bias_str, + float p_drop, + uint64_t drop_seed, + uint64_t drop_offset, + bool drop_prefs, + std::string mask_str, + bool squant, + bool is_rotary_interleaved, + ck_tile::index_t num_splits, + std::string init_method, + uint32_t seed, + int do_validation, + const ck_tile::stream_config& stream_config, + std::optional json = std::nullopt) +{ + const std::string data_type = []() { + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) + return "fp16"; + else if constexpr(std::is_same_v) + return "bf16"; + else if constexpr(std::is_same_v) + return "fp8"; + else if constexpr(std::is_same_v) + return "bf8"; + else if constexpr(std::is_same_v) + return "fp8bf16"; + else if constexpr(std::is_same_v) + return "fp8fp32"; + else + static_assert(false); + }(); + + if(nhead_k < 0) + nhead_k = nhead; + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return fwd_result::invalid_args; + } + + std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}()); + auto next_seed = [&random_engine]() { return static_cast(random_engine()); }; + + if(hdim_v < 0) + hdim_v = hdim_q; + +#if !CK_TILE_FMHA_FWD_APPENDKV_API + if(seqlen_knew != 0) + { + std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option" + << std::endl; + seqlen_knew = 0; + } +#endif + if(seqlen_knew < 0) + { + seqlen_knew = randint(1, seqlen_qs[0], random_engine); + } + + if constexpr(!(std::is_same_v || + std::is_same_v)) + { + if(0 < rotary_dim) + { + std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl; + return fwd_result::invalid_args; + } + } +#if !CK_TILE_FMHA_FWD_APPENDKV_API + else if(0 < rotary_dim) + { + std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option" + << std::endl; + rotary_dim = 0; + } +#endif + // to use fmha_fwd_appendkv(), make sure it's in batch mode + const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); + if(need_append_kvcache && mode == mode_enum::group) + { + std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl; + mode = mode_enum::batch; + } + if(!(rotary_dim <= hdim_q)) + { + std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; + return fwd_result::invalid_args; + } + else if(!(rotary_dim % 16 == 0)) + { + std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl; + return fwd_result::invalid_args; + } + +#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \ + CK_TILE_FMHA_FWD_PAGEDKV_API)) + if(0 < page_block_size) + { + std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" + << std::endl; + page_block_size = 0; + } +#endif + if(!(page_block_size % 128 == 0)) + { + std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" + << std::endl; + return fwd_result::invalid_args; + } + +#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API) + if(use_cache_batch_idx) + { + std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } +#else + if(use_cache_batch_idx) + { + if(0 < page_block_size) + { + std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + else if(mode == mode_enum::group) + { + std::cerr << "group mode will not use cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + } +#endif + const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); + + // Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv) + const bool has_group_padding = + (mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) || + (mode == mode_enum::group && (seqlen_kpads[0] >= 0)); + const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() || + !kv_eff_lens_per_batch.empty())); + const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); + const bool using_pagedkv = (0 < page_block_size); + const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; + if((using_appendkv || using_pagedkv || using_splitkv) && + (has_group_padding || has_batch_efflens)) + { + std::cerr << "Padding (physical or effective lengths) is not supported with " + "appendkv/splitkv/pagedkv pipelines" + << std::endl; + return fwd_result::invalid_args; + } + + std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = + generate_missing_seqlens(mode, + batch, + seqlen_qs, + seqlen_ks, + seqlen_kpads, + /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, + need_append_kvcache, + random_engine); + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + if(seqlen_kpads[wb] > 0 && seqlen_kpads[wb] < seqlen_ks[wb]) + { + std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl; + return fwd_result::invalid_args; + } + } + // compute kvcache seqlen_k (before appending knew/vnew) + auto cache_seqlen_ks = seqlen_ks; + std::transform(cache_seqlen_ks.begin(), + cache_seqlen_ks.end(), + cache_seqlen_ks.begin(), + [&](auto seqlen_k) { return seqlen_k - seqlen_knew; }); + +#if 0 + std::cout << "seqlen_qs: " << seqlen_qs << std::endl; + std::cout << "seqlen_ks: " << seqlen_ks << std::endl; + std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl; + std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl; +#endif + + if(scale_s == .0f) + scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + + bias_info bias = bias_info::decode(bias_str); + + mask_info mask = + mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore + + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return fwd_result::invalid_args; + } + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + +#if !CK_TILE_FMHA_FWD_SPLITKV_API + if(num_splits != 1) + { + std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl; + num_splits = 1; + } +#endif + + const auto seqstart_q_host = to_seqstarts(seqlen_qs); + const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + + // Optional padded Q seqstarts (group-mode only) + std::vector seqstart_q_with_padding_host; + if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) + { + if(seqlen_qpads.size() < static_cast(batch)) + { + seqlen_qpads.resize(batch, seqlen_qpads.back()); + } + if(seqlen_qpads.size() == static_cast(batch)) + { + seqstart_q_with_padding_host = to_seqstarts( + ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); + } + } + + // Optional batch-mode cumulative seqlen overrides + std::vector cuq_cum, cukv_cum; + if(mode == mode_enum::batch) + { + auto calculate_cumulative = [&](std::vector& per_batch_vec, + std::vector& cum_vec) { + if(!per_batch_vec.empty() && per_batch_vec[0] != -1) + { + if(per_batch_vec.size() < static_cast(batch)) + { + per_batch_vec.resize(batch, per_batch_vec.back()); + } + cum_vec.resize(batch + 1); + cum_vec[0] = 0; + for(int i = 0; i < batch; ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + } + }; + + calculate_cumulative(q_eff_lens_per_batch, cuq_cum); + calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); + } + + using TypeConfig = FmhaFwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = std::numeric_limits::min(); + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + + flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + + static_cast(2) * mask.get_unmaskarea() * hdim_v); + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(ODataType) * real_seqlen_q * hdim_v); + num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k); + } + } + + const ck_tile::index_t max_num_page_blocks = + (0 < page_block_size + ? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size)) + : 0); + + // legalize num_splits according to other options + if(num_splits < 1) + { + num_splits = override_num_splits_if_necessary( + batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); + } + if(128 < num_splits) + { + std::cerr << "num_splits greater than 128 is not supported" << std::endl; + return fwd_result::invalid_args; + } +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API + if(0 < p_drop && (1 < num_splits || use_kvcache)) + { + std::cerr << "dropout is not supported by split-kv kernels. ignoring the 'p_drop' option" + << std::endl; + p_drop = 0.0f; + } +#endif + + static const auto get_lengths = [](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen + const ck_tile::index_t shape_seqlen_q_lse = + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); + // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch + ? seqlen_qs[0] + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() + : seqstart_q_with_padding_host.back())); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_ks[0] + : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() + : seqstart_k_with_padding_host.back())); + + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + 0 < page_block_size + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) + : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode + ck_tile::HostTensor knew_host( + 0 < seqlen_knew + ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor v_host( + 0 < page_block_size + ? (is_v_rowmajor + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v) + : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) + : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); + ck_tile::HostTensor vnew_host( + 0 < seqlen_knew + ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) + : get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew)) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor bias_host( + bias.type == bias_enum::elementwise_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + + auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin( + std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, next_seed()); + + ck_tile::HostTensor lse_acc_host( + 1 < num_splits || use_kvcache + ? std::array{shape_batch, nhead, num_splits, shape_seqlen_q} + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor o_acc_host( + 1 < num_splits || use_kvcache ? std::array{shape_batch, + nhead, + num_splits, + shape_seqlen_q, + hdim_v} + : std::array{1, 1, 1, 1, 1}); + + // batch mode of lse data layout is [batch, nhead, seqlen_q] + // group mode of lse data layout is [nhead, total_seqlen_q] + ck_tile::HostTensor lse_host( + lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} + : std::array{1, 1, 1} /* dummy shape for simplifying code */); + + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + + ck_tile::HostTensor block_table_host( + 0 < page_block_size ? std::array{batch, max_num_page_blocks / batch} + : std::array{1, 1}); + + ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx + ? std::array{batch} + : std::array{1}); + float max_o = 5.0; + if(init_method == "ui" || init_method == "0") + { + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(knew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(vnew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}( + bias_host); + } + else if(init_method == "ni") + { + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(knew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(vnew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}( + bias_host); + } + else if(init_method == "uf" || init_method == "1") + { + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(knew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(vnew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(bias_host); + } + else if(init_method == "nf") + { + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(q_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(knew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(vnew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(bias_host); + } + else if(init_method == "tf" || init_method == "2") + { + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(knew_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(vnew_host); + ck_tile::FillTrigValue{}(bias_host); + } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == static_cast(nhead)); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } + iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); + iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); + + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() + ? 0 + : seqstart_q_with_padding_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k_padded_buf( + seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 + : cuq_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem cu_seqlen_kv_buf( + cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || + 0 <= seqlen_kpads[0] + ? seqlen_ks.size() * sizeof(int32_t) + : 0); + ck_tile::DeviceMem cache_seqlen_k_buf( + need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); + + float scale_p = 1.f; + float scale_o = 1.f; + if(squant) + { + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float p_dtype_max = v_dtype_max; // assume p and v is the same type + // Q tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = q_dtype_max / max_value; + + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // K tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + float scale = k_dtype_max / max_value; + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // V tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = k_dtype_max / max_value; + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + + scale_o = (1.0 / p_dtype_max) / scale; + } + + scale_p = p_dtype_max; + + if constexpr(std::is_same_v) + { + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o = scale_o * o_dtype_max / max_o; + } + } + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + knew_buf.ToDevice(knew_host.data()); + vnew_buf.ToDevice(vnew_host.data()); + bias_buf.ToDevice(bias_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + // Keep logical starts in seqstart_k; pass padded K via separate pointer + seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_q_padded_buf.ToDevice( + seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); + seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr + : seqstart_k_with_padding_host.data()); + cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); + cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); + seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] + ? seqlen_ks.data() + : nullptr); + cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); + rotary_cos_buf.ToDevice(rotary_cos_host.data()); + rotary_sin_buf.ToDevice(rotary_sin_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); + block_table_buf.ToDevice(block_table_host.data()); + cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if(permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if(iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + + std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm) + << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] + << "/" << seqlen_ks[0] + << (seqlen_kpads[0] < 0 ? "" + : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) + << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias + << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant + << ", mask:" << mask << ", v:" << (is_v_rowmajor ? "r" : "c"); +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(0 < rotary_dim) + { + std::cout << ", rotary_dim:" << rotary_dim << "(" + << (is_rotary_interleaved ? "inter" : "half") << ")"; + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API + if(1 < num_splits) + { + std::cout << ", num_splits:" << num_splits; + } + if(0 < page_block_size) + { + std::cout << ", page_block_size:" << page_block_size; + } + if(use_cache_batch_idx) + { + std::cout << ", cache_batch_idx:" << use_cache_batch_idx; + } +#endif + // Padding / effective length diagnostic logging + auto print_vec = [&](const char* label, const std::vector& v) { + if(v.empty()) + return; + std::cout << ", " << label << ":["; + for(std::size_t i = 0; i < v.size(); ++i) + { + if(i) + std::cout << ","; + std::cout << v[i]; + } + std::cout << "]"; + }; + + if(has_group_padding) + { + bool has_qpad = !seqstart_q_with_padding_host.empty(); + bool has_kpad = (seqlen_kpads[0] >= 0); + if(has_qpad) + { + print_vec("q_logical", seqlen_qs); + print_vec("q_padded", seqlen_qpads); + } + if(has_kpad) + { + print_vec("k_logical", seqlen_ks); + print_vec("k_padded", seqlen_kpads); + } + } + else if(has_batch_efflens) + { + // derive effective lengths from cumulative arrays if present + if(!cuq_cum.empty()) + { + std::vector eff_q(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_q[b_i] = static_cast(cuq_cum[b_i + 1] - cuq_cum[b_i]); + print_vec("q_eff", eff_q); + } + if(!cukv_cum.empty()) + { + std::vector eff_kv(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_kv[b_i] = static_cast(cukv_cum[b_i + 1] - cukv_cum[b_i]); + print_vec("kv_eff", eff_kv); + } + } + + std::cout << std::flush; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + + if constexpr(std::is_same_v>) + { + traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved + : rope_enum::half_rotated) + : rope_enum::none); + } + else // fmha_fwd_traits or fmha_splitkv_traits + { + traits.is_group_mode = (mode == mode_enum::group); + traits.has_logits_soft_cap = 0.f < logits_soft_cap; + traits.mask_type = mask.type; + traits.bias_type = bias.type; + traits.has_lse = lse; + traits.do_fp8_static_quant = squant; + + if constexpr(std::is_same_v>) + { + traits.has_dropout = (p_drop > 0.0f); + } + else if constexpr(std::is_same_v>) + { + traits.use_pagedkv = (0 < page_block_size); + } + } + }; + + const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) { + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size) + : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? seqlen_knew : nhead_k * seqlen_knew; + }(); + const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o_acc = (hdim_v); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = + (0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q) + : (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); + const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v) + : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + else + return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size) + : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); + }(); + const ck_tile::index_t nhead_stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? seqlen_knew * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_knew : seqlen_knew; + }(); + const ck_tile::index_t nhead_stride_bias = + (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = + (0 < page_block_size ? (nhead_k * page_block_size * hdim_q) + : (nhead_k * shape_seqlen_k * hdim_q)); + const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); + const ck_tile::index_t batch_stride_v = + (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) + : (nhead_k * hdim_v * shape_seqlen_k)); + const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + // setup split_stride_* arguments (only used in split-kv kernel) + const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); + const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; // unused in group mode + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + if constexpr(std::is_same_v>) + { + args.knew_ptr = knew_buf.GetDeviceBuffer(); + args.vnew_ptr = vnew_buf.GetDeviceBuffer(); + args.seqlen_knew = seqlen_knew; + + args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer(); + + args.rotary_cos_ptr = (0 < rotary_dim ? rotary_cos_buf.GetDeviceBuffer() : nullptr); + args.rotary_sin_ptr = (0 < rotary_dim ? rotary_sin_buf.GetDeviceBuffer() : nullptr); + args.rotary_dim = rotary_dim; + args.has_mask = (mask.type != mask_enum::no_mask); + + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.stride_knew = stride_knew; + args.stride_vnew = stride_vnew; + args.nhead_stride_knew = nhead_stride_knew; + args.nhead_stride_vnew = nhead_stride_vnew; + args.batch_stride_knew = batch_stride_knew; + args.batch_stride_vnew = batch_stride_vnew; + } + else // fmha_fwd_args or fmha_fwd_splitkv_args + { + args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(); + args.lse_ptr = lse_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqstart_q_ptr = + (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = + (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); + + args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + args.scale_p = scale_p; + args.scale_o = scale_o; + + args.logits_soft_cap = logits_soft_cap; + + args.stride_bias = + (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); + args.stride_o = stride_o; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + if constexpr(std::is_same_v>) + { + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); + + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + + args.p_drop = p_drop; + args.s_randval = s_randval; + if(drop_prefs) + { + args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); + } + + // Group-mode: optional physical padded starts for Q/K + if(mode == mode_enum::group) + { + args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() + ? nullptr + : seqstart_q_padded_buf.GetDeviceBuffer()); + args.seqstart_padded_k_ptr = + (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); + } + + // Batch-mode: optional cumulative effective seqlen overrides + if(mode == mode_enum::batch) + { + args.cu_seqlen_q_ptr = cuq_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_q_buf.GetDeviceBuffer()); + args.cu_seqlen_kv_ptr = cukv_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_kv_buf.GetDeviceBuffer()); + } + } + else if constexpr(std::is_same_v>) + { + args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); + args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); + + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.num_splits = num_splits; + + args.stride_o_acc = stride_o_acc; + args.nhead_stride_lse_acc = nhead_stride_lse_acc; + args.nhead_stride_o_acc = nhead_stride_o_acc; + args.batch_stride_lse_acc = batch_stride_lse_acc; + args.batch_stride_o_acc = batch_stride_o_acc; + args.split_stride_lse_acc = split_stride_lse_acc; + args.split_stride_o_acc = split_stride_o_acc; + } + else if constexpr(std::is_same_v>) + { + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + } + } + }; + + auto run_appendkv = [&](const ck_tile::stream_config& sc) { +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(need_append_kvcache) + { + fmha_fwd_appendkv_traits fwd_appendkv_traits; + init_traits(fwd_appendkv_traits); + + fmha_fwd_appendkv_args fwd_appendkv_args; + init_args(fwd_appendkv_args); + + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc); + } +#endif + return 0.0f; + }; + const float appendkv_ave_time = run_appendkv(stream_config); + if(appendkv_ave_time < 0.0f) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return fwd_result::no_instance; + } + + auto run_fwd = [&](const ck_tile::stream_config& sc) { +#if CK_TILE_FMHA_FWD_PAGEDKV_API + if(1 == num_splits && use_kvcache) + { + fmha_fwd_pagedkv_traits fmha_pagedkv_traits; + init_traits(fmha_pagedkv_traits); + + fmha_fwd_pagedkv_args fmha_pagedkv_args; + init_args(fmha_pagedkv_args); + + const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc); +#if CK_TILE_FMHA_FWD_SPLITKV_API + // If there is no instance for these args, fallback to fmha_fwd_splitkv + if(ave_time >= 0.0f) + return ave_time; +#else + return ave_time; +#endif + } +#endif // CK_TILE_FMHA_FWD_PAGEDKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(1 < num_splits || use_kvcache) + { + fmha_fwd_splitkv_traits fmha_splitkv_traits; + init_traits(fmha_splitkv_traits); + + fmha_fwd_splitkv_args fmha_splitkv_args; + init_args(fmha_splitkv_args); + + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc); + } +#endif // CK_TILE_FMHA_FWD_SPLITKV_API + fmha_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_fwd_args fmha_args; + init_args(fmha_args); + + return fmha_fwd(fmha_traits, fmha_args, sc); + }; + const float fwd_ave_time = run_fwd(stream_config); + if(fwd_ave_time < 0.0f) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return fwd_result::no_instance; + } + + const float ave_time = appendkv_ave_time + fwd_ave_time; + const float tflops = static_cast(flop) / 1.E9 / ave_time; + const float gb_per_sec = num_byte / 1.E6 / ave_time; + if(stream_config.time_kernel_) + { + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) + << gb_per_sec << " GB/s" << std::flush; + } + + bool pass = true; + if(do_validation == 0) + { + std::cout << std::flush << std::endl; + } + else if(do_validation == 2) + { + // NOTE: use gpu to do validation + ck_tile::naive_attention_fwd_traits naive_t; + naive_t.q_type = data_type; + naive_t.k_type = data_type; + naive_t.v_type = data_type; + naive_t.o_type = data_type; + naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; + naive_t.variation = 0; // TODO? + naive_t.quant_algo = 0; + + ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes()); + + ck_tile::naive_attention_fwd_args naive_a; + naive_a.q_ptr = q_buf.GetDeviceBuffer(); + naive_a.k_ptr = k_buf.GetDeviceBuffer(); + naive_a.v_ptr = v_buf.GetDeviceBuffer(); + naive_a.o_ptr = o_naive_buf.GetDeviceBuffer(); + naive_a.scale_s = scale_s; + naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer + naive_a.page_table_ptr = + nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn) + naive_a.hdim = hdim_q; + naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different + naive_a.batch_q = batch; + naive_a.batch_kv = batch; + naive_a.batch_ratio_kv = 1; // batch_q / batch_kv + naive_a.seqlen_q = seqlen_qs[0]; + naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field + naive_a.nhead_q = nhead; + naive_a.nhead_kv = nhead_k; + naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv + naive_a.page_size = 0; // if paged, the seqlen-kv for each block + + ck_tile::stream_config naive_s{}; + + naive_attention_fwd(naive_t, naive_a, naive_s); + + auto o_naive_ref = o_naive_buf.ToHost(); + o_buf.FromDevice(o_host.data()); // TODO: ugly + + auto [rtol_, atol_] = get_elimit(init_method); + pass = ck_tile::check_err( + o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_); + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + else + { +#if CK_TILE_FMHA_FWD_APPENDKV_API + // When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times + // when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. + if(0 < rotary_dim && stream_config.time_kernel_) + { + const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0}; + q_buf.ToDevice(q_host.data()); + run_appendkv(stream_config2); + run_fwd(stream_config2); + } +#endif + o_buf.FromDevice(o_host.data()); + lse_buf.FromDevice(lse_host.data()); + randval_buf.FromDevice(randval_host.data()); + + constexpr bool supports_squant = std::is_same_v || + std::is_same_v || + std::is_same_v; + + auto p_compute_element_func = [&]() { + if constexpr(supports_squant) + return ck_tile::scales{scale_p}; + else + return ck_tile::identity{}; + }(); + + auto oacc_element_func = [&]() { + if constexpr(std::is_same_v && supports_squant) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else if constexpr(supports_squant) + return ck_tile::scales{scale_o}; + else + return ck_tile::identity{}; + }(); + + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + if(mode == mode_enum::batch) + { + if(!cuq_cum.empty()) + { + real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; + } + if(!cukv_cum.empty()) + { + real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; + } + } + + // adjust matrix index according to the mode + const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t cache_b_idx = + (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); + const ck_tile::index_t query_offset = + (mode == mode_enum::batch + ? 0 + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] + : seqstart_q_with_padding_host[wb])); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch + ? 0 + : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] + : seqstart_k_with_padding_host[wb])); + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); }); + // clang-format on + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // optionally apply RoPE to the q_host_ref + if(0 < rotary_dim) + { + decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); + + auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin( + rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); + + ck_tile::reference_batched_rotary_position_embedding( + q_host_ref, + rotary_cos_slice, + rotary_sin_slice, + is_rotary_interleaved, + q_host_ref_ro, + /*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask); + + q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API + if(0 < page_block_size) + { + // clang-format off + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); }); + // clang-format on + } + else +#endif + { + // clang-format off + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); + // clang-format on + } + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // copy Knew to the end of K + if(0 < seqlen_knew) + { + ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); + // clang-format off + if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); }); + else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); }); + // clang-format on + + // optionally apply RoPE to the knew_host_ref + auto* real_knew_host_ref = &knew_host_ref; + std::optional knew_host_ref_ro; + if(0 < rotary_dim) + { + knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); + + auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin( + rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); + + ck_tile::reference_batched_rotary_position_embedding(knew_host_ref, + rotary_cos_slice, + rotary_sin_slice, + is_rotary_interleaved, + knew_host_ref_ro.value()); + + real_knew_host_ref = &knew_host_ref_ro.value(); + } + + (*real_knew_host_ref).ForEach([&](auto& self, auto i) { + k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i); + }); + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API + if(0 < page_block_size) + { + if(is_v_rowmajor) + { + // clang-format off + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); }); + // clang-format on + } + else + { + // clang-format off + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); }); + // clang-format on + } + } + else +#endif + { + if(is_v_rowmajor) + { + // clang-format off + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); + // clang-format on + } + else + { + // clang-format off + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); + // clang-format on + } + } + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // copy Vnew to the end of V + if(0 < seqlen_knew) + { + ck_tile::HostTensor vnew_host_ref({nhead, hdim_v, seqlen_knew}); + if(is_v_rowmajor) + { + // clang-format off + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); }); + // clang-format on + } + else + { + // clang-format off + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); + // clang-format on + } + + vnew_host_ref.ForEach([&](auto& self, auto i) { + v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i); + }); + } +#endif + + // reference + ck_tile:: + reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s)); + + if(0.f < logits_soft_cap) + { + ck_tile::reference_unary_elementwise( + s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) { + return ck_tile::type_convert( + logits_soft_cap * + std::tanhf(ck_tile::type_convert(logits / logits_soft_cap))); + }); + } + + if(bias.type == bias_enum::elementwise_bias) + { + // elementwise bias + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); + else bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL + ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + SaccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + const ck_tile::HostTensor masked_s_host_ref = s_host_ref; + if(lse) + { + ck_tile:: + reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile:: + reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func); + } + + if(p_drop > 0) + { + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::reference_batched_dropout_randval( + randval_host_ref, wb, drop_seed, drop_offset); + ck_tile::reference_batched_dropout( + p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + + ck_tile::HostTensor randval_host_result( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_result.ForEach([&](auto& self, const auto& idx) { + self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); + }); + masked_s_host_ref.ForEach([&](const auto& self, const auto& idx) { + // Ignore all masked values in validation check + if(std::isinf(self(idx))) + { + randval_host_ref(idx) = 0; + randval_host_result(idx) = 0; + } + }); + bool cur_pass = ck_tile::check_err(randval_host_result, + randval_host_ref, + "DROPOUT RANDVAL Error: Incorrect results!"); + pass &= cur_pass; + if(!cur_pass) + { + break; + } + } + + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + + ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // clang-format off + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool cur_pass = ck_tile::check_err(o_host_result, + o_host_ref, + std::string("OUT Error: Incorrect results!"), + rtol, + atol); + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "OUT mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + + if(lse) + { + ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + const ck_tile::index_t query_offset_lse = + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); + }); + + cur_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); + + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + if(json) + { + dump_fmha_fwd_json_results(*json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + squant, + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass ? fwd_result::success : fwd_result::failure; +} diff --git a/example/ck_tile/01_unified_attention/utils.hpp b/example/ck_tile/01_unified_attention/utils.hpp new file mode 100644 index 00000000000..7f44d871804 --- /dev/null +++ b/example/ck_tile/01_unified_attention/utils.hpp @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/container/span.hpp" + +enum class mode_enum +{ + batch = 0, + group +}; + +std::ostream& operator<<(std::ostream& stream, mode_enum mode) +{ + return stream << (mode == mode_enum::batch ? "batch" : "group"); +} + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +std::vector to_seqstarts(ck_tile::span seqlens) +{ + std::vector seqstarts = {0}; + for(int32_t seqlen : seqlens) + { + seqstarts.push_back(seqstarts.back() + seqlen); + } + assert(seqstarts.size() == seqlens.size() + 1); + return seqstarts; +} + +template +std::vector generate_seqlens(mode_enum mode, + unsigned count, + int32_t seqlen_avg, + int32_t seqlen_min, // if not negative, clamp min + int32_t seqlen_max, // if not negative, clamp max + RandomEngine& random_engine) +{ + assert(0 < count); + + seqlen_min = (0 < seqlen_min ? seqlen_min : 1); + seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits::max()); + assert(seqlen_min <= seqlen_max); + + std::vector seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max)); + + if(mode == mode_enum::group && 1 < count) + { + using size_type = std::vector::size_type; + + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); + + for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + // make sure each elements of seqlens is in range [seqlen_min, seqlen_max] + if(seqlens[to_decrease] == seqlen_min) + { + continue; + } + + const size_type to_increase = (to_decrease + next_step()) % count; + + if(seqlens[to_increase] >= seqlen_max) + { + continue; + } + + --seqlens[to_decrease]; + ++seqlens[to_increase]; + } + } + + return seqlens; +} + +// return random integer generated uniformly in range [low, high] +template +auto randint(Int low, + Int high, + RandomEngine& random_engine) -> std::enable_if_t, Int> +{ + std::uniform_int_distribution dist(low, high); + return dist(random_engine); +} + +// return random integers generated uniformly in range [low, high] +template +auto randints(ForwardIterator first, + ForwardIterator last, + Int low, + Int high, + RandomEngine& random_engine) -> std::enable_if_t> +{ + std::uniform_int_distribution dist(low, high); + + std::generate(first, last, [&] { return dist(random_engine); }); +} + +/* + * generate missing values in *_val randomly when the number of values is smaller than batch + * example (assume batch=3) + * q_val=1,2,3 k_val=4,5,6 -> OK + * q_val=1,2,3 -> OK, k same as q + * q_val=1,2 -> OK, q will rand remaining 1 element, k same as q + * q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element + * q_val=1,2,3,4 -> OK, but ignore exceed one + * + * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q + * q_val=1,2 k_val=4 -> not OK, k must have same splits with q + */ +template +std::tuple, + std::vector, + std::vector> +generate_missing_seqlens(mode_enum mode, + ck_tile::index_t batch, + const std::vector& q_val, + const std::vector& k_val, + const std::vector& k_pad_val, + ck_tile::index_t seqlen_k_min, + bool need_append_kvcache, + RandomEngine& random_engine) +{ + if(mode == mode_enum::batch) + { + ck_tile::index_t q = q_val[0]; + ck_tile::index_t k = k_val[0]; + + auto s_q = std::vector(batch, q); + auto s_k = [&] { + const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); + std::vector seqlen_ks(batch, seqlen_k_max); + + if(1 < batch && need_append_kvcache) + { + // to keep the original s_k value, we always use seqlen_k_max in first batch + randints(std::next(seqlen_ks.begin()), + seqlen_ks.end(), + seqlen_k_min, + seqlen_k_max, + random_engine); + return seqlen_ks; + } + + return seqlen_ks; + }(); + auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + + // s_k should be greater than or equal to seqlen_k_min if provided + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + + return std::make_tuple(s_q, s_k, s_kpad); + } + else + { + std::vector s_q; + std::vector s_k; + std::vector s_kpad; + ck_tile::index_t idx = 0; + for(; idx < std::min(static_cast(q_val.size()), batch); ++idx) + { + ck_tile::index_t q = q_val[idx]; + ck_tile::index_t k = + k_val[std::min(idx, static_cast(k_val.size()) - 1)]; + ck_tile::index_t kp = + k_pad_val.empty() + ? -1 + : k_pad_val[std::min(idx, static_cast(k_pad_val.size()) - 1)]; + + s_q.push_back(q); + s_k.push_back(k < 0 ? q : k); + s_kpad.push_back(kp); + + // s_k should be greater than or equal to seqlen_k_min + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + } + if(idx < batch) + { + auto rem_q = + generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine); + auto rem_k = generate_seqlens( + mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine); + + s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); + s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); + s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + } + return std::make_tuple(s_q, s_k, s_kpad); + } +} + +template +std::enable_if_t> iota_shuffle(RandomAccessIterator first, + RandomAccessIterator last, + Int value, + RandomEngine& random_engine) +{ + std::iota(first, last, value); + std::shuffle(first, last, random_engine); +} diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index fc067a40cef..729dc0f3a9a 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -15,38 +15,35 @@ namespace ck_tile { -template -struct FmhaFwdV3Kernel +template +struct UnifiedAttentionKernel { - using FmhaPipeline = ck_tile::remove_cvref_t; + using UnifiedAttentionPipeline = ck_tile::remove_cvref_t; using EpiloguePipeline = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; - static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static constexpr ck_tile::index_t kBlockSize = UnifiedAttentionPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); - using QDataType = ck_tile::remove_cvref_t; - using KDataType = ck_tile::remove_cvref_t; - using VDataType = ck_tile::remove_cvref_t; - using ODataType = ck_tile::remove_cvref_t; - using SaccDataType = ck_tile::remove_cvref_t; - using FmhaMask = ck_tile::remove_cvref_t; + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ; + static constexpr bool kPadHeadDim = UnifiedAttentionPipeline::kPadHeadDim; // TODO add yjese - static constexpr index_t HEAD_SIZE = FmhaPipeline::HEAD_SIZE; - static constexpr index_t HEAD_SIZE_PADDED = FmhaPipeline::HEAD_SIZE_PADDED; + static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE; + static constexpr index_t HEAD_SIZE_PADDED = UnifiedAttentionPipeline::HEAD_SIZE_PADDED; // BLOCK_Q = BLOCK_M // num_queries_per_kv // BLOCK_Q is the block size for q seqlen - static constexpr index_t BLOCK_Q = FmhaPipeline::BLOCK_Q; - // static constexpr index_t BLOCK_M = FmhaPipeline::BLOCK_M; + static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; + // static constexpr index_t BLOCK_M = UnifiedAttentionPipeline::BLOCK_M; // BLOCK size for K seqlen - static constexpr index_t BLOCK_SIZE = FmhaPipeline::BLOCK_SIZE; + static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE; // kargs use aggregate initializer, so no constructor will provided @@ -253,7 +250,7 @@ struct FmhaFwdV3Kernel ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks; // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, - // FmhaPipeline::kN1); + // UnifiedAttentionPipeline::kN1); const index_t i_tile_m = pid % total_num_q_blocks; // Query block index const index_t i_tile_n = pid / total_num_q_blocks; // Head index @@ -261,9 +258,11 @@ struct FmhaFwdV3Kernel return ck_tile::make_tuple(i_tile_m, i_tile_n); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -350,7 +349,7 @@ struct FmhaFwdV3Kernel q_ptr, make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), - number{}, + number{}, number<1>{}); const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED @@ -388,7 +387,7 @@ struct FmhaFwdV3Kernel k_ptr, make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_3), - number{}, + number{}, number<1>{}); const auto k_dram_pad = pad_tensor_view( @@ -421,7 +420,7 @@ struct FmhaFwdV3Kernel v_ptr, make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_3), - number{}, + number{}, number<1>{}); const auto v_dram_pad = pad_tensor_view( @@ -460,7 +459,7 @@ struct FmhaFwdV3Kernel }(); auto o_acc_tile = [&]() { - return FmhaPipeline{}(q_dram_window, + return UnifiedAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, block_tables_ptr, @@ -477,7 +476,7 @@ struct FmhaFwdV3Kernel o_ptr, make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1), - number{}, + number{}, number<1>{}); const auto o_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp new file mode 100644 index 00000000000..d0704626e97 --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length() +{ + if constexpr(Headdim == 48) + return 48; + else if constexpr(Headdim == 96) + return 128; + else if constexpr(Headdim == 160) + return 256; + else if constexpr(Headdim == 192) + return 192; + else if constexpr(is_power_of_two_integer(Headdim)) + return Headdim; + else + static_assert(Headdim == 0, + "only Headdim of 48, 96, 160, 192 and power-of-two is supported"); +}; + +template +struct TileUnifiedAttentionShape +{ + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + + static constexpr index_t NumGemm0Warps = + reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); + static constexpr index_t NumGemm1Warps = + reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}); + static_assert(NumGemm1Warps % NumGemm0Warps == 0); + + static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); + + static constexpr index_t BLOCK_Q = BlockTile::at(number<0>{}); // tile size along q seqlen + // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head) + static constexpr index_t BLOCK_SIZE = BlockTile::at(number<1>{}); // BLOCK size for K seqlen + static constexpr index_t HEAD_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen + + // static constexpr index_t kQKHeaddim = + // BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at + // // once (or repeately load Q as a whole tile) + // static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); + + static constexpr index_t HEAD_SIZE_PADDED = ceil_to_qualified_tile_length(); + + // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen + static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; + using VLayout = std::conditional_t; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp new file mode 100644 index 00000000000..682e93fd56f --- /dev/null +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" + +namespace ck_tile { + + +template +struct TileUnifiedAttentionTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDim = kPadHeadDim; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; +} \ No newline at end of file diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 8d37d422916..892a5a3db03 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -4,8 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" -#include "ck_tile/ops/unified_attention/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" #define ENABLE_ASM_MARKER 1 @@ -270,21 +269,15 @@ struct UnifiedAttentionPipeline static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; - static constexpr ck_tile::index_t kM0 = UnifiedAttentionShape::kM0; - static constexpr ck_tile::index_t kN0 = UnifiedAttentionShape::kN0; - static constexpr ck_tile::index_t kK0 = UnifiedAttentionShape::kK0; - static constexpr ck_tile::index_t kN1 = UnifiedAttentionShape::kN1; - static constexpr ck_tile::index_t kK1 = UnifiedAttentionShape::kK1; - static constexpr ck_tile::index_t kQKHeaddim = UnifiedAttentionShape::kQKHeaddim; - static constexpr ck_tile::index_t kSubQKHeaddim = UnifiedAttentionShape::kSubQKHeaddim; + static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q; + static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE; + static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE; + static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED; - static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(HEAD_SIZE_PADDED <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kPadHeadDim = Problem::kPadHeadDim; static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -308,12 +301,12 @@ struct UnifiedAttentionPipeline } }(); - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize(index_t num_queries_per_kv) { // create another LDS buffer for p - return ck_tile::max(kM0 * kN1 * sizeof(PDataType), + return ck_tile::max(BLOCK_Q * num_queries_per_kv * HEAD_SIZE_PADDED * sizeof(PDataType), Policy::template GetSmemSize() + - kM0 * kN0 * sizeof(PDataType)); + BLOCK_Q * num_queries_per_kv * BLOCK_SIZE * sizeof(PDataType)); } // for debug only @@ -391,6 +384,7 @@ struct UnifiedAttentionPipeline [[maybe_unused]] const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile [[maybe_unused]] const VElementFunction& v_element_func, + index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile @@ -411,39 +405,39 @@ struct UnifiedAttentionPipeline std::is_same_v>, "wrong!"); - static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + static_assert(BLOCK_Q == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); + static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize(num_queries_per_kv)); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto s_lds_window = - make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); auto p_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + Policy::template GetSmemSize()), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto p_lds_window = - make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); auto o_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto o_lds_window = - make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); auto m_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + Policy::template GetSmemSize()), - MakeSimpleLdsDesc1D()); + MakeSimpleLdsDesc1D()); [[maybe_unused]] auto m_lds_window = - make_tile_window(m_lds, make_tuple(number{}), {0}); + make_tile_window(m_lds, make_tuple(number{}), {0}); const index_t warp_group_id = get_warp_id() / 4; @@ -550,9 +544,9 @@ struct UnifiedAttentionPipeline const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE); index_t kv_token_start = seqlen_k_start; // check early exit if no work to do @@ -596,11 +590,11 @@ struct UnifiedAttentionPipeline v_dram_window.init_raw(); // prefetch K tile - constexpr index_t k0_loops = kQKHeaddim / kK0; - constexpr index_t k1_loops = kN0 / kK1; + constexpr index_t k0_loops = 1; + constexpr index_t k1_loops = 1; static_assert(1 == k0_loops); static_assert(1 == k1_loops); - static_assert(kN0 == kK1); + // static_assert(BLOCK_SIZE == HEAD_SIZE_PADDED); constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup; static_assert(NumWarpGroups == 2); @@ -832,21 +826,21 @@ struct UnifiedAttentionPipeline clear_tile(sp(sp_reg_idx).sp_compute); // initialize C gemm_0(sp(sp_reg_idx).sp_compute, get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), + sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, + sequence{}), get_slice_tile(kv_tile.k_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{})); + sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, + sequence{})); } else { gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, - sequence<0, (k1_loops - 1) * kK1>{}, - sequence{}), + sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, + sequence{}), get_slice_tile(kv_tile.v_tile, - sequence<0, (k1_loops - 1) * kK1>{}, - sequence{})); + sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, + sequence{})); } }; @@ -856,21 +850,21 @@ struct UnifiedAttentionPipeline clear_tile(sp(sp_reg_idx).sp_compute); // initialize C gemm_0(sp(sp_reg_idx).sp_compute, get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), + sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, + sequence{}), get_slice_tile(kv_tile.k_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{})); + sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, + sequence{})); } else { gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, - sequence<0, (k1_loops - 1) * kK1>{}, - sequence{}), + sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, + sequence{}), get_slice_tile(kv_tile.v_tile, - sequence<0, (k1_loops - 1) * kK1>{}, - sequence{})); + sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, + sequence{})); fmha_alu0(number<1>{} - sp_reg_idx); } }; @@ -915,7 +909,7 @@ struct UnifiedAttentionPipeline if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( - q_origin.at(number<0>{}), kv_token_start, number{}, number{}); + q_origin.at(number<0>{}), kv_token_start, number{}, number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, @@ -1026,7 +1020,7 @@ struct UnifiedAttentionPipeline cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); Scheduler::schedule(cl_p, number<3>{}); - kv_token_start += kN0; + kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1073,7 +1067,7 @@ struct UnifiedAttentionPipeline Scheduler::schedule(cl_p, number<2>{}); fmha_mask(xdl_SP_p01_reg_idx); - kv_token_start += kN0; + kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1151,7 +1145,7 @@ struct UnifiedAttentionPipeline fmha_alu0(number<0>{}); fmha_alu_D_upd(); - kv_token_start += kN0; + kv_token_start += BLOCK_SIZE; ++i_total_loops; if(num_total_loop <= i_total_loops) { diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp index 8c8ccc3bd27..3a8975b1603 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -19,6 +19,7 @@ template struct UnifiedAttentionPipelineProblem { @@ -39,6 +40,7 @@ struct UnifiedAttentionPipelineProblem using ODataType = remove_cvref_t; using UnifiedAttentionShape = remove_cvref_t; using Traits = remove_cvref_t; + using FmhaMask = remove_cvref_t; static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps; static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps; @@ -46,9 +48,7 @@ struct UnifiedAttentionPipelineProblem // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kPadHeadDim = Traits::kPadHeadDim; static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; static constexpr auto BiasEnum = Traits::BiasEnum; From 63c17b723636999f722975a99625c1409f129120 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Thu, 16 Oct 2025 08:54:07 +0000 Subject: [PATCH 24/88] correct masking by transforming y_idx = y_idx / num_queries_per_kv --- .../unified_attention/block/block_masking.hpp | 416 ++---------------- .../kernel/unified_attention_kernel.hpp | 2 + 2 files changed, 38 insertions(+), 380 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp index 2c45945fac0..958f036edf2 100644 --- a/include/ck_tile/ops/unified_attention/block/block_masking.hpp +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -85,22 +85,25 @@ struct GenericAttentionMask static constexpr const char* name = impl::MaskName::name; - CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_) - : GenericAttentionMask(0, 0, y_total_, x_total_) + // New constructor accepting repeat_idx with default value 1 + CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx = 1) + : GenericAttentionMask(0, 0, y_total_, x_total_, repeat_idx) { } CK_TILE_HOST_DEVICE - GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) - : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx = 1) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_), repeat_idx(repeat_idx) { } + template - CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord) + CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord, index_t repeat_idx = 1) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), y_total(mask_coord.at(number<2>{})), - x_total(mask_coord.at(number<3>{})) + x_total(mask_coord.at(number<3>{})), + repeat_idx(repeat_idx) { } @@ -111,17 +114,20 @@ struct GenericAttentionMask CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const { + // Transform the y index according to repeat_idx + index_t y_eff = i_y / repeat_idx; + if constexpr(!IsMasking) { return ck_tile::make_tuple(0, x_total); } else { - // get the tile start/end range assum we loop over along X tile by tile + // get the tile start/end range assuming we loop over along X tile by tile index_t x_start = [&]() { if constexpr(IsLocal) { - index_t tmp = max(-y + i_y + 1, 0); + index_t tmp = max(-y + y_eff + 1, 0); return (tmp / XTile) * XTile; // round to tile aligned } else @@ -133,7 +139,7 @@ struct GenericAttentionMask // TODO: end could be negative, we ignore clamp here, and let caller to check // ... in which case end-start is negative index_t x_end = [&]() { - index_t tmp = min(i_y + YTile - 1 + x, x_total); + index_t tmp = min(y_eff + YTile - 1 + x, x_total); return ((tmp + XTile - 1) / XTile) * XTile; }(); @@ -143,7 +149,7 @@ struct GenericAttentionMask // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) - // TODO: y_end still could be negative, so end-start could be negative(need check) + // Note: this function does not take a dynamic y index so no transform is needed template CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongY(index_t i_x, number, number) const @@ -154,7 +160,7 @@ struct GenericAttentionMask } else { - // get the tile start/end range assum we loop over along Y tile by tile + // get the tile start/end range assuming we loop over along Y tile by tile index_t y_start = [&]() { index_t tmp = max(-x + i_x + 1, 0); return (tmp / YTile) * YTile; // round to tile aligned @@ -174,6 +180,9 @@ struct GenericAttentionMask // per-pixel check if out-of-bound, if true, need mask a value(like -INF) CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const { + // Transform the y index according to repeat_idx + index_t y_eff = i_y / repeat_idx; + if constexpr(!IsMasking) { return i_x >= x_total; @@ -181,8 +190,8 @@ struct GenericAttentionMask else { // no need to do min/max here, since i_x will never be < 0 or >= x_total - index_t x_start = -y + i_y + 1; - index_t x_end = min(i_y + x, x_total); + index_t x_start = -y + y_eff + 1; + index_t x_end = min(y_eff + x, x_total); if constexpr(IsLocal) { @@ -190,19 +199,22 @@ struct GenericAttentionMask } else { - return i_x >= x_end || i_y >= y_total; + return i_x >= x_end || y_eff >= y_total; } } } // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() + // Attention! assume the index passed in this function is within range of GetTileRangeAlongX/Y() // can be used as a fast-path to decide if do per-pixel check or not template CK_TILE_HOST_DEVICE constexpr auto IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number, number) const { + // Transform the y index according to repeat_idx + index_t y_eff = i_tile_top / repeat_idx; + if constexpr(!IsMasking) { // TODO: no need to check begin @@ -212,12 +224,12 @@ struct GenericAttentionMask { if constexpr(IsLocal) { - // check top-right corner > x or left-borrom corner < x + // check top-right corner > x or left-bottom corner < x index_t i_tile_right = i_tile_left + TileWidth; - index_t i_tile_bottom = i_tile_top + TileHeight; - index_t x_end = min(i_tile_top + x, x_total); + index_t i_tile_bottom = y_eff + TileHeight; + index_t x_end = min(y_eff + x, x_total); - bool top_right_edge = i_tile_right > (i_tile_top + x); + bool top_right_edge = i_tile_right > (y_eff + x); bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now @@ -228,7 +240,7 @@ struct GenericAttentionMask { // only need to check top-right corner > x index_t i_tile_right = i_tile_left + TileWidth; - index_t x_end = min(i_tile_top + x, x_total); + index_t x_end = min(y_eff + x, x_total); bool top_right_edge = i_tile_right > x_end; return top_right_edge; @@ -236,369 +248,12 @@ struct GenericAttentionMask } } - private: +private: index_t y, x; index_t y_total, x_total; + index_t repeat_idx; }; -// clang-format off -namespace impl { - template struct SimplifiedMaskName; - template<> struct SimplifiedMaskName { static constexpr const char * name = "nomask"; }; - template<> struct SimplifiedMaskName { static constexpr const char * name = "mask"; }; -} -// clang-format on - -// this version only have 2 variation: masking and non-masking -// This is more friendly to codegen (e.g. need generate less kernel) -// ... with the trade-off that may have more instruction in causal mode -template -struct SimplifiedGenericAttentionMask -{ - static constexpr bool IsMasking = IsMasking_; // false will disable masking - - static constexpr const char* name = impl::SimplifiedMaskName::name; - - CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_) - : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) - { - } - - CK_TILE_HOST_DEVICE - SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) - : y(y_), x(x_), y_total(y_total_), x_total(x_total_) - { - } - template - CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord) - : y(mask_coord.at(number<0>{})), - x(mask_coord.at(number<1>{})), - y_total(mask_coord.at(number<2>{})), - x_total(mask_coord.at(number<3>{})) - { - } - - // to get the loop length along X axis, return index:[start, end), end-start=length - // use this if need loop over X axis tile by tile (like k-seqlen loopover) - // TODO: x_end still could be negative, so end-start could be negative(need check) - template - CK_TILE_HOST_DEVICE constexpr auto - GetTileRangeAlongX(index_t i_y, number, number) const - { - if constexpr(!IsMasking) - { - return ck_tile::make_tuple(0, x_total); - } - else - { - // get the tile start/end range assum we loop over along X tile by tile - index_t x_start = [&]() { - index_t tmp = max(-y + i_y + 1, 0); - return (tmp / XTile) * XTile; // round to tile aligned - }(); - - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t x_end = [&]() { - index_t tmp = min(i_y + YTile - 1 + x, x_total); - return ((tmp + XTile - 1) / XTile) * XTile; - }(); - - return ck_tile::make_tuple(x_start, x_end); - } - } - - template - CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, - number height, - number width, - index_t num_splits, - index_t i_split) const - { - auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); - - const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); - const index_t split_start = x_per_split * i_split; - const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); - - return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), - ck_tile::min(origin_end, split_end)); - } - - // to get the loop length along Y axis, return index:[start, end), end-start=length - // use this if need loop over Y axis tile by tile (like q-seqlen loopover) - // TODO: y_end still could be negative, so end-start could be negative(need check) - template - CK_TILE_HOST_DEVICE constexpr auto - GetTileRangeAlongY(index_t i_x, number, number) const - { - if constexpr(!IsMasking) - { - return ck_tile::make_tuple(0, y_total); - } - else - { - // get the tile start/end range assum we loop over along Y tile by tile - index_t y_start = [&]() { - index_t tmp = max(-x + i_x + 1, 0); - return (tmp / YTile) * YTile; // round to tile aligned - }(); - - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t y_end = [&]() { - index_t tmp = min(i_x + XTile - 1 + y, y_total); - return ((tmp + YTile - 1) / YTile) * YTile; - }(); - - return ck_tile::make_tuple(y_start, y_end); - } - } - - // per-pixel check if out-of-bound, if true, need mask a value(like -INF) - CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const - { - if constexpr(!IsMasking) - { - // the only case that need do following compare is under kPadSeqLenK - // ... for non-masking kernel. - return i_x >= x_total; - } - else - { - index_t x_start = -y + i_y + 1; // this could be negative, but it's fine - index_t x_end = min(i_y + x, x_total); // need min in case x is padded - - return i_x < x_start || i_x >= x_end || i_y >= y_total; - } - } - - // if current tile is at the edge, means need per-pixel mask check. - // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() - // can be used as a fast-path to decide if do per-pixel check or not - template - CK_TILE_HOST_DEVICE constexpr auto - IsEdgeTile(index_t i_y, index_t i_x, number, number) const - { - if constexpr(!IsMasking) - { - // the only case that need do following compare is under kPadSeqLenK - // ... for non-masking kernel. - // return (i_x < x_total) && ((i_x + TileWidth) > x_total); - - // TODO: no need to check begin - return (i_x + TileWidth) > x_total; - } - else - { - // check top-right corner > x or left-borrom corner < x - index_t i_x_end = i_x + TileWidth; - index_t i_y_end = i_y + TileHeight; - // index_t x_end = min(i_y + x, x_total); - - bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad - bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad - // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now - - return top_right_edge || bottom_left_edge; - } - } - - private: - index_t y, x; - index_t y_total, x_total; -}; - -// clang-format off -namespace impl { - template struct SimplifiedRatioMaskName; - template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "nomask"; }; - template<> struct SimplifiedRatioMaskName { static constexpr const char * name = "mask"; }; -} -// clang-format on - -// this version is used for cases that the step length of y-direction changes greater than one. It -// means that the mask is not a regular triangular matrix. - -// clang-format off -/* y_ratio is used to describe the step length of y-direction changes - in certain performance optimization scenarios like merging seqlen - and qk_head_ratio, for example: - - x=1/y=6/y_ratio=2(top-left) - 1 * * * * * * * - 1 * * * * * * * - 1 1 * * * * * * - 1 1 * * * * * * - 1 1 1 * * * * * - 1 1 1 * * * * * - -*/ -// clang-format on -template -struct SimplifiedRatioAttentionMask -{ - static constexpr bool IsMasking = IsMasking_; // false will disable masking - - static constexpr const char* name = impl::SimplifiedRatioMaskName::name; - - CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_) - : SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{}) - { - } - - CK_TILE_HOST_DEVICE - SimplifiedRatioAttentionMask( - index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_) - : SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast(y_ratio_mdiv_.get()), - /*x_=*/x_, - /*y_total_=*/y_total_, - /*x_total_=*/x_total_, - /*y_real_=*/y_real_, - /*y_ratio_=*/static_cast(y_ratio_mdiv_.get()), - /*y_ratio_mdiv_=*/y_ratio_mdiv_) - - { - } - CK_TILE_HOST_DEVICE - SimplifiedRatioAttentionMask(index_t y_, - index_t x_, - index_t y_total_, - index_t x_total_, - index_t y_real_, - index_t y_ratio_, - mdiv y_ratio_mdiv_) - : y(y_), - x(x_), - y_total(y_total_), - x_total(x_total_), - y_real(y_real_), - y_ratio(y_ratio_), - y_ratio_mdiv(y_ratio_mdiv_) - { - } - - // to get the loop length along X axis, return index:[start, end), end-start=length - // use this if need loop over X axis tile by tile (like k-seqlen loopover) - // TODO: x_end still could be negative, so end-start could be negative(need check) - template - CK_TILE_HOST_DEVICE constexpr auto - GetTileRangeAlongX(index_t i_y, number, number) const - { - if constexpr(!IsMasking) - { - return ck_tile::make_tuple(0, x_total); - } - else - { - // get the tile start/end range assum we loop over along X tile by tile - index_t x_start = [&]() { - index_t tmp = -y_real + - static_cast(y_ratio_mdiv.div(static_cast(i_y))) + - 1; - - return (tmp / XTile) * XTile; // round to tile aligned - }(); - - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t x_end = [&]() { - uint32_t y_offset = i_y + YTile - 1; - index_t tmp = min(static_cast(y_ratio_mdiv.div(y_offset)) + x, x_total); - return ((tmp + XTile - 1) / XTile) * XTile; - }(); - - return ck_tile::make_tuple(x_start, x_end); - } - } - - // to get the loop length along Y axis, return index:[start, end), end-start=length - // use this if need loop over Y axis tile by tile (like q-seqlen loopover) - // TODO: y_end still could be negative, so end-start could be negative(need check) - template - CK_TILE_HOST_DEVICE constexpr auto - GetTileRangeAlongY(index_t i_x, number, number) const - { - if constexpr(!IsMasking) - { - return ck_tile::make_tuple(0, y_total); - } - else - { - // get the tile start/end range assum we loop over along Y tile by tile - index_t y_start = [&]() { - index_t tmp = max((-x + i_x + 1) * y_ratio, 0); - return (tmp / YTile) * YTile; // round to tile aligned - }(); - - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t y_end = [&]() { - index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total); - return ((tmp + YTile - 1) / YTile) * YTile; - }(); - - return ck_tile::make_tuple(y_start, y_end); - } - } - - // per-pixel check if out-of-bound, if true, need mask a value(like -INF) - CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const - { - if constexpr(!IsMasking) - { - return i_x >= x_total; - } - else - { - index_t x_tmp = static_cast(y_ratio_mdiv.div(static_cast(i_y))); - index_t x_start = -y_real + x_tmp + 1; - index_t x_end = min(x_tmp + x, - x_total); // need min in case x is padded - return i_x < x_start || i_x >= x_end || i_y >= y_total; - } - } - - // if current tile is at the edge, means need per-pixel mask check. - // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() - // can be used as a fast-path to decide if do per-pixel check or not - template - CK_TILE_HOST_DEVICE constexpr auto - IsEdgeTile(index_t i_y, index_t i_x, number, number) const - { - if constexpr(!IsMasking) - { - // the only case that need do following compare is under kPadSeqLenK - // ... for non-masking kernel. - // return (i_x < x_total) && ((i_x + TileWidth) > x_total); - - return (i_x + TileWidth) > x_total; - } - else - { - // check top-right corner > x or left-borrom corner < x - index_t i_x_end = i_x + TileWidth; - index_t i_y_end = i_y + TileHeight; - // index_t x_end = min(i_y + x, x_total); - uint32_t y_tmp = static_cast(i_y); - bool top_right_edge = i_x_end > min(static_cast(y_ratio_mdiv.div(y_tmp)) + x, - x_total); // consider right pad - bool bottom_left_edge = - i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad - return top_right_edge || bottom_left_edge; - } - } - - private: - index_t y, x; - index_t y_total, x_total; - // y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y - index_t y_real; - index_t y_ratio; - mdiv y_ratio_mdiv; -}; // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate @@ -633,10 +288,11 @@ make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, + index_t repeat_idx = 1, bool is_top_left = true) { auto r = make_generic_attention_mask_coordinates_from_lr_window( left_size, right_size, y_total, x_total, is_top_left); - return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; + return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total, repeat_idx}; } } // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index fc067a40cef..25351656ac3 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -29,7 +29,9 @@ struct FmhaFwdV3Kernel using VDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; From 62932576c4b450a1fc1656c50aefa16957d22665 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:02:08 +0000 Subject: [PATCH 25/88] use correct mask in kernel --- .../ops/unified_attention/kernel/unified_attention_kernel.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 70dae4d1f45..54e60284eb8 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -460,10 +460,11 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.BLOCK_M, + kargs.BLOCK_Q, kargs.BLOCK_SIZE, cur_batch_query_len, seq_len, + num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); else return FmhaMask{cur_batch_query_len, seq_len}; From aa4908ac14d8c32c5076689b2500547e2b08b1d1 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:18:38 +0000 Subject: [PATCH 26/88] fix mask --- .../kernel/unified_attention_kernel.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 54e60284eb8..aebd04cb72a 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -460,11 +460,11 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - kargs.BLOCK_Q, - kargs.BLOCK_SIZE, - cur_batch_query_len, - seq_len, - num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times + cur_batch_query_len, // extend length + seq_len - cur_batch_query_len, // context length + cur_batch_query_len, // extend length + seq_len, // key length (context + extend) + num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times along x dim of the tile kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); else return FmhaMask{cur_batch_query_len, seq_len}; From 072de3842f27f5250aba4c17d9906a03de9e6d30 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:23:39 +0000 Subject: [PATCH 27/88] comment --- .../unified_attention/kernel/unified_attention_kernel.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index aebd04cb72a..57061dbe187 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -460,10 +460,10 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - cur_batch_query_len, // extend length - seq_len - cur_batch_query_len, // context length - cur_batch_query_len, // extend length - seq_len, // key length (context + extend) + cur_batch_query_len, // x (i.e. extend) + seq_len - cur_batch_query_len, // y (i.e. context) + cur_batch_query_len, // x_total + seq_len, // y_total (x + y) num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times along x dim of the tile kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); else From 9940bd07f6da1f90693bcd3cb27878228de22d32 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Thu, 16 Oct 2025 11:23:46 +0000 Subject: [PATCH 28/88] fix order in mask caller --- .../ops/unified_attention/kernel/unified_attention_kernel.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 57061dbe187..e19f8c4fce3 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -460,10 +460,10 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - cur_batch_query_len, // x (i.e. extend) seq_len - cur_batch_query_len, // y (i.e. context) - cur_batch_query_len, // x_total + cur_batch_query_len, // x (i.e. extend) seq_len, // y_total (x + y) + cur_batch_query_len, // x_total num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times along x dim of the tile kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); else From af9167abad794d11d39e9925ca02cac141ce056f Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 17 Oct 2025 09:05:10 +0000 Subject: [PATCH 29/88] example --- ...d_v3.cpp => example_unified_attention.cpp} | 251 ++++++++---------- .../unified_attention.hpp | 3 +- .../unified_attention_impl.hpp | 1 - .../kernel/unified_attention_kernel.hpp | 11 - 4 files changed, 113 insertions(+), 153 deletions(-) rename example/ck_tile/01_unified_attention/{example_fmha_fwd_v3.cpp => example_unified_attention.cpp} (75%) diff --git a/example/ck_tile/01_unified_attention/example_fmha_fwd_v3.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp similarity index 75% rename from example/ck_tile/01_unified_attention/example_fmha_fwd_v3.cpp rename to example/ck_tile/01_unified_attention/example_unified_attention.cpp index 7ddb65a2dbc..78c711804d3 100644 --- a/example/ck_tile/01_unified_attention/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -30,16 +30,23 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair std::pair(hdim)); + query_lens = args.get_int_vec("query_lens"); + kv_lens = args.get_int_vec("kv_lens"); + // softmax_scale = args.get_float("scale_s"); + // if(softmax_scale == .0f) + // softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); - const auto is_causal = args.get_bool("causal"); - if(is_causal) - { - mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); - } - else - { - mask = mask_info::decode("0", seqlen_q, seqlen_k); - } - input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; - q_eff_lens = args.get_int_vec("q_eff_lens"); - kv_eff_lens = args.get_int_vec("kv_eff_lens"); + // TODO + // mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); + + // q_eff_lens = args.get_int_vec("q_eff_lens"); + // kv_eff_lens = args.get_int_vec("kv_eff_lens"); } std::vector get_query_shape() const { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } + return {batch * seqlen_q, nhead_q, hdim}; } std::vector get_key_shape() const { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } + return {num_blks, BLOCK_SIZE, nhead_kv, hdim}; } std::vector get_value_shape() const { - if(input_layout == TensorLayout::bhsd) - { - return {batch, nhead_kv, seqlen_k, hdim}; - } - else - { - return {batch, seqlen_k, nhead_kv, hdim}; - } + return {num_blks, BLOCK_SIZE, nhead_kv, hdim}; } std::vector get_output_shape() const { - if(output_layout == TensorLayout::bhsd) - { - return {batch, nhead_q, seqlen_q, hdim}; - } - else - { - return {batch, seqlen_q, nhead_q, hdim}; - } + return {batch * seqlen_q, nhead_q, hdim}; + } ck_tile::fmha_fwd_v3_args::data_type_enum data_type; ck_tile::index_t batch; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; + ck_tile::index_t num_blks; + ck_tile::index_t BLOCK_SIZE; + ck_tile::index_t max_seqlen_q; // sequal seq len, in thd format + ck_tile::index_t max_context_len; ck_tile::index_t nhead_q; ck_tile::index_t nhead_kv; ck_tile::index_t hdim; - float softmax_scale; + float scale_s; + float scale; + float scale_k; + float scale_v; mask_info mask; - TensorLayout input_layout; TensorLayout output_layout; - std::vector q_eff_lens; - std::vector kv_eff_lens; + std::vector query_lens; + std::vector kv_lens; }; struct RunConfig @@ -226,6 +199,7 @@ auto generate_qkv(const Problem& problem, return std::make_tuple(q, k, v); } + namespace host { template (problem.mask.type); - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) + args.num_head_q = problem.nhead_q; + args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv; + args.mask_type = 2; + args.hdim = problem.hdim; + + args.BLOCK_SIZE = problem.BLOCK_SIZE; + args.num_blks = problem.num_blks; + + // args.query_lens = problem.query_lens + // args.kv_lens = problem.kv_lens + args.q_ptr = q_buf.GetDeviceBuffer(); - args.stride_q = - problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_q = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim; - args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) + args.query_stride_0 = problem.hdim * problem.nhead_q; + args.query_stride_0 = problem.hdim; + args.k_ptr = k_buf.GetDeviceBuffer(); - args.stride_k = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_k = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_k, nhead_kv, hdim) - // bhsd: (batch, nhead_kv, seqlen_k, hdim) + + args.stride_k_cache_0 = problem.hdim * problem.nhead_kv * problem.BLOCK_SIZE; + args.stride_k_cache_1 = problem.hdim * problem.nhead_kv; + args.stride_k_cache_2 = problem.hdim; + args.stride_k_cache_3 = 1; + args.v_ptr = v_buf.GetDeviceBuffer(); - args.stride_v = - problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; - args.nhead_stride_v = - problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; - args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim; - - // bshd: (batch, seqlen_q, nhead_q, hdim) - // bhsd: (batch, nhead_q, seqlen_q, hdim) + args.stride_v_cache_0 = args.stride_k_cache_0; + args.stride_v_cache_1 = args.stride_k_cache_1; + args.stride_v_cache_2 = args.stride_k_cache_2; + args.stride_v_cache_3 = args.stride_k_cache_3; + args.o_ptr = o_buf.GetDeviceBuffer(); - args.stride_o = - problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; - args.nhead_stride_o = problem.output_layout == TensorLayout::bshd - ? problem.hdim - : problem.seqlen_q * problem.hdim; - args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; + args.output_stride_0 = query_stride_0; + args.output_stride_1 = query_stride_1; // Optional cumulative seqlen overrides (exclude PAD) - const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; - const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; - auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { std::vector eff; if(!opt_vec.empty() && opt_vec[0] != -1) @@ -416,11 +372,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) return eff; }; - const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); - const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); + const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024); + const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024); // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cuq_cum, cukv_cum; + std::vector cu_query_lens ; + auto calculate_cumulative = [&](const std::vector& per_batch_vec, std::vector& cum_vec) { cum_vec.resize(per_batch_vec.size() + 1); @@ -428,26 +385,42 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) for(std::size_t i = 0; i < per_batch_vec.size(); ++i) cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; }; +mask_type + calculate_cumulative(eff_query_lens, cu_query_lens); - if(has_varlen_q) - { - calculate_cumulative(eff_q_vec, cuq_cum); - } - if(has_varlen_k) - { - calculate_cumulative(eff_kv_vec, cukv_cum); + ck_tile::DeviceMem seq_lens_buf(kv_lens.size()); + ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size()); + + seq_lens_buf.ToDevice(kv_lens.data()); + query_start_len_buf.ToDevice(cu_query_lens.data()); + + args.seq_lens_ptr =reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); + args.query_start_len_ptr =reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); + + + auto max_kv_len = std::max_element(problem.kv_lens.begin(), problem.kv_lens.end()); + + index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE + + // Create block_tables + ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * sizeof(ck_tile::index_t)); + + // Allocate host memory for block_tables + std::vector block_tables_host(problem.batch * max_num_blocks_per_seq); + + // Fill block_tables with random integers between 0 and num_blocks-1 + std::mt19937 rng(run_config.seed ? *run_config.seed : std::random_device{}()); + std::uniform_int_distribution dist(0, problem.num_blks - 1); + for (size_t i = 0; i < block_tables_host.size(); ++i) { + block_tables_host[i] = dist(rng); } - ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); - ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); - cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); - cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); - args.cu_seqlen_q_ptr = - !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) - : nullptr; - args.cu_seqlen_kv_ptr = - !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) - : nullptr; + // Copy to device + block_tables_buf.ToDevice(block_tables_host.data()); + + // Set pointer in args + args.block_tables_ptr = reinterpret_cast(block_tables_buf.GetDeviceBuffer()); + ck_tile::stream_config stream_config{nullptr, true, @@ -455,7 +428,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) run_config.kernel_warmup, run_config.kernel_repeat}; - auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config); + auto [result, time] = ck_tile::unified_attention(args, stream_config); if(!result) { std::cerr << "faild to run fmha_fwd_v3()" << std::endl; diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index 63a348a69c9..25f787246ba 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -28,6 +28,7 @@ struct unified_attention_args index_t num_head_q; index_t num_queries_per_kv; + index_t hdim; // TODO window float scale_s; float scale; @@ -35,8 +36,6 @@ struct unified_attention_args float scale_v; float scale_out; - index_t total_num_q_blocks; - const void* q_ptr; index_t query_stride_0; index_t query_stride_1; diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 952ebfe0fa4..5ac20d354eb 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -122,7 +122,6 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.scale_k, args.scale_v, args.scale_out, - args.total_num_q_blocks, args.query_stride_0, args.query_stride_1, args.stride_k_cache_0, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 57061dbe187..6d4d2a25007 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -24,23 +24,12 @@ struct UnifiedAttentionKernel static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); -<<<<<<< HEAD - using QDataType = ck_tile::remove_cvref_t; - using KDataType = ck_tile::remove_cvref_t; - using VDataType = ck_tile::remove_cvref_t; - using ODataType = ck_tile::remove_cvref_t; - using SaccDataType = ck_tile::remove_cvref_t; - - using FmhaMask = ck_tile::remove_cvref_t; - static constexpr bool kHasMask = FmhaMask::IsMasking; -======= using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; ->>>>>>> 853fa21566f6a1fe4237289c61db772b1bbfeb3f static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ; static constexpr bool kPadHeadDim = UnifiedAttentionPipeline::kPadHeadDim; From f4e8f791fd5b9ee469179c48ac799f2f04400ceb Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 17 Oct 2025 11:03:39 +0000 Subject: [PATCH 30/88] fixing args --- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 32 ++++++++++++------- .../example_unified_attention.cpp | 7 ++-- .../unified_attention.hpp | 1 + .../unified_attention_impl.hpp | 27 +++++++++------- 4 files changed, 39 insertions(+), 28 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index b067a9acae7..457139fbfff 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -20,22 +20,31 @@ #include "fmha_fwd_v3.hpp" #include "mask.hpp" -#define INST_UNIFIED_ATTENTION_V3_DISPATCH(kernel_traits) \ +#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ template <> \ - std::pair unified_attention_kernel_dispatch( \ - const unified_attention_args& args, const stream_config& config) \ + std::pair fmha_fwd_v3_kernel_dispatch( \ + const fmha_fwd_v3_args& args, const stream_config& config) \ { \ return std::make_pair(true, \ - unified_attention_kernel_launch(args, config)); \ + fmha_fwd_v3_kernel_launch(args, config)); \ } namespace ck_tile { -template -struct unified_attention_problem_traits; +template +struct fmha_fwd_v3_problem_traits; template <> -struct unified_attention_problem_traits +struct fmha_fwd_v3_problem_traits +{ + using qkvp_dtype = ck_tile::half_t; + using acc_dtype = float; + using o_dtype = ck_tile::half_t; + using lse_dtype = float; +}; + +template <> +struct fmha_fwd_v3_problem_traits { using qkvp_dtype = ck_tile::bf16_t; using acc_dtype = float; @@ -43,14 +52,13 @@ struct unified_attention_problem_traits -struct unified_attention_kernel_traits +template +struct fmha_fwd_v3_kernel_traits { static constexpr auto date_type = DataType; static constexpr bool is_variable_seqlen = IsVariableSeqlen; - static constexpr bool is_masking = IsMasking; - - // M0 N0 K0 N1 K1 + static constexpr bool is_masking = IsMasking + // M0 N0 K0 N1 K1 using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; using fmha_warp_gemm_shape = sequence<32, 32, 16>; using fmha_block_warps = sequence<8, 1, 1>; diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 78c711804d3..558dba164a3 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -320,8 +320,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.data_type = problem.data_type; args.num_seqs = problem.batch; - args.seqlen_q = problem.seqlen_q; - args.seqlen_k = problem.seqlen_k; + // args.seqlen_q = problem.seqlen_q; + // args.seqlen_k = problem.seqlen_k; args.num_head_q = problem.nhead_q; args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv; args.mask_type = 2; @@ -332,7 +332,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // args.query_lens = problem.query_lens // args.kv_lens = problem.kv_lens - + args.num_tokens = problem.batch * problem.seqlen_q; args.q_ptr = q_buf.GetDeviceBuffer(); args.query_stride_0 = problem.hdim * problem.nhead_q; args.query_stride_0 = problem.hdim; @@ -385,7 +385,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) for(std::size_t i = 0; i < per_batch_vec.size(); ++i) cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; }; -mask_type calculate_cumulative(eff_query_lens, cu_query_lens); ck_tile::DeviceMem seq_lens_buf(kv_lens.size()); diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index 25f787246ba..cf616ef9a51 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -24,6 +24,7 @@ struct unified_attention_args index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and // window_size_right == 0). + index_t num_tokens; // total number of tokens in query index_t num_blks; index_t num_head_q; index_t num_queries_per_kv; diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 5ac20d354eb..4fa0bdab0db 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -58,16 +58,16 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; - // BLOCK_Q BLOCK_SIZE HEAD_SIZE N1 K1 + // BLOCK_Q BLOCK_SIZE HEAD_SIZE using unified_attention_block_tile = sequence<128, 128, 128>; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; using unified_attention_block_warps = sequence<8, 1, 1>; using unified_attention_shape = TileUnifiedAttentionShape; @@ -77,9 +77,9 @@ struct unified_attention_kernel_traits -1 // kBlockPerCu >; - using funified_attention_mask = GenericAttentionMask; + using unified_attention_mask = GenericAttentionMask; - using funified_attention_pipeline_problem = + using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem::qkvp_dtype, typename unified_attention_problem_traits::qkvp_dtype, typename unified_attention_problem_traits::qkvp_dtype, @@ -89,11 +89,11 @@ struct unified_attention_kernel_traits typename unified_attention_problem_traits::qkvp_dtype, typename unified_attention_problem_traits::acc_dtype, typename unified_attention_problem_traits::o_dtype, - funified_attention_shape, - funified_attention_mask, - funified_attention_traits>; + unified_attention_shape, + unified_attention_mask, + unified_attention_traits>; - using funified_attention_pipeline = BlockFunified_attentionFwdV3Pipeline; + using unified_attention_pipeline = Blockunified_attentionFwdV3Pipeline; using epilogue = Default2DEpilogue< Default2DEpilogueProblem::acc_dtype, @@ -103,7 +103,7 @@ struct unified_attention_kernel_traits true // UseRawStore >>; - using kernel = UnifiedAttentionKernel; + using kernel = UnifiedAttentionKernel; }; template @@ -140,7 +140,10 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.num_seqs ); - dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, args.total_num_q_blocks); + index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs + + + dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); constexpr dim3 blocks = Kernel::BlockSize(); constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; From 3f963d4074f52dab2130dfc50cf952b88040276c Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 20 Oct 2025 06:15:13 +0000 Subject: [PATCH 31/88] modified cmake files at unified attention example. Now cmake works, but getting compile errors (expected atm) --- .../01_unified_attention/CMakeLists.txt | 397 ++++++++---------- example/ck_tile/CMakeLists.txt | 2 +- 2 files changed, 187 insertions(+), 212 deletions(-) diff --git a/example/ck_tile/01_unified_attention/CMakeLists.txt b/example/ck_tile/01_unified_attention/CMakeLists.txt index b8ca26193d6..2150ea09b72 100644 --- a/example/ck_tile/01_unified_attention/CMakeLists.txt +++ b/example/ck_tile/01_unified_attention/CMakeLists.txt @@ -1,213 +1,193 @@ -set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) -# Currently only gfx9 archs are supported by FMHA -list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") -if(NOT INST_TARGETS) - message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") - return() -endif() - -# validate user-specified fmha_fwd API list -set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") -set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING - "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") -if(BUILD_TESTING) - # Build instances of all APIs for tests - set(FMHA_FWD_ENABLE_APIS "all") -endif() -if(FMHA_FWD_ENABLE_APIS STREQUAL "all") - set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) -endif() - -foreach(api ${FMHA_FWD_ENABLE_APIS}) - if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) - message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.") - endif() -endforeach() - -# "fwd" is a must-have api for the fmha_fwd example, add it if not specified -if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) - list(PREPEND FMHA_FWD_ENABLE_APIS "fwd") -endif() - -file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS - ${CMAKE_CURRENT_LIST_DIR}/generate.py - ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py -) -# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change -set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") - -string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") -set(FMHA_FWD_CODE_GEN_COMMON_ARGS - ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} - --optdim 32,64,128,256 - # --filter fmha_fwd... -) -set(FMHA_BWD_CODE_GEN_COMMON_ARGS - ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd - --receipt 3 - --optdim 32,64,96,128,256 - # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... -) - -# Reduce building time by disabling instances that are not currently used in the gtests -# TODO: Consider to use a special receipt for testing only, or even two receipts: a small subset of -# instances for quick CI runs and a larger subset for scheduled runs (the tests skip tests when -# there is no corresponding instance for parameters). -if(BUILD_TESTING) - # Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill - list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) -endif() - -# generate a list of kernels, but not actually emit files at config sta -execute_process( - COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt - RESULT_VARIABLE ret -) -if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") -endif() - -execute_process( - COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt - RESULT_VARIABLE ret -) -if(ret AND NOT ret EQUAL 0) - message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.") -endif() - -# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory -# as current cmake list, otherwise will not figure out the dependency properly -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) - -add_custom_command( - OUTPUT ${FMHA_FWD_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} - --output_dir ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${CODE_GEN_SCRIPTS} -) - -add_custom_command( - OUTPUT ${FMHA_BWD_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} - --output_dir ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${CODE_GEN_SCRIPTS} -) - -set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") -set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") - -message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}") -add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) -target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS}) -set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) - -message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}") -add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) -target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS}) -set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) - -set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS) -set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS) -set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS) -set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS) - -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -# ... because they are auto-generated -list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) -list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) - -# Allow comparing floating points directly in order to check sentinel values -list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) -list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) - -# NOTE: this is dangerous since will change the whole kernel to flush denormals -# WIP with compiler team for an exp2 intrinsic..., then remove this -if(NOT DEFINED FMHA_FWD_FAST_EXP2) - set(FMHA_FWD_FAST_EXP2 ON) -endif() - -if(FMHA_FWD_FAST_EXP2) - list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) -else() - list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0) -endif() -list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero) - -# conditionally enable call to the fwd_splitkv API in fmha_fwd example and tests -if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) - list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) -else() - list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) -endif() - -# conditionally enable call to the fwd_appendkv API in fmha_fwd example and tests -if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS) - list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) -else() - list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) -endif() - -# conditionally enable call to the pagedkv_prefill API in fmha_fwd example and tests -if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) - list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) -else() - list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) -endif() - -# conditionally specify the use of OCP_FP8 -if(CK_USE_OCP_FP8) - list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) - list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) -endif() - -# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream -list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) -list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) - -target_compile_options(${FMHA_FWD_INSTANCES} - PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS} - INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS}) -target_compile_options(${FMHA_BWD_INSTANCES} - PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS} - INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS}) - -set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") -set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") - -message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") -# not using add_example_executable() to add this target, since we don't want this to be included in -# "make all/install/check" -add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp) -target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES}) -target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) - -message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") -# not using add_example_executable() to add this target, since we don't want this to be included in -# "make all/install/check" -add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) -target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) -target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) - -# add fmha_fwd_v3 example -set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") +# Commented out: FMHA fwd/bwd instance generation and codegen commands not used by unified_attention +# +# set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) +# # Currently only gfx9 archs are supported by FMHA +# list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") +# if(NOT INST_TARGETS) +# message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +# return() +# endif() +# +# # validate user-specified fmha_fwd API list +# set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") +# set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING +# "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") +# if(BUILD_TESTING) +# # Build instances of all APIs for tests +# set(FMHA_FWD_ENABLE_APIS "all") +# endif() +# if(FMHA_FWD_ENABLE_APIS STREQUAL "all") +# set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) +# endif() +# +# foreach(api ${FMHA_FWD_ENABLE_APIS}) +# if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) +# message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.") +# endif() +# endforeach() +# +# # "fwd" is a must-have api for the fmha_fwd example, add it if not specified +# if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) +# list(PREPEND FMHA_FWD_ENABLE_APIS "fwd") +# endif() +# +# file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS +# ${CMAKE_CURRENT_LIST_DIR}/generate.py +# ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py +# ) +# set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") +# +# string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +# set(FMHA_FWD_CODE_GEN_COMMON_ARGS +# ${CMAKE_CURRENT_LIST_DIR}/generate.py +# --api ${FMHA_FWD_APIS} +# --optdim 32,64,128,256 +# ) +# set(FMHA_BWD_CODE_GEN_COMMON_ARGS +# ${CMAKE_CURRENT_LIST_DIR}/generate.py +# --api bwd +# --receipt 3 +# --optdim 32,64,96,128,256 +# ) +# +# if(BUILD_TESTING) +# list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) +# endif() +# +# execute_process( +# COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} +# --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt +# RESULT_VARIABLE ret +# ) +# if(ret AND NOT ret EQUAL 0) +# message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") +# endif() +# +# execute_process( +# COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} +# --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt +# RESULT_VARIABLE ret +# ) +# if(ret AND NOT ret EQUAL 0) +# message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.") +# endif() +# +# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) +# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) +# +# add_custom_command( +# OUTPUT ${FMHA_FWD_GEN_BLOBS} +# COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} +# --output_dir ${CMAKE_CURRENT_BINARY_DIR} +# DEPENDS ${CODE_GEN_SCRIPTS} +# ) +# +# add_custom_command( +# OUTPUT ${FMHA_BWD_GEN_BLOBS} +# COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} +# --output_dir ${CMAKE_CURRENT_BINARY_DIR} +# DEPENDS ${CODE_GEN_SCRIPTS} +# ) +# +# set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") +# set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") +# +# message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}") +# add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) +# target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +# target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +# set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +# set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) +# +# message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}") +# add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) +# target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +# target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS}) +# set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +# set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) +# +# set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS) +# set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS) +# set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS) +# set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS) +# +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) +# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template) +# +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) +# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal) +# +# if(NOT DEFINED FMHA_FWD_FAST_EXP2) +# set(FMHA_FWD_FAST_EXP2 ON) +# endif() +# +# if(FMHA_FWD_FAST_EXP2) +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) +# else() +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0) +# endif() +# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero) +# +# if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) +# else() +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) +# endif() +# +# if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS) +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) +# else() +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) +# endif() +# +# if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) +# else() +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) +# endif() +# +# if(CK_USE_OCP_FP8) +# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +# endif() +# +# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +# list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +# +# target_compile_options(${FMHA_FWD_INSTANCES} +# PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS} +# INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS}) +# target_compile_options(${FMHA_BWD_INSTANCES} +# PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS} +# INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS}) +# +# set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") +# set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") +# +# message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") +# add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp) +# target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES}) +# target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +# +# message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") +# add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) +# target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) +# target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +# +# set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) + +# --- Unified Attention target (kept) --- + +set(EXAMPLE_FMHA_FWD_V3 "tile_example_unified_attention") message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") -add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) +add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_unified_attention.cpp) target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS - "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" + "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" ) target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE - fmha_fwd_v3.cpp + unified_attention.cpp ${FMHA_FWD_V3_INSTANCES} ) @@ -222,17 +202,12 @@ set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) if(HAS_DISABLE_PACKED_FP32) list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS - -mllvm --amdgpu-disable-packed-fp32=1 + -mllvm --amdgpu-disable-packed-fp32=1 ) list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS - -DCK_TILE_DISABLE_PACKED_FP32=1 + -DCK_TILE_DISABLE_PACKED_FP32=1 ) endif() target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) -# TODO: we have to turn off this global prop, otherwise the progress bar generated -# by cmake will print too many files, execvp: /bin/sh: Argument list too long -# however, this property may affect global -# TODO: consider codegen a makefile by us -set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 931f2d315f4..9ad279f009a 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -1,7 +1,7 @@ include_directories(AFTER ${CMAKE_CURRENT_LIST_DIR} ) - +add_subdirectory(01_unified_attention) add_subdirectory(01_fmha) add_subdirectory(02_layernorm2d) add_subdirectory(03_gemm) From 9fda954253d015b8fc6c15a3fa7f25fbe3324f46 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 20 Oct 2025 13:16:19 +0000 Subject: [PATCH 32/88] Compiling fix --- .../ck_tile/01_unified_attention/example_unified_attention.cpp | 3 +-- ...d128_bf16_mask.cpp => unified_attention_d128_bf16_mask.cpp} | 0 ...28_bf16_nmask.cpp => unified_attention_d128_bf16_nmask.cpp} | 0 ...d128_fp16_mask.cpp => unified_attention_d128_fp16_mask.cpp} | 0 ...28_fp16_nmask.cpp => unified_attention_d128_fp16_nmask.cpp} | 0 .../ck_tile/01_unified_attention/unified_attention_runner.hpp | 2 +- .../pipeline/tile_unified_attention_traits.hpp | 2 +- .../pipeline/unified_attention_pipeline_problem.hpp | 2 +- 8 files changed, 4 insertions(+), 5 deletions(-) rename example/ck_tile/01_unified_attention/instances/{fmha_fwd_v3_d128_bf16_mask.cpp => unified_attention_d128_bf16_mask.cpp} (100%) rename example/ck_tile/01_unified_attention/instances/{fmha_fwd_v3_d128_bf16_nmask.cpp => unified_attention_d128_bf16_nmask.cpp} (100%) rename example/ck_tile/01_unified_attention/instances/{fmha_fwd_v3_d128_fp16_mask.cpp => unified_attention_d128_fp16_mask.cpp} (100%) rename example/ck_tile/01_unified_attention/instances/{fmha_fwd_v3_d128_fp16_nmask.cpp => unified_attention_d128_fp16_nmask.cpp} (100%) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 558dba164a3..dc0e389f0a4 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -22,8 +22,7 @@ #include #include -#include "fmha_fwd.hpp" -#include "fmha_fwd_v3.hpp" +#include "unified_attention.hpp" #include "mask.hpp" auto parse_cmd_args(int argc, char* argv[]) -> std::pair diff --git a/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_mask.cpp rename to example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp diff --git a/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_bf16_nmask.cpp rename to example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp diff --git a/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_mask.cpp rename to example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp diff --git a/example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/instances/fmha_fwd_v3_d128_fp16_nmask.cpp rename to example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp diff --git a/example/ck_tile/01_unified_attention/unified_attention_runner.hpp b/example/ck_tile/01_unified_attention/unified_attention_runner.hpp index 0703af71e33..7da84b8a927 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_runner.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_runner.hpp @@ -5,7 +5,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/ref/naive_attention.hpp" -#include "fmha_fwd.hpp" +#include "unified_attention.hpp" #include "utils.hpp" #include "ck_tile/utility/json_dump.hpp" diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp index 682e93fd56f..a285c308761 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -21,4 +21,4 @@ struct TileUnifiedAttentionTraits static constexpr bool kStoreLSE = kStoreLSE_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -} \ No newline at end of file +} diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp index 3a8975b1603..3f676bf01dd 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -57,4 +57,4 @@ struct UnifiedAttentionPipelineProblem static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -} \ No newline at end of file +} From 97e7527eb1721c5d243797b17da45d1d72b67f53 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 20 Oct 2025 14:03:15 +0000 Subject: [PATCH 33/88] fixing compile errors... --- .../unified_attention/block/block_masking.hpp | 12 +++---- .../kernel/unified_attention_kernel.hpp | 36 ++++++++++--------- .../pipeline/unified_attention_pipeline.hpp | 30 +++------------- 3 files changed, 30 insertions(+), 48 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp index 958f036edf2..87868a56a18 100644 --- a/include/ck_tile/ops/unified_attention/block/block_masking.hpp +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -86,24 +86,24 @@ struct GenericAttentionMask static constexpr const char* name = impl::MaskName::name; // New constructor accepting repeat_idx with default value 1 - CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx = 1) - : GenericAttentionMask(0, 0, y_total_, x_total_, repeat_idx) + CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) + : GenericAttentionMask(0, 0, y_total_, x_total_, repeat_idx_) { } CK_TILE_HOST_DEVICE - GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx = 1) - : y(y_), x(x_), y_total(y_total_), x_total(x_total_), repeat_idx(repeat_idx) + GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_), repeat_idx(repeat_idx_) { } template - CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord, index_t repeat_idx = 1) + CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord, index_t repeat_idx_ = 1) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), y_total(mask_coord.at(number<2>{})), x_total(mask_coord.at(number<3>{})), - repeat_idx(repeat_idx) + repeat_idx(repeat_idx_) { } diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 42e0dd25ffc..b5d46c754fd 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -5,7 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" #include "ck_tile/core/numeric/math.hpp" #include @@ -30,9 +30,12 @@ struct UnifiedAttentionKernel using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; - + static constexpr bool kHasMask = FmhaMask::IsMasking; + + static constexpr bool kPadSeqLenK = UnifiedAttentionPipeline::kPadSeqLenK; static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ; - static constexpr bool kPadHeadDim = UnifiedAttentionPipeline::kPadHeadDim; + static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV; // TODO add yjese static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE; @@ -181,7 +184,7 @@ struct UnifiedAttentionKernel find_seq_idx(const int32_t* query_start_len_ptr, ck_tile::index_t target_idx, ck_tile::index_t num_seqs, - ck_tile::index_t BLOCK_Q, + ck_tile::index_t block_q, bool use_q_block_mode) { ck_tile::index_t left = 0; @@ -191,7 +194,7 @@ struct UnifiedAttentionKernel { ck_tile::index_t mid = (left + right) / 2; ck_tile::index_t val = query_start_len_ptr[mid]; - ck_tile::index_t mid_val = use_q_block_mode ? (val / BLOCK_Q + mid) : val; + ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val; if (mid_val <= target_idx) { @@ -276,9 +279,9 @@ struct UnifiedAttentionKernel const index_t BLOCK_M = BLOCK_Q * kargs.num_queries_per_kv; // for simplicity, batch stride we just modify the pointer - const index_t num_head_q = kargs.num_head_q; + // const index_t num_head_q = kargs.num_head_q; const index_t num_queries_per_kv = kargs.num_queries_per_kv; - const index_t num_head_k = num_head_q / num_queries_per_kv; + // const index_t num_head_k = num_head_q / num_queries_per_kv; pid = RemapTileIndices(pid, kargs); @@ -311,14 +314,14 @@ struct UnifiedAttentionKernel const index_t query_pos = q_block_local_idx * BLOCK_Q; const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = seq_len - cur_batch_query_len; + // const index_t context_len = seq_len - cur_batch_query_len; - const index_t max_seq_prefix_len = ( - context_len - + q_block_local_idx * BLOCK_Q - + (BLOCK_M - 1) // num_queries_per_kv - + 1 - ); + // const index_t max_seq_prefix_len = ( + // context_len + // + q_block_local_idx * BLOCK_Q + // + (BLOCK_M - 1) // num_queries_per_kv + // + 1 + // ); index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2; @@ -463,9 +466,8 @@ struct UnifiedAttentionKernel return UnifiedAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, - block_tables_ptr, + kargs.block_tables_ptr, block_table_offset, - lse_dram_window, mask, kargs.scale_s, smem_ptr); @@ -484,7 +486,7 @@ struct UnifiedAttentionKernel o_dram_base, // block sizes make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), - sequence{} + sequence{} ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) const auto o_dram_merged = transform_tensor_view( diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 892a5a3db03..5c1a91fb22b 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -277,8 +277,10 @@ struct UnifiedAttentionPipeline static_assert(HEAD_SIZE_PADDED <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadHeadDim = Problem::kPadHeadDim; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + // static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -387,7 +389,6 @@ struct UnifiedAttentionPipeline index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, - LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, [[maybe_unused]] const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, @@ -554,15 +555,7 @@ struct UnifiedAttentionPipeline { if(num_total_loop <= 0) { - if constexpr(kStoreLSE) - { - auto lse = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse, -numeric::infinity()); - - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); - } + // Note: here occ are all cleard, return it // Note: q loaded but no fence, ignore it. @@ -1193,19 +1186,6 @@ struct UnifiedAttentionPipeline fmha_post_process(number<0>{}); } - // store lse - if constexpr(kStoreLSE) - { - auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - - constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); - sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]); - }); - - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); - } // finally, O constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); From d68a541c1994c552b7baef484682373ccb3be843 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 20 Oct 2025 15:04:47 +0000 Subject: [PATCH 34/88] fixing compile errors... --- .../unified_attention_d128_bf16_mask.cpp | 2 +- .../unified_attention_d128_fp16_nmask.cpp | 2 +- example/ck_tile/01_unified_attention/mask.hpp | 2 +- .../unified_attention_impl.hpp | 5 ++-- include/ck_tile/ops/unified_attention.hpp | 30 +++++++++++++++++++ .../block/block_position_encoding.hpp | 2 +- .../tile_unified_attention_traits.hpp | 8 ++--- .../pipeline/unified_attention_pipeline.hpp | 19 ++++-------- 8 files changed, 47 insertions(+), 23 deletions(-) create mode 100644 include/ck_tile/ops/unified_attention.hpp diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp index d99838d17c0..72717026bc5 100644 --- a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp +++ b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp index d8fcd7d97db..6a2a9984d1f 100644 --- a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp +++ b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_unified_attention/mask.hpp b/example/ck_tile/01_unified_attention/mask.hpp index 2dfe0e7c529..33f9bf72a9b 100644 --- a/example/ck_tile/01_unified_attention/mask.hpp +++ b/example/ck_tile/01_unified_attention/mask.hpp @@ -7,7 +7,7 @@ #include #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/unified_attention.hpp" // keep this in sync with ck_tile::GenericAttentionMaskEnum enum class mask_enum diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 4fa0bdab0db..65f17fa251d 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -85,6 +85,7 @@ struct unified_attention_kernel_traits typename unified_attention_problem_traits::qkvp_dtype, typename unified_attention_problem_traits::acc_dtype, typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, typename unified_attention_problem_traits::lse_dtype, typename unified_attention_problem_traits::qkvp_dtype, typename unified_attention_problem_traits::acc_dtype, @@ -93,7 +94,7 @@ struct unified_attention_kernel_traits unified_attention_mask, unified_attention_traits>; - using unified_attention_pipeline = Blockunified_attentionFwdV3Pipeline; + using unified_attention_pipeline = UnifiedAttentionPipeline; using epilogue = Default2DEpilogue< Default2DEpilogueProblem::acc_dtype, @@ -140,7 +141,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.num_seqs ); - index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs + index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs; dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); diff --git a/include/ck_tile/ops/unified_attention.hpp b/include/ck_tile/ops/unified_attention.hpp new file mode 100644 index 00000000000..62e6c58acb7 --- /dev/null +++ b/include/ck_tile/ops/unified_attention.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + + +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" + +// Block-level components +#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/unified_attention/block/block_dropout.hpp" +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" +#include "ck_tile/ops/unified_attention/block/block_position_encoding.hpp" +#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/unified_attention/block/page_block_navigator.hpp" +#include "ck_tile/ops/unified_attention/block/variants.hpp" + +// Kernel-level components +#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp" + +// Pipeline-level components +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp" +#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp" +#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp" + diff --git a/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp b/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp index 703ec0967ab..3dd36a712da 100644 --- a/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp +++ b/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/unified_attention/block/block_masking.hpp" #include #include diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp index a285c308761..f10b0644876 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -4,20 +4,20 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" namespace ck_tile { template struct TileUnifiedAttentionTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadHeadDim = kPadHeadDim; + static constexpr bool kPadHeadDim = kPadHeadDim_; static constexpr bool kStoreLSE = kStoreLSE_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 5c1a91fb22b..0b7d3137579 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -389,7 +389,6 @@ struct UnifiedAttentionPipeline index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, - const LSEElementFunction& lse_element_func, [[maybe_unused]] const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, @@ -564,7 +563,8 @@ struct UnifiedAttentionPipeline } index_t i_total_loops = 0; - index_t kv_blk_idx = block_tables_ptr[block_table_offset + i_total_loops]; + const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; index_t kv_blk_idx_prev = 0; @@ -674,11 +674,7 @@ struct UnifiedAttentionPipeline async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx as the index /// FIXME: use the future-predicting method to move the window - // move K tile windows - auto k_dram_window = make_tile_window(k_dram_window.get_bottom_tensor_view(), - k_dram_window.get_window_lengths(), - {(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0}, - Policy::template MakeVDramTileDistribution()); + k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; auto K_lds_load = [&](auto k_lds_read_idx) { @@ -687,12 +683,9 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - - /// FIXME: use the future-predicting method to move the window - auto v_dram_window = make_tile_window(v_dram_window.get_bottom_tensor_view(), - v_dram_window.get_window_lengths(), - {(block_tables_ptr[block_table_offset + i_total_loops]) * BLOCK_SIZE, 0}, - Policy::template MakeVDramTileDistribution()); + // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + /// FIXME: use the future-predicting method to move the window + v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; auto V_lds_load = [&](auto v_lds_read_idx) { From f72b994b00939bf0840aa08276ab27787bd8855e Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 20 Oct 2025 15:53:35 +0000 Subject: [PATCH 35/88] More compilation fixes --- .../01_unified_attention/CMakeLists.txt | 9 + .../example_unified_attention.cpp | 168 ++++++++---------- .../unified_attention_d128_bf16_nmask.cpp | 2 +- .../unified_attention_d128_fp16_mask.cpp | 2 +- .../unified_attention_impl.hpp | 8 +- .../kernel/unified_attention_kernel.hpp | 9 +- .../tile_unified_attention_traits.hpp | 2 - .../pipeline/unified_attention_pipeline.hpp | 8 +- .../unified_attention_pipeline_problem.hpp | 2 - 9 files changed, 92 insertions(+), 118 deletions(-) diff --git a/example/ck_tile/01_unified_attention/CMakeLists.txt b/example/ck_tile/01_unified_attention/CMakeLists.txt index 2150ea09b72..11c413d192d 100644 --- a/example/ck_tile/01_unified_attention/CMakeLists.txt +++ b/example/ck_tile/01_unified_attention/CMakeLists.txt @@ -178,6 +178,15 @@ # --- Unified Attention target (kept) --- +# +set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) +# Currently only gfx9 archs are supported by FMHA +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9") +if(NOT INST_TARGETS) + message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() +endif() + set(EXAMPLE_FMHA_FWD_V3 "tile_example_unified_attention") message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index dc0e389f0a4..0885179c778 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -154,7 +154,6 @@ struct Problem float scale_k; float scale_v; mask_info mask; - TensorLayout output_layout; std::vector query_lens; std::vector kv_lens; }; @@ -350,8 +349,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.stride_v_cache_3 = args.stride_k_cache_3; args.o_ptr = o_buf.GetDeviceBuffer(); - args.output_stride_0 = query_stride_0; - args.output_stride_1 = query_stride_1; + args.output_stride_0 = args.query_stride_0; + args.output_stride_1 = args.query_stride_1; // Optional cumulative seqlen overrides (exclude PAD) auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { @@ -386,19 +385,19 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }; calculate_cumulative(eff_query_lens, cu_query_lens); - ck_tile::DeviceMem seq_lens_buf(kv_lens.size()); + ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size()); ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size()); - seq_lens_buf.ToDevice(kv_lens.data()); + seq_lens_buf.ToDevice(eff_kv_lens.data()); query_start_len_buf.ToDevice(cu_query_lens.data()); args.seq_lens_ptr =reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); args.query_start_len_ptr =reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); - auto max_kv_len = std::max_element(problem.kv_lens.begin(), problem.kv_lens.end()); + int max_kv_len = std::max_element(eff_kv_lens.begin(), eff_kv_lens.end()); - index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE + ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE; // Create block_tables ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * sizeof(ck_tile::index_t)); @@ -433,30 +432,24 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) return false; } - std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) - { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else - { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - }(); + // std::size_t flop = [&] { + // if(problem.mask.type == mask_enum::no_mask) + // { + // return 4 * args.num_tokens * problem.nhead_q * + // problem.hdim; + // } + // else + // { + // /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. + // return 2 * args.num_tokens * problem.nhead_q * + // problem.hdim; + // } + // }(); + // TODO fix this + std::size_t flop = 1; float tflops = static_cast(flop) / 1.e9 / time; std::cout << "[" << problem.data_type << "|"; - if(problem.input_layout == problem.output_layout) - { - std::cout << problem.input_layout; - } - else - { - std::cout << problem.input_layout << "-" << problem.output_layout; - } std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed @@ -469,85 +462,70 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) } // transpose tensor descriptors from bhsd to bshd if necessary - if(problem.input_layout != TensorLayout::bshd) - { - q = q.transpose({0, 2, 1, 3}); - k = k.transpose({0, 2, 1, 3}); - v = v.transpose({0, 2, 1, 3}); - } - - ck_tile::HostTensor o_ref(problem.get_output_shape()); - if(problem.output_layout != TensorLayout::bshd) - { - o_ref = o_ref.transpose({0, 2, 1, 3}); - } + // if(problem.input_layout != TensorLayout::bshd) + // { + // q = q.transpose({0, 2, 1, 3}); + // k = k.transpose({0, 2, 1, 3}); + // v = v.transpose({0, 2, 1, 3}); + // } + + // ck_tile::HostTensor o_ref(problem.get_output_shape()); + // if(problem.output_layout != TensorLayout::bshd) + // { + // o_ref = o_ref.transpose({0, 2, 1, 3}); + // } // If variable lengths are provided, compute per-batch references // with the effective lengths; else compute a single full reference. - if(has_varlen_q || has_varlen_k) + // Variable-length aware verification: zero-fill padded region and only compute valid part. + o_ref.SetZero(); + + for(int b = 0; b < problem.batch; ++b) { - // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); + const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; + + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); + v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); - for(int b = 0; b < problem.batch; ++b) + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; - - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) + for(int h = 0; h < problem.nhead_q; ++h) { - for(int h = 0; h < problem.nhead_q; ++h) + for(int d = 0; d < problem.hdim; ++d) { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } + o_ref(b, s, h, d) = o_b(0, s, h, d); } } } } - else - { - // No varlen override: compute the full reference once - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - } + ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp index a6806b95d7a..391103891a9 100644 --- a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp +++ b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp index a710efd2cbd..f2cc00f8356 100644 --- a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp +++ b/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp @@ -7,7 +7,7 @@ namespace ck_tile { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 65f17fa251d..f83209a2c45 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -20,7 +20,7 @@ #include "unified_attention.hpp" #include "mask.hpp" -#define INST_unified_attention_DISPATCH(kernel_traits) \ +#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \ template <> \ std::pair unified_attention_kernel_dispatch( \ const unified_attention_args& args, const stream_config& config) \ @@ -73,7 +73,6 @@ struct unified_attention_kernel_traits using unified_attention_traits = TileUnifiedAttentionTraits; @@ -110,6 +109,7 @@ struct unified_attention_kernel_traits template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { + index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, @@ -123,6 +123,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.scale_k, args.scale_v, args.scale_out, + total_num_q_blocks, args.query_stride_0, args.query_stride_1, args.stride_k_cache_0, @@ -141,9 +142,6 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.num_seqs ); - index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs; - - dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); constexpr dim3 blocks = Kernel::BlockSize(); constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index b5d46c754fd..0ce14afdd0a 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -156,12 +156,11 @@ struct UnifiedAttentionKernel stride_v_cache_3, output_stride_0, output_stride_1}, - { block_tables_ptr, seq_lens_ptr, query_start_len_ptr, num_seqs - }}; + }; return kargs; } @@ -344,7 +343,7 @@ struct UnifiedAttentionKernel index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q; - const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); + // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { @@ -359,7 +358,7 @@ struct UnifiedAttentionKernel q_dram_base, // block sizes make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), - sequence{} + sequence{} ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) const auto q_dram_merged = transform_tensor_view( @@ -486,7 +485,7 @@ struct UnifiedAttentionKernel o_dram_base, // block sizes make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), - sequence{} + sequence{} ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) const auto o_dram_merged = transform_tensor_view( diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp index f10b0644876..b27a09a1b41 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -12,13 +12,11 @@ namespace ck_tile { template struct TileUnifiedAttentionTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; static constexpr bool kPadHeadDim = kPadHeadDim_; - static constexpr bool kStoreLSE = kStoreLSE_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; } diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 0b7d3137579..3dee9b4ad84 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -256,7 +256,6 @@ struct UnifiedAttentionPipeline using VDataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; using SMPLComputeDataType = ck_tile::remove_cvref_t; - using LSEDataType = ck_tile::remove_cvref_t; using PDataType = ck_tile::remove_cvref_t; using OaccDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; @@ -372,11 +371,9 @@ struct UnifiedAttentionPipeline template @@ -1206,14 +1203,12 @@ struct UnifiedAttentionPipeline template + typename VDramBlockWindowTmp> CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const void* block_tables_ptr, index_t block_table_offset, - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale_s, void* smem_ptr) const @@ -1228,7 +1223,6 @@ struct UnifiedAttentionPipeline identity{}, block_tables_ptr, block_table_offset, - lse_dram_block_window_tmp, identity{}, identity{}, identity{}, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp index 3f676bf01dd..d21d8316afe 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -51,8 +51,6 @@ struct UnifiedAttentionPipelineProblem static constexpr bool kPadHeadDim = Traits::kPadHeadDim; static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; - static constexpr auto BiasEnum = Traits::BiasEnum; - static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; From e14487230867650db8fbb8aed56bb4b440e0e1fc Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Thu, 23 Oct 2025 08:11:55 +0000 Subject: [PATCH 36/88] change to BLOCK_M in shape definitions --- .../01_unified_attention/unified_attention_impl.hpp | 3 ++- .../pipeline/tile_unified_attention_shape.hpp | 2 +- .../pipeline/unified_attention_pipeline.hpp | 13 +++++++------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index f83209a2c45..f9a7f1476e3 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -58,7 +58,7 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; - // BLOCK_Q BLOCK_SIZE HEAD_SIZE + // BLOCK_M BLOCK_SIZE HEAD_SIZE using unified_attention_block_tile = sequence<128, 128, 128>; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; using unified_attention_block_warps = sequence<8, 1, 1>; @@ -109,6 +109,7 @@ struct unified_attention_kernel_traits template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { + index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp index d0704626e97..914619dc5a5 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -47,7 +47,7 @@ struct TileUnifiedAttentionShape static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); - static constexpr index_t BLOCK_Q = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along q seqlen // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head) static constexpr index_t BLOCK_SIZE = BlockTile::at(number<1>{}); // BLOCK size for K seqlen static constexpr index_t HEAD_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3dee9b4ad84..a8734ac0ab0 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -268,7 +268,8 @@ struct UnifiedAttentionPipeline static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; - static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q; + static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M; + static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE; static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE; static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED; @@ -302,12 +303,12 @@ struct UnifiedAttentionPipeline } }(); - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize(index_t num_queries_per_kv) + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { // create another LDS buffer for p - return ck_tile::max(BLOCK_Q * num_queries_per_kv * HEAD_SIZE_PADDED * sizeof(PDataType), + return ck_tile::max(BLOCK_M * HEAD_SIZE_PADDED * sizeof(PDataType), Policy::template GetSmemSize() + - BLOCK_Q * num_queries_per_kv * BLOCK_SIZE * sizeof(PDataType)); + BLOCK_M * BLOCK_SIZE * sizeof(PDataType)); } // for debug only @@ -394,7 +395,7 @@ struct UnifiedAttentionPipeline void* smem_ptr) const { using namespace ck_tile; - + index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; static_assert( std::is_same_v> && @@ -409,7 +410,7 @@ struct UnifiedAttentionPipeline BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize(num_queries_per_kv)); + static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); From 3c0e6d37bf23d5c6de4d9003a2dba996a0dc26c4 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Thu, 23 Oct 2025 09:47:30 +0000 Subject: [PATCH 37/88] fixing bugs --- .../unified_attention_impl.hpp | 4 +- .../kernel/unified_attention_kernel.hpp | 10 ++-- .../pipeline/tile_unified_attention_shape.hpp | 2 +- .../pipeline/unified_attention_pipeline.hpp | 4 +- ...fied_attention_pipeline_default_policy.hpp | 48 +++++++++---------- 5 files changed, 36 insertions(+), 32 deletions(-) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index f9a7f1476e3..1c04806a045 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -110,7 +110,9 @@ template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { - index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs; + index_t BLOCK_Q = Kernel::BLOCK_M / args.num_queries_per_kv; + + index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 0ce14afdd0a..639a9db5b0a 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -43,8 +43,8 @@ struct UnifiedAttentionKernel // BLOCK_Q = BLOCK_M // num_queries_per_kv // BLOCK_Q is the block size for q seqlen - static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; - // static constexpr index_t BLOCK_M = UnifiedAttentionPipeline::BLOCK_M; + /// static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; + static constexpr index_t BLOCK_M = UnifiedAttentionPipeline::BLOCK_M; // BLOCK size for K seqlen static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE; @@ -276,10 +276,12 @@ struct UnifiedAttentionKernel ck_tile::index_t pid = blockIdx.x; - const index_t BLOCK_M = BLOCK_Q * kargs.num_queries_per_kv; + const index_t num_queries_per_kv = kargs.num_queries_per_kv; + + const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; // for simplicity, batch stride we just modify the pointer // const index_t num_head_q = kargs.num_head_q; - const index_t num_queries_per_kv = kargs.num_queries_per_kv; + // const index_t num_head_k = num_head_q / num_queries_per_kv; pid = RemapTileIndices(pid, kargs); diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp index 914619dc5a5..8b453723ac4 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -47,7 +47,7 @@ struct TileUnifiedAttentionShape static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); - static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head) static constexpr index_t BLOCK_SIZE = BlockTile::at(number<1>{}); // BLOCK size for K seqlen static constexpr index_t HEAD_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index a8734ac0ab0..af4d79759fd 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -395,7 +395,7 @@ struct UnifiedAttentionPipeline void* smem_ptr) const { using namespace ck_tile; - index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; + constexpr index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; static_assert( std::is_same_v> && @@ -410,7 +410,7 @@ struct UnifiedAttentionPipeline BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * BLOCK_Q * BLOCK_SIZE <= GetSmemSize()); + static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index bfbb1a93f04..72b757c6688 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -92,8 +92,8 @@ struct UnifiedAttentionPipelineDefaultPolicy { using namespace ck_tile; - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -126,8 +126,8 @@ struct UnifiedAttentionPipelineDefaultPolicy { using namespace ck_tile; - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -197,8 +197,8 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t MWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<0>{}); constexpr index_t NWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<1>{}); - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; @@ -233,9 +233,9 @@ struct UnifiedAttentionPipelineDefaultPolicy typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize, - TileGemmShape, + TileGemmShape, typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>; @@ -279,9 +279,9 @@ struct UnifiedAttentionPipelineDefaultPolicy typename Problem::VDataType, typename Problem::OaccDataType, Problem::kBlockSize, - TileGemmShape, + TileGemmShape, typename Problem::UnifiedAttentionShape::Gemm1BlockWarps, typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>; /// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass @@ -317,8 +317,8 @@ struct UnifiedAttentionPipelineDefaultPolicy using namespace ck_tile; // K is always k-major, we use async-copy to load into LDS - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -374,8 +374,8 @@ struct UnifiedAttentionPipelineDefaultPolicy using namespace ck_tile; // K is always k-major, we use async-copy to load into LDS - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -423,8 +423,8 @@ struct UnifiedAttentionPipelineDefaultPolicy { // this function assume K/V can share smem constexpr index_t SingleKSize = [&]() { - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -447,8 +447,8 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t kKPack = GetSmemKPackK(); static_assert(PixelsPerRow % kKPack == 0); constexpr index_t NPerRow = PixelsPerRow / kKPack; - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; static_assert(kNPerBlock % NPerRow == 0); static_assert(kKPerBlock % kKPack == 0); @@ -465,8 +465,8 @@ struct UnifiedAttentionPipelineDefaultPolicy using namespace ck_tile; /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -522,8 +522,8 @@ struct UnifiedAttentionPipelineDefaultPolicy using namespace ck_tile; /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); From 0d2a9badba821b38c8bdc49c91e2fe441dffadba Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 23 Oct 2025 11:17:46 +0000 Subject: [PATCH 38/88] fixed example --- .../example_unified_attention.cpp | 368 +++++++++--------- .../kernel/unified_attention_kernel.hpp | 2 + 2 files changed, 192 insertions(+), 178 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 0885179c778..b77bb1d19db 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -94,8 +94,8 @@ struct Problem explicit Problem(const ck_tile::ArgParser& args) { data_type = args.get_str("prec") == "fp16" - ? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16 - : ck_tile::fmha_fwd_v3_args::data_type_enum::bf16; + ? ck_tile::unified_attention_args::data_type_enum::fp16 + : ck_tile::unified_attention_args::data_type_enum::bf16; batch = args.get_int("b"); max_seqlen_q = args.get_int("s"); max_context_len = args.get_int("s_k"); @@ -107,21 +107,32 @@ struct Problem hdim = args.get_int("d"); query_lens = args.get_int_vec("query_lens"); kv_lens = args.get_int_vec("kv_lens"); - // softmax_scale = args.get_float("scale_s"); - // if(softmax_scale == .0f) - // softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); + // Calculate scale_s + scale_s = args.get_float("scale_s"); + if(scale_s == 0.0f) + scale_s = 1.0f / ck_tile::sqrt(static_cast(hdim)); - // TODO - // mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); + // Initialize other scales + scale = args.get_float("scale"); + scale_k = args.get_float("scale_k"); + scale_v = args.get_float("scale_v"); - // q_eff_lens = args.get_int_vec("q_eff_lens"); - // kv_eff_lens = args.get_int_vec("kv_eff_lens"); + // Calculate sums of query_lens and kv_lens if provided + // int64_t kv_lens_sum = 0; + + for (const auto& len : query_lens) { + num_tokens += len; + } + + // for (const auto& len : kv_lens) { + // kv_lens_sum += len; + // } } std::vector get_query_shape() const { - return {batch * seqlen_q, nhead_q, hdim}; + return {num_tokens, nhead_q, hdim}; } std::vector get_key_shape() const @@ -136,11 +147,11 @@ struct Problem std::vector get_output_shape() const { - return {batch * seqlen_q, nhead_q, hdim}; + return {num_tokens, nhead_q, hdim}; } - ck_tile::fmha_fwd_v3_args::data_type_enum data_type; + ck_tile::unified_attention_args::data_type_enum data_type; ck_tile::index_t batch; ck_tile::index_t num_blks; ck_tile::index_t BLOCK_SIZE; @@ -149,6 +160,7 @@ struct Problem ck_tile::index_t nhead_q; ck_tile::index_t nhead_kv; ck_tile::index_t hdim; + ck_tile::index_t num_tokens; float scale_s; float scale; float scale_k; @@ -198,104 +210,104 @@ auto generate_qkv(const Problem& problem, } -namespace host { -template -CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, - const ck_tile::HostTensor& k_bshd, - const ck_tile::HostTensor& v_bshd, - const mask_info& mask, - ck_tile::HostTensor& o_bshd, - const QElementOp& q_element_op = {}, - const KElementOp& k_element_op = {}, - const VElementOp& v_element_op = {}, - const SAccElementOp& s_acc_element_op = {}) -{ - const int batch_size = q_bshd.mDesc.get_lengths()[0]; - const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; - const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; - const int nhead_q = q_bshd.mDesc.get_lengths()[2]; - const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; - const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; - const int hdim_v = v_bshd.mDesc.get_lengths()[3]; - - const int nr = nhead_q / nhead_kv; - - ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); - ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); - ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); - ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); - - ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); - ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - - // do computation for each batch - for(int b = 0; b < batch_size; ++b) - { - // copy per-batch data from input tensors - // clang-format off - q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); - k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); - v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); - // clang-format on - ck_tile::reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, seqlen_q, seqlen_kv)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - } +// namespace host { +// template +// CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, +// const ck_tile::HostTensor& k_bshd, +// const ck_tile::HostTensor& v_bshd, +// const mask_info& mask, +// ck_tile::HostTensor& o_bshd, +// const QElementOp& q_element_op = {}, +// const KElementOp& k_element_op = {}, +// const VElementOp& v_element_op = {}, +// const SAccElementOp& s_acc_element_op = {}) +// { + // const int batch_size = q_bshd.mDesc.get_lengths()[0]; + // const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; + // const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; + // const int nhead_q = q_bshd.mDesc.get_lengths()[2]; + // const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; + // const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; + // const int hdim_v = v_bshd.mDesc.get_lengths()[3]; + + // const int nr = nhead_q / nhead_kv; + + // ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); + // ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); + // ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); + // ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); + + // ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); + // ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); + + // // do computation for each batch + // for(int b = 0; b < batch_size; ++b) + // { + // // copy per-batch data from input tensors + // // clang-format off + // q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); + // k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); + // v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + // // clang-format on + // ck_tile::reference_batched_gemm( + // q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); + + // if(mask.type == mask_enum::no_mask) + // { + // ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); + // } + // else if(mask.type == mask_enum::window_generic) + // { + // ck_tile::reference_batched_masking( + // s_host_ref, + // ck_tile::make_generic_attention_mask_from_lr_window( + // mask.left, mask.right, seqlen_q, seqlen_kv)); + // } + // else + // { + // // if left window size is negative, means causal + // // else means generic (for current batch) + // if(mask.left < 0) + // ck_tile::reference_batched_masking( + // s_host_ref, + // ck_tile::make_generic_attention_mask_from_lr_window( + // mask.left, + // mask.right, + // seqlen_q, + // seqlen_kv, + // mask.type == mask_enum::mask_top_left)); + // else + // ck_tile::reference_batched_masking( + // s_host_ref, + // ck_tile::make_generic_attention_mask_from_lr_window( + // mask.left, + // mask.right, + // seqlen_q, + // seqlen_kv, + // mask.type == mask_enum::mask_top_left)); + // } - ck_tile::reference_batched_softmax( - s_host_ref, p_host_ref, ck_tile::identity{}); + // ck_tile::reference_batched_softmax( + // s_host_ref, p_host_ref, ck_tile::identity{}); - ck_tile::reference_batched_gemm( - p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); + // ck_tile::reference_batched_gemm( + // p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - // copy resulting per-batch data to the output tensor - o_host_ref.ForEach( - [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); - } -} -} // namespace host + // // copy resulting per-batch data to the output tensor + // o_host_ref.ForEach( + // [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); + // } +// } +// } // namespace host template bool run_impl(const Problem& problem, const RunConfig& run_config) @@ -325,12 +337,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.mask_type = 2; args.hdim = problem.hdim; - args.BLOCK_SIZE = problem.BLOCK_SIZE; args.num_blks = problem.num_blks; // args.query_lens = problem.query_lens // args.kv_lens = problem.kv_lens - args.num_tokens = problem.batch * problem.seqlen_q; args.q_ptr = q_buf.GetDeviceBuffer(); args.query_stride_0 = problem.hdim * problem.nhead_q; args.query_stride_0 = problem.hdim; @@ -373,6 +383,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024); const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024); + args.num_tokens = std::accumulate(eff_query_lens.begin(), eff_query_lens.end(), 0); + // Calculate cumulative sums for kernel arguments if varlen is used std::vector cu_query_lens ; @@ -394,7 +406,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.seq_lens_ptr =reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); args.query_start_len_ptr =reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); - int max_kv_len = std::max_element(eff_kv_lens.begin(), eff_kv_lens.end()); ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE; @@ -446,20 +457,20 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // } // }(); // TODO fix this - std::size_t flop = 1; - float tflops = static_cast(flop) / 1.e9 / time; + // std::size_t flop = 1; + // float tflops = static_cast(flop) / 1.e9 / time; - std::cout << "[" << problem.data_type << "|"; - std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim - << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed - << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops - << " TFlops" << std::endl; + // std::cout << "[" << problem.data_type << "|"; + // std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv + // << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim + // << ", scale_s:" << problem.sacle_s << ", mask:" << problem.mask << std::fixed + // << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops + // << " TFlops" << std::endl; - if(!run_config.verify) - { - return true; - } + // if(!run_config.verify) + // { + // return true; + // } // transpose tensor descriptors from bhsd to bshd if necessary // if(problem.input_layout != TensorLayout::bshd) @@ -478,65 +489,66 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // If variable lengths are provided, compute per-batch references // with the effective lengths; else compute a single full reference. // Variable-length aware verification: zero-fill padded region and only compute valid part. - o_ref.SetZero(); + // o_ref.SetZero(); - for(int b = 0; b < problem.batch; ++b) - { - const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - continue; - - // Slice current batch from inputs (bshd) and build single-batch tensors - ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // Copy effective region - q_b.ForEach([&](auto& self, auto idx) { - // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); - }); - k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - host::fmha_fwd(q_b, - k_b, - v_b, - problem.mask, - o_b, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); - - // Scatter into o_ref's bshd descriptor memory - for(int s = 0; s < seqlen_q_eff; ++s) - { - for(int h = 0; h < problem.nhead_q; ++h) - { - for(int d = 0; d < problem.hdim; ++d) - { - o_ref(b, s, h, d) = o_b(0, s, h, d); - } - } - } - } + // for(int b = 0; b < problem.batch; ++b) + // { + // const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; + // const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + + // if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + // continue; + + // // Slice current batch from inputs (bshd) and build single-batch tensors + // ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + // ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + // ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + // ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + + // // Copy effective region + // q_b.ForEach([&](auto& self, auto idx) { + // // idx: [0, s, h, d] + // self(idx) = q(b, idx[1], idx[2], idx[3]); + // }); + // k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); + // v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + // host::fmha_fwd(q_b, + // k_b, + // v_b, + // problem.mask, + // o_b, + // ck_tile::identity{}, + // ck_tile::identity{}, + // ck_tile::identity{}, + // ck_tile::scales{problem.scale_s}); + + // // Scatter into o_ref's bshd descriptor memory + // for(int s = 0; s < seqlen_q_eff; ++s) + // { + // for(int h = 0; h < problem.nhead_q; ++h) + // { + // for(int d = 0; d < problem.hdim; ++d) + // { + // o_ref(b, s, h, d) = o_b(0, s, h, d); + // } + // } + // } + // } - ck_tile::HostTensor o(problem.get_output_shape()); - o_buf.FromDevice(o.data()); + // ck_tile::HostTensor o(problem.get_output_shape()); + // o_buf.FromDevice(o.data()); - const auto [rtol, atol] = [&] { - if constexpr(std::is_same_v) - return std::make_tuple(1e-3, 1e-3); - else - return std::make_tuple(1e-2, 1e-2); - }(); - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + // const auto [rtol, atol] = [&] { + // if constexpr(std::is_same_v) + // return std::make_tuple(1e-3, 1e-3); + // else + // return std::make_tuple(1e-2, 1e-2); + // }(); + // return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + return true; } int main(int argc, char* argv[]) @@ -551,7 +563,7 @@ int main(int argc, char* argv[]) RunConfig run_config(args); const auto run = [&] { - if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16) + if(problem.data_type == ck_tile::unified_attention_args::data_type_enum::fp16) { return run_impl(problem, run_config); } diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 639a9db5b0a..d0d4e8ecf21 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -126,6 +126,7 @@ struct UnifiedAttentionKernel ck_tile::index_t output_stride_0, ck_tile::index_t output_stride_1, const int32_t* block_tables_ptr, + ck_tile::index_t block_table_stride, const int32_t* seq_lens_ptr, const int32_t* query_start_len_ptr, ck_tile::index_t num_seqs @@ -157,6 +158,7 @@ struct UnifiedAttentionKernel output_stride_0, output_stride_1}, block_tables_ptr, + block_table_stride, seq_lens_ptr, query_start_len_ptr, num_seqs From 3bcef5953631926b97f80322f591704340059a84 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 23 Oct 2025 11:25:07 +0000 Subject: [PATCH 39/88] block table stride fix --- .../ck_tile/01_unified_attention/example_unified_attention.cpp | 1 + example/ck_tile/01_unified_attention/unified_attention.hpp | 1 + example/ck_tile/01_unified_attention/unified_attention_impl.hpp | 1 + .../ops/unified_attention/kernel/unified_attention_kernel.hpp | 1 + 4 files changed, 4 insertions(+) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index b77bb1d19db..02fd6e8db5c 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -428,6 +428,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // Set pointer in args args.block_tables_ptr = reinterpret_cast(block_tables_buf.GetDeviceBuffer()); + args.block_table_stride = max_num_blocks_per_seq; ck_tile::stream_config stream_config{nullptr, diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index cf616ef9a51..50462d31102 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -58,6 +58,7 @@ struct unified_attention_args index_t output_stride_1; const int32_t* block_tables_ptr; + index_t block_table_stride; const int32_t* seq_lens_ptr; // seq len in each batch const int32_t* query_start_len_ptr; // [num_seqs+1] diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 1c04806a045..8d321acfe5f 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -140,6 +140,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.output_stride_0, args.output_stride_1, args.block_tables_ptr, + args.block_table_stride, args.seq_lens_ptr, args.query_start_len_ptr, args.num_seqs diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index d0d4e8ecf21..f9ea3d0b506 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -91,6 +91,7 @@ struct UnifiedAttentionKernel struct UnifiedAttentionVarlenKargs: UnifiedAttentionCommonKargs { const int32_t* block_tables_ptr; + ck_tile::index_t block_table_stride; const int32_t* seq_lens_ptr; // seq len in each batch const int32_t* query_start_len_ptr; // [num_seqs+1] From 5bf72d2bcb7df4f9253ed8dd88421fc038af45b3 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Thu, 23 Oct 2025 11:40:48 +0000 Subject: [PATCH 40/88] fixing bugs --- .../kernel/unified_attention_kernel.hpp | 5 +++-- .../pipeline/unified_attention_pipeline.hpp | 22 +++++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index f9ea3d0b506..44236a734c3 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -460,8 +460,8 @@ struct UnifiedAttentionKernel cur_batch_query_len, // x (i.e. extend) seq_len, // y_total (x + y) cur_batch_query_len, // x_total - num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv times along x dim of the tile - kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + num_queries_per_kv // the same sequence index is repeated num_queries_per_kv times along x dim of the tile + ); else return FmhaMask{cur_batch_query_len, seq_len}; }(); @@ -470,6 +470,7 @@ struct UnifiedAttentionKernel return UnifiedAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, + num_queries_per_kv, kargs.block_tables_ptr, block_table_offset, mask, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index af4d79759fd..b2cb1a3da0d 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -278,8 +278,8 @@ struct UnifiedAttentionPipeline static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDim; // static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -1208,6 +1208,7 @@ struct UnifiedAttentionPipeline CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, FmhaMask mask, @@ -1216,12 +1217,29 @@ struct UnifiedAttentionPipeline { using namespace ck_tile; + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + [[maybe_unused]] const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + [[maybe_unused]] const VElementFunction& v_element_func, + index_t num_queries_per_kv, + const void* block_tables_ptr, + index_t block_table_offset, + [[maybe_unused]] const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + float scale_s, + void* smem_ptr) const + return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, identity{}, v_dram_block_window_tmp, identity{}, + num_queries_per_kv, block_tables_ptr, block_table_offset, identity{}, From e03ed35944c5362cb610de71a8b4412f2bb8dc9b Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 23 Oct 2025 11:42:15 +0000 Subject: [PATCH 41/88] fix the vector max --- .../example_unified_attention.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 02fd6e8db5c..50ac6ea94cc 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -406,7 +406,18 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.seq_lens_ptr =reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); args.query_start_len_ptr =reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); - int max_kv_len = std::max_element(eff_kv_lens.begin(), eff_kv_lens.end()); + + auto max_element = [&](const std::vector& opt_vec) { + ck_tile::index_t max = opt_vec[0]; + for (ck_tile::index_t i: opt_vec) { + if (i > max){ + max = i; + } + } + return max; + }; + + ck_tile::index_t max_kv_len = max_element(eff_kv_lens); ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE; From 6ea56bec34a3d56ceac3a63594a516ba66b21f02 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 23 Oct 2025 11:44:14 +0000 Subject: [PATCH 42/88] removed redundent code --- .../pipeline/unified_attention_pipeline.hpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index b2cb1a3da0d..1bbaacdb1e6 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -1217,22 +1217,6 @@ struct UnifiedAttentionPipeline { using namespace ck_tile; - CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - [[maybe_unused]] const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - [[maybe_unused]] const VElementFunction& v_element_func, - index_t num_queries_per_kv, - const void* block_tables_ptr, - index_t block_table_offset, - [[maybe_unused]] const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, - FmhaMask mask, - float scale_s, - void* smem_ptr) const - return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, From 3bb29bfd6cfb9e5a24929d9ac531208c2a310d42 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 23 Oct 2025 11:49:35 +0000 Subject: [PATCH 43/88] Fixed pipeline args --- .../unified_attention/pipeline/unified_attention_pipeline.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 1bbaacdb1e6..b8534b7a6ae 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -1229,7 +1229,6 @@ struct UnifiedAttentionPipeline identity{}, identity{}, identity{}, - identity{}, mask, scale_s, smem_ptr); From ebf1c4c305db789197bd90b96b65cb1850993d15 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 23 Oct 2025 11:57:36 +0000 Subject: [PATCH 44/88] const blockq --- .../unified_attention/pipeline/unified_attention_pipeline.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index b8534b7a6ae..ec4b3355ded 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -395,7 +395,8 @@ struct UnifiedAttentionPipeline void* smem_ptr) const { using namespace ck_tile; - constexpr index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; + // TODO do we make the num_queries_per_kv and num_head conexpr??? + const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; static_assert( std::is_same_v> && From d18f8e46bf6d5edf98baa92f8cde544c9caa1e46 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 23 Oct 2025 12:02:10 +0000 Subject: [PATCH 45/88] Fixed block Q with M --- .../ops/unified_attention/kernel/unified_attention_kernel.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 44236a734c3..820d129f615 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -470,7 +470,6 @@ struct UnifiedAttentionKernel return UnifiedAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, - num_queries_per_kv, kargs.block_tables_ptr, block_table_offset, mask, From 89cfdb35e05b0ff790b294e072a2a5d1d3b871e6 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 23 Oct 2025 12:02:18 +0000 Subject: [PATCH 46/88] Fixed block Q with M --- .../pipeline/unified_attention_pipeline.hpp | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index ec4b3355ded..48f1f4deb8c 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -384,7 +384,6 @@ struct UnifiedAttentionPipeline [[maybe_unused]] const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile [[maybe_unused]] const VElementFunction& v_element_func, - index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, [[maybe_unused]] const SAccElementFunction& s_acc_element_func, @@ -395,16 +394,13 @@ struct UnifiedAttentionPipeline void* smem_ptr) const { using namespace ck_tile; - // TODO do we make the num_queries_per_kv and num_head conexpr??? - const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; - static_assert( std::is_same_v> && std::is_same_v> && std::is_same_v>, "wrong!"); - static_assert(BLOCK_Q == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + static_assert(BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -414,29 +410,29 @@ struct UnifiedAttentionPipeline static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto s_lds_window = - make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); auto p_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + Policy::template GetSmemSize()), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto p_lds_window = - make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); auto o_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto o_lds_window = - make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); auto m_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + Policy::template GetSmemSize()), - MakeSimpleLdsDesc1D()); + MakeSimpleLdsDesc1D()); [[maybe_unused]] auto m_lds_window = - make_tile_window(m_lds, make_tuple(number{}), {0}); + make_tile_window(m_lds, make_tuple(number{}), {0}); const index_t warp_group_id = get_warp_id() / 4; @@ -543,7 +539,7 @@ struct UnifiedAttentionPipeline const auto q_origin = q_dram_window.get_window_origin(); const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE); index_t kv_token_start = seqlen_k_start; @@ -812,7 +808,7 @@ struct UnifiedAttentionPipeline gemm_0(sp(sp_reg_idx).sp_compute, get_slice_tile(q_tile, sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence{}), get_slice_tile(kv_tile.k_tile, sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, sequence{})); @@ -822,7 +818,7 @@ struct UnifiedAttentionPipeline gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence{}), get_slice_tile(kv_tile.v_tile, sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, sequence{})); @@ -836,7 +832,7 @@ struct UnifiedAttentionPipeline gemm_0(sp(sp_reg_idx).sp_compute, get_slice_tile(q_tile, sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence{}), get_slice_tile(kv_tile.k_tile, sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, sequence{})); @@ -846,7 +842,7 @@ struct UnifiedAttentionPipeline gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence{}), get_slice_tile(kv_tile.v_tile, sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, sequence{})); @@ -894,7 +890,7 @@ struct UnifiedAttentionPipeline if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( - q_origin.at(number<0>{}), kv_token_start, number{}, number{}); + q_origin.at(number<0>{}), kv_token_start, number{}, number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, @@ -1209,7 +1205,6 @@ struct UnifiedAttentionPipeline CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - index_t num_queries_per_kv, const void* block_tables_ptr, index_t block_table_offset, FmhaMask mask, @@ -1224,7 +1219,6 @@ struct UnifiedAttentionPipeline identity{}, v_dram_block_window_tmp, identity{}, - num_queries_per_kv, block_tables_ptr, block_table_offset, identity{}, From 22c5c209771d8ab0d99c3254149b70cfe92d5715 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 24 Oct 2025 09:45:32 +0000 Subject: [PATCH 47/88] Debugging window size --- .../kernel/unified_attention_kernel.hpp | 6 +++--- .../pipeline/unified_attention_pipeline.hpp | 6 ++++++ .../unified_attention_pipeline_default_policy.hpp | 13 +++++++------ script/cmake-ck-dev.sh | 2 +- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 820d129f615..405d93f1244 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -384,7 +384,7 @@ struct UnifiedAttentionKernel // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) auto q_dram_window = make_tile_window( q_dram, - make_tuple(BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED), + make_tuple(BLOCK_M, HEAD_SIZE_PADDED), {query_pos * num_queries_per_kv, 0} ); @@ -481,7 +481,7 @@ struct UnifiedAttentionKernel auto o_dram = [&]() { const auto o_dram_base = make_naive_tensor_view( o_ptr, - make_tuple(seq_len, num_queries_per_kv, HEAD_SIZE), + make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1), number{}, number<1>{}); @@ -497,7 +497,7 @@ struct UnifiedAttentionKernel o_dram_pad, make_tuple( make_merge_transform( - make_tuple(seq_len, num_queries_per_kv) + make_tuple(query_len_padded, num_queries_per_kv) ), make_pass_through_transform(HEAD_SIZE_PADDED) ), diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 48f1f4deb8c..56bfd515450 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -400,6 +400,12 @@ struct UnifiedAttentionPipeline std::is_same_v>, "wrong!"); + // TODO remove these static asserts + static_assert(BLOCK_M == 128, "BLOCK_M == 128"); // pass so BLOCK_M=128 + // static_assert(QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] == 128, "QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] == 128"); + static_assert(QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] == 0, "QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] == 0"); + // passed so windows length[0] = 0 why? + static_assert(BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index 72b757c6688..32f97aba50d 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -130,9 +130,10 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; - constexpr index_t WarpSize = ck_tile::get_warp_size(); + constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64 constexpr index_t KVector = GetAlignmentV(); // this is for global load + // 4 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave @@ -140,11 +141,11 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - constexpr index_t N0 = NumIssues; - constexpr index_t N1 = LaneGroups; - constexpr index_t N2 = NumWarps; - constexpr index_t K0 = LanesPerK; - constexpr index_t K1 = KVector; + constexpr index_t N0 = NumIssues; // 8 + constexpr index_t N1 = LaneGroups; // 2 + constexpr index_t N2 = NumWarps; // 8 + constexpr index_t K0 = LanesPerK; // 32 + constexpr index_t K1 = KVector; // 4 return make_static_tile_distribution( tile_distribution_encoding, diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 6220009b03c..c03d8e8f0a5 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -17,7 +17,7 @@ else MY_PROJECT_SOURCE=".." fi -GPU_TARGETS="gfx908;gfx90a;gfx942" +GPU_TARGETS="gfx950" if [ $# -ge 1 ]; then case "$1" in From d5c8315affac9f8774b38710d4ed8a4333a69058 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 28 Oct 2025 11:32:48 +0000 Subject: [PATCH 48/88] fixed window creation number<>{} --- .../unified_attention/kernel/unified_attention_kernel.hpp | 8 ++++---- .../pipeline/unified_attention_pipeline.hpp | 6 ------ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 405d93f1244..f436cd8a2d4 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -384,7 +384,7 @@ struct UnifiedAttentionKernel // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) auto q_dram_window = make_tile_window( q_dram, - make_tuple(BLOCK_M, HEAD_SIZE_PADDED), + make_tuple(number{}, number{}), {query_pos * num_queries_per_kv, 0} ); @@ -420,7 +420,7 @@ struct UnifiedAttentionKernel }(); auto k_dram_window = make_tile_window( - k_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0}); + k_dram, make_tuple(number{}, number{}), {0, 0}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( @@ -451,7 +451,7 @@ struct UnifiedAttentionKernel }(); auto v_dram_window = make_tile_window( - v_dram, make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), {0, 0}); + v_dram, make_tuple(number{}, number{}), {0, 0}); FmhaMask mask = [&]() { if constexpr(kHasMask) @@ -510,7 +510,7 @@ struct UnifiedAttentionKernel auto o_dram_window = make_tile_window(o_dram, - make_tuple(BLOCK_M, HEAD_SIZE_PADDED), + make_tuple(number{}, number{}), {query_pos * num_queries_per_kv, 0}); EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 56bfd515450..48f1f4deb8c 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -400,12 +400,6 @@ struct UnifiedAttentionPipeline std::is_same_v>, "wrong!"); - // TODO remove these static asserts - static_assert(BLOCK_M == 128, "BLOCK_M == 128"); // pass so BLOCK_M=128 - // static_assert(QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] == 128, "QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] == 128"); - static_assert(QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] == 0, "QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] == 0"); - // passed so windows length[0] = 0 why? - static_assert(BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && From 98f15eedcf5a4639c562f21b60747ea49db63641 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 5 Nov 2025 11:44:10 +0000 Subject: [PATCH 49/88] Added block q --- .../01_unified_attention/unified_attention_impl.hpp | 4 ++-- .../kernel/unified_attention_kernel.hpp | 11 +++++++---- .../pipeline/tile_unified_attention_shape.hpp | 5 +++-- .../pipeline/unified_attention_pipeline.hpp | 2 ++ 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 8d321acfe5f..5e561d1cf04 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -58,8 +58,8 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; - // BLOCK_M BLOCK_SIZE HEAD_SIZE - using unified_attention_block_tile = sequence<128, 128, 128>; + // BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE + using unified_attention_block_tile = sequence<128, 32, 128, 128>; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; using unified_attention_block_warps = sequence<8, 1, 1>; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index f436cd8a2d4..83261d41444 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -45,6 +45,7 @@ struct UnifiedAttentionKernel // BLOCK_Q is the block size for q seqlen /// static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; static constexpr index_t BLOCK_M = UnifiedAttentionPipeline::BLOCK_M; + static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; // BLOCK size for K seqlen static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE; @@ -281,7 +282,7 @@ struct UnifiedAttentionKernel const index_t num_queries_per_kv = kargs.num_queries_per_kv; - const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; + // const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; // for simplicity, batch stride we just modify the pointer // const index_t num_head_q = kargs.num_head_q; @@ -357,12 +358,12 @@ struct UnifiedAttentionKernel make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, - number<1>{}); + number<2>{}); const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED q_dram_base, // block sizes - make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), + make_tuple(number{}, number<1>{}, number{}), sequence{} ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) @@ -372,13 +373,15 @@ struct UnifiedAttentionKernel make_merge_transform( make_tuple(query_len_padded, num_queries_per_kv) ), - make_pass_through_transform(HEAD_SIZE_PADDED) + make_pass_through_transform(number{}) ), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}) ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim + return q_dram_merged; }(); + // static_assert(q_dram.desc_[number<0>{}] == 0, "q_dram.get_bottom_tensor_view()[number<0>{}] == 0"); // Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim) // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp index 8b453723ac4..790b0614a67 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -48,9 +48,10 @@ struct TileUnifiedAttentionShape static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) + static constexpr index_t BLOCK_Q = BlockTile::at(number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head) - static constexpr index_t BLOCK_SIZE = BlockTile::at(number<1>{}); // BLOCK size for K seqlen - static constexpr index_t HEAD_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen + static constexpr index_t BLOCK_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen + static constexpr index_t HEAD_SIZE = BlockTile::at(number<3>{}); // BLOCK size for K seqlen // static constexpr index_t kQKHeaddim = // BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 48f1f4deb8c..0afce5c287a 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -269,6 +269,7 @@ struct UnifiedAttentionPipeline static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M; + static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q; static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE; static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE; @@ -443,6 +444,7 @@ struct UnifiedAttentionPipeline auto q_dram_window = make_tile_window_linear( q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution()); + // auto q_dram_window = q_dram_block_window_tmp; // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; From 1d3304ab9ed41d3ef5e656a95b6d506bd796dc9d Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 7 Nov 2025 12:31:31 +0000 Subject: [PATCH 50/88] revert change on fmha --- example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index 457139fbfff..194675f9627 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -57,8 +57,9 @@ struct fmha_fwd_v3_kernel_traits { static constexpr auto date_type = DataType; static constexpr bool is_variable_seqlen = IsVariableSeqlen; - static constexpr bool is_masking = IsMasking - // M0 N0 K0 N1 K1 + static constexpr bool is_masking = IsMasking; + + // M0 N0 K0 N1 K1 using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; using fmha_warp_gemm_shape = sequence<32, 32, 16>; using fmha_block_warps = sequence<8, 1, 1>; From 47c9d0a131f24ef17770b928516ef86437586e11 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 10 Nov 2025 11:34:52 +0000 Subject: [PATCH 51/88] More compilation fix --- .../unified_attention_impl.hpp | 2 +- .../kernel/unified_attention_kernel.hpp | 24 ++++++++---- .../pipeline/unified_attention_pipeline.hpp | 38 +++++++++++-------- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 5e561d1cf04..64aead84f5b 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -59,7 +59,7 @@ struct unified_attention_kernel_traits static constexpr bool is_masking = IsMasking; // BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE - using unified_attention_block_tile = sequence<128, 32, 128, 128>; + using unified_attention_block_tile = sequence<256, 64, 128, 128>; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; using unified_attention_block_warps = sequence<8, 1, 1>; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 83261d41444..ee4eeab9204 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -319,16 +319,24 @@ struct UnifiedAttentionKernel const index_t query_pos = q_block_local_idx * BLOCK_Q; const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - // const index_t context_len = seq_len - cur_batch_query_len; + const index_t context_len = seq_len - cur_batch_query_len; - // const index_t max_seq_prefix_len = ( - // context_len - // + q_block_local_idx * BLOCK_Q - // + (BLOCK_M - 1) // num_queries_per_kv - // + 1 - // ); + index_t _max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ); + + if (seq_len < _max_seq_prefix_len) { + _max_seq_prefix_len = seq_len; + } + const auto max_seq_prefix_len = _max_seq_prefix_len; + const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + // TODO sliding window + const index_t num_blocks_start = 0; index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2; // Q/K/V DRAM and DRAM window @@ -473,6 +481,8 @@ struct UnifiedAttentionKernel return UnifiedAttentionPipeline{}(q_dram_window, k_dram_window, v_dram_window, + num_blocks, + num_blocks_start, kargs.block_tables_ptr, block_table_offset, mask, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 0afce5c287a..b2541ab74e8 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -277,8 +277,7 @@ struct UnifiedAttentionPipeline static_assert(HEAD_SIZE_PADDED <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + // static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim; static constexpr bool kPadHeadDimV = Problem::kPadHeadDim; // static constexpr bool kStoreLSE = Problem::kStoreLSE; @@ -385,6 +384,8 @@ struct UnifiedAttentionPipeline [[maybe_unused]] const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile [[maybe_unused]] const VElementFunction& v_element_func, + const index_t num_blocks, + const index_t num_blocks_start, const void* block_tables_ptr, index_t block_table_offset, [[maybe_unused]] const SAccElementFunction& s_acc_element_func, @@ -540,16 +541,18 @@ struct UnifiedAttentionPipeline clear_tile(l); const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + // const auto [seqlen_k_start, seqlen_k_end] = + // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE); - index_t kv_token_start = seqlen_k_start; + // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE); + const auto num_total_loop = num_blocks; + // index_t kv_token_start = seqlen_k_start; + // TODO check is paddings kPadSeqLenK // check early exit if no work to do - if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + if constexpr(FmhaMask::IsMasking) { - if(num_total_loop <= 0) + if(num_total_loop - num_blocks_start <= 0) { @@ -559,7 +562,7 @@ struct UnifiedAttentionPipeline } } - index_t i_total_loops = 0; + index_t i_total_loops = num_blocks_start; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; index_t kv_blk_idx_prev = 0; @@ -889,10 +892,10 @@ struct UnifiedAttentionPipeline }; auto fmha_mask = [&](auto sp_reg_idx) { - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + if constexpr(FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile( - q_origin.at(number<0>{}), kv_token_start, number{}, number{}); + q_origin.at(number<0>{}), i_total_loops * BLOCK_SIZE, number{}, number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, @@ -900,7 +903,7 @@ struct UnifiedAttentionPipeline [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = kv_token_start + tile_idx.at(number<1>{}); + const auto col = i_total_loops * BLOCK_SIZE + tile_idx.at(number<1>{}); return mask.IsOutOfBound(row, col); }); } @@ -975,6 +978,7 @@ struct UnifiedAttentionPipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); + // TODO what is this??? Scheduler::schedule(cl_p, number<1>{}); fmha_mask(xdl_SP_p01_reg_idx); @@ -1003,7 +1007,7 @@ struct UnifiedAttentionPipeline cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); Scheduler::schedule(cl_p, number<3>{}); - kv_token_start += BLOCK_SIZE; + // kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1050,7 +1054,7 @@ struct UnifiedAttentionPipeline Scheduler::schedule(cl_p, number<2>{}); fmha_mask(xdl_SP_p01_reg_idx); - kv_token_start += BLOCK_SIZE; + // kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1128,7 +1132,7 @@ struct UnifiedAttentionPipeline fmha_alu0(number<0>{}); fmha_alu_D_upd(); - kv_token_start += BLOCK_SIZE; + // kv_token_start += BLOCK_SIZE; ++i_total_loops; if(num_total_loop <= i_total_loops) { @@ -1207,6 +1211,8 @@ struct UnifiedAttentionPipeline CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const index_t num_blocks, + const index_t num_blocks_start, const void* block_tables_ptr, index_t block_table_offset, FmhaMask mask, @@ -1221,6 +1227,8 @@ struct UnifiedAttentionPipeline identity{}, v_dram_block_window_tmp, identity{}, + num_blocks, + num_blocks_start, block_tables_ptr, block_table_offset, identity{}, From 618ed6defb2d7959ae3dcf3db43be74a241141e6 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 11 Nov 2025 14:35:26 +0000 Subject: [PATCH 52/88] cmake list update --- .../01_unified_attention/CMakeLists.txt | 34 +- .../example_unified_attention.cpp | 243 ++++++----- .../unified_attention.cpp | 12 +- .../unified_attention.hpp | 10 +- .../unified_attention_impl.hpp | 112 ++--- include/ck_tile/ops/unified_attention.hpp | 16 +- .../unified_attention/block/block_masking.hpp | 14 +- .../kernel/unified_attention_kernel.hpp | 402 +++++++++--------- .../pipeline/tile_unified_attention_shape.hpp | 11 +- .../tile_unified_attention_traits.hpp | 5 +- .../pipeline/unified_attention_pipeline.hpp | 66 +-- ...fied_attention_pipeline_default_policy.hpp | 59 +-- .../unified_attention_pipeline_problem.hpp | 18 +- 13 files changed, 505 insertions(+), 497 deletions(-) diff --git a/example/ck_tile/01_unified_attention/CMakeLists.txt b/example/ck_tile/01_unified_attention/CMakeLists.txt index 11c413d192d..45f67f3e0d6 100644 --- a/example/ck_tile/01_unified_attention/CMakeLists.txt +++ b/example/ck_tile/01_unified_attention/CMakeLists.txt @@ -187,36 +187,42 @@ if(NOT INST_TARGETS) return() endif() -set(EXAMPLE_FMHA_FWD_V3 "tile_example_unified_attention") -message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") +set(EXAMPLE_UNIFIED_ATTENTION "tile_example_unified_attention") +message(DEBUG "adding example ${EXAMPLE_UNIFIED_ATTENTION}") -add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_unified_attention.cpp) -target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS +add_executable(${EXAMPLE_UNIFIED_ATTENTION} EXCLUDE_FROM_ALL example_unified_attention.cpp) +target_include_directories(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +file(GLOB UNIFIED_ATTENTION_INSTANCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" ) -target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE +target_sources(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE unified_attention.cpp - ${FMHA_FWD_V3_INSTANCES} + ${UNIFIED_ATTENTION_INSTANCES} ) -set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) -list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS +set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS) +list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero -Wno-undefined-func-template --save-temps ) -set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) +set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS) check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) if(HAS_DISABLE_PACKED_FP32) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS -mllvm --amdgpu-disable-packed-fp32=1 ) - list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS + list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS -DCK_TILE_DISABLE_PACKED_FP32=1 ) endif() -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) -target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) +target_compile_options(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS}) +target_compile_definitions(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 50ac6ea94cc..b103043b708 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -93,20 +93,20 @@ struct Problem { explicit Problem(const ck_tile::ArgParser& args) { - data_type = args.get_str("prec") == "fp16" - ? ck_tile::unified_attention_args::data_type_enum::fp16 - : ck_tile::unified_attention_args::data_type_enum::bf16; - batch = args.get_int("b"); - max_seqlen_q = args.get_int("s"); + data_type = args.get_str("prec") == "fp16" + ? ck_tile::unified_attention_args::data_type_enum::fp16 + : ck_tile::unified_attention_args::data_type_enum::bf16; + batch = args.get_int("b"); + max_seqlen_q = args.get_int("s"); max_context_len = args.get_int("s_k"); - num_blks = args.get_int("nb"); - BLOCK_SIZE = args.get_int("bs"); - nhead_q = args.get_int("h"); - nhead_kv = args.get_int("h_k"); + num_blks = args.get_int("nb"); + BLOCK_SIZE = args.get_int("bs"); + nhead_q = args.get_int("h"); + nhead_kv = args.get_int("h_k"); - hdim = args.get_int("d"); + hdim = args.get_int("d"); query_lens = args.get_int_vec("query_lens"); - kv_lens = args.get_int_vec("kv_lens"); + kv_lens = args.get_int_vec("kv_lens"); // Calculate scale_s scale_s = args.get_float("scale_s"); @@ -114,14 +114,15 @@ struct Problem scale_s = 1.0f / ck_tile::sqrt(static_cast(hdim)); // Initialize other scales - scale = args.get_float("scale"); + scale = args.get_float("scale"); scale_k = args.get_float("scale_k"); scale_v = args.get_float("scale_v"); // Calculate sums of query_lens and kv_lens if provided // int64_t kv_lens_sum = 0; - for (const auto& len : query_lens) { + for(const auto& len : query_lens) + { num_tokens += len; } @@ -130,10 +131,7 @@ struct Problem // } } - std::vector get_query_shape() const - { - return {num_tokens, nhead_q, hdim}; - } + std::vector get_query_shape() const { return {num_tokens, nhead_q, hdim}; } std::vector get_key_shape() const { @@ -145,11 +143,7 @@ struct Problem return {num_blks, BLOCK_SIZE, nhead_kv, hdim}; } - std::vector get_output_shape() const - { - return {num_tokens, nhead_q, hdim}; - - } + std::vector get_output_shape() const { return {num_tokens, nhead_q, hdim}; } ck_tile::unified_attention_args::data_type_enum data_type; ck_tile::index_t batch; @@ -209,7 +203,6 @@ auto generate_qkv(const Problem& problem, return std::make_tuple(q, k, v); } - // namespace host { // template q_host_ref({nhead_q, seqlen_q, hdim_qk}); - // ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); - // ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); - // ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); - - // ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); - // ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - - // // do computation for each batch - // for(int b = 0; b < batch_size; ++b) - // { - // // copy per-batch data from input tensors - // // clang-format off - // q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); - // k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); - // v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); - // // clang-format on - // ck_tile::reference_batched_gemm( - // q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - - // if(mask.type == mask_enum::no_mask) - // { - // ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - // } - // else if(mask.type == mask_enum::window_generic) - // { - // ck_tile::reference_batched_masking( - // s_host_ref, - // ck_tile::make_generic_attention_mask_from_lr_window( - // mask.left, mask.right, seqlen_q, seqlen_kv)); - // } - // else - // { - // // if left window size is negative, means causal - // // else means generic (for current batch) - // if(mask.left < 0) - // ck_tile::reference_batched_masking( - // s_host_ref, - // ck_tile::make_generic_attention_mask_from_lr_window( - // mask.left, - // mask.right, - // seqlen_q, - // seqlen_kv, - // mask.type == mask_enum::mask_top_left)); - // else - // ck_tile::reference_batched_masking( - // s_host_ref, - // ck_tile::make_generic_attention_mask_from_lr_window( - // mask.left, - // mask.right, - // seqlen_q, - // seqlen_kv, - // mask.type == mask_enum::mask_top_left)); - // } - - // ck_tile::reference_batched_softmax( - // s_host_ref, p_host_ref, ck_tile::identity{}); - - // ck_tile::reference_batched_gemm( - // p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - - // // copy resulting per-batch data to the output tensor - // o_host_ref.ForEach( - // [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); - // } +// const int batch_size = q_bshd.mDesc.get_lengths()[0]; +// const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; +// const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; +// const int nhead_q = q_bshd.mDesc.get_lengths()[2]; +// const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; +// const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; +// const int hdim_v = v_bshd.mDesc.get_lengths()[3]; + +// const int nr = nhead_q / nhead_kv; + +// ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); +// ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); +// ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); +// ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); + +// ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); +// ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); + +// // do computation for each batch +// for(int b = 0; b < batch_size; ++b) +// { +// // copy per-batch data from input tensors +// // clang-format off +// q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , +// idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], +// idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = +// v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); +// // clang-format on +// ck_tile::reference_batched_gemm( +// q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); + +// if(mask.type == mask_enum::no_mask) +// { +// ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); +// } +// else if(mask.type == mask_enum::window_generic) +// { +// ck_tile::reference_batched_masking( +// s_host_ref, +// ck_tile::make_generic_attention_mask_from_lr_window( +// mask.left, mask.right, seqlen_q, seqlen_kv)); +// } +// else +// { +// // if left window size is negative, means causal +// // else means generic (for current batch) +// if(mask.left < 0) +// ck_tile::reference_batched_masking( +// s_host_ref, +// ck_tile::make_generic_attention_mask_from_lr_window( +// mask.left, +// mask.right, +// seqlen_q, +// seqlen_kv, +// mask.type == mask_enum::mask_top_left)); +// else +// ck_tile::reference_batched_masking( +// s_host_ref, +// ck_tile::make_generic_attention_mask_from_lr_window( +// mask.left, +// mask.right, +// seqlen_q, +// seqlen_kv, +// mask.type == mask_enum::mask_top_left)); +// } + +// ck_tile::reference_batched_softmax( +// s_host_ref, p_host_ref, ck_tile::identity{}); + +// ck_tile::reference_batched_gemm( +// p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); + +// // copy resulting per-batch data to the output tensor +// o_host_ref.ForEach( +// [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); +// } // } // } // namespace host @@ -328,20 +322,20 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::unified_attention_args args{}; - args.data_type = problem.data_type; - args.num_seqs = problem.batch; + args.data_type = problem.data_type; + args.num_seqs = problem.batch; // args.seqlen_q = problem.seqlen_q; // args.seqlen_k = problem.seqlen_k; - args.num_head_q = problem.nhead_q; - args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv; - args.mask_type = 2; - args.hdim = problem.hdim; + args.num_head_q = problem.nhead_q; + args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv; + args.mask_type = 2; + args.hdim = problem.hdim; args.num_blks = problem.num_blks; // args.query_lens = problem.query_lens // args.kv_lens = problem.kv_lens - args.q_ptr = q_buf.GetDeviceBuffer(); + args.q_ptr = q_buf.GetDeviceBuffer(); args.query_stride_0 = problem.hdim * problem.nhead_q; args.query_stride_0 = problem.hdim; @@ -352,13 +346,13 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.stride_k_cache_2 = problem.hdim; args.stride_k_cache_3 = 1; - args.v_ptr = v_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); args.stride_v_cache_0 = args.stride_k_cache_0; args.stride_v_cache_1 = args.stride_k_cache_1; args.stride_v_cache_2 = args.stride_k_cache_2; args.stride_v_cache_3 = args.stride_k_cache_3; - args.o_ptr = o_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); args.output_stride_0 = args.query_stride_0; args.output_stride_1 = args.query_stride_1; @@ -380,13 +374,13 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) return eff; }; - const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024); - const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024); + const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024); + const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024); args.num_tokens = std::accumulate(eff_query_lens.begin(), eff_query_lens.end(), 0); // Calculate cumulative sums for kernel arguments if varlen is used - std::vector cu_query_lens ; + std::vector cu_query_lens; auto calculate_cumulative = [&](const std::vector& per_batch_vec, std::vector& cum_vec) { @@ -403,14 +397,16 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) seq_lens_buf.ToDevice(eff_kv_lens.data()); query_start_len_buf.ToDevice(cu_query_lens.data()); - args.seq_lens_ptr =reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); - args.query_start_len_ptr =reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); - + args.seq_lens_ptr = reinterpret_cast(seq_lens_buf.GetDeviceBuffer()); + args.query_start_len_ptr = + reinterpret_cast(query_start_len_buf.GetDeviceBuffer()); auto max_element = [&](const std::vector& opt_vec) { ck_tile::index_t max = opt_vec[0]; - for (ck_tile::index_t i: opt_vec) { - if (i > max){ + for(ck_tile::index_t i : opt_vec) + { + if(i > max) + { max = i; } } @@ -419,10 +415,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::index_t max_kv_len = max_element(eff_kv_lens); - ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE; + ck_tile::index_t max_num_blocks_per_seq = + (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE; // Create block_tables - ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * + sizeof(ck_tile::index_t)); // Allocate host memory for block_tables std::vector block_tables_host(problem.batch * max_num_blocks_per_seq); @@ -430,7 +428,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // Fill block_tables with random integers between 0 and num_blocks-1 std::mt19937 rng(run_config.seed ? *run_config.seed : std::random_device{}()); std::uniform_int_distribution dist(0, problem.num_blks - 1); - for (size_t i = 0; i < block_tables_host.size(); ++i) { + for(size_t i = 0; i < block_tables_host.size(); ++i) + { block_tables_host[i] = dist(rng); } @@ -438,10 +437,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) block_tables_buf.ToDevice(block_tables_host.data()); // Set pointer in args - args.block_tables_ptr = reinterpret_cast(block_tables_buf.GetDeviceBuffer()); + args.block_tables_ptr = + reinterpret_cast(block_tables_buf.GetDeviceBuffer()); args.block_table_stride = max_num_blocks_per_seq; - ck_tile::stream_config stream_config{nullptr, true, /*log_level=*/0, @@ -476,7 +475,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv // << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim // << ", scale_s:" << problem.sacle_s << ", mask:" << problem.mask << std::fixed - // << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops + // << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << + // tflops // << " TFlops" << std::endl; // if(!run_config.verify) @@ -548,7 +548,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // } // } // } - // ck_tile::HostTensor o(problem.get_output_shape()); // o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_unified_attention/unified_attention.cpp b/example/ck_tile/01_unified_attention/unified_attention.cpp index 8c2b22f0a29..fb3e37e1e06 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/unified_attention.cpp @@ -7,7 +7,8 @@ namespace ck_tile { -std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type) +std::ostream& operator<<(std::ostream& stream, + const unified_attention_args::data_type_enum& data_type) { switch(data_type) { @@ -17,14 +18,16 @@ std::ostream& operator<<(std::ostream& stream, const unified_attention_args::dat } } -std::pair unified_attention(const unified_attention_args& args, const stream_config& config) +std::pair unified_attention(const unified_attention_args& args, + const stream_config& config) { if(args.data_type == unified_attention_args::data_type_enum::fp16) { if(args.mask_type == static_cast(mask_enum::no_mask)) { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; return unified_attention_kernel_dispatch(args, config); } @@ -41,7 +44,8 @@ std::pair unified_attention(const unified_attention_args& args, con if(args.mask_type == static_cast(mask_enum::no_mask)) { using kernel_traits = - unified_attention_kernel_traits; + unified_attention_kernel_traits; return unified_attention_kernel_dispatch(args, config); } diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index 50462d31102..083a3acd85e 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -52,24 +52,26 @@ struct unified_attention_args index_t stride_v_cache_1; index_t stride_v_cache_2; index_t stride_v_cache_3; - + void* o_ptr; index_t output_stride_0; index_t output_stride_1; const int32_t* block_tables_ptr; index_t block_table_stride; - const int32_t* seq_lens_ptr; // seq len in each batch + const int32_t* seq_lens_ptr; // seq len in each batch const int32_t* query_start_len_ptr; // [num_seqs+1] index_t num_seqs; // number of batches for q }; -std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type); +std::ostream& operator<<(std::ostream& stream, + const unified_attention_args::data_type_enum& data_type); // return value: // first = whether the kernel was launched (true = launched, false = skipped) // second = elapsed time (ms) of the kernel launch, valid only if first == true -std::pair unified_attention(const unified_attention_args& args, const stream_config& config); +std::pair unified_attention(const unified_attention_args& args, + const stream_config& config); } // namespace ck_tile diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 64aead84f5b..dc3104e4f23 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -20,13 +20,13 @@ #include "unified_attention.hpp" #include "mask.hpp" -#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \ - template <> \ - std::pair unified_attention_kernel_dispatch( \ - const unified_attention_args& args, const stream_config& config) \ - { \ - return std::make_pair(true, \ - unified_attention_kernel_launch(args, config)); \ +#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \ + template <> \ + std::pair unified_attention_kernel_dispatch( \ + const unified_attention_args& args, const stream_config& config) \ + { \ + return std::make_pair( \ + true, unified_attention_kernel_launch(args, config)); \ } namespace ck_tile { @@ -55,8 +55,8 @@ struct unified_attention_problem_traits struct unified_attention_kernel_traits { - static constexpr auto date_type = DataType; - static constexpr bool is_masking = IsMasking; + static constexpr auto date_type = DataType; + static constexpr bool is_masking = IsMasking; // BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE using unified_attention_block_tile = sequence<256, 64, 128, 128>; @@ -64,34 +64,34 @@ struct unified_attention_kernel_traits using unified_attention_block_warps = sequence<8, 1, 1>; using unified_attention_shape = TileUnifiedAttentionShape; + unified_attention_block_warps, + unified_attention_warp_gemm_shape, + unified_attention_block_warps, + unified_attention_warp_gemm_shape, + true // IsVLayoutRowMajor + >; using unified_attention_traits = TileUnifiedAttentionTraits; + false, // kPadHeadDimQ + -1 // kBlockPerCu + >; using unified_attention_mask = GenericAttentionMask; - using unified_attention_pipeline_problem = - UnifiedAttentionPipelineProblem::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::lse_dtype, - typename unified_attention_problem_traits::qkvp_dtype, - typename unified_attention_problem_traits::acc_dtype, - typename unified_attention_problem_traits::o_dtype, - unified_attention_shape, - unified_attention_mask, - unified_attention_traits>; + using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem< + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::lse_dtype, + typename unified_attention_problem_traits::qkvp_dtype, + typename unified_attention_problem_traits::acc_dtype, + typename unified_attention_problem_traits::o_dtype, + unified_attention_shape, + unified_attention_mask, + unified_attention_traits>; using unified_attention_pipeline = UnifiedAttentionPipeline; @@ -107,11 +107,12 @@ struct unified_attention_kernel_traits }; template -float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) +float unified_attention_kernel_launch(const unified_attention_args& args, + const stream_config& config) { - + index_t BLOCK_Q = Kernel::BLOCK_M / args.num_queries_per_kv; - + index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, @@ -128,26 +129,25 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.scale_out, total_num_q_blocks, args.query_stride_0, - args.query_stride_1, - args.stride_k_cache_0, - args.stride_k_cache_1, - args.stride_k_cache_2, - args.stride_k_cache_3, - args.stride_v_cache_0, - args.stride_v_cache_1, - args.stride_v_cache_2, - args.stride_v_cache_3, - args.output_stride_0, - args.output_stride_1, - args.block_tables_ptr, - args.block_table_stride, - args.seq_lens_ptr, - args.query_start_len_ptr, - args.num_seqs - ); - - dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); - constexpr dim3 blocks = Kernel::BlockSize(); + args.query_stride_1, + args.stride_k_cache_0, + args.stride_k_cache_1, + args.stride_k_cache_2, + args.stride_k_cache_3, + args.stride_v_cache_0, + args.stride_v_cache_1, + args.stride_v_cache_2, + args.stride_v_cache_3, + args.output_stride_0, + args.output_stride_1, + args.block_tables_ptr, + args.block_table_stride, + args.seq_lens_ptr, + args.query_start_len_ptr, + args.num_seqs); + + dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); + constexpr dim3 blocks = Kernel::BlockSize(); constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); @@ -158,6 +158,6 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const // second = elapsed time (ms) of the kernel launch, valid only if first == true template std::pair unified_attention_kernel_dispatch(const unified_attention_args& args, - const stream_config& config); + const stream_config& config); } // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention.hpp b/include/ck_tile/ops/unified_attention.hpp index 62e6c58acb7..20eee5a819e 100644 --- a/include/ck_tile/ops/unified_attention.hpp +++ b/include/ck_tile/ops/unified_attention.hpp @@ -3,12 +3,6 @@ #pragma once - -#include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/common/utils.hpp" - -// Block-level components #include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/unified_attention/block/block_dropout.hpp" #include "ck_tile/ops/unified_attention/block/block_masking.hpp" @@ -16,15 +10,15 @@ #include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" #include "ck_tile/ops/unified_attention/block/page_block_navigator.hpp" #include "ck_tile/ops/unified_attention/block/variants.hpp" - -// Kernel-level components #include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp" - -// Pipeline-level components #include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp" #include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp" - +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/unified_attention/block/block_masking.hpp b/include/ck_tile/ops/unified_attention/block/block_masking.hpp index 87868a56a18..33ca84d2c5c 100644 --- a/include/ck_tile/ops/unified_attention/block/block_masking.hpp +++ b/include/ck_tile/ops/unified_attention/block/block_masking.hpp @@ -86,19 +86,22 @@ struct GenericAttentionMask static constexpr const char* name = impl::MaskName::name; // New constructor accepting repeat_idx with default value 1 - CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) + CK_TILE_HOST_DEVICE + GenericAttentionMask(index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) : GenericAttentionMask(0, 0, y_total_, x_total_, repeat_idx_) { } CK_TILE_HOST_DEVICE - GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) + GenericAttentionMask( + index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t repeat_idx_ = 1) : y(y_), x(x_), y_total(y_total_), x_total(x_total_), repeat_idx(repeat_idx_) { } template - CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord, index_t repeat_idx_ = 1) + CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord, + index_t repeat_idx_ = 1) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), y_total(mask_coord.at(number<2>{})), @@ -248,13 +251,12 @@ struct GenericAttentionMask } } -private: + private: index_t y, x; index_t y_total, x_total; index_t repeat_idx; }; - // TODO: prefer use this function in host code // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask @@ -289,7 +291,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size, index_t y_total, index_t x_total, index_t repeat_idx = 1, - bool is_top_left = true) + bool is_top_left = true) { auto r = make_generic_attention_mask_coordinates_from_lr_window( left_size, right_size, y_total, x_total, is_top_left); diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ee4eeab9204..ac7a06a961d 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -18,8 +18,8 @@ namespace ck_tile { template struct UnifiedAttentionKernel { - using UnifiedAttentionPipeline = ck_tile::remove_cvref_t; - using EpiloguePipeline = ck_tile::remove_cvref_t; + using UnifiedAttentionPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t kBlockSize = UnifiedAttentionPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockPerCu = UnifiedAttentionPipeline::kBlockPerCu; static_assert(kBlockPerCu > 0); @@ -29,18 +29,18 @@ struct UnifiedAttentionKernel using VDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; - using FmhaMask = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; - + static constexpr bool kPadSeqLenK = UnifiedAttentionPipeline::kPadSeqLenK; static constexpr bool kPadSeqLenQ = UnifiedAttentionPipeline::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV; - // TODO add yjese - static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE; + // TODO add yjese + static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE; static constexpr index_t HEAD_SIZE_PADDED = UnifiedAttentionPipeline::HEAD_SIZE_PADDED; - + // BLOCK_Q = BLOCK_M // num_queries_per_kv // BLOCK_Q is the block size for q seqlen /// static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; @@ -49,7 +49,6 @@ struct UnifiedAttentionKernel // BLOCK size for K seqlen static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE; - // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. @@ -88,12 +87,11 @@ struct UnifiedAttentionKernel ck_tile::index_t output_stride_1; }; - - struct UnifiedAttentionVarlenKargs: UnifiedAttentionCommonKargs + struct UnifiedAttentionVarlenKargs : UnifiedAttentionCommonKargs { const int32_t* block_tables_ptr; ck_tile::index_t block_table_stride; - const int32_t* seq_lens_ptr; // seq len in each batch + const int32_t* seq_lens_ptr; // seq len in each batch const int32_t* query_start_len_ptr; // [num_seqs+1] ck_tile::index_t num_seqs; // number of batches for q @@ -101,38 +99,36 @@ struct UnifiedAttentionKernel using Kargs = UnifiedAttentionVarlenKargs; - CK_TILE_HOST static constexpr Kargs MakeKargs( - const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - void* o_ptr, - ck_tile::index_t num_blks, - ck_tile::index_t num_head_q, - const ck_tile::index_t num_queries_per_kv, - float scale_s, - float scale, - float scale_k, - float scale_v, - float scale_out, - ck_tile::index_t total_num_q_blocks, - ck_tile::index_t query_stride_0, - ck_tile::index_t query_stride_1, - ck_tile::index_t stride_k_cache_0, - ck_tile::index_t stride_k_cache_1, - ck_tile::index_t stride_k_cache_2, - ck_tile::index_t stride_k_cache_3, - ck_tile::index_t stride_v_cache_0, - ck_tile::index_t stride_v_cache_1, - ck_tile::index_t stride_v_cache_2, - ck_tile::index_t stride_v_cache_3, - ck_tile::index_t output_stride_0, - ck_tile::index_t output_stride_1, - const int32_t* block_tables_ptr, - ck_tile::index_t block_table_stride, - const int32_t* seq_lens_ptr, - const int32_t* query_start_len_ptr, - ck_tile::index_t num_seqs - ) + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* o_ptr, + ck_tile::index_t num_blks, + ck_tile::index_t num_head_q, + const ck_tile::index_t num_queries_per_kv, + float scale_s, + float scale, + float scale_k, + float scale_v, + float scale_out, + ck_tile::index_t total_num_q_blocks, + ck_tile::index_t query_stride_0, + ck_tile::index_t query_stride_1, + ck_tile::index_t stride_k_cache_0, + ck_tile::index_t stride_k_cache_1, + ck_tile::index_t stride_k_cache_2, + ck_tile::index_t stride_k_cache_3, + ck_tile::index_t stride_v_cache_0, + ck_tile::index_t stride_v_cache_1, + ck_tile::index_t stride_v_cache_2, + ck_tile::index_t stride_v_cache_3, + ck_tile::index_t output_stride_0, + ck_tile::index_t output_stride_1, + const int32_t* block_tables_ptr, + ck_tile::index_t block_table_stride, + const int32_t* seq_lens_ptr, + const int32_t* query_start_len_ptr, + ck_tile::index_t num_seqs) { Kargs kargs{{q_ptr, k_ptr, @@ -146,31 +142,30 @@ struct UnifiedAttentionKernel scale_k, scale_v, scale_out, - total_num_q_blocks, - query_stride_0, - query_stride_1, - stride_k_cache_0, - stride_k_cache_1, - stride_k_cache_2, - stride_k_cache_3, - stride_v_cache_0, - stride_v_cache_1, - stride_v_cache_2, - stride_v_cache_3, - output_stride_0, - output_stride_1}, - block_tables_ptr, - block_table_stride, - seq_lens_ptr, - query_start_len_ptr, - num_seqs - }; + total_num_q_blocks, + query_stride_0, + query_stride_1, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + output_stride_0, + output_stride_1}, + block_tables_ptr, + block_table_stride, + seq_lens_ptr, + query_start_len_ptr, + num_seqs}; return kargs; } CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads, - ck_tile::index_t total_num_q_blocks) + ck_tile::index_t total_num_q_blocks) { return dim3(num_kv_heads * total_num_q_blocks, 0, 0); } @@ -190,16 +185,16 @@ struct UnifiedAttentionKernel ck_tile::index_t block_q, bool use_q_block_mode) { - ck_tile::index_t left = 0; + ck_tile::index_t left = 0; ck_tile::index_t right = num_seqs; - - while (left < right) + + while(left < right) { - ck_tile::index_t mid = (left + right) / 2; - ck_tile::index_t val = query_start_len_ptr[mid]; + ck_tile::index_t mid = (left + right) / 2; + ck_tile::index_t val = query_start_len_ptr[mid]; ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val; - - if (mid_val <= target_idx) + + if(mid_val <= target_idx) { left = mid + 1; } @@ -208,32 +203,31 @@ struct UnifiedAttentionKernel right = mid; } } - + return left - 1; } - - CK_TILE_DEVICE static constexpr auto - RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs) + + CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid, + const Kargs& kargs) { using namespace ck_tile; constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.total_num_q_blocks * - (kargs.num_head_q); - + const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q); + // Number of pids per XCD in the new arrangement const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; - + // When GRID_MN cannot divide NUM_XCDS, some xcds will have // pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. // We calculate the number of xcds that have pids_per_xcd pids as tall_xcds index_t tall_xcds = GRID_MN % NUM_XCDS; - tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; - + tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; + // Compute current XCD and local pid within the XCD - const index_t xcd = pid % NUM_XCDS; + const index_t xcd = pid % NUM_XCDS; const index_t local_pid = pid / NUM_XCDS; - + // Calculate new pid based on the new grouping index_t remapped_pid = 0; // Initialize to avoid constexpr error if(xcd < tall_xcds) @@ -242,15 +236,15 @@ struct UnifiedAttentionKernel } else { - remapped_pid = tall_xcds * pids_per_xcd + - (xcd - tall_xcds) * (pids_per_xcd - 1) + - local_pid; + remapped_pid = + tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid; } - + return remapped_pid; } - CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) + CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, + const Kargs& kargs) { using namespace ck_tile; @@ -258,8 +252,8 @@ struct UnifiedAttentionKernel // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, // UnifiedAttentionPipeline::kN1); - const index_t i_tile_m = pid % total_num_q_blocks; // Query block index - const index_t i_tile_n = pid / total_num_q_blocks; // Head index + const index_t i_tile_m = pid % total_num_q_blocks; // Query block index + const index_t i_tile_n = pid / total_num_q_blocks; // Head index return ck_tile::make_tuple(i_tile_m, i_tile_n); } @@ -268,7 +262,8 @@ struct UnifiedAttentionKernel CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + return ck_tile::max(UnifiedAttentionPipeline::GetSmemSize(), + EpiloguePipeline::GetSmemSize()); } CK_TILE_DEVICE void operator()(Kargs kargs) const @@ -285,7 +280,7 @@ struct UnifiedAttentionKernel // const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; // for simplicity, batch stride we just modify the pointer // const index_t num_head_q = kargs.num_head_q; - + // const index_t num_head_k = num_head_q / num_queries_per_kv; pid = RemapTileIndices(pid, kargs); @@ -296,65 +291,76 @@ struct UnifiedAttentionKernel // grid size is (num_kv_heads, total_num_q_blocks) // total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs // q.shape[0] is total number of query tokens across all batches - // one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. One query token group shares one kv token + // one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. + // One query token group shares one kv token - const index_t seq_idx = find_seq_idx( - kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true - ); // which batch + const index_t seq_idx = find_seq_idx(kargs.query_start_len_ptr, + q_block_global_idx, + kargs.num_seqs, + BLOCK_Q, + true); // which batch - const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t q_block_start_idx = + amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); - const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + const index_t q_block_local_idx = + amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); - const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); - const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); + const index_t cur_batch_in_all_start_index = + amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t cur_batch_in_all_stop_index = + amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); - const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index; + const index_t cur_batch_query_len = + cur_batch_in_all_stop_index - cur_batch_in_all_start_index; // TODO check if we get the block size info from pipeline - if (q_block_local_idx * BLOCK_Q >= cur_batch_query_len) { + if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len) + { return; } const index_t query_pos = q_block_local_idx * BLOCK_Q; - const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; + const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; const index_t context_len = seq_len - cur_batch_query_len; - index_t _max_seq_prefix_len = ( - context_len - + q_block_local_idx * BLOCK_Q - + (BLOCK_M - 1) // num_queries_per_kv - + 1 - ); + index_t _max_seq_prefix_len = + (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv + + 1); - if (seq_len < _max_seq_prefix_len) { + if(seq_len < _max_seq_prefix_len) + { _max_seq_prefix_len = seq_len; } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE; // TODO sliding window const index_t num_blocks_start = 0; - index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2; - + index_t kv_head_offset = kv_head_idx * kargs.stride_k_cache_2; + // Q/K/V DRAM and DRAM window - index_t q_ptr_offset_0 = cur_batch_in_all_start_index * kargs.query_stride_0; // move the pointer to the batch start - index_t q_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.query_stride_1; // move the pointer to the correct head group start + index_t q_ptr_offset_0 = cur_batch_in_all_start_index * + kargs.query_stride_0; // move the pointer to the batch start + index_t q_ptr_offset_1 = + kv_head_idx * num_queries_per_kv * + kargs.query_stride_1; // move the pointer to the correct head group start index_t q_ptr_offset = q_ptr_offset_0 + q_ptr_offset_1; - index_t o_ptr_offset_0 = cur_batch_in_all_start_index * kargs.output_stride_0; // move the pointer to the batch start - index_t o_ptr_offset_1 = kv_head_idx * num_queries_per_kv * kargs.output_stride_1; // move the pointer to the correct head group start - index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1; + index_t o_ptr_offset_0 = cur_batch_in_all_start_index * + kargs.output_stride_0; // move the pointer to the batch start + index_t o_ptr_offset_1 = + kv_head_idx * num_queries_per_kv * + kargs.output_stride_1; // move the pointer to the correct head group start + index_t o_ptr_offset = o_ptr_offset_0 + o_ptr_offset_1; index_t block_table_offset = seq_idx * kargs.block_table_stride; - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + q_ptr_offset; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + kv_head_offset; const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q; // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); @@ -368,37 +374,35 @@ struct UnifiedAttentionKernel number{}, number<2>{}); - const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED - q_dram_base, - // block sizes - make_tuple(number{}, number<1>{}, number{}), - sequence{} - ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) + const auto q_dram_pad = + pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + q_dram_base, + // block sizes + make_tuple(number{}, number<1>{}, number{}), + sequence{}); // pads to (seq_len_padded, num_head_q, + // HEAD_SIZE_PADDED) const auto q_dram_merged = transform_tensor_view( - q_dram_pad, - make_tuple( - make_merge_transform( - make_tuple(query_len_padded, num_queries_per_kv) - ), - make_pass_through_transform(number{}) - ), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}) - ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim - + q_dram_pad, + make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, + sequence<1>{})); // flattens the first two dims, head idx is the fastest + // changing dim in the merged dim + return q_dram_merged; }(); - // static_assert(q_dram.desc_[number<0>{}] == 0, "q_dram.get_bottom_tensor_view()[number<0>{}] == 0"); + // static_assert(q_dram.desc_[number<0>{}] == 0, + // "q_dram.get_bottom_tensor_view()[number<0>{}] == 0"); // Q has the shape (k_head, seq_len, num_queries_per_kv, head_dim) // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) - auto q_dram_window = make_tile_window( - q_dram, - make_tuple(number{}, number{}), - {query_pos * num_queries_per_kv, 0} - ); - + auto q_dram_window = + make_tile_window(q_dram, + make_tuple(number{}, number{}), + {query_pos * num_queries_per_kv, 0}); + const auto k_dram = [&]() { // HEAD dim is skipped as defined in the ptrs const auto k_dram_naive = make_naive_tensor_view( @@ -408,24 +412,19 @@ struct UnifiedAttentionKernel number{}, number<1>{}); - const auto k_dram_pad = pad_tensor_view( - k_dram_naive, - // TODO can the BLOCK_SIZE_RAW needs padding? - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); - + const auto k_dram_pad = pad_tensor_view(k_dram_naive, + // TODO can the BLOCK_SIZE_RAW needs padding? + make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); const auto k_dram_merged = transform_tensor_view( - k_dram_pad, - make_tuple( - make_merge_transform( - make_tuple(kargs.num_blks, BLOCK_SIZE) - ), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}) - ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim + k_dram_pad, + make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), + make_pass_through_transform(HEAD_SIZE_PADDED)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, + sequence<1>{})); // flattens the first two dims, head idx is the fastest + // changing dim in the merged dim return k_dram_merged; }(); @@ -441,53 +440,50 @@ struct UnifiedAttentionKernel number{}, number<1>{}); - const auto v_dram_pad = pad_tensor_view( - v_dram_naive, - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); + const auto v_dram_pad = pad_tensor_view(v_dram_naive, + make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); const auto v_dram_merged = transform_tensor_view( - v_dram_pad, - make_tuple( - make_merge_transform( - make_tuple(kargs.num_blks, BLOCK_SIZE) - ), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}) - ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim + v_dram_pad, + make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), + make_pass_through_transform(HEAD_SIZE_PADDED)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, + sequence<1>{})); // flattens the first two dims, head idx is the fastest + // changing dim in the merged dim return v_dram_merged; }(); auto v_dram_window = make_tile_window( v_dram, make_tuple(number{}, number{}), {0, 0}); - + FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( seq_len - cur_batch_query_len, // y (i.e. context) - cur_batch_query_len, // x (i.e. extend) - seq_len, // y_total (x + y) - cur_batch_query_len, // x_total - num_queries_per_kv // the same sequence index is repeated num_queries_per_kv times along x dim of the tile + cur_batch_query_len, // x (i.e. extend) + seq_len, // y_total (x + y) + cur_batch_query_len, // x_total + num_queries_per_kv // the same sequence index is repeated num_queries_per_kv + // times along x dim of the tile ); else return FmhaMask{cur_batch_query_len, seq_len}; }(); - + auto o_acc_tile = [&]() { return UnifiedAttentionPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - num_blocks, - num_blocks_start, - kargs.block_tables_ptr, - block_table_offset, - mask, - kargs.scale_s, - smem_ptr); + k_dram_window, + v_dram_window, + num_blocks, + num_blocks_start, + kargs.block_tables_ptr, + block_table_offset, + mask, + kargs.scale_s, + smem_ptr); }(); // O DRAM and O DRAM window @@ -499,24 +495,20 @@ struct UnifiedAttentionKernel number{}, number<1>{}); - const auto o_dram_pad = pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED - o_dram_base, - // block sizes - make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), - sequence{} - ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED) + const auto o_dram_pad = + pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + o_dram_base, + // block sizes + make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), + sequence{}); // pads to (seq_len_padded, num_head_q, + // HEAD_SIZE_PADDED) const auto o_dram_merged = transform_tensor_view( - o_dram_pad, - make_tuple( - make_merge_transform( - make_tuple(query_len_padded, num_queries_per_kv) - ), - make_pass_through_transform(HEAD_SIZE_PADDED) - ), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}) - ); + o_dram_pad, + make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), + make_pass_through_transform(HEAD_SIZE_PADDED)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); return o_dram_merged; }(); diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp index 790b0614a67..de7762e1219 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -47,11 +47,14 @@ struct TileUnifiedAttentionShape static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); - static constexpr index_t BLOCK_M = BlockTile::at(number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) - static constexpr index_t BLOCK_Q = BlockTile::at(number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) - // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * num_queries_per_kv (q_head//kv_head) + static constexpr index_t BLOCK_M = BlockTile::at( + number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) + static constexpr index_t BLOCK_Q = BlockTile::at( + number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) + // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * + // num_queries_per_kv (q_head//kv_head) static constexpr index_t BLOCK_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen - static constexpr index_t HEAD_SIZE = BlockTile::at(number<3>{}); // BLOCK size for K seqlen + static constexpr index_t HEAD_SIZE = BlockTile::at(number<3>{}); // BLOCK size for K seqlen // static constexpr index_t kQKHeaddim = // BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp index b27a09a1b41..40ec0fd0aa7 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -9,14 +9,13 @@ namespace ck_tile { - template struct TileUnifiedAttentionTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; - static constexpr bool kPadHeadDim = kPadHeadDim_; + static constexpr bool kPadHeadDim = kPadHeadDim_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; -} +} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index b2541ab74e8..486acc42435 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -268,14 +268,15 @@ struct UnifiedAttentionPipeline static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; - static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M; - static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q; - - static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE; - static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE; - static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED; + static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M; + static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q; - static_assert(HEAD_SIZE_PADDED <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE; + static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE; + static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED; + + static_assert(HEAD_SIZE_PADDED <= 256, + "hdim bigger than 256 is not suitable for this pipeline!"); // static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim; @@ -402,12 +403,13 @@ struct UnifiedAttentionPipeline std::is_same_v>, "wrong!"); - static_assert(BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); + static_assert( + BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); auto s_lds = make_tensor_view( @@ -542,9 +544,11 @@ struct UnifiedAttentionPipeline const auto q_origin = q_dram_window.get_window_origin(); // const auto [seqlen_k_start, seqlen_k_end] = - // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, + // number{}); - // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, BLOCK_SIZE); + // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, + // BLOCK_SIZE); const auto num_total_loop = num_blocks; // index_t kv_token_start = seqlen_k_start; @@ -554,7 +558,6 @@ struct UnifiedAttentionPipeline { if(num_total_loop - num_blocks_start <= 0) { - // Note: here occ are all cleard, return it // Note: q loaded but no fence, ignore it. @@ -562,11 +565,11 @@ struct UnifiedAttentionPipeline } } - index_t i_total_loops = num_blocks_start; - const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; - index_t kv_blk_idx_prev = 0; - + index_t i_total_loops = num_blocks_start; + const ck_tile::index_t* block_tables_ptr_ = + reinterpret_cast(block_tables_ptr); + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + index_t kv_blk_idx_prev = 0; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -672,9 +675,10 @@ struct UnifiedAttentionPipeline auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); - // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx as the index + // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx + // as the index /// FIXME: use the future-predicting method to move the window - k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); + k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; auto K_lds_load = [&](auto k_lds_read_idx) { @@ -683,9 +687,9 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; - /// FIXME: use the future-predicting method to move the window - v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); + // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + /// FIXME: use the future-predicting method to move the window + v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; auto V_lds_load = [&](auto v_lds_read_idx) { @@ -894,8 +898,10 @@ struct UnifiedAttentionPipeline auto fmha_mask = [&](auto sp_reg_idx) { if constexpr(FmhaMask::IsMasking) { - bool need_perpixel_check = mask.IsEdgeTile( - q_origin.at(number<0>{}), i_total_loops * BLOCK_SIZE, number{}, number{}); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + i_total_loops * BLOCK_SIZE, + number{}, + number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, @@ -903,7 +909,8 @@ struct UnifiedAttentionPipeline [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = i_total_loops * BLOCK_SIZE + tile_idx.at(number<1>{}); + const auto col = + i_total_loops * BLOCK_SIZE + tile_idx.at(number<1>{}); return mask.IsOutOfBound(row, col); }); } @@ -1180,7 +1187,6 @@ struct UnifiedAttentionPipeline fmha_post_process(number<0>{}); } - // finally, O constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index 32f97aba50d..3d5b46c1762 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -141,11 +141,11 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - constexpr index_t N0 = NumIssues; // 8 + constexpr index_t N0 = NumIssues; // 8 constexpr index_t N1 = LaneGroups; // 2 - constexpr index_t N2 = NumWarps; // 8 - constexpr index_t K0 = LanesPerK; // 32 - constexpr index_t K1 = KVector; // 4 + constexpr index_t N2 = NumWarps; // 8 + constexpr index_t K0 = LanesPerK; // 32 + constexpr index_t K1 = KVector; // 4 return make_static_tile_distribution( tile_distribution_encoding, @@ -259,13 +259,13 @@ struct UnifiedAttentionPipelineDefaultPolicy } }(); - using BlockGemmPolicy = - BlockGemmARegBRegCRegV2CustomPolicy; + using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::SaccDataType, + typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, + decltype(warp_gemm), + GemmLoopOrder::MNK>; return BlockGemmARegBRegCRegV2{}; } @@ -287,24 +287,25 @@ struct UnifiedAttentionPipelineDefaultPolicy typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>; /// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass /// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single - using WarpGemm = WarpGemmDispatcher{}), - Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}), - Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}), - true, - false, - false, - WGAttrNumAccessEnum::Double>; - - using BlockGemmPolicy = - BlockGemmARegBRegCRegV2CustomPolicy; + using WarpGemm = + WarpGemmDispatcher{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>; + + using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + typename Problem::UnifiedAttentionShape::Gemm1BlockWarps, + WarpGemm, + GemmLoopOrder::MNK>; return BlockGemmARegBRegCRegV2{}; } diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp index d21d8316afe..f2caaa23df9 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -19,23 +19,23 @@ template struct UnifiedAttentionPipelineProblem { // TODO kM0 and KN1?? - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; // first gemm accumulation dtype - using SaccDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; // Softmax dtype using SMPLComputeDataType = remove_cvref_t; using BiasDataType = remove_cvref_t; using RandValOutputDataType = remove_cvref_t; // data type for A matrix of second gemm - using PDataType = remove_cvref_t; - // data type for second gemm accumulation + using PDataType = remove_cvref_t; + // data type for second gemm accumulation using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using UnifiedAttentionShape = remove_cvref_t; @@ -48,11 +48,11 @@ struct UnifiedAttentionPipelineProblem // attributes from traits static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadHeadDim = Traits::kPadHeadDim; + static constexpr bool kPadHeadDim = Traits::kPadHeadDim; static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; -} +} // namespace ck_tile From 3b56f0f0c3ff26b41f1cb9d748b0b153f0a45e05 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 12 Nov 2025 13:59:13 +0000 Subject: [PATCH 53/88] fixed block sizes --- .../pipeline/unified_attention_pipeline.hpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 486acc42435..ad43decc99e 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -407,8 +407,8 @@ struct UnifiedAttentionPipeline BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); @@ -428,8 +428,8 @@ struct UnifiedAttentionPipeline auto o_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); - [[maybe_unused]] auto o_lds_window = - make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); + [[maybe_unused]] auto o_lds_window = make_tile_window( + o_lds, make_tuple(number{}, number{}), {0, 0}); auto m_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + @@ -826,11 +826,11 @@ struct UnifiedAttentionPipeline { gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, - sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, + sequence{}), get_slice_tile(kv_tile.v_tile, - sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{})); + sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, + sequence{})); } }; From 07bb33866b1ae113f65e7844d66d49a4fccdf5a8 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 12 Nov 2025 14:00:51 +0000 Subject: [PATCH 54/88] block_shape fixes --- .../pipeline/unified_attention_pipeline.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index ad43decc99e..b8799ac9f17 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -850,11 +850,11 @@ struct UnifiedAttentionPipeline { gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, - sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, + sequence{}), get_slice_tile(kv_tile.v_tile, - sequence<0, (k1_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{})); + sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, + sequence{})); fmha_alu0(number<1>{} - sp_reg_idx); } }; From f4392ddaafe35ab72b8ba3c6548f93378902d434 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 13 Nov 2025 08:14:52 +0000 Subject: [PATCH 55/88] block_shape fixes --- .../unified_attention/pipeline/unified_attention_pipeline.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index b8799ac9f17..a575230ef6b 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -830,7 +830,7 @@ struct UnifiedAttentionPipeline sequence{}), get_slice_tile(kv_tile.v_tile, sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, - sequence{})); + sequence{})); } }; @@ -854,7 +854,7 @@ struct UnifiedAttentionPipeline sequence{}), get_slice_tile(kv_tile.v_tile, sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, - sequence{})); + sequence{})); fmha_alu0(number<1>{} - sp_reg_idx); } }; From db7224e0671dbb3f83e063ec0ab92d5a88f3bf6a Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 13 Nov 2025 14:02:24 +0000 Subject: [PATCH 56/88] Fixed impl --- .../01_unified_attention/unified_attention_impl.hpp | 7 ++++++- .../unified_attention/kernel/unified_attention_kernel.hpp | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index dc3104e4f23..995ea8fdf83 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -58,8 +58,13 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; + // TODO please fix this + static constexpr index_t BLOCK_M = 256; + static constexpr index_t num_head_per_kv = 4; + static constexpr index_t BLOCK_Q = BLOCK_M / num_head_per_kv; + // BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE - using unified_attention_block_tile = sequence<256, 64, 128, 128>; + using unified_attention_block_tile = sequence; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; using unified_attention_block_warps = sequence<8, 1, 1>; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index ac7a06a961d..1cf7698b61b 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -167,7 +167,7 @@ struct UnifiedAttentionKernel CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads, ck_tile::index_t total_num_q_blocks) { - return dim3(num_kv_heads * total_num_q_blocks, 0, 0); + return dim3(num_kv_heads * total_num_q_blocks); } // CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads, From 4a13749f7fe9a76b5f15d2be66c72b57c99daeaf Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 17 Nov 2025 07:33:10 +0000 Subject: [PATCH 57/88] fix to example --- .../01_unified_attention/example_unified_attention.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index b103043b708..6e3708acdd0 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -30,9 +30,9 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair Date: Mon, 17 Nov 2025 07:36:42 +0000 Subject: [PATCH 58/88] add handling for -1 k heads arg --- .../01_unified_attention/example_unified_attention.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 6e3708acdd0..b8c65d7c0a6 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -32,7 +32,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair Date: Mon, 17 Nov 2025 08:27:19 +0000 Subject: [PATCH 59/88] refactor to clearer BLOCK Q logic --- .../example_unified_attention.cpp | 19 ++++++++---------- .../unified_attention_impl.hpp | 20 +++++++++++-------- .../kernel/unified_attention_kernel.hpp | 9 ++------- 3 files changed, 22 insertions(+), 26 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index b8c65d7c0a6..d34d8d9593b 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -30,11 +30,11 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair; + using unified_attention_block_tile = sequence; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; + // need to have 8 warps per workgroup to have warp specialization using unified_attention_block_warps = sequence<8, 1, 1>; using unified_attention_shape = TileUnifiedAttentionShape float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { - - index_t BLOCK_Q = Kernel::BLOCK_M / args.num_queries_per_kv; - + index_t BLOCK_Q = Kernel::BLOCK_Q; + assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; - auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 1cf7698b61b..2f1b5746553 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -170,13 +170,6 @@ struct UnifiedAttentionKernel return dim3(num_kv_heads * total_num_q_blocks); } - // CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads, - // ck_tile::index_t total_num_q_blocks) - // { - // // TODO: fix 3D grid - // return dim2(num_kv_heads, total_num_q_blocks); - // } - // Binary search to find the sequence index for a given target index CK_TILE_DEVICE static constexpr ck_tile::index_t find_seq_idx(const int32_t* query_start_len_ptr, @@ -277,6 +270,8 @@ struct UnifiedAttentionKernel const index_t num_queries_per_kv = kargs.num_queries_per_kv; + assert(BLOCK_M / num_queries_per_kv == BLOCK_Q); + // const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; // for simplicity, batch stride we just modify the pointer // const index_t num_head_q = kargs.num_head_q; From 5d2a9e5f16d60fa97412a60a7ecd8cd347135b30 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 17 Nov 2025 09:46:31 +0000 Subject: [PATCH 60/88] deving the test... --- .../example_unified_attention.cpp | 389 +++++++++--------- 1 file changed, 199 insertions(+), 190 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index d34d8d9593b..4f58a13c510 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -73,6 +73,13 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + enum class TensorLayout { bhsd, @@ -204,105 +211,105 @@ auto generate_qkv(const Problem& problem, return std::make_tuple(q, k, v); } -// namespace host { -// template -// CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, -// const ck_tile::HostTensor& k_bshd, -// const ck_tile::HostTensor& v_bshd, -// const mask_info& mask, -// ck_tile::HostTensor& o_bshd, -// const QElementOp& q_element_op = {}, -// const KElementOp& k_element_op = {}, -// const VElementOp& v_element_op = {}, -// const SAccElementOp& s_acc_element_op = {}) -// { -// const int batch_size = q_bshd.mDesc.get_lengths()[0]; -// const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; -// const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; -// const int nhead_q = q_bshd.mDesc.get_lengths()[2]; -// const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; -// const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; -// const int hdim_v = v_bshd.mDesc.get_lengths()[3]; - -// const int nr = nhead_q / nhead_kv; - -// ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); -// ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); -// ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); -// ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); - -// ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); -// ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - -// // do computation for each batch -// for(int b = 0; b < batch_size; ++b) -// { -// // copy per-batch data from input tensors -// // clang-format off -// q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , -// idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], -// idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = -// v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); -// // clang-format on -// ck_tile::reference_batched_gemm( -// q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - -// if(mask.type == mask_enum::no_mask) -// { -// ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); -// } -// else if(mask.type == mask_enum::window_generic) -// { -// ck_tile::reference_batched_masking( -// s_host_ref, -// ck_tile::make_generic_attention_mask_from_lr_window( -// mask.left, mask.right, seqlen_q, seqlen_kv)); -// } -// else -// { -// // if left window size is negative, means causal -// // else means generic (for current batch) -// if(mask.left < 0) -// ck_tile::reference_batched_masking( -// s_host_ref, -// ck_tile::make_generic_attention_mask_from_lr_window( -// mask.left, -// mask.right, -// seqlen_q, -// seqlen_kv, -// mask.type == mask_enum::mask_top_left)); -// else -// ck_tile::reference_batched_masking( -// s_host_ref, -// ck_tile::make_generic_attention_mask_from_lr_window( -// mask.left, -// mask.right, -// seqlen_q, -// seqlen_kv, -// mask.type == mask_enum::mask_top_left)); -// } - -// ck_tile::reference_batched_softmax( -// s_host_ref, p_host_ref, ck_tile::identity{}); - -// ck_tile::reference_batched_gemm( -// p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - -// // copy resulting per-batch data to the output tensor -// o_host_ref.ForEach( -// [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); -// } -// } -// } // namespace host +namespace host { +template +CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, + const ck_tile::HostTensor& k_bshd, + const ck_tile::HostTensor& v_bshd, + const mask_info& mask, + ck_tile::HostTensor& o_bshd, + const QElementOp& q_element_op = {}, + const KElementOp& k_element_op = {}, + const VElementOp& v_element_op = {}, + const SAccElementOp& s_acc_element_op = {}) +{ +const int batch_size = q_bshd.mDesc.get_lengths()[0]; +const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; +const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; +const int nhead_q = q_bshd.mDesc.get_lengths()[2]; +const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; +const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; +const int hdim_v = v_bshd.mDesc.get_lengths()[3]; + +const int nr = nhead_q / nhead_kv; + +ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); +ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); +ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); +ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); + +ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); +ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); + +// do computation for each batch +for(int b = 0; b < batch_size; ++b) +{ + // copy per-batch data from input tensors + // clang-format off + q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , + idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], + idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = + v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + // clang-format on + ck_tile::reference_batched_gemm( + q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, seqlen_q, seqlen_kv)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + } + + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, ck_tile::identity{}); + + ck_tile::reference_batched_gemm( + p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); + + // copy resulting per-batch data to the output tensor + o_host_ref.ForEach( + [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); +} +} +} // namespace host template bool run_impl(const Problem& problem, const RunConfig& run_config) @@ -455,111 +462,113 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) return false; } - // std::size_t flop = [&] { - // if(problem.mask.type == mask_enum::no_mask) - // { - // return 4 * args.num_tokens * problem.nhead_q * - // problem.hdim; - // } - // else - // { - // /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - // return 2 * args.num_tokens * problem.nhead_q * - // problem.hdim; - // } - // }(); + std::size_t flop = [&] { + if(problem.mask.type == mask_enum::no_mask) + { + return 4 * args.num_tokens * problem.nhead_q * + problem.hdim; + } + else + { + /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. + return 2 * args.num_tokens * problem.nhead_q * + problem.hdim; + } + }(); // TODO fix this // std::size_t flop = 1; - // float tflops = static_cast(flop) / 1.e9 / time; + float tflops = static_cast(flop) / 1.e9 / time; - // std::cout << "[" << problem.data_type << "|"; - // std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - // << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim - // << ", scale_s:" << problem.sacle_s << ", mask:" << problem.mask << std::fixed - // << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << - // tflops - // << " TFlops" << std::endl; + std::cout << "[" << problem.data_type << "|"; + std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv + << ", d:" << problem.hdim + << ", mask:" << problem.mask << std::fixed + << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << + tflops + << " TFlops" << std::endl; // if(!run_config.verify) // { // return true; // } - // transpose tensor descriptors from bhsd to bshd if necessary - // if(problem.input_layout != TensorLayout::bshd) - // { - // q = q.transpose({0, 2, 1, 3}); - // k = k.transpose({0, 2, 1, 3}); - // v = v.transpose({0, 2, 1, 3}); - // } - - // ck_tile::HostTensor o_ref(problem.get_output_shape()); - // if(problem.output_layout != TensorLayout::bshd) - // { - // o_ref = o_ref.transpose({0, 2, 1, 3}); - // } - - // If variable lengths are provided, compute per-batch references + // variable lengths are provided -> compute per-batch references // with the effective lengths; else compute a single full reference. // Variable-length aware verification: zero-fill padded region and only compute valid part. - // o_ref.SetZero(); + ck_tile::HostTensor o_ref(problem.get_output_shape()); + o_ref.SetZero(); - // for(int b = 0; b < problem.batch; ++b) - // { - // const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; - // const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; - - // if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) - // continue; - - // // Slice current batch from inputs (bshd) and build single-batch tensors - // ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - // ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - // ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); - // ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); - - // // Copy effective region - // q_b.ForEach([&](auto& self, auto idx) { - // // idx: [0, s, h, d] - // self(idx) = q(b, idx[1], idx[2], idx[3]); - // }); - // k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); - // v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); - - // // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) - // host::fmha_fwd(q_b, - // k_b, - // v_b, - // problem.mask, - // o_b, - // ck_tile::identity{}, - // ck_tile::identity{}, - // ck_tile::identity{}, - // ck_tile::scales{problem.scale_s}); - - // // Scatter into o_ref's bshd descriptor memory - // for(int s = 0; s < seqlen_q_eff; ++s) - // { - // for(int h = 0; h < problem.nhead_q; ++h) - // { - // for(int d = 0; d < problem.hdim; ++d) - // { - // o_ref(b, s, h, d) = o_b(0, s, h, d); - // } - // } - // } - // } + for(int b = 0; b < problem.batch; ++b) + { + const ck_tile::index_t seqlen_q_eff = eff_query_lens[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_lens[b]; + + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; + + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { + // kv cache is paged + ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE); + ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col; + ck_tile::index_t block_idx = block_tables_host[block_table_offset]; + + self(idx) = k(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]); + + }); + v_b.ForEach([&](auto& self, auto idx) { + ck_tile::index_t table_col = int(idx[1] / problem.BLOCK_SIZE); + ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col; + ck_tile::index_t block_idx = block_tables_host[block_table_offset]; + + self(idx) = v(block_idx, idx[1] % problem.BLOCK_SIZE, idx[2], idx[3]); + }); + // v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.scale_s}); + + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(b, s, h, d) = o_b(0, s, h, d); + } + } + } + } - // ck_tile::HostTensor o(problem.get_output_shape()); - // o_buf.FromDevice(o.data()); + ck_tile::HostTensor o(problem.get_output_shape()); + o_buf.FromDevice(o.data()); - // const auto [rtol, atol] = [&] { - // if constexpr(std::is_same_v) - // return std::make_tuple(1e-3, 1e-3); - // else - // return std::make_tuple(1e-2, 1e-2); - // }(); - // return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + const auto [rtol, atol] = [&] { + if constexpr(std::is_same_v) + return std::make_tuple(1e-3, 1e-3); + else + return std::make_tuple(1e-2, 1e-2); + }(); + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); return true; } From 5e2fd848b9ba83fe643ce89a9ce8f6b750564af8 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 17 Nov 2025 10:04:30 +0000 Subject: [PATCH 61/88] remove unneeded args --- .../example_unified_attention.cpp | 50 ++++--------------- .../unified_attention_impl.hpp | 19 ++++--- 2 files changed, 19 insertions(+), 50 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index d34d8d9593b..ee675cb871f 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -30,15 +30,14 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair std::pair get_query_shape() const { return {num_tokens, nhead_q, hdim}; } @@ -150,8 +124,6 @@ struct Problem ck_tile::index_t batch; ck_tile::index_t num_blks; ck_tile::index_t BLOCK_SIZE; - ck_tile::index_t max_seqlen_q; // sequal seq len, in thd format - ck_tile::index_t max_context_len; ck_tile::index_t nhead_q; ck_tile::index_t nhead_kv; ck_tile::index_t hdim; @@ -334,8 +306,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.num_blks = problem.num_blks; - // args.query_lens = problem.query_lens - // args.kv_lens = problem.kv_lens args.q_ptr = q_buf.GetDeviceBuffer(); args.query_stride_0 = problem.hdim * problem.nhead_q; args.query_stride_1 = problem.hdim; diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 01b91f2644f..191825740fd 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -58,21 +58,19 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; - - static constexpr index_t BLOCK_M = 256; - static constexpr index_t BLOCK_SIZE = 32; - static constexpr index_t HEAD_SIZE = 128; - - + static constexpr index_t BLOCK_M = 256; + static constexpr index_t BLOCK_SIZE = 32; + static constexpr index_t HEAD_SIZE = 128; + // TODO please fix this to support also other num_qhead_per_kvhead static constexpr index_t num_qhead_per_kvhead = 4; - static constexpr index_t BLOCK_Q = BLOCK_M / num_qhead_per_kvhead; + static constexpr index_t BLOCK_Q = BLOCK_M / num_qhead_per_kvhead; // BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE using unified_attention_block_tile = sequence; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; // need to have 8 warps per workgroup to have warp specialization - using unified_attention_block_warps = sequence<8, 1, 1>; + using unified_attention_block_warps = sequence<8, 1, 1>; using unified_attention_shape = TileUnifiedAttentionShape Date: Mon, 17 Nov 2025 12:40:31 +0000 Subject: [PATCH 62/88] Assert block_size num_queries_per_kv --- .../example_unified_attention.cpp | 46 +++++++++---------- .../unified_attention.hpp | 1 + .../unified_attention_impl.hpp | 10 ++-- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 0e3b425287b..2195ad03a3c 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -25,6 +25,9 @@ #include "unified_attention.hpp" #include "mask.hpp" +const ck_tile::index_t BLOCK_SIZE = 32; +const ck_tile::index_t num_queries_per_kv = 4; + auto parse_cmd_args(int argc, char* argv[]) -> std::pair { ck_tile::ArgParser arg_parser; @@ -37,7 +40,6 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair; @@ -120,6 +120,10 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { index_t BLOCK_Q = Kernel::BLOCK_Q; + assert(args.num_queries_per_kv == Kernel::num_queries_per_kv && + "argument num_queries_per_kv must equal compiled num_queries_per_kv"); + assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && + "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; From ff28bd21bae1f6fa49ae6f2e7c70f8af2f67a395 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 17 Nov 2025 13:55:54 +0000 Subject: [PATCH 63/88] flops and mem calculation --- .../example_unified_attention.cpp | 49 ++++++++++++++----- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 2195ad03a3c..d1eb6f54258 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -31,8 +31,9 @@ const ck_tile::index_t num_queries_per_kv = 4; auto parse_cmd_args(int argc, char* argv[]) -> std::pair { ck_tile::ArgParser arg_parser; - arg_parser.insert("prec", "fp16", "data type. fp16/bf16") - .insert("b", "3", "batch size") + arg_parser + .insert("prec", "fp16", "data type. fp16/bf16") + // .insert("b", "3", "batch size") .insert("h_k", "8", "num head for k/v. num head for q is 4 times this") // .insert("h_k", // "-1", @@ -88,7 +89,6 @@ struct Problem data_type = args.get_str("prec") == "fp16" ? ck_tile::unified_attention_args::data_type_enum::fp16 : ck_tile::unified_attention_args::data_type_enum::bf16; - batch = args.get_int("b"); num_blks = args.get_int("nb"); nhead_kv = args.get_int("h_k"); // TODO: support other GQA/MQA cases than just 4x @@ -97,6 +97,7 @@ struct Problem hdim = args.get_int("d"); query_lens = args.get_int_vec("query_lens"); kv_lens = args.get_int_vec("kv_lens"); + batch = std::max(query_lens.size(), kv_lens.size()); // Calculate scale_s scale_s = args.get_float("scale_s"); @@ -432,25 +433,49 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) } std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) - { - return 4 * args.num_tokens * problem.nhead_q * problem.hdim; - } - else + long flop_result = 0; + + for(size_t b = 0; b < eff_query_lens.size(); ++b) { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * args.num_tokens * problem.nhead_q * problem.hdim; + long query_lens = eff_query_lens[b]; + long kv_lens = eff_kv_lens[b]; + long valid_out_elements = 0; + + // Causal logic for valid output elements + if(query_lens > kv_lens) + { + valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2; + } + else + { + valid_out_elements = + query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2); + } + + flop_result += 2 * problem.nhead_q * valid_out_elements * (problem.hdim + problem.hdim); } + return flop_result; }(); // TODO fix this // std::size_t flop = 1; float tflops = static_cast(flop) / 1.e9 / time; + long mem = 0; + + mem += problem.num_tokens * problem.nhead_q * problem.hdim * 2 * 2; // q and o, fp16 + // Count unique block indices used in block_tables_host + std::unordered_set unique_blocks(block_tables_host.begin(), + block_tables_host.end()); + mem += unique_blocks.size() * BLOCK_SIZE * problem.nhead_kv * problem.hdim * 2 * + 2; // k and v, fp16 + mem += problem.batch * max_num_blocks_per_seq * 4; // int32 block table + mem += problem.batch * 4; // int32 seq_lens_ptr std::cout << "[" << problem.data_type << "|"; std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv << ", d:" << problem.hdim << ", mask:" << problem.mask << std::fixed << ", " - << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops - << " TFlops" << std::endl; + << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops + << " TFlops, " << std::setprecision(2) + << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; // if(!run_config.verify) // { From 8f44fc959344845cbd41be1c1c0fd1526fba1a98 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 17 Nov 2025 14:49:09 +0000 Subject: [PATCH 64/88] mem calculation fixed --- .../01_unified_attention/example_unified_attention.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index d1eb6f54258..58d1d27d68c 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -465,10 +465,9 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // Count unique block indices used in block_tables_host std::unordered_set unique_blocks(block_tables_host.begin(), block_tables_host.end()); - mem += unique_blocks.size() * BLOCK_SIZE * problem.nhead_kv * problem.hdim * 2 * - 2; // k and v, fp16 - mem += problem.batch * max_num_blocks_per_seq * 4; // int32 block table - mem += problem.batch * 4; // int32 seq_lens_ptr + mem += unique_blocks.size() * problem.nhead_kv * problem.hdim * 2 * 2; // k and v, fp16 + mem += problem.batch * max_num_blocks_per_seq * 4; // int32 block table + mem += problem.batch * 4; // int32 seq_lens_ptr std::cout << "[" << problem.data_type << "|"; std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv From 6ef0b9da8c9f654a1fbc55bd2007f8c3f4f26578 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Tue, 18 Nov 2025 08:57:30 +0000 Subject: [PATCH 65/88] fixing --- .../example_unified_attention.cpp | 98 +++++++++---------- .../kernel/unified_attention_kernel.hpp | 10 +- 2 files changed, 50 insertions(+), 58 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index d1eb6f54258..fdf9c7122a8 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -34,7 +34,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair& q_bshd, const ck_tile::HostTensor& k_bshd, const ck_tile::HostTensor& v_bshd, - const mask_info& mask, + // const mask_info& mask, ck_tile::HostTensor& o_bshd, const QElementOp& q_element_op = {}, const KElementOp& k_element_op = {}, @@ -222,61 +223,34 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); - // do computation for each batch for(int b = 0; b < batch_size; ++b) { // copy per-batch data from input tensors // clang-format off - q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , - idx[2]); }); k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], - idx[0] / nr, idx[2]); }); v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = - v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , + idx[2]); }); + k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], + idx[0] / nr, idx[2]); }); + v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = + v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); // clang-format on ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, seqlen_q, seqlen_kv)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - seqlen_q, - seqlen_kv, - mask.type == mask_enum::mask_top_left)); - } - + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + -1, + 0, + seqlen_q, + seqlen_kv, + true)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); - ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - + // copy resulting per-batch data to the output tensor o_host_ref.ForEach( [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); @@ -528,7 +502,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) host::fmha_fwd(q_b, k_b, v_b, - problem.mask, + // problem.mask, o_b, ck_tile::identity{}, ck_tile::identity{}, @@ -551,13 +525,31 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); - const auto [rtol, atol] = [&] { - if constexpr(std::is_same_v) - return std::make_tuple(1e-3, 1e-3); - else - return std::make_tuple(1e-2, 1e-2); - }(); - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + // const auto [rtol, atol] = [&] { + // if constexpr(std::is_same_v) + // return std::make_tuple(1e-3, 1e-3); + // else + // return std::make_tuple(1e-2, 1e-2); + // }(); + + // Print some of the output data for debugging + std::cout << "\nFirst few elements of output tensor o:" << std::endl; + for(int b = 0; b < std::min(2, static_cast(problem.batch)); ++b) { + std::cout << "Batch " << b << ":" << std::endl; + for(int s = 0; s < std::min(5, static_cast(eff_query_lens[b])); ++s) { + for(int h = 0; h < std::min(2, static_cast(problem.nhead_q)); ++h) { + for(int d = 0; d < std::min(4, static_cast(problem.hdim)); ++d) { + std::cout << "o[" << b << "][" << s << "][" << h << "][" << d << "] = " + << static_cast(o(b, s, h, d)) + << std::endl; + } + } + } + } + + + + return 1; // ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } int main(int argc, char* argv[]) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 2f1b5746553..1d34ee56701 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -321,7 +321,7 @@ struct UnifiedAttentionKernel const index_t context_len = seq_len - cur_batch_query_len; index_t _max_seq_prefix_len = - (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv + (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + 1); if(seq_len < _max_seq_prefix_len) @@ -457,10 +457,10 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - seq_len - cur_batch_query_len, // y (i.e. context) - cur_batch_query_len, // x (i.e. extend) - seq_len, // y_total (x + y) - cur_batch_query_len, // x_total + -1, + 0, + cur_batch_query_len, // y_total + seq_len, // x_total num_queries_per_kv // the same sequence index is repeated num_queries_per_kv // times along x dim of the tile ); From de995fea7116cceed2d01e195ecc583b079c61cf Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 18 Nov 2025 13:04:58 +0000 Subject: [PATCH 66/88] Various fixes --- .../example_unified_attention.cpp | 53 +++++++++++-------- .../unified_attention_impl.hpp | 2 +- .../kernel/unified_attention_kernel.hpp | 42 ++++++--------- .../pipeline/unified_attention_pipeline.hpp | 1 - 4 files changed, 49 insertions(+), 49 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index fdf9c7122a8..50eac35c3f2 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -343,8 +343,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }; calculate_cumulative(eff_query_lens, cu_query_lens); - ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size()); - ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size()); + ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size() * sizeof(ck_tile::index_t)); seq_lens_buf.ToDevice(eff_kv_lens.data()); query_start_len_buf.ToDevice(cu_query_lens.data()); @@ -525,31 +525,40 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); - // const auto [rtol, atol] = [&] { - // if constexpr(std::is_same_v) - // return std::make_tuple(1e-3, 1e-3); - // else - // return std::make_tuple(1e-2, 1e-2); - // }(); + const auto [rtol, atol] = [&] { + if constexpr(std::is_same_v) + return std::make_tuple(1e-3, 1e-3); + else + return std::make_tuple(1e-2, 1e-2); + }(); - // Print some of the output data for debugging - std::cout << "\nFirst few elements of output tensor o:" << std::endl; - for(int b = 0; b < std::min(2, static_cast(problem.batch)); ++b) { - std::cout << "Batch " << b << ":" << std::endl; - for(int s = 0; s < std::min(5, static_cast(eff_query_lens[b])); ++s) { - for(int h = 0; h < std::min(2, static_cast(problem.nhead_q)); ++h) { - for(int d = 0; d < std::min(4, static_cast(problem.hdim)); ++d) { - std::cout << "o[" << b << "][" << s << "][" << h << "][" << d << "] = " - << static_cast(o(b, s, h, d)) - << std::endl; + size_t total = static_cast(problem.num_tokens) * + static_cast(problem.nhead_q) * + static_cast(problem.hdim); + + size_t nonzero = 0; + + for (int b = 0; b < problem.batch; ++b) { + for (int s = 0; s < eff_query_lens[b]; ++s) { + for (int h = 0; h < problem.nhead_q; ++h) { + for (int d = 0; d < problem.hdim; ++d) { + if (static_cast(o(b, s, h, d)) != 0.0f) { + nonzero++; + } } } } } - - - - return 1; // ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + + float percent = (total > 0) + ? (100.0f * static_cast(nonzero) / static_cast(total)) + : 0.0f; + + std::cout << "\nNon-zero elements in output tensor o: " + << nonzero << " / " << total + << " (" << percent << "%)\n"; + + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } int main(int argc, char* argv[]) diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 8b1536f52a8..855c99f8419 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -124,7 +124,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, "argument num_queries_per_kv must equal compiled num_queries_per_kv"); assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); - assert(BLOCK_Q == args.num_head_q / args.num_queries_per_kv && + assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv && "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 1d34ee56701..969a9aac823 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -184,7 +184,7 @@ struct UnifiedAttentionKernel while(left < right) { ck_tile::index_t mid = (left + right) / 2; - ck_tile::index_t val = query_start_len_ptr[mid]; + ck_tile::index_t val = amd_wave_read_first_lane(query_start_len_ptr[mid]); ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val; if(mid_val <= target_idx) @@ -206,7 +206,7 @@ struct UnifiedAttentionKernel using namespace ck_tile; constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q); + const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q / kargs.num_queries_per_kv); // Number of pids per XCD in the new arrangement const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; @@ -245,10 +245,7 @@ struct UnifiedAttentionKernel // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, // UnifiedAttentionPipeline::kN1); - const index_t i_tile_m = pid % total_num_q_blocks; // Query block index - const index_t i_tile_n = pid / total_num_q_blocks; // Head index - - return ck_tile::make_tuple(i_tile_m, i_tile_n); + return ck_tile::make_tuple(pid / total_num_q_blocks, pid % total_num_q_blocks); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -277,7 +274,6 @@ struct UnifiedAttentionKernel // const index_t num_head_q = kargs.num_head_q; // const index_t num_head_k = num_head_q / num_queries_per_kv; - pid = RemapTileIndices(pid, kargs); // divide problem @@ -295,19 +291,15 @@ struct UnifiedAttentionKernel BLOCK_Q, true); // which batch - const index_t q_block_start_idx = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); + const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx; - const index_t q_block_local_idx = - amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); - const index_t cur_batch_in_all_start_index = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx]); - const index_t cur_batch_in_all_stop_index = - amd_wave_read_first_lane(kargs.query_start_len_ptr[seq_idx + 1]); + const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx]; + const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1]; const index_t cur_batch_query_len = - cur_batch_in_all_stop_index - cur_batch_in_all_start_index; + amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); // TODO check if we get the block size info from pipeline if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len) @@ -315,14 +307,14 @@ struct UnifiedAttentionKernel return; } - const index_t query_pos = q_block_local_idx * BLOCK_Q; + const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = seq_len - cur_batch_query_len; + const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); index_t _max_seq_prefix_len = - (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - + 1); + amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + + 1)); if(seq_len < _max_seq_prefix_len) { @@ -330,7 +322,7 @@ struct UnifiedAttentionKernel } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = (max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window const index_t num_blocks_start = 0; @@ -357,7 +349,7 @@ struct UnifiedAttentionKernel const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q; + index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q); // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window @@ -367,20 +359,20 @@ struct UnifiedAttentionKernel make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, - number<2>{}); + number<1>{}); const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED q_dram_base, // block sizes - make_tuple(number{}, number<1>{}, number{}), + make_tuple(number{}, 1, HEAD_SIZE_PADDED), sequence{}); // pads to (seq_len_padded, num_head_q, // HEAD_SIZE_PADDED) const auto q_dram_merged = transform_tensor_view( q_dram_pad, make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), - make_pass_through_transform(number{})), + make_pass_through_transform(HEAD_SIZE_PADDED)), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); // flattens the first two dims, head idx is the fastest diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index a575230ef6b..3d941f5503d 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" - #define ENABLE_ASM_MARKER 1 #if ENABLE_ASM_MARKER #define ASM_MARKER(marker) \ From f552cd7841d40f06987f994bf609d2ff73ca2904 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 20 Nov 2025 11:34:39 +0000 Subject: [PATCH 67/88] ref data copying --- .../01_unified_attention/example_unified_attention.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 50eac35c3f2..5bc65447468 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -32,7 +32,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::index_t seq_q_off = cu_query_lens[b]; // Copy effective region q_b.ForEach([&](auto& self, auto idx) { // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); + self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]); }); k_b.ForEach([&](auto& self, auto idx) { // kv cache is paged @@ -516,7 +517,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) { for(int d = 0; d < problem.hdim; ++d) { - o_ref(b, s, h, d) = o_b(0, s, h, d); + o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d); } } } From f2fbc44b7bf4dbfb0e8a5f031eb7398e9e218a37 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Mon, 24 Nov 2025 10:20:04 +0000 Subject: [PATCH 68/88] fix --- .../example_unified_attention.cpp | 17 +++++++++-------- .../kernel/unified_attention_kernel.hpp | 18 +++++++++--------- .../pipeline/unified_attention_pipeline.hpp | 9 +++++---- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 5bc65447468..5d8f3fb4351 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -238,14 +238,14 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - -1, - 0, - seqlen_q, - seqlen_kv, - true)); + // ck_tile::reference_batched_masking( + // s_host_ref, + // ck_tile::make_generic_attention_mask_from_lr_window( + // -1, + // 0, + // seqlen_q, + // seqlen_kv, + // true)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( @@ -526,6 +526,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); + const auto [rtol, atol] = [&] { if constexpr(std::is_same_v) return std::make_tuple(1e-3, 1e-3); diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 969a9aac823..366a75e2dfe 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -310,18 +310,18 @@ struct UnifiedAttentionKernel const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); + // const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); - index_t _max_seq_prefix_len = - amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - + 1)); + // index_t _max_seq_prefix_len = + // amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + // + 1)); - if(seq_len < _max_seq_prefix_len) - { - _max_seq_prefix_len = seq_len; - } + // if(seq_len < _max_seq_prefix_len) + // { + // _max_seq_prefix_len = seq_len; + // } - const auto max_seq_prefix_len = _max_seq_prefix_len; + const auto max_seq_prefix_len = seq_len; // _max_seq_prefix_len; const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3d941f5503d..3bb30149bf4 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -897,10 +897,11 @@ struct UnifiedAttentionPipeline auto fmha_mask = [&](auto sp_reg_idx) { if constexpr(FmhaMask::IsMasking) { - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - i_total_loops * BLOCK_SIZE, - number{}, - number{}); + bool need_perpixel_check = false; + // mask.IsEdgeTile(q_origin.at(number<0>{}), + // i_total_loops * BLOCK_SIZE, + // number{}, + // number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, From 76d1866537c8edd804a43b3cc8ff01cb97abc3a1 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 24 Nov 2025 10:26:26 +0000 Subject: [PATCH 69/88] Pipeline minor fixes --- .../pipeline/unified_attention_pipeline.hpp | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3bb30149bf4..5844285ffeb 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -410,7 +410,7 @@ struct UnifiedAttentionPipeline HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * BLOCK_SIZE <= GetSmemSize()); + static_assert(sizeof(SaccDataType) * BLOCK_SIZE * BLOCK_M <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), MakeSimpleLdsDesc()); @@ -426,7 +426,7 @@ struct UnifiedAttentionPipeline auto o_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto o_lds_window = make_tile_window( o_lds, make_tuple(number{}, number{}), {0, 0}); @@ -542,16 +542,9 @@ struct UnifiedAttentionPipeline clear_tile(l); const auto q_origin = q_dram_window.get_window_origin(); - // const auto [seqlen_k_start, seqlen_k_end] = - // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, - // number{}); - // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, - // BLOCK_SIZE); const auto num_total_loop = num_blocks; - // index_t kv_token_start = seqlen_k_start; - // TODO check is paddings kPadSeqLenK // check early exit if no work to do if constexpr(FmhaMask::IsMasking) { @@ -567,20 +560,19 @@ struct UnifiedAttentionPipeline index_t i_total_loops = num_blocks_start; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; - index_t kv_blk_idx_prev = 0; + index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + i_total_loops]; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, + {kv_blk_idx_intial * BLOCK_SIZE, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {(kv_blk_idx - kv_blk_idx_prev) * BLOCK_SIZE, 0}, // TODO: hdim split? + {kv_blk_idx_intial * BLOCK_SIZE, 0}, Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); @@ -676,6 +668,7 @@ struct UnifiedAttentionPipeline async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx // as the index + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; /// FIXME: use the future-predicting method to move the window k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -686,7 +679,7 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - // kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; /// FIXME: use the future-predicting method to move the window v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -985,7 +978,6 @@ struct UnifiedAttentionPipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); - // TODO what is this??? Scheduler::schedule(cl_p, number<1>{}); fmha_mask(xdl_SP_p01_reg_idx); From b3c5cd0c762f77988618b43d6fc59c9298803452 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 24 Nov 2025 15:32:33 +0000 Subject: [PATCH 70/88] Fixed the block_table --- .../example_unified_attention.cpp | 10 ++--- .../kernel/unified_attention_kernel.hpp | 41 ++++++------------- .../pipeline/unified_attention_pipeline.hpp | 16 +++++--- 3 files changed, 25 insertions(+), 42 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 5d8f3fb4351..765c55a5ced 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -35,11 +35,6 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair( k_ptr, - make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), - make_tuple(kargs.stride_k_cache_0, kargs.stride_k_cache_1, kargs.stride_k_cache_3), + make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.stride_k_cache_1, kargs.stride_k_cache_3), number{}, number<1>{}); const auto k_dram_pad = pad_tensor_view(k_dram_naive, // TODO can the BLOCK_SIZE_RAW needs padding? - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); + make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); - const auto k_dram_merged = transform_tensor_view( - k_dram_pad, - make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), - make_pass_through_transform(HEAD_SIZE_PADDED)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, - sequence<1>{})); // flattens the first two dims, head idx is the fastest - // changing dim in the merged dim - - return k_dram_merged; + return k_dram_pad; }(); auto k_dram_window = make_tile_window( @@ -422,25 +414,16 @@ struct UnifiedAttentionKernel const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.num_blks, BLOCK_SIZE, HEAD_SIZE), - make_tuple(kargs.stride_v_cache_0, kargs.stride_v_cache_1, kargs.stride_v_cache_3), + make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.stride_v_cache_1, kargs.stride_v_cache_3), number{}, number<1>{}); const auto v_dram_pad = pad_tensor_view(v_dram_naive, - make_tuple(1, BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); - - const auto v_dram_merged = transform_tensor_view( - v_dram_pad, - make_tuple(make_merge_transform(make_tuple(kargs.num_blks, BLOCK_SIZE)), - make_pass_through_transform(HEAD_SIZE_PADDED)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, - sequence<1>{})); // flattens the first two dims, head idx is the fastest - // changing dim in the merged dim + make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), + sequence{}); - return v_dram_merged; + return v_dram_pad; }(); auto v_dram_window = make_tile_window( diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 5844285ffeb..3fbfb0cd9ed 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -544,6 +544,8 @@ struct UnifiedAttentionPipeline const auto q_origin = q_dram_window.get_window_origin(); const auto num_total_loop = num_blocks; + index_t k_block_table_off = num_blocks_start; + index_t v_block_table_off = num_blocks_start; // check early exit if no work to do if constexpr(FmhaMask::IsMasking) @@ -557,10 +559,11 @@ struct UnifiedAttentionPipeline } } + // TODO check correctness of this index_t i_total_loops = num_blocks_start; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + i_total_loops]; + index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + k_block_table_off]; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -668,7 +671,9 @@ struct UnifiedAttentionPipeline async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx // as the index - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + + k_block_table_off++; + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + k_block_table_off]; /// FIXME: use the future-predicting method to move the window k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -679,7 +684,9 @@ struct UnifiedAttentionPipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + i_total_loops]; + v_block_table_off++; + + index_t kv_blk_idx = block_tables_ptr_[block_table_offset + v_block_table_off]; /// FIXME: use the future-predicting method to move the window v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); }; @@ -1006,7 +1013,6 @@ struct UnifiedAttentionPipeline cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); Scheduler::schedule(cl_p, number<3>{}); - // kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1053,7 +1059,6 @@ struct UnifiedAttentionPipeline Scheduler::schedule(cl_p, number<2>{}); fmha_mask(xdl_SP_p01_reg_idx); - // kv_token_start += BLOCK_SIZE; if(num_total_loop <= ++i_total_loops) { result = false; @@ -1131,7 +1136,6 @@ struct UnifiedAttentionPipeline fmha_alu0(number<0>{}); fmha_alu_D_upd(); - // kv_token_start += BLOCK_SIZE; ++i_total_loops; if(num_total_loop <= i_total_loops) { From cc7caf4d7dfbde0368c5c6c385b462f3aa109b76 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 25 Nov 2025 09:27:40 +0000 Subject: [PATCH 71/88] correct results --- .../example_unified_attention.cpp | 40 +++++++++++++++---- .../kernel/unified_attention_kernel.hpp | 19 +++++---- .../pipeline/unified_attention_pipeline.hpp | 9 ++--- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 765c55a5ced..74f0e3f80f1 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -233,14 +233,14 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_gemm( q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); - // ck_tile::reference_batched_masking( - // s_host_ref, - // ck_tile::make_generic_attention_mask_from_lr_window( - // -1, - // 0, - // seqlen_q, - // seqlen_kv, - // true)); + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + -1, + 0, + seqlen_q, + seqlen_kv, + true)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( @@ -556,6 +556,30 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) << nonzero << " / " << total << " (" << percent << "%)\n"; + // std::cout << "\n=== Complete Output Tensor (o) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o(tok, h, d)) << " "; + // } + // std::cout << "\n"; + // } + // } + + // std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o_ref(tok, h, d)) << " "; + // } + // std::cout << "\n"; + // } + // } + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 90433da0fb2..735a8c42529 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -310,19 +310,18 @@ struct UnifiedAttentionKernel const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - // const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); + const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); - // index_t _max_seq_prefix_len = - // amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - // + 1)); + index_t _max_seq_prefix_len = + amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + + 1)); - // if(seq_len < _max_seq_prefix_len) - // { - // _max_seq_prefix_len = seq_len; - // } + if(seq_len < _max_seq_prefix_len) + { + _max_seq_prefix_len = seq_len; + } - // const auto max_seq_prefix_len = _max_seq_prefix_len; - const auto max_seq_prefix_len = seq_len; + const auto max_seq_prefix_len = _max_seq_prefix_len; const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 3fbfb0cd9ed..d661c342f95 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -897,11 +897,10 @@ struct UnifiedAttentionPipeline auto fmha_mask = [&](auto sp_reg_idx) { if constexpr(FmhaMask::IsMasking) { - bool need_perpixel_check = false; - // mask.IsEdgeTile(q_origin.at(number<0>{}), - // i_total_loops * BLOCK_SIZE, - // number{}, - // number{}); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + i_total_loops * BLOCK_SIZE, + number{}, + number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, From 6a2ac8f758ac7d314a29ebe3b3c09d052499115b Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 09:16:30 +0000 Subject: [PATCH 72/88] causal mask fix --- .../example_unified_attention.cpp | 47 ++++++++----------- .../unified_attention.hpp | 8 ++++ .../kernel/unified_attention_kernel.hpp | 3 +- .../pipeline/unified_attention_pipeline.hpp | 2 +- 4 files changed, 31 insertions(+), 29 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 74f0e3f80f1..a4458eeffc2 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -70,13 +70,6 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair; - using GenericMask = ck_tile::GenericAttentionMask; - using CausalMask = ck_tile::GenericAttentionMask; -}; - struct Problem { explicit Problem(const ck_tile::ArgParser& args) @@ -235,12 +228,13 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_masking( s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( + ck_tile::make_generic_attention_mask_from_lr_window( -1, 0, seqlen_q, seqlen_kv, - true)); + 1, + false)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( @@ -557,29 +551,28 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) << " (" << percent << "%)\n"; // std::cout << "\n=== Complete Output Tensor (o) ===\n"; - // for (int tok = 0; tok < problem.num_tokens; ++tok) { - // std::cout << "Token " << tok << ":\n"; - // for (int h = 0; h < problem.nhead_q; ++h) { - // std::cout << " Head " << h << ": "; - // for (int d = 0; d < problem.hdim; ++d) { - // std::cout << static_cast(o(tok, h, d)) << " "; - // } - // std::cout << "\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o(tok, h, d)) << " "; // } + // std::cout << "\n"; // } + // } - // std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n"; - // for (int tok = 0; tok < problem.num_tokens; ++tok) { - // std::cout << "Token " << tok << ":\n"; - // for (int h = 0; h < problem.nhead_q; ++h) { - // std::cout << " Head " << h << ": "; - // for (int d = 0; d < problem.hdim; ++d) { - // std::cout << static_cast(o_ref(tok, h, d)) << " "; - // } - // std::cout << "\n"; + // std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n"; + // for (int tok = 0; tok < problem.num_tokens; ++tok) { + // std::cout << "Token " << tok << ":\n"; + // for (int h = 0; h < problem.nhead_q; ++h) { + // std::cout << " Head " << h << ": "; + // for (int d = 0; d < problem.hdim; ++d) { + // std::cout << static_cast(o_ref(tok, h, d)) << " "; // } + // std::cout << "\n"; // } - + // } return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index f418a4a0d9e..ed3e1e6b50f 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/ops/unified_attention.hpp" namespace ck_tile { @@ -76,3 +77,10 @@ std::pair unified_attention(const unified_attention_args& args, const stream_config& config); } // namespace ck_tile + +struct UnifiedAttentionMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 735a8c42529..2727c563c0c 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -435,8 +435,9 @@ struct UnifiedAttentionKernel 0, cur_batch_query_len, // y_total seq_len, // x_total - num_queries_per_kv // the same sequence index is repeated num_queries_per_kv + num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv // times along x dim of the tile + false ); else return FmhaMask{cur_batch_query_len, seq_len}; diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index d661c342f95..105bfcca6e7 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -899,7 +899,7 @@ struct UnifiedAttentionPipeline { bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), i_total_loops * BLOCK_SIZE, - number{}, + number{}, number{}); if(need_perpixel_check) { From c641d0d42c38e74b8c463f03e3aca09c9cc81974 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 09:24:52 +0000 Subject: [PATCH 73/88] non zero calculation fix --- .../example_unified_attention.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index a4458eeffc2..401dd904966 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -436,7 +436,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) std::cout << "[" << problem.data_type << "|"; std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", d:" << problem.hdim << ", mask:" << problem.mask << std::fixed << ", " + << ", d:" << problem.hdim << ", mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; @@ -530,13 +530,11 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) size_t nonzero = 0; - for (int b = 0; b < problem.batch; ++b) { - for (int s = 0; s < eff_query_lens[b]; ++s) { - for (int h = 0; h < problem.nhead_q; ++h) { - for (int d = 0; d < problem.hdim; ++d) { - if (static_cast(o(b, s, h, d)) != 0.0f) { + for (int tok = 0; tok < problem.num_tokens; ++tok) { + for (int h = 0; h < problem.nhead_q; ++h) { + for (int d = 0; d < problem.hdim; ++d) { + if (static_cast(o(tok, h, d)) != 0.0f) { nonzero++; - } } } } From eeb419845df84a224e4bbd300c6285481996395c Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 10:32:28 +0000 Subject: [PATCH 74/88] fmha v3 flops calculation --- .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 7ddb65a2dbc..4bb740217ef 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -463,17 +463,33 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) } std::size_t flop = [&] { - if(problem.mask.type == mask_enum::no_mask) - { - return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; - } - else + long flop_result = 0; + + for(int b = 0; b < problem.batch; ++b) { - /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. - return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * - problem.hdim; + long query_lens = has_varlen_q ? eff_q_vec[b] : problem.seqlen_q; + long kv_lens = has_varlen_k ? eff_kv_vec[b] : problem.seqlen_k; + long valid_out_elements = 0; + + if(problem.mask.type == mask_enum::no_mask) { + valid_out_elements = kv_lens * query_lens; + } else { + if(query_lens > kv_lens) + { + valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2; + } + else + { + valid_out_elements = + query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2); + } + + } + // Causal logic for valid output elements + + flop_result += 2 * problem.nhead_q * valid_out_elements * (problem.hdim + problem.hdim); } + return flop_result; }(); float tflops = static_cast(flop) / 1.e9 / time; From 3131ebf1dfb227cddd03eefbc5a60ab10a307b55 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 13:28:35 +0000 Subject: [PATCH 75/88] simplified kernel pid logic --- .../kernel/unified_attention_kernel.hpp | 42 +------------------ 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 2727c563c0c..396bd6d2b84 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -200,52 +200,15 @@ struct UnifiedAttentionKernel return left - 1; } - CK_TILE_DEVICE static constexpr auto RemapTileIndices(const ck_tile::index_t pid, - const Kargs& kargs) - { - using namespace ck_tile; - - constexpr index_t NUM_XCDS = 8; - const index_t GRID_MN = kargs.total_num_q_blocks * (kargs.num_head_q / kargs.num_queries_per_kv); - - // Number of pids per XCD in the new arrangement - const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS; - - // When GRID_MN cannot divide NUM_XCDS, some xcds will have - // pids_per_xcd pids, the other will have pids_per_xcd - 1 pids. - // We calculate the number of xcds that have pids_per_xcd pids as tall_xcds - index_t tall_xcds = GRID_MN % NUM_XCDS; - tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds; - - // Compute current XCD and local pid within the XCD - const index_t xcd = pid % NUM_XCDS; - const index_t local_pid = pid / NUM_XCDS; - - // Calculate new pid based on the new grouping - index_t remapped_pid = 0; // Initialize to avoid constexpr error - if(xcd < tall_xcds) - { - remapped_pid = xcd * pids_per_xcd + local_pid; - } - else - { - remapped_pid = - tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid; - } - - return remapped_pid; - } CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) { using namespace ck_tile; - ck_tile::index_t total_num_q_blocks = kargs.total_num_q_blocks; - // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, - // UnifiedAttentionPipeline::kN1); + ck_tile::index_t num_head_kv = kargs.num_head_q / kargs.num_queries_per_kv; - return ck_tile::make_tuple(pid / total_num_q_blocks, pid % total_num_q_blocks); + return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } @@ -274,7 +237,6 @@ struct UnifiedAttentionKernel // const index_t num_head_q = kargs.num_head_q; // const index_t num_head_k = num_head_q / num_queries_per_kv; - pid = RemapTileIndices(pid, kargs); // divide problem const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs); From 60ca9484b4471cf86e057a915ad72694d82b2d0d Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 27 Nov 2025 15:07:03 +0000 Subject: [PATCH 76/88] refined benchmarking --- .../example_unified_attention.cpp | 84 ++++++++++++++++++- 1 file changed, 81 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 401dd904966..d4313d79d51 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -35,7 +35,10 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair std::pair std::pair std::pair& query_lens_input, + const std::vector& kv_lens_input, + bool varlen) -> std::pair, std::vector> +{ + // If both query_lens and kv_lens are provided, return them directly + if(!query_lens_input.empty() && !kv_lens_input.empty()) + { + return std::make_pair(query_lens_input, kv_lens_input); + } + + std::vector query_lens; + std::vector kv_lens; + + if(!varlen) + { + // Fixed length mode: fill with max seqlen + query_lens.assign(batch, max_seqlen_q); + kv_lens.assign(batch, max_seqlen_kv); + } + else + { + // Variable length mode: generate random lengths up to max + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution q_dist(1, max_seqlen_q); + std::uniform_int_distribution kv_dist(1, max_seqlen_kv); + + query_lens.resize(batch); + kv_lens.resize(batch); + + for(ck_tile::index_t i = 0; i < batch; ++i) + { + query_lens[i] = q_dist(gen); + kv_lens[i] = kv_dist(gen); + } + } + + return std::make_pair(query_lens, kv_lens); +} + struct Problem { explicit Problem(const ck_tile::ArgParser& args) @@ -82,10 +129,30 @@ struct Problem // TODO: support other GQA/MQA cases than just 4x nhead_q = nhead_kv * num_queries_per_kv; + ck_tile::index_t max_seqlen_q = args.get_int("s"); + ck_tile::index_t max_seqlen_kv = args.get_int("s_k"); + + if (max_seqlen_kv == -1) { + max_seqlen_kv = max_seqlen_q; + } + hdim = args.get_int("d"); query_lens = args.get_int_vec("query_lens"); kv_lens = args.get_int_vec("kv_lens"); assert(query_lens.size() == kv_lens.size() && "query_lens and kv_lens must have the same length b"); + batch = args.get_int("b"); + + bool varlen = args.get_bool("varlen"); + auto [query_lens_, kv_lens_] = seqlen_preprocess( + batch, + max_seqlen_q, + max_seqlen_kv, + query_lens, + kv_lens, + varlen); + + query_lens = query_lens_; + kv_lens = kv_lens_; batch = query_lens.size(); // Calculate scale_s @@ -436,7 +503,18 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) std::cout << "[" << problem.data_type << "|"; std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", d:" << problem.hdim << ", mask:" << "causal mask" << std::fixed << ", " + << ", d:" << problem.hdim << ", scale_s:" << problem.scale_s + << ", query_lens:["; + for (size_t i = 0; i < problem.query_lens.size(); ++i) { + std::cout << problem.query_lens[i]; + if (i < problem.query_lens.size() - 1) std::cout << ","; + } + std::cout << "], kv_lens:["; + for (size_t i = 0; i < problem.kv_lens.size(); ++i) { + std::cout << problem.kv_lens[i]; + if (i < problem.kv_lens.size() - 1) std::cout << ","; + } + std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; From 7078de91d8dc75374c4a86249949333e1db67905 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Wed, 3 Dec 2025 13:08:29 +0000 Subject: [PATCH 77/88] adding PAGE_BLOCK_SIZE >= BLOCK_SIZE optionality, now it regresses perf when it should improve? --- .../example_unified_attention.cpp | 30 ++++---- .../unified_attention.hpp | 3 +- .../unified_attention_impl.hpp | 3 +- .../kernel/unified_attention_kernel.hpp | 16 +++- .../pipeline/unified_attention_pipeline.hpp | 75 ++++++++++++++----- 5 files changed, 90 insertions(+), 37 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 335d23b1544..61f64c5a921 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -25,8 +25,8 @@ #include "unified_attention.hpp" #include "mask.hpp" -const ck_tile::index_t BLOCK_SIZE = 32; -const ck_tile::index_t num_queries_per_kv = 4; +// const ck_tile::index_t page_blk_size = 32; +const ck_tile::index_t num_queries_per_kv = 1; auto parse_cmd_args(int argc, char* argv[]) -> std::pair { @@ -60,13 +60,14 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair get_key_shape() const { - return {num_blks, BLOCK_SIZE, nhead_kv, hdim}; + return {num_blks, page_blk_size, nhead_kv, hdim}; } std::vector get_value_shape() const { - return {num_blks, BLOCK_SIZE, nhead_kv, hdim}; + return {num_blks, page_blk_size, nhead_kv, hdim}; } std::vector get_output_shape() const { return {num_tokens, nhead_q, hdim}; } @@ -191,6 +194,7 @@ struct Problem ck_tile::index_t nhead_q; ck_tile::index_t nhead_kv; ck_tile::index_t hdim; + ck_tile::index_t page_blk_size; ck_tile::index_t num_tokens; float scale_s; float scale; @@ -338,7 +342,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.num_seqs = problem.batch; args.num_head_q = problem.nhead_q; args.num_queries_per_kv = num_queries_per_kv; - args.BLOCK_SIZE = BLOCK_SIZE; + args.page_blk_size = problem.page_blk_size; args.mask_type = 2; args.hdim = problem.hdim; @@ -350,7 +354,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.k_ptr = k_buf.GetDeviceBuffer(); - args.stride_k_cache_0 = problem.hdim * problem.nhead_kv * BLOCK_SIZE; + args.stride_k_cache_0 = problem.hdim * problem.nhead_kv * problem.page_blk_size; args.stride_k_cache_1 = problem.hdim * problem.nhead_kv; args.stride_k_cache_2 = problem.hdim; args.stride_k_cache_3 = 1; @@ -424,7 +428,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::index_t max_kv_len = max_element(eff_kv_lens); - ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size; // Create block_tables ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * @@ -551,18 +555,18 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) }); k_b.ForEach([&](auto& self, auto idx) { // kv cache is paged - ck_tile::index_t table_col = int(idx[1] / BLOCK_SIZE); + ck_tile::index_t table_col = int(idx[1] / problem.page_blk_size); ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col; ck_tile::index_t block_idx = block_tables_host[block_table_offset]; - self(idx) = k(block_idx, idx[1] % BLOCK_SIZE, idx[2], idx[3]); + self(idx) = k(block_idx, idx[1] % problem.page_blk_size, idx[2], idx[3]); }); v_b.ForEach([&](auto& self, auto idx) { - ck_tile::index_t table_col = int(idx[1] / BLOCK_SIZE); + ck_tile::index_t table_col = int(idx[1] / problem.page_blk_size); ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col; ck_tile::index_t block_idx = block_tables_host[block_table_offset]; - self(idx) = v(block_idx, idx[1] % BLOCK_SIZE, idx[2], idx[3]); + self(idx) = v(block_idx, idx[1] % problem.page_blk_size, idx[2], idx[3]); }); // v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index ed3e1e6b50f..3a366901589 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -29,7 +29,8 @@ struct unified_attention_args index_t num_blks; index_t num_head_q; index_t num_queries_per_kv; - index_t BLOCK_SIZE; + index_t page_blk_size; + //index_t BLOCK_SIZE; index_t hdim; // TODO window diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 855c99f8419..021e7753365 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -63,7 +63,7 @@ struct unified_attention_kernel_traits static constexpr index_t HEAD_SIZE = 128; // TODO please fix this to support also other num_queries_per_kv - static constexpr index_t num_queries_per_kv = 4; + static constexpr index_t num_queries_per_kv = 1; static constexpr index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; // BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE @@ -139,6 +139,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, args.scale_k, args.scale_v, args.scale_out, + args.page_blk_size, total_num_q_blocks, args.query_stride_0, args.query_stride_1, diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 396bd6d2b84..4b983d503a1 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -56,8 +56,8 @@ struct UnifiedAttentionKernel struct UnifiedAttentionCommonKargs { const void* q_ptr; - const void* k_ptr; // [num_blks, blk_size, num_kv_heads, head_size] - const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size] + const void* k_ptr; // [num_blks, page_blk_size, num_kv_heads, head_size] + const void* v_ptr; // [num_blks, page_blk_size, num_kv_heads, head_size] void* o_ptr; ck_tile::index_t num_blks; @@ -72,6 +72,8 @@ struct UnifiedAttentionKernel float scale_v; float scale_out; + ck_tile::index_t page_blk_size; + ck_tile::index_t total_num_q_blocks; ck_tile::index_t query_stride_0; ck_tile::index_t query_stride_1; @@ -111,6 +113,7 @@ struct UnifiedAttentionKernel float scale_k, float scale_v, float scale_out, + ck_tile::index_t page_blk_size, ck_tile::index_t total_num_q_blocks, ck_tile::index_t query_stride_0, ck_tile::index_t query_stride_1, @@ -142,6 +145,7 @@ struct UnifiedAttentionKernel scale_k, scale_v, scale_out, + page_blk_size, total_num_q_blocks, query_stride_0, query_stride_1, @@ -356,7 +360,7 @@ struct UnifiedAttentionKernel // HEAD dim is skipped as defined in the ptrs const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.num_blks * kargs.page_blk_size, HEAD_SIZE), make_tuple(kargs.stride_k_cache_1, kargs.stride_k_cache_3), number{}, number<1>{}); @@ -375,7 +379,7 @@ struct UnifiedAttentionKernel const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.num_blks * BLOCK_SIZE, HEAD_SIZE), + make_tuple(kargs.num_blks * kargs.page_blk_size, HEAD_SIZE), make_tuple(kargs.stride_v_cache_1, kargs.stride_v_cache_3), number{}, number<1>{}); @@ -405,6 +409,9 @@ struct UnifiedAttentionKernel return FmhaMask{cur_batch_query_len, seq_len}; }(); + const index_t kv_page_size_in_blocks = kargs.page_blk_size / BLOCK_SIZE; + assert(kv_page_size_in_blocks >= 1); // BLOCK_SIZE <= page_blk_size + auto o_acc_tile = [&]() { return UnifiedAttentionPipeline{}(q_dram_window, k_dram_window, @@ -413,6 +420,7 @@ struct UnifiedAttentionKernel num_blocks_start, kargs.block_tables_ptr, block_table_offset, + kv_page_size_in_blocks, mask, kargs.scale_s, smem_ptr); diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 105bfcca6e7..1452e4bad1a 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -388,6 +388,7 @@ struct UnifiedAttentionPipeline const index_t num_blocks_start, const void* block_tables_ptr, index_t block_table_offset, + const index_t kv_page_size_in_blocks, [[maybe_unused]] const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, @@ -546,6 +547,7 @@ struct UnifiedAttentionPipeline const auto num_total_loop = num_blocks; index_t k_block_table_off = num_blocks_start; index_t v_block_table_off = num_blocks_start; + // check early exit if no work to do if constexpr(FmhaMask::IsMasking) @@ -561,21 +563,23 @@ struct UnifiedAttentionPipeline // TODO check correctness of this index_t i_total_loops = num_blocks_start; + const index_t PAGE_BLOCK_SIZE = kv_page_size_in_blocks * BLOCK_SIZE; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - index_t kv_blk_idx_intial = block_tables_ptr_[block_table_offset + k_block_table_off]; + assert(k_block_table_off == v_block_table_off); // because of the following line + index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_table_off]; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {kv_blk_idx_intial * BLOCK_SIZE, 0}, + {kv_blk_idx_initial * PAGE_BLOCK_SIZE, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {kv_blk_idx_intial * BLOCK_SIZE, 0}, + {kv_blk_idx_initial * PAGE_BLOCK_SIZE, 0}, Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); @@ -667,28 +671,61 @@ struct UnifiedAttentionPipeline constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access(); constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); + // Page block index tracking + // const index_t kv_page_size_in_blocks = + // PAGE_BLOCK_SIZE / BLOCK_SIZE; + index_t k_block_i_inside_page = 0; + index_t v_block_i_inside_page = 0; auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); - // TODO maybe needs i_total_loops as argument. Or maybe needs to use the k_lds_write_idx - // as the index - - k_block_table_off++; - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + k_block_table_off]; - /// FIXME: use the future-predicting method to move the window - k_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); - }; - - auto K_lds_load = [&](auto k_lds_read_idx) { - kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx)); + // prefetch next K tile (only if not at the end of loop) + if (k_block_table_off * kv_page_size_in_blocks + k_block_i_inside_page + 1 >= num_total_loop) + { + return; + } + // Update block index inside the page + ++k_block_i_inside_page; + if(k_block_i_inside_page < kv_page_size_in_blocks) + { + // Staying inside the page, just move the window + move_tile_window(k_dram_window, {BLOCK_SIZE, 0}); + } + else + { + // Moving outside the page, fetch new physical page index + k_block_table_off++; + index_t k_page_blk_idx = block_tables_ptr_[block_table_offset + k_block_table_off]; + k_dram_window.set_window_origin({k_page_blk_idx * PAGE_BLOCK_SIZE, 0}); + k_block_i_inside_page = 0; + } }; auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - v_block_table_off++; + // prefetch next V tile (only if not at the end of loop) + if (v_block_table_off * kv_page_size_in_blocks + v_block_i_inside_page + 1 >= num_total_loop) + { + return; + } + // Update the block index inside the page + ++v_block_i_inside_page; + if(v_block_i_inside_page < kv_page_size_in_blocks) + { + // Staying inside the page, just move the window + move_tile_window(v_dram_window, {BLOCK_SIZE, 0}); + } + else + { + // Moving outside the page, fetch new physical page index + v_block_table_off++; + index_t v_page_blk_idx = block_tables_ptr_[block_table_offset + v_block_table_off]; + v_dram_window.set_window_origin({v_page_blk_idx * PAGE_BLOCK_SIZE, 0}); + v_block_i_inside_page = 0; + } + }; - index_t kv_blk_idx = block_tables_ptr_[block_table_offset + v_block_table_off]; - /// FIXME: use the future-predicting method to move the window - v_dram_window.set_window_origin({kv_blk_idx * BLOCK_SIZE, 0}); + auto K_lds_load = [&](auto k_lds_read_idx) { + kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx)); }; auto V_lds_load = [&](auto v_lds_read_idx) { @@ -1216,6 +1253,7 @@ struct UnifiedAttentionPipeline const index_t num_blocks_start, const void* block_tables_ptr, index_t block_table_offset, + const index_t kv_page_size_in_blocks, FmhaMask mask, float scale_s, void* smem_ptr) const @@ -1232,6 +1270,7 @@ struct UnifiedAttentionPipeline num_blocks_start, block_tables_ptr, block_table_offset, + kv_page_size_in_blocks, identity{}, identity{}, identity{}, From 345758971e103b1efa6a0b1059e5f380952487a3 Mon Sep 17 00:00:00 2001 From: Juuso Korhonen <40278371+juuso-oskari@users.noreply.github.com> Date: Wed, 3 Dec 2025 13:36:48 +0000 Subject: [PATCH 78/88] fix --- .../example_unified_attention.cpp | 18 +++++++++--------- .../pipeline/unified_attention_pipeline.hpp | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 61f64c5a921..2b538a8e974 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -52,7 +52,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair std::pair seed; @@ -522,10 +522,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; - // if(!run_config.verify) - // { - // return true; - // } + if(!run_config.verify) + { + return true; + } // variable lengths are provided -> compute per-batch references // with the effective lengths; else compute a single full reference. diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 1452e4bad1a..4cb5637c3ab 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -694,7 +694,7 @@ struct UnifiedAttentionPipeline { // Moving outside the page, fetch new physical page index k_block_table_off++; - index_t k_page_blk_idx = block_tables_ptr_[block_table_offset + k_block_table_off]; + index_t k_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + k_block_table_off]); k_dram_window.set_window_origin({k_page_blk_idx * PAGE_BLOCK_SIZE, 0}); k_block_i_inside_page = 0; } @@ -718,7 +718,7 @@ struct UnifiedAttentionPipeline { // Moving outside the page, fetch new physical page index v_block_table_off++; - index_t v_page_blk_idx = block_tables_ptr_[block_table_offset + v_block_table_off]; + index_t v_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + v_block_table_off]); v_dram_window.set_window_origin({v_page_blk_idx * PAGE_BLOCK_SIZE, 0}); v_block_i_inside_page = 0; } From 73aed1b57cceda50f66e1bfd369fce188b07f970 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 11 Dec 2025 09:21:55 +0000 Subject: [PATCH 79/88] remove if statements --- .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 8 +- .../example_unified_attention.cpp | 115 +++++++++--------- .../unified_attention.hpp | 2 +- .../kernel/unified_attention_kernel.hpp | 34 +++--- .../pipeline/unified_attention_pipeline.hpp | 75 ++++-------- 5 files changed, 106 insertions(+), 128 deletions(-) diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 074d69c2f99..0c1b6255004 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -471,9 +471,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) long kv_lens = has_varlen_k ? eff_kv_vec[b] : problem.seqlen_k; long valid_out_elements = 0; - if(problem.mask.type == mask_enum::no_mask) { + if(problem.mask.type == mask_enum::no_mask) + { valid_out_elements = kv_lens * query_lens; - } else { + } + else + { if(query_lens > kv_lens) { valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2; @@ -483,7 +486,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) valid_out_elements = query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2); } - } // Causal logic for valid output elements diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 2b538a8e974..e43a8df76ef 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -34,7 +34,10 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair std::pair& query_lens_input, - const std::vector& kv_lens_input, - bool varlen) -> std::pair, std::vector> + ck_tile::index_t max_seqlen_q, + ck_tile::index_t max_seqlen_kv, + const std::vector& query_lens_input, + const std::vector& kv_lens_input, + bool varlen) -> std::pair, std::vector> { // If both query_lens and kv_lens are provided, return them directly if(!query_lens_input.empty() && !kv_lens_input.empty()) @@ -107,11 +110,11 @@ auto seqlen_preprocess(ck_tile::index_t batch, query_lens.resize(batch); kv_lens.resize(batch); - + for(ck_tile::index_t i = 0; i < batch; ++i) { query_lens[i] = q_dist(gen); - kv_lens[i] = kv_dist(gen); + kv_lens[i] = kv_dist(gen); } } @@ -131,31 +134,27 @@ struct Problem nhead_q = nhead_kv * num_queries_per_kv; ck_tile::index_t max_seqlen_q = args.get_int("s"); - ck_tile::index_t max_seqlen_kv = args.get_int("s_k"); + ck_tile::index_t max_seqlen_kv = args.get_int("s_k"); - if (max_seqlen_kv == -1) { + if(max_seqlen_kv == -1) + { max_seqlen_kv = max_seqlen_q; } - + hdim = args.get_int("d"); query_lens = args.get_int_vec("query_lens"); kv_lens = args.get_int_vec("kv_lens"); - assert(query_lens.size() == kv_lens.size() && "query_lens and kv_lens must have the same length b"); - batch = args.get_int("b"); + assert(query_lens.size() == kv_lens.size() && + "query_lens and kv_lens must have the same length b"); + batch = args.get_int("b"); page_blk_size = args.get_int("page_blk_size"); - bool varlen = args.get_bool("varlen"); - auto [query_lens_, kv_lens_] = seqlen_preprocess( - batch, - max_seqlen_q, - max_seqlen_kv, - query_lens, - kv_lens, - varlen); + auto [query_lens_, kv_lens_] = + seqlen_preprocess(batch, max_seqlen_q, max_seqlen_kv, query_lens, kv_lens, varlen); query_lens = query_lens_; - kv_lens = kv_lens_; + kv_lens = kv_lens_; batch = query_lens.size(); // Calculate scale_s @@ -164,9 +163,9 @@ struct Problem scale_s = 1.0f / ck_tile::sqrt(static_cast(hdim)); // Initialize other scales - scale = args.get_float("scale"); - scale_k = args.get_float("scale_k"); - scale_v = args.get_float("scale_v"); + scale = args.get_float("scale"); + scale_k = args.get_float("scale_k"); + scale_v = args.get_float("scale_v"); num_tokens = 0; for(const auto& len : query_lens) { @@ -300,17 +299,12 @@ CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, ck_tile::reference_batched_masking( s_host_ref, ck_tile::make_generic_attention_mask_from_lr_window( - -1, - 0, - seqlen_q, - seqlen_kv, - 1, - false)); + -1, 0, seqlen_q, seqlen_kv, 1, false)); ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, ck_tile::identity{}); ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); - + // copy resulting per-batch data to the output tensor o_host_ref.ForEach( [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); @@ -342,7 +336,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.num_seqs = problem.batch; args.num_head_q = problem.nhead_q; args.num_queries_per_kv = num_queries_per_kv; - args.page_blk_size = problem.page_blk_size; + args.page_blk_size = problem.page_blk_size; args.mask_type = 2; args.hdim = problem.hdim; @@ -428,7 +422,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::index_t max_kv_len = max_element(eff_kv_lens); - ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size; + ck_tile::index_t max_num_blocks_per_seq = + (max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size; // Create block_tables ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * @@ -506,20 +501,22 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) std::cout << "[" << problem.data_type << "|"; std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv - << ", d:" << problem.hdim << ", scale_s:" << problem.scale_s - << ", query_lens:["; - for (size_t i = 0; i < problem.query_lens.size(); ++i) { + << ", d:" << problem.hdim << ", scale_s:" << problem.scale_s << ", query_lens:["; + for(size_t i = 0; i < problem.query_lens.size(); ++i) + { std::cout << problem.query_lens[i]; - if (i < problem.query_lens.size() - 1) std::cout << ","; + if(i < problem.query_lens.size() - 1) + std::cout << ","; } std::cout << "], kv_lens:["; - for (size_t i = 0; i < problem.kv_lens.size(); ++i) { + for(size_t i = 0; i < problem.kv_lens.size(); ++i) + { std::cout << problem.kv_lens[i]; - if (i < problem.kv_lens.size() - 1) std::cout << ","; + if(i < problem.kv_lens.size() - 1) + std::cout << ","; } - std::cout << "], mask:" << "causal mask" << std::fixed << ", " - << std::setprecision(8) << time << " ms, " << std::setprecision(2) << tflops - << " TFlops, " << std::setprecision(2) + std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time + << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << (static_cast(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl; if(!run_config.verify) @@ -597,37 +594,37 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); - const auto [rtol, atol] = [&] { if constexpr(std::is_same_v) return std::make_tuple(1e-3, 1e-3); else return std::make_tuple(1e-2, 1e-2); }(); - - size_t total = static_cast(problem.num_tokens) * - static_cast(problem.nhead_q) * + + size_t total = static_cast(problem.num_tokens) * static_cast(problem.nhead_q) * static_cast(problem.hdim); size_t nonzero = 0; - for (int tok = 0; tok < problem.num_tokens; ++tok) { - for (int h = 0; h < problem.nhead_q; ++h) { - for (int d = 0; d < problem.hdim; ++d) { - if (static_cast(o(tok, h, d)) != 0.0f) { - nonzero++; + for(int tok = 0; tok < problem.num_tokens; ++tok) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + if(static_cast(o(tok, h, d)) != 0.0f) + { + nonzero++; } } } } - float percent = (total > 0) - ? (100.0f * static_cast(nonzero) / static_cast(total)) - : 0.0f; + float percent = + (total > 0) ? (100.0f * static_cast(nonzero) / static_cast(total)) : 0.0f; - std::cout << "\nNon-zero elements in output tensor o: " - << nonzero << " / " << total - << " (" << percent << "%)\n"; + std::cout << "\nNon-zero elements in output tensor o: " << nonzero << " / " << total << " (" + << percent << "%)\n"; // std::cout << "\n=== Complete Output Tensor (o) ===\n"; // for (int tok = 0; tok < problem.num_tokens; ++tok) { @@ -652,7 +649,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // std::cout << "\n"; // } // } - return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); } int main(int argc, char* argv[]) diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index 3a366901589..64f340c5562 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -30,7 +30,7 @@ struct unified_attention_args index_t num_head_q; index_t num_queries_per_kv; index_t page_blk_size; - //index_t BLOCK_SIZE; + // index_t BLOCK_SIZE; index_t hdim; // TODO window diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 4b983d503a1..480d0f3ee92 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -204,7 +204,6 @@ struct UnifiedAttentionKernel return left - 1; } - CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs) { @@ -259,13 +258,14 @@ struct UnifiedAttentionKernel const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx; - const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); + const index_t q_block_local_idx = + amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); const index_t cur_batch_in_all_start_index = kargs.query_start_len_ptr[seq_idx]; - const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1]; + const index_t cur_batch_in_all_stop_index = kargs.query_start_len_ptr[seq_idx + 1]; const index_t cur_batch_query_len = - amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); + amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); // TODO check if we get the block size info from pipeline if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len) @@ -276,11 +276,10 @@ struct UnifiedAttentionKernel const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; - const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); + const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); - index_t _max_seq_prefix_len = - amd_wave_read_first_lane((context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) - + 1)); + index_t _max_seq_prefix_len = amd_wave_read_first_lane( + (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + 1)); if(seq_len < _max_seq_prefix_len) { @@ -288,7 +287,8 @@ struct UnifiedAttentionKernel } const auto max_seq_prefix_len = _max_seq_prefix_len; - const index_t num_blocks = amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); + const index_t num_blocks = + amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); // TODO sliding window const index_t num_blocks_start = 0; @@ -315,7 +315,8 @@ struct UnifiedAttentionKernel const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + kv_head_offset; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; - index_t query_len_padded = amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q); + index_t query_len_padded = + amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q); // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); // Q/K/V DRAM and DRAM window @@ -397,21 +398,20 @@ struct UnifiedAttentionKernel FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( - -1, + -1, 0, cur_batch_query_len, // y_total - seq_len, // x_total - num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv - // times along x dim of the tile - false - ); + seq_len, // x_total + num_queries_per_kv, // the same sequence index is repeated num_queries_per_kv + // times along x dim of the tile + false); else return FmhaMask{cur_batch_query_len, seq_len}; }(); const index_t kv_page_size_in_blocks = kargs.page_blk_size / BLOCK_SIZE; assert(kv_page_size_in_blocks >= 1); // BLOCK_SIZE <= page_blk_size - + auto o_acc_tile = [&]() { return UnifiedAttentionPipeline{}(q_dram_window, k_dram_window, diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 4cb5637c3ab..6cc8ee954af 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -545,9 +545,8 @@ struct UnifiedAttentionPipeline const auto q_origin = q_dram_window.get_window_origin(); const auto num_total_loop = num_blocks; - index_t k_block_table_off = num_blocks_start; - index_t v_block_table_off = num_blocks_start; - + index_t k_block_idx = 0; + index_t v_block_idx = 0; // check early exit if no work to do if constexpr(FmhaMask::IsMasking) @@ -562,12 +561,13 @@ struct UnifiedAttentionPipeline } // TODO check correctness of this - index_t i_total_loops = num_blocks_start; + index_t i_total_loops = num_blocks_start; const index_t PAGE_BLOCK_SIZE = kv_page_size_in_blocks * BLOCK_SIZE; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); - assert(k_block_table_off == v_block_table_off); // because of the following line - index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_table_off]; + assert(k_block_idx == v_block_idx); // because of the following line + block_table_offset += num_blocks_start; + index_t kv_blk_idx_initial = block_tables_ptr_[block_table_offset + k_block_idx]; auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), @@ -672,56 +672,35 @@ struct UnifiedAttentionPipeline constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); // Page block index tracking - // const index_t kv_page_size_in_blocks = + // const index_t kv_page_size_in_blocks = // PAGE_BLOCK_SIZE / BLOCK_SIZE; - index_t k_block_i_inside_page = 0; - index_t v_block_i_inside_page = 0; + // index_t kv_block_idx = 0; + // only for block 0 and thread + if(blockIdx.x == 0 && threadIdx.x == 0) {} auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); - // prefetch next K tile (only if not at the end of loop) - if (k_block_table_off * kv_page_size_in_blocks + k_block_i_inside_page + 1 >= num_total_loop) - { - return; - } - // Update block index inside the page - ++k_block_i_inside_page; - if(k_block_i_inside_page < kv_page_size_in_blocks) - { - // Staying inside the page, just move the window - move_tile_window(k_dram_window, {BLOCK_SIZE, 0}); - } - else - { - // Moving outside the page, fetch new physical page index - k_block_table_off++; - index_t k_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + k_block_table_off]); - k_dram_window.set_window_origin({k_page_blk_idx * PAGE_BLOCK_SIZE, 0}); - k_block_i_inside_page = 0; - } + k_block_idx++; + + index_t k_page_blk_idx = + block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; + k_dram_window.set_window_origin( + {k_page_blk_idx * PAGE_BLOCK_SIZE + + (k_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE, + 0}); }; auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); // prefetch next V tile (only if not at the end of loop) - if (v_block_table_off * kv_page_size_in_blocks + v_block_i_inside_page + 1 >= num_total_loop) - { - return; - } - // Update the block index inside the page - ++v_block_i_inside_page; - if(v_block_i_inside_page < kv_page_size_in_blocks) - { - // Staying inside the page, just move the window - move_tile_window(v_dram_window, {BLOCK_SIZE, 0}); - } - else - { - // Moving outside the page, fetch new physical page index - v_block_table_off++; - index_t v_page_blk_idx = amd_wave_read_first_lane(block_tables_ptr_[block_table_offset + v_block_table_off]); - v_dram_window.set_window_origin({v_page_blk_idx * PAGE_BLOCK_SIZE, 0}); - v_block_i_inside_page = 0; - } + v_block_idx++; + + index_t v_page_blk_idx = + block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)]; + v_dram_window.set_window_origin( + {v_page_blk_idx * PAGE_BLOCK_SIZE + + (v_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE, + 0}); + // we assume that v load is always after k }; auto K_lds_load = [&](auto k_lds_read_idx) { From e13494cc3abaaa784165500279f8ecca51b16963 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 11 Dec 2025 13:34:27 +0000 Subject: [PATCH 80/88] refactor --- .../unified_attention_impl.hpp | 16 +- include/ck_tile/ops/unified_attention.hpp | 7 - .../block/block_attention_bias_enum.hpp | 37 - .../unified_attention/block/block_dropout.hpp | 654 ------------------ .../block/block_position_encoding.hpp | 205 ------ .../block/block_rotary_embedding.hpp | 108 --- .../block/page_block_navigator.hpp | 358 ---------- .../ops/unified_attention/block/variants.hpp | 302 -------- .../kernel/unified_attention_kernel.hpp | 102 ++- .../pipeline/tile_unified_attention_shape.hpp | 13 +- .../tile_unified_attention_traits.hpp | 2 - .../pipeline/unified_attention_pipeline.hpp | 109 ++- ...fied_attention_pipeline_default_policy.hpp | 50 +- .../unified_attention_pipeline_enum.hpp | 42 -- .../unified_attention_pipeline_problem.hpp | 1 - 15 files changed, 144 insertions(+), 1862 deletions(-) delete mode 100644 include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp delete mode 100644 include/ck_tile/ops/unified_attention/block/block_dropout.hpp delete mode 100644 include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp delete mode 100644 include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp delete mode 100644 include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp delete mode 100644 include/ck_tile/ops/unified_attention/block/variants.hpp delete mode 100644 include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 021e7753365..8087c4b8e64 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -58,16 +58,16 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; - static constexpr index_t BLOCK_M = 256; + static constexpr index_t kBlockM = 256; static constexpr index_t BLOCK_SIZE = 32; static constexpr index_t HEAD_SIZE = 128; // TODO please fix this to support also other num_queries_per_kv static constexpr index_t num_queries_per_kv = 1; - static constexpr index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; + static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv; - // BLOCK_M BLOCK_Q BLOCK_SIZE HEAD_SIZE - using unified_attention_block_tile = sequence; + // kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE + using unified_attention_block_tile = sequence; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; // need to have 8 warps per workgroup to have warp specialization using unified_attention_block_warps = sequence<8, 1, 1>; @@ -119,14 +119,14 @@ template float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config) { - index_t BLOCK_Q = Kernel::BLOCK_Q; + index_t kBlockQ = Kernel::kBlockQ; assert(args.num_queries_per_kv == Kernel::num_queries_per_kv && "argument num_queries_per_kv must equal compiled num_queries_per_kv"); assert(args.BLOCK_SIZE == Kernel::BLOCK_SIZE && "argument BLOCK_SIZE must equal compiled BLOCK_SIZE"); - assert(BLOCK_Q == BLOCK_M / args.num_queries_per_kv && - "BLOCK_Q must equal BLOCK_M / num_queries_per_kv"); - index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs; + assert(kBlockQ == kBlockM / args.num_queries_per_kv && + "kBlockQ must equal kBlockM / num_queries_per_kv"); + index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs; auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, diff --git a/include/ck_tile/ops/unified_attention.hpp b/include/ck_tile/ops/unified_attention.hpp index 20eee5a819e..53ca1da684d 100644 --- a/include/ck_tile/ops/unified_attention.hpp +++ b/include/ck_tile/ops/unified_attention.hpp @@ -3,19 +3,12 @@ #pragma once -#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/unified_attention/block/block_dropout.hpp" #include "ck_tile/ops/unified_attention/block/block_masking.hpp" -#include "ck_tile/ops/unified_attention/block/block_position_encoding.hpp" -#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" -#include "ck_tile/ops/unified_attention/block/page_block_navigator.hpp" -#include "ck_tile/ops/unified_attention/block/variants.hpp" #include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp" #include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp" #include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" -#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_interleaved_pk_type.hpp" diff --git a/include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp b/include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp deleted file mode 100644 index e5be21e0489..00000000000 --- a/include/ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -namespace ck_tile { - -// This class is used for codegen pattern matching -enum class BlockAttentionBiasEnum -{ - NO_BIAS = 0, - ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale) - ALIBI = 2, // bias computed with position encoding, applied after scale -}; - -template -struct BlockAttentionBiasEnumToStr; - -template <> -struct BlockAttentionBiasEnumToStr -{ - static constexpr const char* name = ""; -}; -template <> -struct BlockAttentionBiasEnumToStr -{ - static constexpr const char* name = "bias"; -}; -template <> -struct BlockAttentionBiasEnumToStr -{ - static constexpr const char* name = "alibi"; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_dropout.hpp b/include/ck_tile/ops/unified_attention/block/block_dropout.hpp deleted file mode 100644 index 8abdd54cd97..00000000000 --- a/include/ck_tile/ops/unified_attention/block/block_dropout.hpp +++ /dev/null @@ -1,654 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" - -namespace ck_tile { - -// BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and -// 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random -// numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host -// (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp). -// -// The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of -// random numbers (ph_subsequence). -// The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and -// ph_offset). -// This means that subsequences are non-overlapping, reproducible and independent of mask or window. -// -// There are 3 modes (all produce the same results): -// * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates -// the entire 32x32 tile (64 * 16 = 32 * 32). -// * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4 -// warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock > -// MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions -// are needed for generating a 32x32 tile. -// * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2 -// warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp * -// WG::kM one warp can generate two 16x16 tiles. - -namespace detail { -// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values -constexpr index_t philox_per_tile = 64; -} // namespace detail - -struct NullBlockDropout -{ - template - CK_TILE_HOST_DEVICE static constexpr auto - MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - index_t seqlen_qk_start) - { - (void)randval_dram_block_window_tmp; - (void)seqlen_qk_start; - - return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); - } -}; - -struct BlockDropout -{ - CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, - index_t i_head, - index_t nheads, - unsigned long long seed, - unsigned long long offset, - float rp_undrop_, - uint8_t p_undrop_in_uint8_t_, - bool is_store_randval_) - : ph_seed(amd_wave_read_first_lane(seed)), - ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * - detail::philox_per_tile)), - rp_undrop(rp_undrop_), - p_undrop_in_uint8_t(p_undrop_in_uint8_t_), - is_store_randval(is_store_randval_) - { - } - - template - CK_TILE_HOST_DEVICE static constexpr auto - MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - index_t seqlen_qk_start) - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr bool IsWG32 = WG::kM == 32; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; - constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; - - const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); - auto randval_dram_window = [&]() { - if constexpr(IsFwd) - { - return make_tile_window( - randval_dram_block_window_tmp.get_bottom_tensor_view(), - ck_tile::make_tuple(number{}, number{}), - {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N - } - else - { - return make_tile_window( - randval_dram_block_window_tmp.get_bottom_tensor_view(), - ck_tile::make_tuple(number{}, number{}), - {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N - } - }(); - - return randval_dram_window; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr bool IsWG32 = WG::kM == 32; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; - constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; - constexpr index_t kN1 = 8; - constexpr index_t kN0 = kNPerStep / kN1; - - constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( - ck_tile::make_tuple(number{}, number{}, number{}), - ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto randval_lds_block_desc = transform_tensor_descriptor( - randval_lds_block_desc_0, - ck_tile::make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(ck_tile::make_tuple(number{}, number{}))), - ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), - ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); - - return randval_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr bool IsWG32 = WG::kM == 32; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; - constexpr index_t NIterPerWarp = 1; - - // The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution, - // because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once - constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 0>>{}; - - // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. - constexpr auto randval_block_inner_part_dstr_encoding = - typename WarpGemmDispatcher::CWarpDstrEncoding{}; - - constexpr auto randval_block_part_dstr_encode = - detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, - randval_block_inner_part_dstr_encoding); - - return make_static_tile_distribution(randval_block_part_dstr_encode); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr bool IsWG32 = WG::kM == 32; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; - constexpr index_t NIterPerWarp = 1; - - constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto randval_block_part_dstr_encode = - detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, - typename WG::CWarpDstrEncoding{}); - - return make_static_tile_distribution(randval_block_part_dstr_encode); - } - - template - CK_TILE_HOST_DEVICE void Run(void* randval_ptr, - const index_t start_n0_idx, - PComputeWindow& p_compute, - RandValDramWindow& randval_dram_window) const - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr bool IsWG32 = WG::kM == 32; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t kNPerBlock = BlockGemmShape::kN; - constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; - constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; - - // randval tile in LDS - auto randval_lds = make_tensor_view( - reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); - - auto randval_lds_window = make_tile_window( - randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); - - // register distribute - auto randval_dist_generated = - make_static_distributed_tensor(MakeRandValTileDistribution()); - - const auto randval_lds_read_window = - make_tile_window(randval_lds_window.get_bottom_tensor_view(), - randval_lds_window.get_window_lengths(), - randval_lds_window.get_window_origin(), - MakeRandValLdsShuffleTileDistribution()); - - const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() % NWarp; - - auto generate_randval = [&](auto i_m0, auto i_n0) { - // Generate random numbers - uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; - const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; - const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); - if constexpr(IsWG32) - { - // Generate the whole 32x32 tile at once (each tile consists of random numbers taken - // from a separate subsequence of Philox) - const unsigned long long ph_subsequence = - bit_cast(make_uint2(wg_m0, wg_n0)); - const index_t ph_offset = get_lane_id(); - const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); - static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); - ph.get_random_16x8(random_uint8_t, ph_subsequence); - } - else - { - // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether - // MIterPerWarp is equal to 1 or 2) - const unsigned long long ph_subsequence = - bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); - const index_t subtile_m0 = wg_m0 % 2; - if constexpr(get_warp_size() == 32) - { - const index_t ph_offset = (get_lane_id() & 15) + - (((get_lane_id() >> 4) & 1) << 5) + - ((wg_n0 % 2) << 4); - const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); - if constexpr(MIterPerWarp == 1) - { - static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); - ph.get_random_8x8( - random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); - } - else - { - static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); - ph.get_random_16x8(random_uint8_t, ph_subsequence); - } - } - else - { - const index_t subtile_n0 = (get_lane_id() >> 4) & 1; - const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); - const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); - if constexpr(MIterPerWarp == 1) - { - static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); - ph.get_random_4x8( - random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); - } - else - { - static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); - ph.get_random_8x8( - random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); - } - } - } - - constexpr auto randval_dist_generated_spans = - decltype(randval_dist_generated)::get_distributed_spans(); - int i_random_idx = 0; - sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); - randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; - }); - }); - // Transpose randval using LDS - store_tile(randval_lds_window, randval_dist_generated); - block_sync_lds(); - const auto randval = load_tile(randval_lds_read_window); - block_sync_lds(); - return randval; - }; - - if(is_store_randval) - { - static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { - static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - const auto randval = generate_randval(i_m0, i_n0); - // save to Global - const auto randval_store = cast_tile(randval); - store_tile(randval_dram_window, randval_store); - move_tile_window(randval_dram_window, {0, kNPerStep}); - }); - move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); - }); - move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); - } - static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { - static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - const auto randval = generate_randval(i_m0, i_n0); - // Drop values of P based on the generated probabilities - constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); - sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto p_idx0 = - tile_distributed_index()>{}; - constexpr auto p_idx1 = - tile_distributed_index(), - idx1.impl_.template at<2>()>{}; - constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); - constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); - p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t - ? p_compute[p_idx] * rp_undrop - : PComputeDataType(0); - }); - }); - }); - }); - } - - const unsigned long long ph_seed; - const unsigned long long ph_head_offset; - const float rp_undrop; - const uint8_t p_undrop_in_uint8_t; - const bool is_store_randval; -}; - -// TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be -// replaced with NullBlockDropout. This requires changes in xformers and other libs. -template -struct BlockDropoutBwd; - -template -struct BlockDropoutBwd -{ - static constexpr bool IsDropout = false; - static constexpr bool IsStoreRandval = IsStoreRandval_; - - template - CK_TILE_HOST_DEVICE static constexpr auto - MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - index_t seqlen_qk_start) - { - (void)randval_dram_block_window_tmp; - (void)seqlen_qk_start; - - return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); - } -}; - -template -struct BlockDropoutBwd -{ - static constexpr bool IsDropout = true; - static constexpr bool IsStoreRandval = IsStoreRandval_; - - CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, - index_t i_head, - index_t nheads, - unsigned long long seed, - unsigned long long offset, - float rp_undrop_, - uint8_t p_undrop_in_uint8_t_) - : ph_seed(amd_wave_read_first_lane(seed)), - ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) * - detail::philox_per_tile)), - rp_undrop(rp_undrop_), - p_undrop_in_uint8_t(p_undrop_in_uint8_t_) - { - } - - template - CK_TILE_HOST_DEVICE static constexpr auto - MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - index_t seqlen_qk_start) - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr bool IsWG32 = WG::kM == 32; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; - constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; - - const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); - auto randval_dram_window = [&]() { - if constexpr(IsFwd) - { - return make_tile_window( - randval_dram_block_window_tmp.get_bottom_tensor_view(), - ck_tile::make_tuple(number{}, number{}), - {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N - } - else - { - return make_tile_window( - randval_dram_block_window_tmp.get_bottom_tensor_view(), - ck_tile::make_tuple(number{}, number{}), - {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N - } - }(); - - return randval_dram_window; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr bool IsWG32 = WG::kM == 32; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; - constexpr index_t NIterPerWarp = 1; - - constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<1, 0>>{}; - - constexpr auto randval_block_inner_part_dstr_encoding = - typename WarpGemmDispatcher::CWarpDstrEncoding{}; - static_assert( - std::is_same_v, - typename WG::CWarpDstrEncoding>); - - constexpr auto randval_block_part_dstr_encode = - detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, - randval_block_inner_part_dstr_encoding); - - return make_static_tile_distribution(randval_block_part_dstr_encode); - } - - template - CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, - const index_t start_n0_idx, - PComputeWindow& p_compute, - RandValDramWindow& randval_dram_window) const - { - constexpr auto config = - BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr bool IsWG32 = WG::kM == 32; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - using BlockGemmShape = remove_cvref_t; - constexpr index_t kMPerBlock = BlockGemmShape::kM; - constexpr index_t kNPerBlock = BlockGemmShape::kN; - constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1; - constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM; - constexpr index_t kNPerStep = NWarp * WG::kN; - - // register distribute - auto randval_dist_generated = - make_static_distributed_tensor(MakeRandValTileDistribution()); - - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() % NWarp; - - auto generate_randval = [&](auto i_m0, auto i_n0) { - // Generate random numbers - uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize]; - const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp; - const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp); - if constexpr(IsWG32) - { - // Generate the whole 32x32 tile at once (each tile consists of random numbers - // taken from a separate subsequence of Philox) - const unsigned long long ph_subsequence = - bit_cast(make_uint2(wg_m0, wg_n0)); - const index_t ph_offset = get_lane_id(); - const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); - static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); - ph.get_random_16x8(random_uint8_t, ph_subsequence); - } - else - { - // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether - // MIterPerWarp is equal to 1 or 2) - const unsigned long long ph_subsequence = - bit_cast(make_uint2(wg_m0 / 2, wg_n0 / 2)); - const index_t subtile_m0 = wg_m0 % 2; - if constexpr(get_warp_size() == 32) - { - const index_t ph_offset = (get_lane_id() & 15) + - (((get_lane_id() >> 4) & 1) << 5) + - ((wg_n0 % 2) << 4); - const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); - if constexpr(MIterPerWarp == 1) - { - static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); - ph.get_random_8x8( - random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1); - } - else - { - static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); - ph.get_random_16x8(random_uint8_t, ph_subsequence); - } - } - else - { - const index_t subtile_n0 = (get_lane_id() >> 4) & 1; - const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4); - const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset); - if constexpr(MIterPerWarp == 1) - { - static_assert(randval_dist_generated.kThreadElementSpaceSize == 4); - ph.get_random_4x8( - random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0); - } - else - { - static_assert(randval_dist_generated.kThreadElementSpaceSize == 8); - ph.get_random_8x8( - random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0); - } - } - } - - constexpr auto randval_dist_generated_spans = - decltype(randval_dist_generated)::get_distributed_spans(); - int i_random_idx = 0; - sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); - randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; - }); - }); - return randval_dist_generated; - }; - - static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { - static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { - const auto randval = generate_randval(i_m0, i_n0); - // Drop values of P based on the generated probabilities, negative sign is used to - // distinguish such values ​​later in bwd pipeline. - constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); - sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { - constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); - constexpr auto p_idx0 = - tile_distributed_index(), - idx0.impl_.template at<1>(), - idx0.impl_.template at<2>()>{}; - constexpr auto p_idx1 = tile_distributed_index{}; - constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); - p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t - ? p_compute[p_idx] - : -p_compute[p_idx]; - }); - }); - // save to Global - if constexpr(IsStoreRandval) - { - const auto randval_store = cast_tile(randval); - store_tile(randval_dram_window, randval_store); - move_tile_window(randval_dram_window, {kMPerStep, 0}); - } - }); - if constexpr(IsStoreRandval) - { - move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); - } - }); - if constexpr(IsStoreRandval) - { - move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); - } - } - - const unsigned long long ph_seed; - const unsigned long long ph_head_offset; - const float rp_undrop; - const uint8_t p_undrop_in_uint8_t; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp b/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp deleted file mode 100644 index 3dd36a712da..00000000000 --- a/include/ck_tile/ops/unified_attention/block/block_position_encoding.hpp +++ /dev/null @@ -1,205 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/unified_attention/block/block_masking.hpp" -#include -#include - -namespace ck_tile { - -enum struct PositionEncodingEnum -{ - NO = 0, - ALIBI = 1, -}; - -/* -VERTICAL: - [0] 1 2 3 4 5 - [0] 1 2 3 4 5 - [0] 1 2 3 4 5 - [0] 1 2 3 4 5 - -TOP_LEFT(but negative): - [0] 1 2 3 4 5 - 1 [0] 1 2 3 4 - 2 1 [0] 1 2 3 - 3 2 1 [0] 1 2 - -FROM_BOTTOM_RIGHT(but negative): - 2 1 [0] 1 2 3 - 3 2 1 [0] 1 2 - 4 3 2 1 [0] 1 - 5 4 3 2 1 [0] -*/ - -enum struct AlibiMode -{ - VERTICAL = 0, - FROM_TOP_LEFT = 1, // keep sync with mask enum - FROM_BOTTOM_RIGHT = 2, -}; - -template -struct Alibi -{ - static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32, - "for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t"); - - // RowMajor here means if pixel within the same thread are along the row, or col - // this may impact the performance of update(), while the result are the same. - // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false - CK_TILE_HOST_DEVICE Alibi(DataType slope_, - index_t y_total_, - index_t x_total_, - AlibiMode mode_ = AlibiMode::VERTICAL) - { - slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_; - - shift_left_up = [&]() { - if(RowMajor) - { - return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; - } - else - { - return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; - } - }(); - shift_right_down = [&]() { - if(RowMajor) - { - return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; - } - else - { - return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; - } - }(); - mode = mode_; - } - - CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); } - - CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) - { - if constexpr(LogMaxSadOprndSize <= 16) - { - return sad_u16( - static_cast(x), static_cast(y), static_cast(acc)); - } - - return sad_u32(x, y, acc); - } - - CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) - { - if constexpr(RowMajor) - { - // at least 3 instructions per row - index_t current_zero_point = - mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down; - - // for every threads, most of the pixels are along the row, below operation should be - // the main hot spot. - auto position = type_convert(sad(bit_cast(current_zero_point), - bit_cast(col_idx + shift_left_up), - 0)); - pixel += slope * position; - } - else - { - // at least 3 instructions per col; - index_t current_zero_point = mode == AlibiMode::VERTICAL - ? row_idx + col_idx + shift_right_down - : col_idx + shift_right_down; - - // for every threads, most of the pixels are along the col, below operation should be - // the main hot spot. - auto position = type_convert(sad(bit_cast(current_zero_point), - bit_cast(row_idx + shift_left_up), - 0)); - pixel += slope * position; - } - } - - DataType slope; // float? - index_t shift_left_up; // always possitive - index_t shift_right_down; // always possitive - AlibiMode mode; -}; - -template -struct EmptyPositionEncoding -{ - CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/) - { - } -}; - -// -// can convert from the FA style left/right to our generic coordinate -// if left_size < 0 && right_size = 0, it is normal causal mask -// local is left_size >=0 or right_size >=0 -template -CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, - index_t window_left_size, - index_t window_right_size, - index_t y_total, - index_t x_total, - GenericAttentionMaskEnum mask_enum) -{ - // assume mask_enum will never be NO_MASK, since if we do not have mask, it's - // totally OK to use constexpr - bool is_causal = window_left_size < 0 && window_right_size == 0; - AlibiMode alibi_mode = - is_causal ? AlibiMode::VERTICAL - : static_cast(mask_enum) /*either top-left or bottom-right*/; - return Alibi{slope, y_total, x_total, alibi_mode}; -} - -// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 -// Do we need a device version? -template -CK_TILE_HOST std::vector get_alibi_slopes(ck_tile::index_t nheads) -{ - auto get_slopes_power_of_2 = [](ck_tile::index_t n) { - float start = std::powf( - static_cast(2), - -std::powf(static_cast(2), -static_cast((integer_log2_floor(n) - 3)))); - - std::vector rtn; - for(auto i = 0; i < n; i++) - { - rtn.push_back(static_cast(start * std::powf(start, i))); - } - return rtn; - }; - if(is_power_of_two_integer(nheads)) - { - // power of 2 calculation - return get_slopes_power_of_2(nheads); - } - else - { - ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads); - auto v0 = get_slopes_power_of_2(closest_power_of_2); - auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2); - auto v1_sliced = [&](auto vec, ck_tile::index_t rem) { - std::vector sliced; - for(ck_tile::index_t i = 0; i < static_cast(vec.size()); i++) - { - if(i % 2 == 0) - sliced.push_back(vec[i]); - } - std::vector sliced_2(sliced.begin(), sliced.begin() + rem); - return sliced_2; - }(v1, nheads - closest_power_of_2); - v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end()); - return v0; - } -} -} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp b/include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp deleted file mode 100644 index 51732792990..00000000000 --- a/include/ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp +++ /dev/null @@ -1,108 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -namespace ck_tile { - -// This class is used for codegen pattern matching -enum class RotaryEmbeddingEnum -{ - NONE = 0, - INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc - HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc -}; - -template -struct RotaryEmbeddingEnumToStr; - -template <> -struct RotaryEmbeddingEnumToStr -{ - static constexpr const char* name = ""; -}; -template <> -struct RotaryEmbeddingEnumToStr -{ - static constexpr const char* name = "inter"; -}; -template <> -struct RotaryEmbeddingEnumToStr -{ - static constexpr const char* name = "half"; -}; - -template -struct BlockRotaryEmbedding -{ - template - CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile, - OtherDramBlockWindow other_window, - RotaryCosDramBlockWindow rotary_cos_window, - RotarySinDramBlockWindow rotary_sin_window, - index_t rotary_dim, - index_t thread_end) - { - using DataType = typename remove_cvref_t::DataType; - - if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) - { - auto rotary_cos_tile = load_tile(rotary_cos_window); - auto rotary_sin_tile = load_tile(rotary_sin_window); - - if(thread_end <= rotary_dim) - { - constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); - static_for<0, thread_buffer_size, 2>{}([&](auto idx) { - const auto left = type_convert(tile.thread_buf_[idx]); - const auto right = type_convert(tile.thread_buf_[idx + 1]); - - const auto cos = - type_convert(rotary_cos_tile.thread_buf_[idx / 2]); - const auto sin = - type_convert(rotary_sin_tile.thread_buf_[idx / 2]); - - tile.thread_buf_[idx] = type_convert(left * cos - right * sin); - tile.thread_buf_[idx + 1] = type_convert(right * cos + left * sin); - }); - } - } - else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) - { - if(thread_end <= rotary_dim) - { - const bool is_left = (thread_end <= (rotary_dim / 2)); - - move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); - auto other_tile = load_tile(other_window); - - move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); - auto rotary_cos_tile = load_tile(rotary_cos_window); - - move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)}); - auto rotary_sin_tile = load_tile(rotary_sin_window); - - constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size(); - static_for<0, thread_buffer_size, 1>{}([&](auto idx) { - const auto curr = type_convert(tile.thread_buf_[idx]); - const auto other = type_convert(other_tile.thread_buf_[idx]); - - const auto cos = - type_convert(rotary_cos_tile.thread_buf_[idx]); - const auto sin = - type_convert(rotary_sin_tile.thread_buf_[idx]); - - tile.thread_buf_[idx] = - type_convert(curr * cos + other * (is_left ? -sin : sin)); - }); - } - } - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp b/include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp deleted file mode 100644 index f1e6101d1d4..00000000000 --- a/include/ck_tile/ops/unified_attention/block/page_block_navigator.hpp +++ /dev/null @@ -1,358 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/core/tensor/tile_window.hpp" - -namespace ck_tile { - -// assume that we have only 1 page-block/tensor view -template -struct TrivialPageBlockNavigator -{ - using DataType = typename TensorView::DataType; - using WindowOrigin = multi_index<2>; - - CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_) - : tensor_view(tensor_view_) - { - } - - template - CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths, - const WindowOrigin& window_origin) const - { - return make_tuple(/*block_index=*/0, - ck_tile::make_tile_window(tensor_view, window_lengths, window_origin)); - } - - template - CK_TILE_HOST_DEVICE constexpr auto - make_tile_window(const WindowLengths& window_lengths, - const WindowOrigin& window_origin, - const TileDistribution& tile_distribution) const - { - return make_tuple( - /*block_index=*/0, - ck_tile::make_tile_window( - tensor_view, window_lengths, window_origin, tile_distribution)); - } - - template - CK_TILE_HOST_DEVICE static index_t - move_tile_window(index_t /*block_index*/, - TileWindow& tile_window, - const typename remove_cvref_t::BottomTensorIndex& step) - { - ck_tile::move_tile_window(tile_window, step); - - return /*block_index=*/0; - } - - template - CK_TILE_HOST_DEVICE index_t - move_tile_window(index_t /*block_index*/, - TileWindow& tile_window, - const typename remove_cvref_t::BottomTensorIndex& step, - index_t /*id*/) const - { - - ck_tile::move_tile_window(tile_window, step); - return 0; - } - - template - CK_TILE_HOST_DEVICE index_t - prefetch_table_id(index_t /*block_index*/, - TileWindow /*tile_window*/, - const typename remove_cvref_t::BottomTensorIndex& /*step*/) const - { - return -1; - } - - CK_TILE_HOST_DEVICE static constexpr WindowOrigin - to_local_window_origin(const WindowOrigin& global_window_origin) - { - return global_window_origin; - } - - CK_TILE_HOST_DEVICE static constexpr WindowOrigin - to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin) - { - return local_window_origin; - } - - private: - TensorView tensor_view; -}; - -// default page-block navigator, assume that tensor view size is same as page-block size or smaller -// if tile window on last page-block -template -struct PageBlockNavigator -{ - using DataType = DataType_; - static_assert(std::is_same_v); - static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window"); - using WindowOrigin = multi_index<2>; - - CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t* physical_blocks_, - long_index_t block_stride_, - long_index_t fixed_offset_, - const int32_t* physical_block_indices_, - index_t num_blocks_, - index_t page_block_size_, - const TensorView& complete_view_, - const TensorView& last_view_) - : physical_blocks(reinterpret_cast(physical_blocks_)), - block_stride(block_stride_), - fixed_offset(fixed_offset_), - physical_block_indices(physical_block_indices_), - num_blocks(num_blocks_), - page_block_size(page_block_size_), - complete_view(complete_view_), - last_view(last_view_) - { - } - - template - CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths, - const WindowOrigin& window_origin) const - { - const index_t block_index = get_block_index(window_origin); - const WindowOrigin local_window_origin = to_local_window_origin(window_origin); - - auto new_tile_window = - ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view, - window_lengths, - local_window_origin); - new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index)); - - return make_tuple(block_index, new_tile_window); - } - - template - CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths, - const WindowOrigin& window_origin, - const TileDistribution& tile_distribution) const - { - const index_t block_index = get_block_index(window_origin); - const WindowOrigin local_window_origin = to_local_window_origin(window_origin); - - auto new_tile_window = - ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view, - window_lengths, - local_window_origin, - tile_distribution); - new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index)); - - return make_tuple(block_index, new_tile_window); - } - - template - CK_TILE_HOST_DEVICE index_t - move_tile_window(index_t block_index, - TileWindow& tile_window, - const typename remove_cvref_t::BottomTensorIndex& step) const - { - - ck_tile::move_tile_window(tile_window, step); - - const WindowOrigin global_window_origin = - to_global_window_origin(block_index, tile_window.get_window_origin()); - const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin); - - const index_t new_block_index = get_block_index(global_window_origin); - /// TODO: only update necessary attributes - tile_window.bottom_tensor_view_.desc_ = - (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); - tile_window.set_window_origin(local_window_origin); - tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); - - return new_block_index; - } - - template - CK_TILE_HOST_DEVICE index_t - move_tile_window(index_t block_index, - TileWindow& tile_window, - const typename remove_cvref_t::BottomTensorIndex& step, - index_t id) const - { - ck_tile::move_tile_window(tile_window, step); - - const WindowOrigin global_window_origin = - to_global_window_origin(block_index, tile_window.get_window_origin()); - const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin); - - const index_t new_block_index = get_block_index(global_window_origin); - /// TODO: only update necessary attributes - tile_window.bottom_tensor_view_.desc_ = - (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); - tile_window.set_window_origin(local_window_origin); - if(id >= 0) - tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride + - fixed_offset); - else - tile_window.set_bottom_tensor_view_data_ptr(nullptr); - - return new_block_index; - } - - template - CK_TILE_HOST_DEVICE index_t - prefetch_table_id(index_t block_index, - TileWindow& tile_window, - const typename remove_cvref_t::BottomTensorIndex& step) const - { - auto local_tile_window = tile_window; // not affect origin window - ck_tile::move_tile_window(local_tile_window, step); - - const WindowOrigin global_window_origin = - to_global_window_origin(block_index, local_tile_window.get_window_origin()); - const index_t new_block_index = get_block_index(global_window_origin); - - if(new_block_index < num_blocks) - { - return physical_block_indices[new_block_index]; - } - else - { - return -1; - } - } - - CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const - { - return block_index == num_blocks - 1; - } - - template - CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index, - const TileWindow& tile_window) const - { - const index_t origin = tile_window.get_window_origin().at(number{}); - const index_t length = tile_window.get_window_lengths().at(number{}); - return (block_index < num_blocks - 1) && (page_block_size < origin + length); - } - - template - CK_TILE_HOST_DEVICE void - move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const - { - const multi_index<2> step = [&]() { - const index_t origin_diff = (block_index - new_block_index) * page_block_size; - if constexpr(VirtualDim == 0) - { - return make_multi_index(origin_diff, 0); - } - else - { - return make_multi_index(0, origin_diff); - } - }(); - - /// TODO: only update necessary attributes - tile_window.bottom_tensor_view_.desc_ = - (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); - tile_window.set_window_origin(tile_window.get_window_origin() + step); - tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); - } - - CK_TILE_HOST_DEVICE WindowOrigin - to_local_window_origin(const WindowOrigin& global_window_origin) const - { - if constexpr(VirtualDim == 0) - { - const index_t length = global_window_origin.at(number<0>{}); - const index_t num_complete_blocks = integer_divide_floor(length, page_block_size); - return make_multi_index(length - page_block_size * num_complete_blocks, - global_window_origin.at(number<1>{})); - } - else - { - const index_t length = global_window_origin.at(number<1>{}); - const index_t num_complete_blocks = integer_divide_floor(length, page_block_size); - return make_multi_index(global_window_origin.at(number<0>{}), - length - page_block_size * num_complete_blocks); - } - } - - CK_TILE_HOST_DEVICE WindowOrigin - to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const - { - if constexpr(VirtualDim == 0) - { - return make_multi_index(block_index * page_block_size + - local_window_origin.at(number<0>{}), - local_window_origin.at(number<1>{})); - } - else - { - return make_multi_index(local_window_origin.at(number<0>{}), - block_index * page_block_size + - local_window_origin.at(number<1>{})); - } - } - - private: - CK_TILE_HOST_DEVICE - DataType* get_block_ptr(index_t block_index) const - { - if(block_index < num_blocks) - { - return physical_blocks + physical_block_indices[block_index] * block_stride + - fixed_offset; - } - else - { - return nullptr; - } - } - - CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const - { - return integer_divide_floor(global_window_origin.at(number{}), page_block_size); - } - - DataType* physical_blocks; - long_index_t block_stride; - long_index_t fixed_offset; - - const int32_t* physical_block_indices; - index_t num_blocks; - index_t page_block_size; - - TensorView complete_view; - TensorView last_view; -}; - -template -CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view) -{ - return TrivialPageBlockNavigator(tensor_view); -} - -template -CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t* physical_blocks, - long_index_t block_stride, - long_index_t fixed_offset, - const int32_t* physical_block_indices, - index_t num_blocks, - index_t page_block_size, - const TensorView& complete_view, - const TensorView& last_view) -{ - return PageBlockNavigator(physical_blocks, - block_stride, - fixed_offset, - physical_block_indices, - num_blocks, - page_block_size, - complete_view, - last_view); -} - -} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/block/variants.hpp b/include/ck_tile/ops/unified_attention/block/variants.hpp deleted file mode 100644 index d8b0cdbb86b..00000000000 --- a/include/ck_tile/ops/unified_attention/block/variants.hpp +++ /dev/null @@ -1,302 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include -#include - -#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0 -#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1 - -#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT -#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH -#endif - -#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM -#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0 -#endif - -namespace ck_tile { -namespace internal { -__device__ inline float -exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp) -{ -#if(defined(__gfx90a__) || defined(__gfx94__)) && \ - (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ - CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) - /// NOTICE: Make sure softmax_scale is stored in SGPR - float result, numerator, denominator; - asm volatile( - "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n" - "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n" - "v_rcp_f32_e32 %[denominator], %[denominator]\n" - "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n" - "v_mul_f32_e32 %[result], %[numerator], %[denominator]" - : [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result) - : [softmax_scale] "s"(softmax_scale), - [logits] "v"(logits), - [logits_soft_cap_rcp] "v"(logits_soft_cap_rcp)); - return result; -#else - return softmax_scale * logits * rcp(1.f + abs(logits * logits_soft_cap_rcp)); -#endif -} -} // namespace internal - -template -struct StandardAttentionParams -{ - __device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_) - : impl_mask(impl_mask_), sm_scale(sm_scale_) - { - } - - const ImplMask& impl_mask; - float sm_scale; -}; - -template -struct LogitsSoftCapParams -{ - __device__ - LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) - : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) - { - if(0.f < logits_soft_cap) - { - logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap); - } - else - { - logits_soft_cap_rcp = 0.f; - } - - // move computation here to prevent compiler from generating inefficient instruction - // sequence - if constexpr(UseExp2) - { - logits_soft_cap = log2e_v * logits_soft_cap; - logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; - } - } - - __host__ - LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) - : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) - { - if(0.f < logits_soft_cap) - { - logits_soft_cap_rcp = 1.f / logits_soft_cap; - } - else - { - logits_soft_cap_rcp = 0.f; - } - - // move computation here to prevent compiler from generating inefficient instruction - // sequence - if constexpr(UseExp2) - { - logits_soft_cap = log2e_v * logits_soft_cap; - logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; - } - } - - __device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_, - float sm_scale_, - float logits_soft_cap_, - float logits_soft_cap_rcp_) - : impl_mask(impl_mask_), - sm_scale(sm_scale_), - logits_soft_cap(logits_soft_cap_), - logits_soft_cap_rcp(logits_soft_cap_rcp_) - { - // move computation here to prevent compiler from generating inefficient instruction - // sequence - if constexpr(UseExp2) - { - logits_soft_cap = log2e_v * logits_soft_cap; - logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; - } - } - - const ImplMask& impl_mask; - float sm_scale; - float logits_soft_cap; - float logits_soft_cap_rcp; -}; - -struct StandardAttention -{ - __device__ __host__ StandardAttention() = default; - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) const - { - return type_convert(q) * params.sm_scale; - } - - /// NOTICE: For better performance, we simpliy transform thread buffer without calculating - /// qo_idx/kv_idx. - template - __device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params, - T logits, - [[maybe_unused]] uint32_t batch_idx, - /*uint32_t qo_idx, uint32_t kv_idx,*/ - [[maybe_unused]] uint32_t qo_head_idx, - [[maybe_unused]] uint32_t kv_head_idx) const - { - return logits; - } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, - [[maybe_unused]] uint32_t batch_idx, - uint32_t qo_idx, - uint32_t kv_idx, - [[maybe_unused]] uint32_t qo_head_idx, - [[maybe_unused]] uint32_t kv_head_idx) const - { - return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); - } -}; - -template -struct LogitsSoftCap -{ - __device__ __host__ LogitsSoftCap() = default; - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) const - { - if constexpr(UseExp2) - { - return q; - } - else - { - return type_convert(q) * params.sm_scale; - } - } - - /// NOTICE: For better performance, we simpliy transform thread buffer without calculating - /// qo_idx/kv_idx. - template - __device__ __forceinline__ T LogitsTransform(const Params& params, - T logits, - [[maybe_unused]] uint32_t batch_idx, - /*uint32_t qo_idx, uint32_t kv_idx,*/ - [[maybe_unused]] uint32_t qo_head_idx, - [[maybe_unused]] uint32_t kv_head_idx) const - { - if constexpr(UseExp2) - { -#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH - return params.logits_soft_cap * - tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); -#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return internal::exp2_soft_sign_impl( - params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); -#endif - } - else - { -#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH - return params.logits_soft_cap * - tanhf(type_convert(logits) * params.logits_soft_cap_rcp); -#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return type_convert(logits) * - rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); -#endif - } - } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, - [[maybe_unused]] uint32_t batch_idx, - uint32_t qo_idx, - uint32_t kv_idx, - [[maybe_unused]] uint32_t qo_head_idx, - [[maybe_unused]] uint32_t kv_head_idx) const - { - return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); - } -}; - -constexpr uint32_t CUSTOM_MASK = 1U; -constexpr uint32_t SLIDING_WINDOW = 2U; -constexpr uint32_t LOGITS_SOFT_CAP = 4U; -constexpr uint32_t ALIBI = 8U; - -template -struct ComposedAttention -{ - static constexpr bool use_exp2 = UseExp2; - - static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0; - - __device__ __host__ ComposedAttention() = default; - - template - __device__ __forceinline__ T QueryTransform(const Params& params, T q) const - { - if constexpr(use_logits_soft_cap && UseExp2) - { - return q; - } - return type_convert(q) * params.sm_scale; - } - - /// NOTICE: For better performance, we simpliy transform thread buffer without calculating - /// qo_idx/kv_idx. - template - __device__ __forceinline__ T LogitsTransform(const Params& params, - T logits, - [[maybe_unused]] uint32_t batch_idx, - /*uint32_t qo_idx, uint32_t kv_idx,*/ - [[maybe_unused]] uint32_t qo_head_idx, - [[maybe_unused]] uint32_t kv_head_idx) const - { - if constexpr(use_logits_soft_cap) - { - if constexpr(UseExp2) - { -#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH - return params.logits_soft_cap * - tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); -#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return internal::exp2_soft_sign_impl( - params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); -#endif - } - else - { -#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH - return params.logits_soft_cap * - tanhf(type_convert(logits) * params.logits_soft_cap_rcp); -#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return type_convert(logits) * - rcp(1.f + - abs(type_convert(logits) * params.logits_soft_cap_rcp)); -#endif - } - } - return logits; - } - - template - __device__ __forceinline__ bool LogitsMask(const Params& params, - [[maybe_unused]] uint32_t batch_idx, - uint32_t qo_idx, - uint32_t kv_idx, - [[maybe_unused]] uint32_t qo_head_idx, - [[maybe_unused]] uint32_t kv_head_idx) const - { - return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp index 480d0f3ee92..1a69afad201 100644 --- a/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp +++ b/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp @@ -37,17 +37,16 @@ struct UnifiedAttentionKernel static constexpr bool kPadHeadDimQ = UnifiedAttentionPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = UnifiedAttentionPipeline::kPadHeadDimV; - // TODO add yjese - static constexpr index_t HEAD_SIZE = UnifiedAttentionPipeline::HEAD_SIZE; - static constexpr index_t HEAD_SIZE_PADDED = UnifiedAttentionPipeline::HEAD_SIZE_PADDED; - - // BLOCK_Q = BLOCK_M // num_queries_per_kv - // BLOCK_Q is the block size for q seqlen - /// static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; - static constexpr index_t BLOCK_M = UnifiedAttentionPipeline::BLOCK_M; - static constexpr index_t BLOCK_Q = UnifiedAttentionPipeline::BLOCK_Q; + static constexpr index_t kHeadDim = UnifiedAttentionPipeline::kHeadDim; + static constexpr index_t kHeadDimPadded = UnifiedAttentionPipeline::kHeadDimPadded; + + // kBlockQ = kBlockM // num_queries_per_kv + // kBlockQ is the block size for q seqlen + /// static constexpr index_t kBlockQ = UnifiedAttentionPipeline::kBlockQ; + static constexpr index_t kBlockM = UnifiedAttentionPipeline::kBlockM; + static constexpr index_t kBlockQ = UnifiedAttentionPipeline::kBlockQ; // BLOCK size for K seqlen - static constexpr index_t BLOCK_SIZE = UnifiedAttentionPipeline::BLOCK_SIZE; + static constexpr index_t kPageBlockSize = UnifiedAttentionPipeline::kPageBlockSize; // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size @@ -56,8 +55,8 @@ struct UnifiedAttentionKernel struct UnifiedAttentionCommonKargs { const void* q_ptr; - const void* k_ptr; // [num_blks, page_blk_size, num_kv_heads, head_size] - const void* v_ptr; // [num_blks, page_blk_size, num_kv_heads, head_size] + const void* k_ptr; // [num_blks, page_size, num_kv_heads, head_size] + const void* v_ptr; // [num_blks, page_size, num_kv_heads, head_size] void* o_ptr; ck_tile::index_t num_blks; @@ -72,7 +71,7 @@ struct UnifiedAttentionKernel float scale_v; float scale_out; - ck_tile::index_t page_blk_size; + ck_tile::index_t page_size; ck_tile::index_t total_num_q_blocks; ck_tile::index_t query_stride_0; @@ -113,7 +112,7 @@ struct UnifiedAttentionKernel float scale_k, float scale_v, float scale_out, - ck_tile::index_t page_blk_size, + ck_tile::index_t page_size, ck_tile::index_t total_num_q_blocks, ck_tile::index_t query_stride_0, ck_tile::index_t query_stride_1, @@ -145,7 +144,7 @@ struct UnifiedAttentionKernel scale_k, scale_v, scale_out, - page_blk_size, + page_size, total_num_q_blocks, query_stride_0, query_stride_1, @@ -213,7 +212,6 @@ struct UnifiedAttentionKernel return ck_tile::make_tuple(pid % num_head_kv, pid / num_head_kv); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() @@ -233,30 +231,27 @@ struct UnifiedAttentionKernel const index_t num_queries_per_kv = kargs.num_queries_per_kv; - assert(BLOCK_M / num_queries_per_kv == BLOCK_Q); + assert(kBlockM / num_queries_per_kv == kBlockQ); - // const index_t BLOCK_Q = BLOCK_M / num_queries_per_kv; // for simplicity, batch stride we just modify the pointer // const index_t num_head_q = kargs.num_head_q; - // const index_t num_head_k = num_head_q / num_queries_per_kv; - // divide problem const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs); // grid size is (num_kv_heads, total_num_q_blocks) - // total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + // total_num_q_blocks = q.shape[0] // kBlockQ + num_seqs // q.shape[0] is total number of query tokens across all batches - // one q_block spans BLOCK_Q = BLOCK_M // num_queries_per_kv number of query token groups. + // one q_block spans kBlockQ = kBlockM // num_queries_per_kv number of query token groups. // One query token group shares one kv token const index_t seq_idx = find_seq_idx(kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, - BLOCK_Q, + kBlockQ, true); // which batch - const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / BLOCK_Q + seq_idx; + const index_t q_block_start_idx = kargs.query_start_len_ptr[seq_idx] / kBlockQ + seq_idx; const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx); @@ -268,18 +263,18 @@ struct UnifiedAttentionKernel amd_wave_read_first_lane(cur_batch_in_all_stop_index - cur_batch_in_all_start_index); // TODO check if we get the block size info from pipeline - if(q_block_local_idx * BLOCK_Q >= cur_batch_query_len) + if(q_block_local_idx * kBlockQ >= cur_batch_query_len) { return; } - const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * BLOCK_Q); + const index_t query_pos = amd_wave_read_first_lane(q_block_local_idx * kBlockQ); const index_t seq_len = kargs.seq_lens_ptr[seq_idx]; const index_t context_len = amd_wave_read_first_lane(seq_len - cur_batch_query_len); index_t _max_seq_prefix_len = amd_wave_read_first_lane( - (context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) + 1)); + (context_len + q_block_local_idx * kBlockQ + (kBlockM - 1) + 1)); if(seq_len < _max_seq_prefix_len) { @@ -288,7 +283,7 @@ struct UnifiedAttentionKernel const auto max_seq_prefix_len = _max_seq_prefix_len; const index_t num_blocks = - amd_wave_read_first_lane((max_seq_prefix_len + BLOCK_SIZE - 1) / BLOCK_SIZE); + amd_wave_read_first_lane((max_seq_prefix_len + kPageBlockSize - 1) / kPageBlockSize); // TODO sliding window const index_t num_blocks_start = 0; @@ -316,30 +311,30 @@ struct UnifiedAttentionKernel ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + o_ptr_offset; index_t query_len_padded = - amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q); - // const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0); + amd_wave_read_first_lane(integer_divide_ceil(cur_batch_query_len, kBlockQ) * kBlockQ); + // const bool is_query_len_padded = (cur_batch_query_len % kBlockQ == 0); // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { const auto q_dram_base = make_naive_tensor_view( q_ptr, - make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), + make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim), make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1), number{}, number<1>{}); const auto q_dram_pad = - pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + pad_tensor_view( // aling seqlen with kBlockQ and head dim with kHeadDimPadded q_dram_base, // block sizes - make_tuple(number{}, 1, HEAD_SIZE_PADDED), + make_tuple(number{}, 1, kHeadDimPadded), sequence{}); // pads to (seq_len_padded, num_head_q, - // HEAD_SIZE_PADDED) + // kHeadDimPadded) const auto q_dram_merged = transform_tensor_view( q_dram_pad, make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), - make_pass_through_transform(HEAD_SIZE_PADDED)), + make_pass_through_transform(kHeadDimPadded)), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); // flattens the first two dims, head idx is the fastest @@ -354,46 +349,47 @@ struct UnifiedAttentionKernel // stride for dim 0 (num_queries_per_kv * head_dim, head_dim, 1) auto q_dram_window = make_tile_window(q_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {query_pos * num_queries_per_kv, 0}); const auto k_dram = [&]() { // HEAD dim is skipped as defined in the ptrs const auto k_dram_naive = make_naive_tensor_view( k_ptr, - make_tuple(kargs.num_blks * kargs.page_blk_size, HEAD_SIZE), + make_tuple(kargs.num_blks * kargs.page_size, kHeadDim), make_tuple(kargs.stride_k_cache_1, kargs.stride_k_cache_3), number{}, number<1>{}); - const auto k_dram_pad = pad_tensor_view(k_dram_naive, - // TODO can the BLOCK_SIZE_RAW needs padding? - make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), - sequence{}); + const auto k_dram_pad = + pad_tensor_view(k_dram_naive, + // TODO can the kPageBlockSize_RAW needs padding? + make_tuple(kPageBlockSize, kHeadDimPadded), + sequence{}); return k_dram_pad; }(); auto k_dram_window = make_tile_window( - k_dram, make_tuple(number{}, number{}), {0, 0}); + k_dram, make_tuple(number{}, number{}), {0, 0}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.num_blks * kargs.page_blk_size, HEAD_SIZE), + make_tuple(kargs.num_blks * kargs.page_size, kHeadDim), make_tuple(kargs.stride_v_cache_1, kargs.stride_v_cache_3), number{}, number<1>{}); const auto v_dram_pad = pad_tensor_view(v_dram_naive, - make_tuple(BLOCK_SIZE, HEAD_SIZE_PADDED), + make_tuple(kPageBlockSize, kHeadDimPadded), sequence{}); return v_dram_pad; }(); auto v_dram_window = make_tile_window( - v_dram, make_tuple(number{}, number{}), {0, 0}); + v_dram, make_tuple(number{}, number{}), {0, 0}); FmhaMask mask = [&]() { if constexpr(kHasMask) @@ -409,8 +405,8 @@ struct UnifiedAttentionKernel return FmhaMask{cur_batch_query_len, seq_len}; }(); - const index_t kv_page_size_in_blocks = kargs.page_blk_size / BLOCK_SIZE; - assert(kv_page_size_in_blocks >= 1); // BLOCK_SIZE <= page_blk_size + const index_t kv_page_size_in_blocks = kargs.page_size / kPageBlockSize; + assert(kv_page_size_in_blocks >= 1); // kPageBlockSize <= page_size auto o_acc_tile = [&]() { return UnifiedAttentionPipeline{}(q_dram_window, @@ -430,23 +426,23 @@ struct UnifiedAttentionKernel auto o_dram = [&]() { const auto o_dram_base = make_naive_tensor_view( o_ptr, - make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE), + make_tuple(cur_batch_query_len, num_queries_per_kv, kHeadDim), make_tuple(kargs.output_stride_0, kargs.output_stride_1, 1), number{}, number<1>{}); const auto o_dram_pad = - pad_tensor_view( // aling cu_seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED + pad_tensor_view( // aling cu_seqlen with kBlockQ and head dim with kHeadDimPadded o_dram_base, // block sizes - make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED), + make_tuple(kBlockQ, 1, kHeadDimPadded), sequence{}); // pads to (seq_len_padded, num_head_q, - // HEAD_SIZE_PADDED) + // kHeadDimPadded) const auto o_dram_merged = transform_tensor_view( o_dram_pad, make_tuple(make_merge_transform(make_tuple(query_len_padded, num_queries_per_kv)), - make_pass_through_transform(HEAD_SIZE_PADDED)), + make_pass_through_transform(kHeadDimPadded)), make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -455,7 +451,7 @@ struct UnifiedAttentionKernel auto o_dram_window = make_tile_window(o_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {query_pos * num_queries_per_kv, 0}); EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp index de7762e1219..1cdafd2429f 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -47,21 +47,22 @@ struct TileUnifiedAttentionShape static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); - static constexpr index_t BLOCK_M = BlockTile::at( + static constexpr index_t kBlockM = BlockTile::at( number<0>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) - static constexpr index_t BLOCK_Q = BlockTile::at( + static constexpr index_t kBlockQ = BlockTile::at( number<1>{}); // tile size along the flattened batch dimension (: num_queries_per_kv * BS) - // static constexpr index_t BLOCK_M = BlockTile::at(number<1>{}); // tile size along q seqlen * + // static constexpr index_t kBlockM = BlockTile::at(number<1>{}); // tile size along q seqlen * // num_queries_per_kv (q_head//kv_head) - static constexpr index_t BLOCK_SIZE = BlockTile::at(number<2>{}); // BLOCK size for K seqlen - static constexpr index_t HEAD_SIZE = BlockTile::at(number<3>{}); // BLOCK size for K seqlen + static constexpr index_t kPageBlockSize = + BlockTile::at(number<2>{}); // BLOCK size for K seqlen + static constexpr index_t kHeadDim = BlockTile::at(number<3>{}); // BLOCK size for K seqlen // static constexpr index_t kQKHeaddim = // BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at // // once (or repeately load Q as a whole tile) // static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); - static constexpr index_t HEAD_SIZE_PADDED = ceil_to_qualified_tile_length(); + static constexpr index_t kHeadDimPadded = ceil_to_qualified_tile_length(); // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp index 40ec0fd0aa7..8b01a5722da 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp @@ -4,8 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/unified_attention/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 6cc8ee954af..f70819d9282 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -267,15 +267,14 @@ struct UnifiedAttentionPipeline static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; - static constexpr ck_tile::index_t BLOCK_M = UnifiedAttentionShape::BLOCK_M; - static constexpr ck_tile::index_t BLOCK_Q = UnifiedAttentionShape::BLOCK_Q; + static constexpr ck_tile::index_t kBlockM = UnifiedAttentionShape::kBlockM; + static constexpr ck_tile::index_t kBlockQ = UnifiedAttentionShape::kBlockQ; - static constexpr ck_tile::index_t BLOCK_SIZE = UnifiedAttentionShape::BLOCK_SIZE; - static constexpr ck_tile::index_t HEAD_SIZE = UnifiedAttentionShape::HEAD_SIZE; - static constexpr ck_tile::index_t HEAD_SIZE_PADDED = UnifiedAttentionShape::HEAD_SIZE_PADDED; + static constexpr ck_tile::index_t kPageBlockSize = UnifiedAttentionShape::kPageBlockSize; + static constexpr ck_tile::index_t kHeadDim = UnifiedAttentionShape::kHeadDim; + static constexpr ck_tile::index_t kHeadDimPadded = UnifiedAttentionShape::kHeadDimPadded; - static_assert(HEAD_SIZE_PADDED <= 256, - "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kHeadDimPadded <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); // static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDim; @@ -306,9 +305,9 @@ struct UnifiedAttentionPipeline CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { // create another LDS buffer for p - return ck_tile::max(BLOCK_M * HEAD_SIZE_PADDED * sizeof(PDataType), + return ck_tile::max(kBlockM * kHeadDimPadded * sizeof(PDataType), Policy::template GetSmemSize() + - BLOCK_M * BLOCK_SIZE * sizeof(PDataType)); + kBlockM * kPageBlockSize * sizeof(PDataType)); } // for debug only @@ -404,39 +403,39 @@ struct UnifiedAttentionPipeline "wrong!"); static_assert( - BLOCK_M == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - BLOCK_SIZE == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - HEAD_SIZE_PADDED == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - BLOCK_SIZE == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - HEAD_SIZE_PADDED == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kBlockM == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kPageBlockSize == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kHeadDimPadded == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kPageBlockSize == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kHeadDimPadded == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - static_assert(sizeof(SaccDataType) * BLOCK_SIZE * BLOCK_M <= GetSmemSize()); + static_assert(sizeof(SaccDataType) * kPageBlockSize * kBlockM <= GetSmemSize()); auto s_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); - [[maybe_unused]] auto s_lds_window = - make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); + MakeSimpleLdsDesc()); + [[maybe_unused]] auto s_lds_window = make_tile_window( + s_lds, make_tuple(number{}, number{}), {0, 0}); auto p_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + Policy::template GetSmemSize()), - MakeSimpleLdsDesc()); - [[maybe_unused]] auto p_lds_window = - make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); + MakeSimpleLdsDesc()); + [[maybe_unused]] auto p_lds_window = make_tile_window( + p_lds, make_tuple(number{}, number{}), {0, 0}); auto o_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr)), - MakeSimpleLdsDesc()); + MakeSimpleLdsDesc()); [[maybe_unused]] auto o_lds_window = make_tile_window( - o_lds, make_tuple(number{}, number{}), {0, 0}); + o_lds, make_tuple(number{}, number{}), {0, 0}); auto m_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + Policy::template GetSmemSize()), - MakeSimpleLdsDesc1D()); + MakeSimpleLdsDesc1D()); [[maybe_unused]] auto m_lds_window = - make_tile_window(m_lds, make_tuple(number{}), {0}); + make_tile_window(m_lds, make_tuple(number{}), {0}); const index_t warp_group_id = get_warp_id() / 4; @@ -561,8 +560,8 @@ struct UnifiedAttentionPipeline } // TODO check correctness of this - index_t i_total_loops = num_blocks_start; - const index_t PAGE_BLOCK_SIZE = kv_page_size_in_blocks * BLOCK_SIZE; + index_t i_total_loops = num_blocks_start; + const index_t PageSize = kv_page_size_in_blocks * kPageBlockSize; const ck_tile::index_t* block_tables_ptr_ = reinterpret_cast(block_tables_ptr); assert(k_block_idx == v_block_idx); // because of the following line @@ -572,14 +571,14 @@ struct UnifiedAttentionPipeline auto k_dram_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {kv_blk_idx_initial * PAGE_BLOCK_SIZE, 0}, + {kv_blk_idx_initial * PageSize, 0}, Policy::template MakeKDramTileDistribution()); k_dram_window.init_raw(); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {kv_blk_idx_initial * PAGE_BLOCK_SIZE, 0}, + {kv_blk_idx_initial * PageSize, 0}, Policy::template MakeVDramTileDistribution()); v_dram_window.init_raw(); @@ -588,7 +587,7 @@ struct UnifiedAttentionPipeline constexpr index_t k1_loops = 1; static_assert(1 == k0_loops); static_assert(1 == k1_loops); - // static_assert(BLOCK_SIZE == HEAD_SIZE_PADDED); + // static_assert(kPageBlockSize == kHeadDimPadded); constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup; static_assert(NumWarpGroups == 2); @@ -673,7 +672,7 @@ struct UnifiedAttentionPipeline // Page block index tracking // const index_t kv_page_size_in_blocks = - // PAGE_BLOCK_SIZE / BLOCK_SIZE; + // PageSize / kPageBlockSize; // index_t kv_block_idx = 0; // only for block 0 and thread if(blockIdx.x == 0 && threadIdx.x == 0) {} @@ -684,8 +683,8 @@ struct UnifiedAttentionPipeline index_t k_page_blk_idx = block_tables_ptr_[block_table_offset + (k_block_idx / kv_page_size_in_blocks)]; k_dram_window.set_window_origin( - {k_page_blk_idx * PAGE_BLOCK_SIZE + - (k_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE, + {k_page_blk_idx * PageSize + + (k_block_idx % kv_page_size_in_blocks) * kPageBlockSize, 0}); }; @@ -697,8 +696,8 @@ struct UnifiedAttentionPipeline index_t v_page_blk_idx = block_tables_ptr_[block_table_offset + (v_block_idx / kv_page_size_in_blocks)]; v_dram_window.set_window_origin( - {v_page_blk_idx * PAGE_BLOCK_SIZE + - (v_block_idx % kv_page_size_in_blocks) * BLOCK_SIZE, + {v_page_blk_idx * PageSize + + (v_block_idx % kv_page_size_in_blocks) * kPageBlockSize, 0}); // we assume that v load is always after k }; @@ -831,21 +830,21 @@ struct UnifiedAttentionPipeline clear_tile(sp(sp_reg_idx).sp_compute); // initialize C gemm_0(sp(sp_reg_idx).sp_compute, get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence<0, (k0_loops - 1) * kHeadDimPadded>{}, + sequence{}), get_slice_tile(kv_tile.k_tile, - sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{})); + sequence<0, (k0_loops - 1) * kHeadDimPadded>{}, + sequence{})); } else { gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, - sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, - sequence{}), + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{}), get_slice_tile(kv_tile.v_tile, - sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, - sequence{})); + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{})); } }; @@ -855,21 +854,21 @@ struct UnifiedAttentionPipeline clear_tile(sp(sp_reg_idx).sp_compute); // initialize C gemm_0(sp(sp_reg_idx).sp_compute, get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{}), + sequence<0, (k0_loops - 1) * kHeadDimPadded>{}, + sequence{}), get_slice_tile(kv_tile.k_tile, - sequence<0, (k0_loops - 1) * HEAD_SIZE_PADDED>{}, - sequence{})); + sequence<0, (k0_loops - 1) * kHeadDimPadded>{}, + sequence{})); } else { gemm_1(o_acc, get_slice_tile(sp(sp_reg_idx).p, - sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, - sequence{}), + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{}), get_slice_tile(kv_tile.v_tile, - sequence<0, (k1_loops - 1) * BLOCK_SIZE>{}, - sequence{})); + sequence<0, (k1_loops - 1) * kPageBlockSize>{}, + sequence{})); fmha_alu0(number<1>{} - sp_reg_idx); } }; @@ -914,9 +913,9 @@ struct UnifiedAttentionPipeline if constexpr(FmhaMask::IsMasking) { bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - i_total_loops * BLOCK_SIZE, - number{}, - number{}); + i_total_loops * kPageBlockSize, + number{}, + number{}); if(need_perpixel_check) { set_tile_if(sp(sp_reg_idx).sp_compute, @@ -925,7 +924,7 @@ struct UnifiedAttentionPipeline const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = - i_total_loops * BLOCK_SIZE + tile_idx.at(number<1>{}); + i_total_loops * kPageBlockSize + tile_idx.at(number<1>{}); return mask.IsOutOfBound(row, col); }); } diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index 3d5b46c1762..3e926470905 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" @@ -92,8 +94,8 @@ struct UnifiedAttentionPipelineDefaultPolicy { using namespace ck_tile; - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -126,8 +128,8 @@ struct UnifiedAttentionPipelineDefaultPolicy { using namespace ck_tile; - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); // 64 @@ -198,8 +200,8 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t MWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<0>{}); constexpr index_t NWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<1>{}); - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; @@ -234,9 +236,9 @@ struct UnifiedAttentionPipelineDefaultPolicy typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize, - TileGemmShape, + TileGemmShape, typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>; @@ -280,9 +282,9 @@ struct UnifiedAttentionPipelineDefaultPolicy typename Problem::VDataType, typename Problem::OaccDataType, Problem::kBlockSize, - TileGemmShape, + TileGemmShape, typename Problem::UnifiedAttentionShape::Gemm1BlockWarps, typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>; /// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass @@ -319,8 +321,8 @@ struct UnifiedAttentionPipelineDefaultPolicy using namespace ck_tile; // K is always k-major, we use async-copy to load into LDS - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -376,8 +378,8 @@ struct UnifiedAttentionPipelineDefaultPolicy using namespace ck_tile; // K is always k-major, we use async-copy to load into LDS - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -425,8 +427,8 @@ struct UnifiedAttentionPipelineDefaultPolicy { // this function assume K/V can share smem constexpr index_t SingleKSize = [&]() { - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -449,8 +451,8 @@ struct UnifiedAttentionPipelineDefaultPolicy constexpr index_t kKPack = GetSmemKPackK(); static_assert(PixelsPerRow % kKPack == 0); constexpr index_t NPerRow = PixelsPerRow / kKPack; - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kHeadDim; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; static_assert(kNPerBlock % NPerRow == 0); static_assert(kKPerBlock % kKPack == 0); @@ -467,8 +469,8 @@ struct UnifiedAttentionPipelineDefaultPolicy using namespace ck_tile; /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); @@ -524,8 +526,8 @@ struct UnifiedAttentionPipelineDefaultPolicy using namespace ck_tile; /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension - constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::BLOCK_SIZE; - constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::HEAD_SIZE; + constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kPageBlockSize; + constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kHeadDim; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps; constexpr index_t WarpSize = ck_tile::get_warp_size(); diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp deleted file mode 100644 index 45a1c8f4b87..00000000000 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_enum.hpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -namespace ck_tile { - -// This class is used for codegen pattern matching -enum class BlockFmhaPipelineEnum -{ - QRKSVS = 0, - QRKSVS_ASYNC, - QSKSVS, - QRKSVS_ASYNC_TRLOAD, -}; - -template -struct BlockFmhaPipelineEnumToStr; - -template <> -struct BlockFmhaPipelineEnumToStr -{ - static constexpr const char* name = "qr"; -}; -template <> -struct BlockFmhaPipelineEnumToStr -{ - static constexpr const char* name = "qr_async"; -}; -template <> -struct BlockFmhaPipelineEnumToStr -{ - static constexpr const char* name = "qs"; -}; - -template <> -struct BlockFmhaPipelineEnumToStr -{ - static constexpr const char* name = "qr_async_trload"; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp index f2caaa23df9..2b655c74b3f 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp" namespace ck_tile { From 284de6f12628d8a16fbc6f8295b020469bce2572 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 11 Dec 2025 14:13:18 +0000 Subject: [PATCH 81/88] Gemm dispatch --- .../pipeline/tile_unified_attention_shape.hpp | 1 - ...fied_attention_pipeline_default_policy.hpp | 48 +++++++++++-------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp index 1cdafd2429f..68c53401cdf 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp @@ -57,7 +57,6 @@ struct TileUnifiedAttentionShape BlockTile::at(number<2>{}); // BLOCK size for K seqlen static constexpr index_t kHeadDim = BlockTile::at(number<3>{}); // BLOCK size for K seqlen - // static constexpr index_t kQKHeaddim = // BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at // // once (or repeately load Q as a whole tile) // static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index 3e926470905..ecbfc5a427d 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -242,31 +242,41 @@ struct UnifiedAttentionPipelineDefaultPolicy typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use - /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here - return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{}; - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use - /// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here - return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{}; - } - }(); + // constexpr auto warp_gemm = []() { + // if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + // /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here + // return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{}; + // } + // else if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + // /// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here + // return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{}; + // } + // }(); + using WarpGemm = + WarpGemmDispatcher{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}), + Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false>; using BlockGemmPolicy = BlockGemmARegBRegCRegV2CustomPolicy< typename Problem::QDataType, typename Problem::KDataType, typename Problem::SaccDataType, typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, - decltype(warp_gemm), + WarpGemm, GemmLoopOrder::MNK>; return BlockGemmARegBRegCRegV2{}; From 4dacd3340c546a17b0a953051497a3f77a4b41a0 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Thu, 11 Dec 2025 14:14:54 +0000 Subject: [PATCH 82/88] dispatcher --- ...ified_attention_pipeline_default_policy.hpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp index ecbfc5a427d..b0f8b26af68 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp @@ -242,24 +242,6 @@ struct UnifiedAttentionPipelineDefaultPolicy typename Problem::UnifiedAttentionShape::Gemm0BlockWarps, typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>; - // constexpr auto warp_gemm = []() { - // if constexpr(std::is_same_v && - // std::is_same_v && - // std::is_same_v) - // { - // /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use - // /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here - // return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{}; - // } - // else if constexpr(std::is_same_v && - // std::is_same_v && - // std::is_same_v) - // { - // /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use - // /// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here - // return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{}; - // } - // }(); using WarpGemm = WarpGemmDispatcher Date: Fri, 12 Dec 2025 09:43:23 +0000 Subject: [PATCH 83/88] Fixes --- example/ck_tile/01_unified_attention/bias.hpp | 114 -- .../01_unified_attention/codegen/__init__.py | 0 .../codegen/cmake_config.py | 5 - .../codegen/cpp_symbol_map.py | 138 -- .../codegen/ops/__init__.py | 0 .../codegen/ops/fmha_batch_prefill.py | 633 ------ .../codegen/ops/fmha_bwd.py | 929 --------- .../codegen/ops/fmha_fwd.py | 783 -------- .../codegen/ops/fmha_fwd_appendkv.py | 376 ---- .../codegen/ops/fmha_fwd_splitkv.py | 885 -------- .../codegen/ops/fmha_pagedkv_prefill.py | 591 ------ .../01_unified_attention/codegen/utils.py | 21 - .../ck_tile/01_unified_attention/generate.py | 132 -- .../script/benchmark_bwd.sh | 20 - .../script/benchmark_fwd_v3.sh | 42 - .../unified_attention_runner.hpp | 1789 ----------------- .../ck_tile/01_unified_attention/utils.hpp | 244 --- .../CMakeLists.txt | 0 .../README.md | 20 +- .../example_unified_attention.cpp | 0 .../unified_attention_d128_bf16_mask.cpp | 0 .../unified_attention_d128_bf16_nmask.cpp | 0 .../unified_attention_d128_fp16_mask.cpp | 0 .../unified_attention_d128_fp16_nmask.cpp | 0 .../mask.hpp | 0 .../misc/gamc.png | Bin .../rotary.hpp | 0 .../script/benchmark_fwd.sh | 2 +- .../script/fmha_bwd_known_fails_gfx90a.txt | 0 .../script/fmha_bwd_known_fails_gfx942.txt | 0 .../script/fmha_bwd_known_fails_gfx950.txt | 0 .../script/fmha_fwd_known_fails_gfx90a.txt | 0 .../script/fmha_fwd_known_fails_gfx942.txt | 0 .../script/fmha_fwd_known_fails_gfx950.txt | 0 .../script/run_full_test.sh | 0 .../script/smoke_test_bwd.sh | 0 .../script/smoke_test_fwd.sh | 0 .../unified_attention.cpp | 0 .../unified_attention.hpp | 0 .../unified_attention_impl.hpp | 0 example/ck_tile/CMakeLists.txt | 3 +- .../pipeline/unified_attention_pipeline.hpp | 270 +-- 42 files changed, 43 insertions(+), 6954 deletions(-) delete mode 100644 example/ck_tile/01_unified_attention/bias.hpp delete mode 100644 example/ck_tile/01_unified_attention/codegen/__init__.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/cmake_config.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/cpp_symbol_map.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/ops/__init__.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_batch_prefill.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_bwd.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_appendkv.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_splitkv.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/ops/fmha_pagedkv_prefill.py delete mode 100644 example/ck_tile/01_unified_attention/codegen/utils.py delete mode 100644 example/ck_tile/01_unified_attention/generate.py delete mode 100755 example/ck_tile/01_unified_attention/script/benchmark_bwd.sh delete mode 100755 example/ck_tile/01_unified_attention/script/benchmark_fwd_v3.sh delete mode 100644 example/ck_tile/01_unified_attention/unified_attention_runner.hpp delete mode 100644 example/ck_tile/01_unified_attention/utils.hpp rename example/ck_tile/{01_unified_attention => 42_unified_attention}/CMakeLists.txt (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/README.md (85%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/example_unified_attention.cpp (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/instances/unified_attention_d128_bf16_mask.cpp (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/instances/unified_attention_d128_bf16_nmask.cpp (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/instances/unified_attention_d128_fp16_mask.cpp (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/instances/unified_attention_d128_fp16_nmask.cpp (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/mask.hpp (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/misc/gamc.png (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/rotary.hpp (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/benchmark_fwd.sh (96%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/fmha_bwd_known_fails_gfx90a.txt (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/fmha_bwd_known_fails_gfx942.txt (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/fmha_bwd_known_fails_gfx950.txt (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/fmha_fwd_known_fails_gfx90a.txt (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/fmha_fwd_known_fails_gfx942.txt (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/fmha_fwd_known_fails_gfx950.txt (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/run_full_test.sh (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/smoke_test_bwd.sh (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/script/smoke_test_fwd.sh (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/unified_attention.cpp (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/unified_attention.hpp (100%) rename example/ck_tile/{01_unified_attention => 42_unified_attention}/unified_attention_impl.hpp (100%) diff --git a/example/ck_tile/01_unified_attention/bias.hpp b/example/ck_tile/01_unified_attention/bias.hpp deleted file mode 100644 index c07232a13a9..00000000000 --- a/example/ck_tile/01_unified_attention/bias.hpp +++ /dev/null @@ -1,114 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha.hpp" - -// keep sync with BlockAttentionBiasEnum -enum class bias_enum -{ - no_bias = 0, - elementwise_bias = 1, - alibi = 2, -}; - -struct bias_info -{ - bias_enum type; - /* - * simple dispatch logic - * - * if type == elementwise_bias: - * if rank_info == 0: - * bias is 1*1*s*s - * elif rank_info == 1: - * bias is 1*h*s*s - * elif rank_info == 2: - * bias is b*h*s*s - * - * elif type == alibi: - * if rank_info == 0: - * alibi in 1*h - * elif rank_info == 1: - * alibi in b*h - */ - int rank_info; - - void serialize(std::ostream& os) const - { - if(type == bias_enum::no_bias) - os << "n"; - else if(type == bias_enum::elementwise_bias) - { - os << "e"; - if(rank_info != 0) - { - os << "[" << rank_info << "]"; - } - } - else if(type == bias_enum::alibi) - { - os << "alibi"; - if(rank_info != 0) - { - os << "[" << rank_info << "]"; - } - } - } - - static bias_info decode(std::string str) - { - bias_info info{bias_enum::no_bias, 0}; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string t = str.substr(0, found_0); - std::string v = str.substr(found_0 + 1); - if(t == "e" || t == "elementwise") - { - info.type = bias_enum::elementwise_bias; - info.rank_info = std::stoi(v); - if(info.rank_info < 0 || info.rank_info > 2) - throw std::invalid_argument("invalid bias rank: " + str); - } - else if(t == "a" || t == "alibi") - { - info.type = bias_enum::alibi; - info.rank_info = std::stoi(v); - if(info.rank_info < 0 || info.rank_info > 1) - throw std::invalid_argument("invalid bias rank: " + str); - } - else - { - throw std::invalid_argument("invalid bias value: " + str); - } - } - else if(str == "0" || str == "n") - { - info.type = bias_enum::no_bias; - } - else if(str == "1" || str == "e" || str == "elementwise") - { - info.type = bias_enum::elementwise_bias; - } - else if(str == "2" || str == "a" || str == "alibi") - { - info.type = bias_enum::alibi; - } - else - { - throw std::invalid_argument("invalid bias value: " + str); - } - return info; - } - - friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) - { - bi.serialize(os); - return os; - } -}; diff --git a/example/ck_tile/01_unified_attention/codegen/__init__.py b/example/ck_tile/01_unified_attention/codegen/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/example/ck_tile/01_unified_attention/codegen/cmake_config.py b/example/ck_tile/01_unified_attention/codegen/cmake_config.py deleted file mode 100644 index 03ebfd67021..00000000000 --- a/example/ck_tile/01_unified_attention/codegen/cmake_config.py +++ /dev/null @@ -1,5 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -GEN_DIR = "" # in Cmake, have to generate files in same folder \ No newline at end of file diff --git a/example/ck_tile/01_unified_attention/codegen/cpp_symbol_map.py b/example/ck_tile/01_unified_attention/codegen/cpp_symbol_map.py deleted file mode 100644 index 81d34484a54..00000000000 --- a/example/ck_tile/01_unified_attention/codegen/cpp_symbol_map.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -FWD_DTYPE_MAP = { - "fp32" : "FmhaFwdFp32", - "fp16" : "FmhaFwdFp16", - "bf16" : "FmhaFwdBf16", - "fp8" : "FmhaFwdFp8", - "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16", - "fp8fp32": "FmhaFwdFp8Fp32" -} - -BWD_DTYPE_MAP = { - "fp32": "FmhaBwdFp32", - "fp16": "FmhaBwdFp16", - "bf16": "FmhaBwdBf16" -} - -MASK_IMPL = { - "generic" : "ck_tile::GenericAttentionMask", - "simplified" : "ck_tile::SimplifiedGenericAttentionMask" -} - -_MASK_SIMPLIFIED_MAP = { - "s_no" : "ck_tile::SimplifiedGenericAttentionMask", - "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", -} - -_MASK_MAP = { - "no" : "FmhaMasks::NoMask", - "causal" : "FmhaMasks::CausalMask", - "generic" : "FmhaMasks::GenericMask" -} - -def get_mask_map(mask : str): - if mask == "generic": - return _MASK_MAP - elif mask == "simplified": - return _MASK_SIMPLIFIED_MAP - else: - assert False - return None - -_MASK_CHECK_MAP = { - "no" : "t.mask_type == mask_enum::no_mask", - "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", - "generic" : "t.mask_type == mask_enum::window_generic", -} - -_MASK_SIMPLIFIED_CHECK_MAP = { - "s_no" : "t.mask_type == mask_enum::no_mask", - "s_mask" : "t.mask_type != mask_enum::no_mask", -} - -def get_mask_check_map(mask : str): - if mask == "generic": - return _MASK_CHECK_MAP - elif mask == "simplified": - return _MASK_SIMPLIFIED_CHECK_MAP - else: - assert False - return None - -BIAS_MAP = { - "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", - "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", - "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" -} - -# TODO: this is ugly -BIAS_CHECK_MAP = { - "no" : "bias_enum::no_bias", - "bias" : "bias_enum::elementwise_bias", - "alibi" : "bias_enum::alibi" -} - -DROPOUT_MAP = { - "no" : "ck_tile::BlockDropoutBwd", - "dropout_wg32" : "ck_tile::BlockDropoutBwd", - "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", - "dropout_wg16" : "ck_tile::BlockDropoutBwd", - "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" -} - -DROPOUT_CHECK_MAP = { - "no" : "t.has_dropout == false", - "dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false", - "dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true", - "dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false", - "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", -} - -ROPE_MAP = { - "no" : "ck_tile::RotaryEmbeddingEnum::NONE", - "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", - "half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED" -} - -ROPE_CHECK_MAP = { - "no" : "rope_enum::none", - "inter" : "rope_enum::interleaved", - "half" : "rope_enum::half_rotated" -} - -MODE_MAP = { - "batch" : "false", - "group" : "true" -} - -LAYOUT_MAP = { - "row" : "true", - "col" : "false" -} - -PIPELINE_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", - "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", -} - -PIPELINE_ENUM_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", - "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", - "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", -} - -BOOL_MAP = { - "t" : "true", - "f" : "false", - True : "true", - False : "false", -} diff --git a/example/ck_tile/01_unified_attention/codegen/ops/__init__.py b/example/ck_tile/01_unified_attention/codegen/ops/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_batch_prefill.py deleted file mode 100644 index e2f69fa49ab..00000000000 --- a/example/ck_tile/01_unified_attention/codegen/ops/fmha_batch_prefill.py +++ /dev/null @@ -1,633 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass, field -import fnmatch -import itertools -from pathlib import Path -from typing import List, Optional, Tuple - -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * - - -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} - -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 64 : 64, - 96 : 128, - 128: 128, - 256: 256 -} - -FMHA_BATCH_PREFILL_PIPELINE_MAP = { - "qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", -} - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "ck_tile/ops/fmha/block/variants.hpp" -#include "fmha_fwd.hpp" -""" - -FMHA_FWD_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - false, - {F_lse}, - {F_dropout}, - {F_squant}, - {F_occupancy}>; - -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; - -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, - false, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; - -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel_{F_idx} = - ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; - -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; - -#include - -template<> -float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_batch_prefill_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp" -FMHA_FWD_API=""" -#include - -namespace {{ -bool get_num_cus(unsigned& num_cu) {{ - int device; - auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device"); - return false; - }} - - hipDeviceProp_t props{{}}; - status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device properties"); - return false; - }} - - num_cu = props.multiProcessorCount; - return true; -}} - -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ - const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; - const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 - - return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{ - float r = -1; - - [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate - - unsigned num_cus; - if (!get_num_cus(num_cus)) {{ - return r; - }} - - [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ - return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); - }}; - -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; - return fmha_batch_prefill_(s, a); - }} -""" - -@dataclass -class CppConstraint: - bool_expr: str = None - - def __str__(self): - if self.bool_expr is None: - return 'true' - else: - return f'{self.bool_expr}' - - def __and__(self, other): - return CppConstraint(f'({str(self)}) && ({str(other)})') - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag : str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - logits : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - constraint : CppConstraint - - @property - def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' - - @property - def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - - @property - def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' - else: assert False - - @property - def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False - - @property - def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False - -@dataclass -class FmhaFwdPipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' - - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' - - if self.F_dropout == 't' : n += '_dropout' - else: n += '_ndropout' - - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' - return n - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait : FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], - F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: - # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) - -@dataclass -class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) - - @property - def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - -@dataclass -class FmhaFwdKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) - -class KernelComponentFactory: - @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - } - else: - return None - - @staticmethod - def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - else: - assert False - return pipelines - -class CustomFactory(KernelComponentFactory): - @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == 'fp16' or dtype == 'bf16': - if 128 in result.keys(): - result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) - return result - -def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - - for dtype in FWD_DTYPE_MAP.keys(): - d = CustomFactory.get_hdim_tile_size_dict(dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): - for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)): - if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - if hdim == 192 and tile.F_bn1 == 128: - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': - continue - # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): - continue - k = FmhaFwdKernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'batch' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # Aiter(mha_batch_prefill) integration - elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # aiter::mha_batch_prefill C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' - if not cond: - continue - - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) - -def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - write_single_fwd_kernel(kernel, output_dir) - write_fwd_api(api_pool, output_dir) - -def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_bwd.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_bwd.py deleted file mode 100644 index 7319ef7ea1a..00000000000 --- a/example/ck_tile/01_unified_attention/codegen/ops/fmha_bwd.py +++ /dev/null @@ -1,929 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass -import fnmatch -import itertools -from pathlib import Path -from typing import List, Tuple, Dict, Literal, Any -from collections import defaultdict - -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * -from codegen.utils import update_file - - -FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "fmha_bwd.hpp" -""" - -FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile:: - sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; -using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; -using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; -using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; -using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>; -using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>; -using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>; - -// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape -// G0&G2 -> GSdP -// G1&G3 -> GdKV -// G4 -> GdQ -using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; - -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaBwdTraits<{F_dpad}, - {F_dvpad}, - {F_bias}, - {F_dbias}, - {F_occupancy}>; -using fmha_mask_{F_idx} = {F_mask}; -using fmha_dropout_{F_idx} = {F_dropout}; - -using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< - typename FmhaBwdTypeConfig::QDataType, - typename FmhaBwdTypeConfig::KDataType, - typename FmhaBwdTypeConfig::VDataType, - typename FmhaBwdTypeConfig::GemmDataType, - typename FmhaBwdTypeConfig::LSEDataType, - typename FmhaBwdTypeConfig::AccDataType, - typename FmhaBwdTypeConfig::DDataType, - typename FmhaBwdTypeConfig::BiasDataType, - typename FmhaBwdTypeConfig::RandValOutputDataType, - typename FmhaBwdTypeConfig::ODataType, - typename FmhaBwdTypeConfig::OGradDataType, - typename FmhaBwdTypeConfig::QGradDataType, - typename FmhaBwdTypeConfig::KGradDataType, - typename FmhaBwdTypeConfig::VGradDataType, - typename FmhaBwdTypeConfig::BiasGradDataType, - fmha_bwd_shape_{F_idx}, - {F_mode}, - {F_deterministic}, - fmha_mask_{F_idx}, - fmha_dropout_{F_idx}, - {F_trload}, - fmha_bwd_trait_{F_idx}>; - -using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline; - -using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem::AccDataType, - typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, - false, - ({F_dpad} > 0)>>; - -using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem::AccDataType, - typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, - false, - ({F_dvpad} > 0)>>; - -using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem::AccDataType, - typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType, - false, - ({F_dpad} > 0)>>; - -using fmha_bwd_dq_dk_dv_kernel_{F_idx} = - ck_tile::FmhaBwdDQDKDVKernel; - -using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, - {F_dtype}, - {F_mode}, - fmha_mask_{F_idx}, - fmha_dropout_{F_idx}, - {F_bias}, - {F_dbias}, - {F_dpad}, - {F_dvpad}, - {F_deterministic}, - {F_trload}, - {F_maxq}, - {F_bn0}>; - -#include - -template <> -float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -template <> -void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, - fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( - ck_tile::stream_config{{s.stream_id_}}); -}} - -template <> -int fmha_bwd_dq_dk_dv_maxq_() -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - return k_::kMaxSeqLenQ; -}} - -template <> -std::string fmha_bwd_dq_dk_dv_get_name_() -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - return k_::GetName(); -}} -""" - -FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" -FMHA_BWD_API=""" -#include - -template -float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - if constexpr (!std::is_same_v) - {{ - if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} - ); - }} - else - {{ - if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} - ); - }} -}} - -template <> -float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ - [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); - float r = -1; -{F_dispatch} - return r; -}} -""" - -def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str: - lines = [ - f"{'if' if if_ == 0 else 'else if'}({F_cond})", - "{", - *[' ' + line for line in F_body.split('\n') if line.strip() != ''], - "}", - ] - return '\n'.join(' ' * indent + line for line in lines) + '\n' - - -FMHA_BWD_API_INNER_DISPATCH=""" -{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; - r = fmha_bwd_>(s, a); - return r; -}} -""" - -# M0 size for 1d kernels (dot/convert) -M0_1D = 64 - -# GEMM0: Q@K=S^T -# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) -# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) -# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) -# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) -# Is it necessary to distinguish between K0~K4? -@dataclass(frozen=True) -class FmhaBwdDQDKDVTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along gemm0 unroll(F_bhdq) - F_bk1 : int # tile size along gemm1 unroll(F_bm0) - F_bk2 : int # tile size along gemm2 unroll(F_bhdv) - F_bk3 : int # tile size along gemm3 unroll(F_bm0) - F_bk4 : int # tile size along gemm4 unroll(F_bn0) - F_bhdq : int # q head_dim - F_bhdv : int # v head_dim - F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 - F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 - F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2 - F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 - F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3 - F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3 - F_rm2 : int # number of warps along q seqlen (block warps) in gemm4 - F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4 - F_rk2 : int # number of warps along k seqlen (not used) in gemm4 - F_wm0 : int # warp size along m in gemm0/gemm2/gemm4 - F_wn0 : int # warp size along n in gemm0/gemm2/gemm4 - F_wk0 : int # warp size along k in gemm0/gemm2/gemm4 - F_wm1 : int # warp size along m in gemm1/gemm3 - F_wn1 : int # warp size along n in gemm1/gemm3 - F_wk1 : int # warp size along k in gemm1/gemm3 - F_occupancy : int # occupancy - max_seq_q : int = 0 - - @property - def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}" - -@dataclass(frozen=True) -class FmhaBwdDQDKDVKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_tile : FmhaBwdDQDKDVTileSize - F_dpad : Literal[0, 8 ,1] - F_dvpad : Literal[0, 8 ,1] - F_bias : str # - F_dbias : str # - F_dropout : str # - F_mask : str # value from MASK_MAP - F_mode : str # value from MODE_MAP - F_deterministic : str # - mask_impl : str # - F_trload : str # - - @property - def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = BWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bk1 = self.F_tile.F_bk1, - F_bk2 = self.F_tile.F_bk2, - F_bk3 = self.F_tile.F_bk3, - F_bk4 = self.F_tile.F_bk4, - F_bhdq = self.F_tile.F_bhdq, - F_bhdv = self.F_tile.F_bhdv, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_rm2 = self.F_tile.F_rm2, - F_rn2 = self.F_tile.F_rn2, - F_rk2 = self.F_tile.F_rk2, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_dpad = self.F_dpad, - F_dvpad = self.F_dvpad, - F_bias = BIAS_MAP[self.F_bias], - F_dbias = BOOL_MAP[self.F_dbias], - F_dropout = DROPOUT_MAP[self.F_dropout], - F_occupancy = self.F_tile.F_occupancy, - F_mask = get_mask_map(self.mask_impl)[self.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_deterministic = BOOL_MAP[self.F_deterministic], - F_trload = BOOL_MAP[self.F_trload], - F_maxq = self.F_tile.max_seq_q - ) - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_dpad : n += f'd{self.F_dpad}' - if self.F_dvpad : n += f'dv{self.F_dvpad}' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_dbias == 't' : n += '_dbias' - else: n += '_ndbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' - - if self.F_dropout != 'no' : n += f'_{self.F_dropout}' - else: n += '_ndropout' - - if self.F_deterministic == 't' : n += '_deterministic' - else: n += '_ndeterministic' - - if self.F_trload == 't' : n += '_trload' - else: n += '_ntrload' - return n - - @property - def filename(self) -> str: - return self.name + ".cpp" - -# TODO: design a more practical way to do it -# this is current supported tile size. -def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: - if dtype == 'fp32' and tr_load == 'f': - return [ - # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, - FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), - ] - elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': - return [ - FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - ] - elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': - return [ - FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), - FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), - FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), - ] - else: - return [] - -FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_bwd_dot_do_o_trait_{F_idx} = - ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>; - -using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< - typename FmhaBwdTypeConfig::ODataType, - typename FmhaBwdTypeConfig::OGradDataType, - typename FmhaBwdTypeConfig::DDataType, - /* BlockSize = M0 = */ 64, - {F_hdim}, - {F_mode}, - fmha_bwd_dot_do_o_trait_{F_idx}>; - -using fmha_bwd_dot_do_o_{F_idx} = - typename ck_tile::BlockFmhaBwdOGradDotO; - -using fmha_bwd_dot_do_o_kernel_{F_idx} = - ck_tile::FmhaBwdOGradDotOKernel; - -using dot_do_o_trait_{F_idx} = - fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; - -#include - -template <> -float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -template <> -void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( - ck_tile::stream_config{{s.stream_id_}}); -}} - -template <> -std::string fmha_bwd_dot_do_o_get_name_() -{{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - return k_::GetName(); -}} -""" - -@dataclass(frozen=True) -class FmhaBwdOGradDotOKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_spad : str # true/false - F_dvpad : str # - F_mode : str # value from MODE_MAP - F_occupancy : int - - @property - def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = BWD_DTYPE_MAP[self.F_dtype], - F_spad = BOOL_MAP[self.F_spad], - F_dvpad = BOOL_MAP[self.F_dvpad], - F_mode = MODE_MAP[self.F_mode], - F_occupancy = self.F_occupancy) - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" - if pn != '' : n += f'_{pn}' - else: n += '_npad' - return n - - @property - def filename(self) -> str: - return self.name + ".cpp" - -FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_bwd_convert_dq_trait_{F_idx} = - ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>; - -using fmha_bwd_convert_dq_pipeline_problem_{F_idx} = - ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< - typename FmhaBwdTypeConfig::AccDataType, - typename FmhaBwdTypeConfig::QGradDataType, - /* BlockSize = */ 256, - {F_bm0}, - {F_bn0}, - {F_hdim}, - {F_mode}, - {F_deterministic}, - fmha_bwd_convert_dq_trait_{F_idx}>; - -using fmha_bwd_convert_dq_{F_idx} = - typename ck_tile::BlockFmhaBwdConvertQGrad; - -using fmha_bwd_convert_dq_kernel_{F_idx} = - ck_tile::FmhaBwdConvertQGradKernel; - -using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, - {F_dtype}, - {F_mode}, - {F_spad}, - {F_dpad}, - {F_deterministic}, - {F_bn0}>; - -#include - -template <> -float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -template <> -void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, - fmha_bwd_args a) -{{ - using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; - auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( - ck_tile::stream_config{{s.stream_id_}}); -}} - -template <> -std::string fmha_bwd_convert_dq_get_name_() -{{ - using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; - return k_::GetName(); -}} -""" - -@dataclass(frozen=True) -class FmhaBwdConvertQGradKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_spad : str # true/false - F_dpad : str # - F_mode : str # value from MODE_MAP - F_occupancy : int # - F_deterministic : str # - disabled : bool # sometimes this kernel is not used - - @property - def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = BWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_bm0, - F_bn0 = self.F_bn0, - F_spad = BOOL_MAP[self.F_spad], - F_dpad = BOOL_MAP[self.F_dpad], - F_mode = MODE_MAP[self.F_mode], - F_occupancy = self.F_occupancy, - F_deterministic = BOOL_MAP[self.F_deterministic]) - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_dpad == 't' : n += 'd' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" - if pn != '' : n += f'_{pn}' - else: n += '_npad' - if self.F_deterministic == 't' : n += '_deterministic' - else: n += '_ndeterministic' - return n - - @property - def filename(self) -> str: - return self.name + ".cpp" - -@dataclass(frozen=True) -class FmhaBwdApiTrait: - idx : int # this is not a tunable, but a counter to differentiate symbol - # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : int - dtype : str # data type - mode : str # value from MODE_MAP - tile : FmhaBwdDQDKDVTileSize - mask : str - bias : str - dbias : str - dropout : str - spad1d : str # spad for 1d kernels (dot/convert) - dpad : Literal[0, 1, 8] - dvpad : Literal[0, 1, 8] - deterministic : str - mask_impl : str - tr_load : str - - @property - def bm0(self) -> int: - return self.tile.F_bm0 - @property - def bn0(self) -> int: - return self.tile.F_bn0 - @property - def bhdq(self) -> int: - return self.tile.F_bhdq - @property - def bhdv(self) -> int: - return self.tile.F_bhdv - - @property - def scheck(self) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad1d == 't': - return f'a.seqlen_q % {M0_1D} != 0' - else: # self.spad1d == 'f' - return f'a.seqlen_q % {M0_1D} == 0' - - @property - def dcheck(self) -> str: - if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0' - else: return f'a.hdim_q % {self.dpad} == 0' - - @property - def dvcheck(self) -> str: - if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0' - else: return f'a.hdim_v % {self.dvpad} == 0' - - @property - def extra_cond(self) -> str: - if self.tr_load == 't' and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128: - return "&& (a.seqlen_k <= 256)" - else: - return "" - - @property - def convert_dq_bn0(self) -> int: - return self.tile.F_bn0 if self.deterministic == 't' else 0 - - @property - def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 - - F_dvpad = 't' if self.dvpad else 'f' - return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, - F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) - - @property - def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: - return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, - F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, - F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load) - - @property - def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 - - F_dpad = 't' if self.dpad else 'f' - return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad, - F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), - F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) - -class FmhaBwdApiPool: - def __init__(self, mask_impl): - self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))) - - self.mask_impl = mask_impl - - def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: - # TODO: do we need to check duplication? - self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait)) - - @staticmethod - def if_(i: int) -> str: - return 'if' if i == 0 else 'else if' - - def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str: - inners = "" - i = 0 - for trait in traits: - inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], - F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad, - F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, - F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond, - F_convert_dq_bn0=trait.convert_dq_bn0) - i += 1 - return inners - - @staticmethod - def trload_sort_key(tf): - return 0 if tf == 't' else 1 # sort 't' before 'f' - - @staticmethod - def max_seq_q_sort_key(max_seq_q): - return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end - - @staticmethod - def max_seq_q_cond(max_seq_q: int) -> str: - if max_seq_q == 0: - return 'true /* no seqlen_q limit */' - else: - return f'a.seqlen_q <= {max_seq_q}' - - @staticmethod - def dtype_cond(dtype: str) -> str: - return f't.data_type.compare("{dtype}") == 0' - - @staticmethod - def hdim_cond(hdim: int) -> str: - return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}' - - @property - def api(self) -> str: - tr_load_cond_map = { - "t": "has_load_tr", - "f": "true /* no trload requirement */" - } - per_tr_load = '' - for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key): - per_max_seq_q = '' - for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key): - per_dtypes = '' - for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]): - per_hdim_case = '' - for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]): - traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim] - inners = self._api_innders(traits) - per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners) - per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case) - per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes) - per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4) - if not per_tr_load: - # empty string we add some ignore to suppress warning in api - per_tr_load += ' (void)t ; (void)s ; (void)a; (void)has_load_tr;' - result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load) - return result.replace('\n\n', '\n') - -def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: - if filter_list == '': - filter_list = '*@*@*' - filters = filter_list.split('@') - filters.extend(['*'] * (3 - len(filters))) - filter_dot_do_o = filters[0] - filter_convert_dq = filters[1] - filter_dq_dk_dv = filters[2] - - # use dict as ordered set - gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {} - gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {} - gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {} - api_pool = FmhaBwdApiPool(mask_impl) - - for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): - tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) - dpad_options = itertools.product(*([[0, 8, 1]] * 2)) - tf = ["t", "f"] - for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product( - tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf): - assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" - hdim = tile.F_bhdq - if (mode == "group") and (spad1d == "f"): - continue - if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0: - continue - if ((bias == "no" or bias == "alibi") and dbias == "t"): - continue - if ("wg32" in dropout): - continue - if tr_load == "t": - continue # tr_load cannot work with dpad or dvpad - else: # tr_load == "f" - # do not generate instance with only 1 of dpad/dvpad being 8 - if dpad != dvpad and dpad == 8: - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load) - - if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): - continue - if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): - continue - if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): - continue - - # Flash attention integration - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - elif receipt == 3: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dpad == dvpad - cond &= deterministic == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'bias'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - cond &= deterministic == "f" - if not cond: - continue - # Aiter (mha_bwd) integration - elif receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - if not cond: - continue - # Aiter (mha_varlen_bwd) integration - elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - if not cond: - continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - - # fp32 only, all variations - if receipt == 800: - cond = dtype == 'fp32' - cond &= dpad == dvpad - if not cond: - continue - # fp32 only, minimal set of parameters - elif receipt == 801: - cond = dtype == 'fp32' - cond &= hdim in [64, 128] - cond &= dpad == dvpad - cond &= mode == 'batch' - cond &= bias == 'no' - cond &= dropout == 'no' - cond &= mask == 's_no' - cond &= deterministic == "f" - if not cond: - continue - else: - # Don't build fp32 by default - if dtype == 'fp32': - continue - - gen_dot_do_o[t.dot_do_o_kernel] = True - gen_dq_dk_dv[t.dq_dk_dv_kernel] = True - if not t.convert_dq_kernel.disabled: - gen_convert_dq[t.convert_dq_kernel] = True - api_pool.register_dq_dk_dv_traits(t) - - return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys()) - -def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) - update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api) - for k in kernels_dot_do_o: - update_file(output_dir / k.filename, k.template) - for k in kernels_convert_dq: - update_file(output_dir / k.filename, k.template) - for k in kernels_dq_dk_dv: - update_file(output_dir / k.filename, k.template) - - -def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: - _, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( - filter_list, receipt, mask_impl, optdim_list - ) - with file_path.open("a") as f: - for k in kernels_dot_do_o: - f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") - for k in kernels_dq_dk_dv: - f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") - for k in kernels_convert_dq: - f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd.py deleted file mode 100644 index f898d5f7b26..00000000000 --- a/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd.py +++ /dev/null @@ -1,783 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass, field -import fnmatch -import itertools -import os -from pathlib import Path -from typing import List, Optional, Tuple - -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * -from codegen.utils import update_file - - -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} - -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 48 : 48, - 64 : 64, - 96 : 128, - 128: 128, - 192: 192, - 256: 256 -} - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "ck_tile/ops/fmha/block/variants.hpp" -#include "fmha_fwd.hpp" -""" - -FMHA_FWD_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - false, - {F_lse}, - {F_dropout}, - {F_squant}, - {F_occupancy}, - {F_skip}>; - -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; - -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, - {F_trload}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; - -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel; - -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; - -#include - -template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" -FMHA_FWD_API=""" -#include - -#include - -namespace {{ -bool get_num_cus(unsigned& num_cus) {{ - int device; - auto status = hipGetDevice(&device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device"); - return false; - }} - - hipDeviceProp_t props{{}}; - status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) {{ - fprintf(stderr, "failed to get device properties"); - return false; - }} - - num_cus = props.multiProcessorCount; - return true; -}} - -unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ - const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; - const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 - - return batch * nheads * num_m_blocks * num_n_blocks; -}} -}} // namespace - -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ - float r = -1; - - [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate - - unsigned num_cus; - if (!get_num_cus(num_cus)) {{ - return r; - }} - - [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ - return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); - }}; - - [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); - -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ -{F_dtype_case} - }} -""" - -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && - ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; - return fmha_fwd_(s, a); - }} -""" - -@dataclass -class CppConstraint: - bool_expr: str = None - - def __str__(self): - if self.bool_expr is None: - return 'true' - else: - return f'{self.bool_expr}' - - def __and__(self, other): - return CppConstraint(f'({str(self)}) && ({str(other)})') - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag : str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - logits : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - skip : str - tr_load : str - constraint : CppConstraint - - @property - def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' - - @property - def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag in ['qr_async', 'qr_async_trload']: - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr', 'qs']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - - def seqtune(self, max_bm0 : int) -> str: - if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/' - else: - return f'a.seqlen_q <= {self.bm0}' - - @property - def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' - else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' - elif self.pipeline_tag in ['qr', 'qs']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' - elif self.pipeline_tag == 'qr_async_trload': - if self.skpad == 't' : return 'true' - else: return 'true' - else: assert False - - @property - def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False - - @property - def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False - -@dataclass -class FmhaFwdPipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP - F_skip : str # true/false - F_trload : str # true/false - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' - - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' - - if self.F_dropout == 't' : n += '_dropout' - else: n += '_ndropout' - - if self.F_skip == 't' : n += '_skip' - else: n += '_nskip' - - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' - - if self.F_trload == 't' : n += '_trload' - else: n += '_ntrload' - - return n - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait : FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - hdim = trait.hdim, trait.bn1 - if hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][hdim] = list() - - self.pool[trait.dtype][hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - tr_load_cond_map = { - "t": "has_load_tr", - "f": "true" - } - - per_tr_load =str() - for tr_load in ["t", "f"]: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] - max_bm0 = max((t.bm0 for t in traits), default=0) - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) - if not per_tr_load: - # empty string we add some ignore to suppress warning in api - per_tr_load += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) - -@dataclass -class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) - - @property - def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - -@dataclass -class FmhaFwdKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_skip = BOOL_MAP[self.F_pipeline.F_skip], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_trload = BOOL_MAP[self.F_pipeline.F_trload]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - skip=self.F_pipeline.F_skip, - tr_load=self.F_pipeline.F_trload, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) - -class KernelComponentFactory: - # TODO: design a more practical way to do it - # this is current supported tile size per hdim - @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp32': - return { - # bm0, bn0, bk0, bn1, bk1, - ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], - } - elif dtype == 'fp16' or dtype == 'bf16': - return { - (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], - (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - } - elif dtype == 'fp8' or dtype == 'fp8bf16': - return { - (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - } - elif dtype == 'fp8fp32': - return { - (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - } - else: - return None - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - @staticmethod - def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - pipelines = [] - if dtype in ['fp32']: - squant = 'f' - for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - elif dtype in ['fp16', 'bf16']: - squant = 'f' - for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - if hdim == 256 and hdim_v == 256: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - else: - if bias == "bias": - # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": - pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) - pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) - if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: - # no need lse/dropout kernels - for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'bf8']: - # TODO - None - else: - assert False - return pipelines - -class CustomFactory(KernelComponentFactory): - @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == 'fp16' or dtype == 'bf16': - if (128, 128) in result.keys(): - result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) - return result - -def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - - factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory - - for dtype in FWD_DTYPE_MAP.keys(): - d = factory.get_hdim_tile_size_dict(dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): - for tile, next_tile in zip(tiles, tiles[1:]): - assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0' - for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): - if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - if (hdim, hdim_v) == (192, 128): - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': - continue - if dtype != 'fp32': - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): - continue - # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): - continue - k = FmhaFwdKernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - cond &= pipeline.F_skip == 'f' - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' - cond &= mode == 'batch' - cond &= pipeline.F_skip == 'f' - cond &= pipeline.F_logits == 'f' - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ['fp16', 'bf16', 'fp8bf16'] - cond &= mode == 'batch' - cond &= pipeline.F_vlayout == 'row' - if dtype == 'fp8bf16': - cond &= hdim == 128 - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ['fp16', 'bf16', 'fp8bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - if dtype == 'fp8bf16': - cond &= hdim == 128 - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16', 'fp8bf16'] - cond &= pipeline.F_vlayout == 'row' - if dtype == 'fp8bf16': - cond &= hdim == 128 - if not cond: - continue - elif receipt == 888: - cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32'] - cond &= pipeline.F_vlayout == 'row' - cond &= hdim == 128 - if not cond: - continue - - # fp32 only, all variations - if receipt == 800: - cond = dtype == 'fp32' - cond &= pipeline.F_skip == 'f' - cond &= pipeline.F_logits == 'f' - if not cond: - continue - # fp32 only, minimal set of parameters - elif receipt == 801: - cond = dtype == 'fp32' - cond &= hdim in [48, 128] - cond &= mode == 'batch' - cond &= pipeline.F_bias == 'no' - cond &= pipeline.F_lse == 'f' - cond &= pipeline.F_dropout == 'f' - cond &= pipeline.F_skip == 'f' - cond &= pipeline.F_logits == 'f' - cond &= pipeline.F_mask == 's_no' - if not cond: - continue - else: - # Don't build fp32 by default - if dtype == 'fp32': - continue - - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - update_file(autogen_dir / kernel.filename, kernel.template) - -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) - -def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - write_single_fwd_kernel(kernel, output_dir) - write_fwd_api(api_pool, output_dir) - -def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_appendkv.py deleted file mode 100644 index 38491b56c40..00000000000 --- a/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_appendkv.py +++ /dev/null @@ -1,376 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass -import fnmatch -import itertools -from pathlib import Path -from typing import List, Optional, Tuple - -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * - -from codegen.ops.fmha_fwd import ( - FmhaFwdApiTrait, - DTYPE_BITS, - FMHA_FWD_KERNEL_HEADER, - FMHA_FWD_API_PER_DTYPE, - FMHA_FWD_API_PER_HDIM_CASE, -) - - -FMHA_FWD_APPENDKV_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_occupancy}>; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - {F_bs}, - {F_bsk}, - {F_bd}, - {F_bdv}, - {F_vlayout}, - {F_rope}, - {F_pagedkv}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< - fmha_pipeline_problem_{F_idx}>; - -using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel; - -using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, - {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; - -#include - -template<> -float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp" -FMHA_FWD_APPENDKV_API=""" -float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{ - float r = -1; -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv})) {{ - using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; - return fmha_fwd_appendkv_(s, a); - }} -""" - -@dataclass -class FmhaFwdAppendKVApiTrait: - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - bs : int # tile size along q seqlen - bsk : int # tile size along k seqlen - bd : int # tile size along qk gemm unroll - bdv : int # tile size along kv gemm unroll - vlayout : str - spad : str - skpad : str - dpad : str - dvpad : str - rope : str # key from ROPE_MAP - pagedkv : str - - @property - def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\ - f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}' - - @property - def scheck(self) -> str: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/' - else : return f'a.seqlen_q % {self.bs} == 0' - - @property - def skcheck(self) -> str: - # we do not check all the values in a.seqlen_k_ptr - return 'true' - - @property - def dcheck(self) -> str: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bd} == 0' - - @property - def dvcheck(self) -> str: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bdv} == 0' - -@dataclass -class FmhaFwdAppendKVPipeline: - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_rope : str # key from ROPE_MAP - F_pagedkv : str # t/f - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - if self.F_rope != 'no': n += f'_{self.F_rope}' - if self.F_pagedkv == 't': n += '_pagedkv' - return n - -class FmhaFwdAppendKVApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait : FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], - F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], - F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: - # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) - -@dataclass -class FmhaFwdAppendKVTileSize: - F_bs : int # tile size along q seqlen - F_bsk : int # tile size along k seqlen - F_bd : int # tile size along qk gemm unroll - F_bdv : int # tile size along kv gemm unroll - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - @property - def name(self) -> str: - return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - -@dataclass -class FmhaFwdAppendKVKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_tile : FmhaFwdAppendKVTileSize - F_pipeline : FmhaFwdAppendKVPipeline - mask_impl : str - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_APPENDKV_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bs = self.F_tile.F_bs, - F_bsk = self.F_tile.F_bsk, - F_bd = self.F_tile.F_bd, - F_bdv = self.F_tile.F_bdv, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_rope = ROPE_MAP[self.F_pipeline.F_rope], - F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], - F_occupancy = self.F_tile.F_occupancy) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdAppendKVApiTrait: - return FmhaFwdAppendKVApiTrait( - hdim=str(self.F_hdim), - dtype=self.F_dtype, - bs=self.F_tile.F_bs, - bsk=self.F_tile.F_bsk, - bd=self.F_tile.F_bd, - bdv=self.F_tile.F_bdv, - vlayout=self.F_pipeline.F_vlayout, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - rope=self.F_pipeline.F_rope, - pagedkv=self.F_pipeline.F_pagedkv) - -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1), - '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), - '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), - '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), - '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), - '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1) - } - else: - return None - -def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - # NOTICE: it will be very complicated if we consider all the hdim_q padding cases while - # applying rotary embedding, so I just use 't' in inter/half pipelines - for vlayout in ['row', 'col']: - for pagedkv in ["t", "f"]: - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv)) - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv)) - - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv)) - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv)) - - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv)) - pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv)) - elif dtype in ['fp8', 'bf8']: - # rope/paged-kv is not supported - pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: - # TODO - None - else: - assert False - return pipelines - - gen = list() - api_pool = FmhaFwdAppendKVApiPool(mask_impl) - - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str in d.keys(): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - k = FmhaFwdAppendKVKernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - if not cond: - continue - - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' - if not cond: - continue - - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) - -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) - for kernel in kernels: - write_single_kernel(kernel, output_dir) - write_fwd_appendkv_api(api_pool, output_dir) - -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: - _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_splitkv.py deleted file mode 100644 index 281357ef1ec..00000000000 --- a/example/ck_tile/01_unified_attention/codegen/ops/fmha_fwd_splitkv.py +++ /dev/null @@ -1,885 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass -import fnmatch -import itertools -from pathlib import Path -from typing import List, Optional, Tuple, Union - -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * - -from codegen.ops.fmha_fwd import ( - FmhaFwdTileSize, - FmhaFwdApiTrait, - FMHA_FWD_KERNEL_HEADER, - FMHA_FWD_API_PER_DTYPE, - FMHA_FWD_API_PER_HDIM_CASE, -) - - -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} - -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 64 : 64, - 96 : 128, - 128: 128, - # 160: 160, - 256: 256 -} - -FMHA_FWD_SPLITKV_PIPELINE_MAP = { - "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", - "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", -} - -FMHA_FWD_SPLITKV_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; -using fmha_mask_{F_idx} = {F_mask}; - -namespace {{ -template -struct instance {{ -using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - /*kHasBiasGrad=*/false, - {F_lse}, - {F_squant}, - {F_pagedkv}, - kHasUnevenSplits, - kMergeNumHeadGroupsSeqLenQ, - {F_occupancy}>; - -using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::OaccDataType, - fmha_shape, - {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, - fmha_trait>; - -using fmha_pipeline = {F_pipeline}< - fmha_pipeline_problem>; - -/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving -/// store_tile_raw() data corruption issue -using fmha_epilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, - false, false>>; - -using fmha_kernel = - ck_tile::FmhaFwdSplitKVKernel; - -static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) -{{ - using k_ = fmha_kernel; - auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} -}}; -}} - -using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, - {F_dvpad}>; - -#include - -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wtautological-compare" - -namespace {{ -template -void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ - if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS - && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> - || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ - if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ - instance::run(s, a); - }} else {{ - instance::run(s, a); - }} - }} else {{ - instance::run(s, a); - }} -}} -}} // anonymous namespace - -#pragma clang diagnostic pop - -template<> -void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) -{{ - if constexpr({F_mode} == false) {{ // batch mode - // we don't check every seqlen_k values for kvcache - if (a.seqlen_k_ptr != nullptr) {{ - run_instance(s, a); - // make sure F_bn0 is divisible by F_bk1 - }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ - run_instance(s, a); - }} else {{ - run_instance(s, a); - }} - }} else {{ - run_instance(s, a); - }} -}} - -template<> -std::string fmha_fwd_splitkv_get_name_() -{{ - using k_ = instance::fmha_kernel; /// FIXME: choose real kernel type - return k_::GetName(); -}} -""" - -FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -namespace {{ -template -struct instance {{ -using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, - {F_dvpad}, - {F_lse}, - {F_squant}, - kLogMaxSplits, - {F_occupancy}>; - -using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - {F_hdim}, - {F_mode}, - {F_bn1}, - fmha_trait>; - -using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< - fmha_pipeline_problem>; - -/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving -/// store_tile_raw() data corruption issue -using fmha_epilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - false, false>>; - -using fmha_kernel = - ck_tile::FmhaFwdSplitKVCombineKernel; - -static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) -{{ - using k_ = fmha_kernel; - auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} -}}; -}} - -using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1}, - {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; - -#include - -template<> -void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) -{{ - if (a.num_splits <= 8) {{ - instance<3>::run(s, a); - }} else if (a.num_splits <= 16) {{ - instance<4>::run(s, a); - }} else if (a.num_splits <= 32) {{ - instance<5>::run(s, a); - }} else if (a.num_splits <= 64) {{ - instance<6>::run(s, a); - }} else if (a.num_splits <= 128) {{ - instance<7>::run(s, a); - }} -}} - -template<> -std::string fmha_fwd_splitkv_combine_get_name_() -{{ - using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type - return k_::GetName(); -}} -""" - -FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp" -FMHA_FWD_SPLITKV_API=""" -#include - -template -float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) -{{ - if(s.log_level_ > 0) - std::cout - << ", " << fmha_fwd_splitkv_get_name_() - << ", " << fmha_fwd_splitkv_combine_get_name_() - << std::flush; - - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} - ); -}} - -float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{ - float r = -1; -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - - // get combine kernel tile sizes - using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; - constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; - - // make sure we can reuse the padding flags in combine kernels - static_assert({F_bm0} % kM0 == 0); - static_assert({F_bn1} % 32 == 0); - - if (t.has_lse) {{ - if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ - return -1; - }} else {{ - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; - - return fmha_fwd_splitkv_(s, a); - }} - }} else {{ - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; - - return fmha_fwd_splitkv_(s, a); - }} - }} -""" - -@dataclass -class FmhaFwdSplitKVApiTrait: - pipeline_tag : str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - mask : str - logits : str - bias : str # - lse : str # - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - pagedkv : str - - @property - def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ - f'{self.dvpad}-{self.pagedkv}' - - @property - def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - - @property - def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: - if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0' - else: assert False - - @property - def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False - - @property - def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False - -@dataclass -class FmhaFwdSplitKVPipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_squant : str # - F_pagedkv : str # t/f - F_mask : str # value from MASK_MAP - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' - - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' - - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' - - if self.F_pagedkv == 't' : n += '_pagedkv' - else: n += '_npagedkv' - return n - -@dataclass -class FmhaFwdSplitKVCombinePipeline: - tag : str - - F_spad : str # true/false - F_dvpad : str # - F_lse : str # - F_squant : str # - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' - - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' - return n - -class FmhaFwdSplitKVApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], - F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: - # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) - -@dataclass -class FmhaFwdSplitKVCombineTileSize: - F_bn1 : int # tile size along v head_dim - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - @property - def name(self) -> str: - return f"b{self.F_bn1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - -@dataclass -class FmhaFwdSplitKVKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdSplitKVPipeline - mask_impl : str - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_SPLITKV_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdSplitKVApiTrait: - return FmhaFwdSplitKVApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - logits=self.F_pipeline.F_logits, - mask=self.F_pipeline.F_mask, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - squant=self.F_pipeline.F_squant, - pagedkv=self.F_pipeline.F_pagedkv, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) - -@dataclass -class FmhaFwdSplitKVCombineKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdSplitKVCombineTileSize - F_pipeline : FmhaFwdSplitKVCombinePipeline - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bn1 = self.F_tile.F_bn1, - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_occupancy = self.F_tile.F_occupancy, - F_mode = MODE_MAP[self.F_mode]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - # '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - } - else: - return None - -def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : FmhaFwdSplitKVCombineTileSize(32, -1), - '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - '96' : FmhaFwdSplitKVCombineTileSize(32, -1), - '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - # '160' : FmhaFwdSplitKVCombineTileSize(32, -1), - '256' : FmhaFwdSplitKVCombineTileSize(32, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - '256' : FmhaFwdSplitKVCombineTileSize(32, -1), - } - else: - return None - -def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: - Pipeline = FmhaFwdSplitKVPipeline - Kernel = FmhaFwdSplitKVKernel - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): - pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - - pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - - pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) - elif dtype in ['fp8', 'bf8']: - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask)) - elif dtype in ['fp8fp16', 'fp8bf16']: - # TODO - None - else: - assert False - return pipelines - - gen = list() - api_pool = FmhaFwdSplitKVApiPool(mask_impl) - - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): - continue - k = Kernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # Flash attention integration - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16, bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' - cond &= mode == 'batch' - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # aiter::mha_fwd_splikv C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' - if not cond: - continue - - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]: - Pipeline = FmhaFwdSplitKVCombinePipeline - Kernel = FmhaFwdSplitKVCombineKernel - - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]): - pipelines.append(Pipeline('unused', spad, dvpad, lse, squant)) - elif dtype in ['fp8', 'bf8']: - # no need lse kernels - pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant)) - else: - assert False - return pipelines - - gen = list() - - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - if mode == "group": - if pipeline.F_spad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - k = Kernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # Aiter(mha_varlen_fwd) integration - if receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue - # aiter::mha_fwd_splikv C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' - if not cond: - continue - - gen.append(k) - - return gen - -def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None: - file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME - file_path.write_text(api_pool.api) - -def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (2 - len(filter_list))) - - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) - for kernel in kernels: - write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) - for kernel in kernels: - write_single_kernel(kernel, output_dir) - write_fwd_splitkv_api(api_pool, output_dir) - -def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (2 - len(filter_list))) - - with file_path.open('a') as f: - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_unified_attention/codegen/ops/fmha_pagedkv_prefill.py deleted file mode 100644 index 3624b7b387e..00000000000 --- a/example/ck_tile/01_unified_attention/codegen/ops/fmha_pagedkv_prefill.py +++ /dev/null @@ -1,591 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import copy -from dataclasses import dataclass -import fnmatch -import itertools -from pathlib import Path -from typing import List, Optional, Tuple - -from codegen.cmake_config import * -from codegen.cpp_symbol_map import * - - -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} - -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 64 : 64, - 96 : 128, - 128: 128, - 256: 256 -} - -FMHA_FWD_PAGEDKV_PIPELINE_MAP = { - "qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" -} - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "ck_tile/ops/fmha/block/variants.hpp" -#include "fmha_fwd.hpp" -""" - -FMHA_FWD_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, - ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, - ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, - ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, - {F_vlayout}>; - -using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_logits}, - {F_bias}, - false, - {F_lse}, //lse - {F_pagedkv}, //pagedkv - {F_squant}, - {F_occupancy}, - {F_skip}>; - -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; - -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_variant_{F_idx}, - fmha_mask_{F_idx}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; - -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdPagedKVKernel; - -using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; - -#include - -template<> -float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp" -FMHA_FWD_API=""" -float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{ - float r = -1; -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; - return fmha_fwd_pagedkv_(s, a); - }} -""" - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag : str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - logits : str - mask : str - bias : str # - lse : str # - pagedkv : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - skip : str - - @property - def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' - - @property - def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr_pagedkv', 'qs']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - - @property - def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr_pagedkv', 'qs']: - if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0' - else: assert False - - @property - def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr_pagedkv', 'qs']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False - - @property - def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr_pagedkv', 'qs']: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False - -@dataclass -class FmhaFwdPipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_pagedkv : str # - F_squant : str # - F_mask : str # value from MASK_MAP - F_skip : str # true/false - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' - - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' - - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' - - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' - - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' - - if self.F_skip == 't' : n += '_skip' - else: n += '_nskip' - - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' - - if self.F_pagedkv == 't' : n += '_pagedkv' - else: n += '_npagedkv' - - return n - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait : FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: - # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) - -@dataclass -class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - @property - def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - -@dataclass -class FmhaFwdKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_skip = BOOL_MAP[self.F_pipeline.F_skip], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - pagedkv=self.F_pipeline.F_pagedkv, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - skip=self.F_pipeline.F_skip) - -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - } - else: - return None - -def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]): - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) - elif dtype in ['fp8', 'bf8']: - # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: - # TODO - None - else: - assert False - return pipelines - - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - - for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - # if pipeline.F_pagedkv == 'f': - # continue - if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - if hdim == 192 and tile.F_bn1 == 128: - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_lse == 't' : - continue - # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): - continue - k = FmhaFwdKernel(F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - cond &= pipeline.F_skip == 'f' - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' - cond &= pipeline.F_skip == 'f' - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'batch' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue - - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == 'fp32' - if not cond: - continue - - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) - -def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - write_single_fwd_kernel(kernel, output_dir) - write_fwd_api(api_pool, output_dir) - -def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_unified_attention/codegen/utils.py b/example/ck_tile/01_unified_attention/codegen/utils.py deleted file mode 100644 index e3bbb18c427..00000000000 --- a/example/ck_tile/01_unified_attention/codegen/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import os.path as path - - -def update_file(file_path, content): - """Update the file at file_path with the given content if it differs from the existing content. - - It avoids unnecessary touching of the file which triggers rebuilds - """ - - existing_content = "" - if path.exists(file_path): - with open(file_path, "r") as file: - existing_content = file.read() - if existing_content == content: - return - with open(file_path, "w") as file: - file.write(content) diff --git a/example/ck_tile/01_unified_attention/generate.py b/example/ck_tile/01_unified_attention/generate.py deleted file mode 100644 index 03173305118..00000000000 --- a/example/ck_tile/01_unified_attention/generate.py +++ /dev/null @@ -1,132 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import argparse -from enum import IntEnum -from pathlib import Path -import pkgutil -import sys -from typing import List, Optional - -import codegen.ops -from codegen.cmake_config import * - - -class HandlerId(IntEnum): - LIST_BLOBS = 0 - WRITE_BLOBS = 1 - -# inspect all modules under 'codegen.ops' and register API handlers -ops = [] -for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): - full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) - ops.append(importer.find_spec(module_name).loader.load_module(module_name)) -unwanted_prefix = 'fmha_' -handlers = dict( - [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, - (op.list_blobs, op.write_blobs)) for op in ops] -) -assert 0 < len(handlers) - -def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: - if output_dir is None: - output_dir = Path(__file__).parent - else: - output_dir = Path(output_dir) / GEN_DIR - - output_dir.mkdir(parents=True, exist_ok=True) - - for api, kernel_filter in zip(api_list, filters_list): - handler = handlers[api][HandlerId.WRITE_BLOBS] - handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) - -# list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: - assert output_file is not None - file_path = Path(output_file) - - # create an empty file / drop its contents if it exists - open(file_path, "w").close() - - for api, kernel_filter in zip(api_list, filters_list): - handler = handlers[api][HandlerId.LIST_BLOBS] - handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog="generate", - description="gen API for CK fmha kernel", - ) - parser.add_argument( - "-d", - "--direction", # we keep 'direction' option for backward compatibility - "-a", - "--api", - default='fwd', - required=False, - help="supply API(s) to generate (default: fwd). separated by comma." - ) - parser.add_argument( - "-o", - "--output_dir", - required=False, - help="write all the blobs into a directory" - ) - parser.add_argument( - "-l", - "--list_blobs", - required=False, - help="list all the kernels to a file" - ) - # TODO: if using filter, must apply same value to output_dir and list_blobs - parser.add_argument( - "-f", - "--filter", - default='', - required=False, - help="filter out kernels that need to generate, using fnmatch module" - ) - - parser.add_argument( - "-m", - "--mask", - default="simplified", - required=False, - help="mask implementation, simplified/generic" - ) - - parser.add_argument( - "-r", - "--receipt", - default=0, - required=False, - help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ - " 1: generate more instance to cover all hdim\n" + \ - " 2: Only generate instance for Flash attention integration\n" + \ - " 4: Only generate instance for PyTorch integration\n" + \ - " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ - " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ - " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ - " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ - " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" - ) - - parser.add_argument( - "--optdim", - default='-1', - required=False, - help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ - "eg. --optdim=32,64,128,256" - ) - - args = parser.parse_args() - api_list = args.direction.split(',') - filter_list = args.filter.split(',') - filter_list.extend([''] * (len(api_list) - len(filter_list))) - optdim_list = [int(hdim) for hdim in args.optdim.split(',')] - - if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) - else: - write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/01_unified_attention/script/benchmark_bwd.sh b/example/ck_tile/01_unified_attention/script/benchmark_bwd.sh deleted file mode 100755 index cfd792906ce..00000000000 --- a/example/ck_tile/01_unified_attention/script/benchmark_bwd.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/sh -# TODO: run this script from CK root or build directory -EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" -VALID=0 - -for prec in "fp16" "bf16" ; do -for perm in 0 1 ; do -for hdim in 32 64 128 ; do - -nhead=$((2048 / $hdim)) # follow fav2 setup -$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 - -done -done -done diff --git a/example/ck_tile/01_unified_attention/script/benchmark_fwd_v3.sh b/example/ck_tile/01_unified_attention/script/benchmark_fwd_v3.sh deleted file mode 100755 index a3f7d68eb3a..00000000000 --- a/example/ck_tile/01_unified_attention/script/benchmark_fwd_v3.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/sh -# TODO: run this script from CK root or build directory -EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)" -VALID=0 - -for causal in 0 1 ; do -for prec in "fp16" "bf16" ; do -for hdim in 128 ; do -for perm in 0 ; do - -$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID - -$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID - -done -done -done -done - -# Padding benchmark comparisons for v3 (batch mode only) -# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ==== -prec="fp16" -base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID" - -# baseline (no pad) -$EXE $base_v3_args - -# low pad (≈90–95% effective) -$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 - -# medium pad (≈60–75% effective) -$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 - -# high pad (≈30–40% effective) -$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 diff --git a/example/ck_tile/01_unified_attention/unified_attention_runner.hpp b/example/ck_tile/01_unified_attention/unified_attention_runner.hpp deleted file mode 100644 index 7da84b8a927..00000000000 --- a/example/ck_tile/01_unified_attention/unified_attention_runner.hpp +++ /dev/null @@ -1,1789 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/host.hpp" -#include "ck_tile/ref/naive_attention.hpp" -#include "unified_attention.hpp" -#include "utils.hpp" -#include "ck_tile/utility/json_dump.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API -#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()" -#endif - -enum class fwd_result -{ - success, - failure, - invalid_args, - no_instance, -}; - -// different threshold for different dtype -template -auto get_elimit(std::string /*init_method*/) -{ - double rtol = 1e-3; - double atol = 1e-3; - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit(std::string /*init_method*/) -{ - double rtol = 1e-5; - double atol = 1e-5; - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit(std::string /*init_method*/) -{ - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit(std::string /*init_method*/) -{ - using TypeConfig = FmhaFwdTypeConfig; - using ODataType = typename TypeConfig::ODataType; - float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - double rtol = 0; - double atol = 16 * (o_dtype_max > 240 ? 2 : 1); - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit(std::string /*init_method*/) -{ - double rtol = 1e-2; - double atol = 1.8e-1; - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit(std::string /*init_method*/) -{ - double rtol = 1e-2; - double atol = 1.8e-1; - return ck_tile::make_tuple(rtol, atol); -} - -int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) -{ - // If we have enough to almost fill the SMs, then just use 1 split - if(batch_nhead_mblocks >= 0.8f * num_SMs) - { - return 1; - } - max_splits = std::min({max_splits, num_SMs}); - float max_efficiency = 0.f; - std::vector efficiency; - efficiency.reserve(max_splits); - for(int num_splits = 1; num_splits <= max_splits; num_splits++) - { - float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if(eff > max_efficiency) - { - max_efficiency = eff; - } - efficiency.push_back(eff); - } - for(int num_splits = 1; num_splits <= max_splits; num_splits++) - { - if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) - { - // printf("num_splits chosen = %d\n", num_splits); - return num_splits; - } - } - return 1; -} - -int override_num_splits_if_necessary( - int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) -{ - (void)hdim_v; - int device; - auto status = hipGetDevice(&device); - if(status != hipSuccess) - { - return num_splits; - } - - hipDeviceProp_t props{}; - status = hipGetDeviceProperties(&props, device); - if(status != hipSuccess) - { - return num_splits; - } - - // tile size should match the generate.py - const int kM0 = 64; - - const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); - - if(num_splits < 1 && p_drop == 0.0f) - { - return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); - } - - return num_splits; -} - -template -fwd_result fmha_fwd_run(mode_enum mode, - ck_tile::index_t batch, - ck_tile::index_t nhead, - ck_tile::index_t nhead_k, - std::vector seqlen_qs, - std::vector seqlen_ks, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t seqlen_knew, - std::vector seqlen_qpads, - std::vector seqlen_kpads, - std::vector q_eff_lens_per_batch, - std::vector kv_eff_lens_per_batch, - ck_tile::index_t rotary_dim, - bool i_perm, - bool o_perm, - float scale_s, - float logits_soft_cap, - bool is_v_rowmajor, - bool lse, - ck_tile::index_t page_block_size, - bool use_cache_batch_idx, - std::string bias_str, - float p_drop, - uint64_t drop_seed, - uint64_t drop_offset, - bool drop_prefs, - std::string mask_str, - bool squant, - bool is_rotary_interleaved, - ck_tile::index_t num_splits, - std::string init_method, - uint32_t seed, - int do_validation, - const ck_tile::stream_config& stream_config, - std::optional json = std::nullopt) -{ - const std::string data_type = []() { - if constexpr(std::is_same_v) - return "fp32"; - else if constexpr(std::is_same_v) - return "fp16"; - else if constexpr(std::is_same_v) - return "bf16"; - else if constexpr(std::is_same_v) - return "fp8"; - else if constexpr(std::is_same_v) - return "bf8"; - else if constexpr(std::is_same_v) - return "fp8bf16"; - else if constexpr(std::is_same_v) - return "fp8fp32"; - else - static_assert(false); - }(); - - if(nhead_k < 0) - nhead_k = nhead; - if(nhead % nhead_k != 0) - { - std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; - return fwd_result::invalid_args; - } - - std::mt19937 random_engine(seed != 0 ? seed : std::random_device{}()); - auto next_seed = [&random_engine]() { return static_cast(random_engine()); }; - - if(hdim_v < 0) - hdim_v = hdim_q; - -#if !CK_TILE_FMHA_FWD_APPENDKV_API - if(seqlen_knew != 0) - { - std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option" - << std::endl; - seqlen_knew = 0; - } -#endif - if(seqlen_knew < 0) - { - seqlen_knew = randint(1, seqlen_qs[0], random_engine); - } - - if constexpr(!(std::is_same_v || - std::is_same_v)) - { - if(0 < rotary_dim) - { - std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl; - return fwd_result::invalid_args; - } - } -#if !CK_TILE_FMHA_FWD_APPENDKV_API - else if(0 < rotary_dim) - { - std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option" - << std::endl; - rotary_dim = 0; - } -#endif - // to use fmha_fwd_appendkv(), make sure it's in batch mode - const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); - if(need_append_kvcache && mode == mode_enum::group) - { - std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl; - mode = mode_enum::batch; - } - if(!(rotary_dim <= hdim_q)) - { - std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; - return fwd_result::invalid_args; - } - else if(!(rotary_dim % 16 == 0)) - { - std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl; - return fwd_result::invalid_args; - } - -#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \ - CK_TILE_FMHA_FWD_PAGEDKV_API)) - if(0 < page_block_size) - { - std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" - << std::endl; - page_block_size = 0; - } -#endif - if(!(page_block_size % 128 == 0)) - { - std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" - << std::endl; - return fwd_result::invalid_args; - } - -#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API) - if(use_cache_batch_idx) - { - std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" - << std::endl; - use_cache_batch_idx = false; - } -#else - if(use_cache_batch_idx) - { - if(0 < page_block_size) - { - std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " - "'cache_batch_idx' option" - << std::endl; - use_cache_batch_idx = false; - } - else if(mode == mode_enum::group) - { - std::cerr << "group mode will not use cache_batch_idx. ignoring the " - "'cache_batch_idx' option" - << std::endl; - use_cache_batch_idx = false; - } - } -#endif - const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); - - // Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv) - const bool has_group_padding = - (mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) || - (mode == mode_enum::group && (seqlen_kpads[0] >= 0)); - const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() || - !kv_eff_lens_per_batch.empty())); - const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); - const bool using_pagedkv = (0 < page_block_size); - const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; - if((using_appendkv || using_pagedkv || using_splitkv) && - (has_group_padding || has_batch_efflens)) - { - std::cerr << "Padding (physical or effective lengths) is not supported with " - "appendkv/splitkv/pagedkv pipelines" - << std::endl; - return fwd_result::invalid_args; - } - - std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = - generate_missing_seqlens(mode, - batch, - seqlen_qs, - seqlen_ks, - seqlen_kpads, - /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, - need_append_kvcache, - random_engine); - for(ck_tile::index_t wb = 0; wb < batch; ++wb) - { - if(seqlen_kpads[wb] > 0 && seqlen_kpads[wb] < seqlen_ks[wb]) - { - std::cerr << "kpad must be greater than or equal to seqlen for k" << std::endl; - return fwd_result::invalid_args; - } - } - // compute kvcache seqlen_k (before appending knew/vnew) - auto cache_seqlen_ks = seqlen_ks; - std::transform(cache_seqlen_ks.begin(), - cache_seqlen_ks.end(), - cache_seqlen_ks.begin(), - [&](auto seqlen_k) { return seqlen_k - seqlen_knew; }); - -#if 0 - std::cout << "seqlen_qs: " << seqlen_qs << std::endl; - std::cout << "seqlen_ks: " << seqlen_ks << std::endl; - std::cout << "seqlen_kpads: " << seqlen_kpads << std::endl; - std::cout << "cache_seqlen_ks: " << cache_seqlen_ks << std::endl; -#endif - - if(scale_s == .0f) - scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? - - bias_info bias = bias_info::decode(bias_str); - - mask_info mask = - mask_info::decode(mask_str, seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore - - if(p_drop < 0.0f || p_drop > 1.0f) - { - std::cerr << "The value of p_drop should be 0~1" << std::endl; - return fwd_result::invalid_args; - } - - bool s_randval = false; - if(p_drop > 0.0f && do_validation) - { - s_randval = true; - } - -#if !CK_TILE_FMHA_FWD_SPLITKV_API - if(num_splits != 1) - { - std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl; - num_splits = 1; - } -#endif - - const auto seqstart_q_host = to_seqstarts(seqlen_qs); - const auto seqstart_k_host = to_seqstarts(seqlen_ks); - const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); - - // Optional padded Q seqstarts (group-mode only) - std::vector seqstart_q_with_padding_host; - if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) - { - if(seqlen_qpads.size() < static_cast(batch)) - { - seqlen_qpads.resize(batch, seqlen_qpads.back()); - } - if(seqlen_qpads.size() == static_cast(batch)) - { - seqstart_q_with_padding_host = to_seqstarts( - ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); - } - } - - // Optional batch-mode cumulative seqlen overrides - std::vector cuq_cum, cukv_cum; - if(mode == mode_enum::batch) - { - auto calculate_cumulative = [&](std::vector& per_batch_vec, - std::vector& cum_vec) { - if(!per_batch_vec.empty() && per_batch_vec[0] != -1) - { - if(per_batch_vec.size() < static_cast(batch)) - { - per_batch_vec.resize(batch, per_batch_vec.back()); - } - cum_vec.resize(batch + 1); - cum_vec[0] = 0; - for(int i = 0; i < batch; ++i) - cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; - } - }; - - calculate_cumulative(q_eff_lens_per_batch, cuq_cum); - calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); - } - - using TypeConfig = FmhaFwdTypeConfig; - - using QDataType = typename TypeConfig::QDataType; - using KDataType = typename TypeConfig::KDataType; - using VDataType = typename TypeConfig::VDataType; - using BiasDataType = typename TypeConfig::BiasDataType; - using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; - using LSEDataType = typename TypeConfig::LSEDataType; - using SaccDataType = typename TypeConfig::SaccDataType; - using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; - using PDataType = typename TypeConfig::PDataType; - using OaccDataType = typename TypeConfig::OaccDataType; - using ODataType = typename TypeConfig::ODataType; - - // accumulation numbers for performance evaluation - std::size_t flop = 0, num_byte = 0; - auto max_seqlen_q = - std::numeric_limits::min(); // we will use max seqlen to decide grid size - auto max_seqlen_k = std::numeric_limits::min(); - { - for(ck_tile::index_t wb = 0; wb < batch; ++wb) - { - const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - - if(max_seqlen_q < real_seqlen_q) - { - max_seqlen_q = real_seqlen_q; - } - - if(max_seqlen_k < real_seqlen_k) - { - max_seqlen_k = real_seqlen_k; - } - - flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + - static_cast(2) * mask.get_unmaskarea() * hdim_v); - - num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + - sizeof(ODataType) * real_seqlen_q * hdim_v); - num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q + - sizeof(VDataType) * hdim_v * real_seqlen_k); - } - } - - const ck_tile::index_t max_num_page_blocks = - (0 < page_block_size - ? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size)) - : 0); - - // legalize num_splits according to other options - if(num_splits < 1) - { - num_splits = override_num_splits_if_necessary( - batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); - } - if(128 < num_splits) - { - std::cerr << "num_splits greater than 128 is not supported" << std::endl; - return fwd_result::invalid_args; - } -#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API - if(0 < p_drop && (1 < num_splits || use_kvcache)) - { - std::cerr << "dropout is not supported by split-kv kernels. ignoring the 'p_drop' option" - << std::endl; - p_drop = 0.0f; - } -#endif - - static const auto get_lengths = [](bool permute, - ck_tile::index_t b /*batch*/, - ck_tile::index_t h /*nhead*/, - ck_tile::index_t s /*seqlen*/, - ck_tile::index_t d /*hdim*/) { - if(permute) - return std::array{b, h, s, d}; - else - return std::array{b, s, h, d}; - }; - - // host memory for storing all the tensor elements - const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen - const ck_tile::index_t shape_seqlen_q_lse = - (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); - // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical - const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch - ? seqlen_qs[0] - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() - : seqstart_q_with_padding_host.back())); - const ck_tile::index_t shape_seqlen_k = - (mode == mode_enum::batch ? seqlen_ks[0] - : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() - : seqstart_k_with_padding_host.back())); - - ck_tile::HostTensor q_host( - get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); - ck_tile::HostTensor k_host( - 0 < page_block_size - ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) - : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); - /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode - ck_tile::HostTensor knew_host( - 0 < seqlen_knew - ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q) - : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); - ck_tile::HostTensor v_host( - 0 < page_block_size - ? (is_v_rowmajor - ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v) - : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) - : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) - : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); - ck_tile::HostTensor vnew_host( - 0 < seqlen_knew - ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) - : get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew)) - : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); - ck_tile::HostTensor bias_host( - bias.type == bias_enum::elementwise_bias - ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) - : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); - - ck_tile::HostTensor alibi_slope_host( - bias.type == bias_enum::alibi - ? (bias.rank_info == 0 ? std::array{1, nhead} - : std::array{batch, nhead}) - : std::array{1, 1}); - - auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin( - std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, next_seed()); - - ck_tile::HostTensor lse_acc_host( - 1 < num_splits || use_kvcache - ? std::array{shape_batch, nhead, num_splits, shape_seqlen_q} - : std::array{1, 1, 1, 1}); - ck_tile::HostTensor o_acc_host( - 1 < num_splits || use_kvcache ? std::array{shape_batch, - nhead, - num_splits, - shape_seqlen_q, - hdim_v} - : std::array{1, 1, 1, 1, 1}); - - // batch mode of lse data layout is [batch, nhead, seqlen_q] - // group mode of lse data layout is [nhead, total_seqlen_q] - ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} - : std::array{1, 1, 1} /* dummy shape for simplifying code */); - - ck_tile::HostTensor o_host( - get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); - - ck_tile::HostTensor randval_host( - p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) - : std::array{1, 1, 1, 1}); - - ck_tile::HostTensor block_table_host( - 0 < page_block_size ? std::array{batch, max_num_page_blocks / batch} - : std::array{1, 1}); - - ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx - ? std::array{batch} - : std::array{1}); - float max_o = 5.0; - if(init_method == "ui" || init_method == "0") - { - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(knew_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(vnew_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}( - bias_host); - } - else if(init_method == "ni") - { - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(k_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(knew_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(vnew_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}( - bias_host); - } - else if(init_method == "uf" || init_method == "1") - { - ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(q_host); - ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(k_host); - ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(knew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(v_host); - ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(vnew_host); - ck_tile::FillUniformDistribution{0.f, 1.f, next_seed()}(bias_host); - } - else if(init_method == "nf") - { - ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(q_host); - ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(k_host); - ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(knew_host); - ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(v_host); - ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(vnew_host); - ck_tile::FillNormalDistribution{0.f, 3.f, next_seed()}(bias_host); - } - else if(init_method == "tf" || init_method == "2") - { - ck_tile::FillTrigValue{}(q_host); - ck_tile::FillTrigValue{}(k_host); - ck_tile::FillTrigValue{}(knew_host); - ck_tile::FillTrigValue{}(v_host); - ck_tile::FillTrigValue{}(vnew_host); - ck_tile::FillTrigValue{}(bias_host); - } - if(bias.type == bias_enum::alibi) - { - auto slopes = ck_tile::get_alibi_slopes(nhead); - assert(slopes.size() == static_cast(nhead)); - if(bias.rank_info == 0) - { - // alibi in 1*h - std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); - } - else - { - // alibi in b*h - for(auto i_b = 0; i_b < batch; i_b++) - { - std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); - } - } - } - iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); - iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); - - ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() - ? 0 - : seqstart_q_with_padding_host.size() * - sizeof(int32_t)); - ck_tile::DeviceMem seqstart_k_padded_buf( - seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 - : cuq_cum.size() * sizeof(ck_tile::index_t)); - ck_tile::DeviceMem cu_seqlen_kv_buf( - cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); - ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || - 0 <= seqlen_kpads[0] - ? seqlen_ks.size() * sizeof(int32_t) - : 0); - ck_tile::DeviceMem cache_seqlen_k_buf( - need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); - ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); - ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); - ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); - - float scale_p = 1.f; - float scale_o = 1.f; - if(squant) - { - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float p_dtype_max = v_dtype_max; // assume p and v is the same type - // Q tensor - { - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); - q_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - if(val > max_value) - max_value = val; - }); - - float scale = q_dtype_max / max_value; - - q_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - self(idx) = ck_tile::type_convert(val * scale); - }); - scale_s = scale_s / scale; - } - - // K tensor - { - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); - k_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - if(val > max_value) - max_value = val; - }); - float scale = k_dtype_max / max_value; - k_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - self(idx) = ck_tile::type_convert(val * scale); - }); - scale_s = scale_s / scale; - } - - // V tensor - { - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); - v_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - if(val > max_value) - max_value = val; - }); - - float scale = k_dtype_max / max_value; - v_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - self(idx) = ck_tile::type_convert(val * scale); - }); - - scale_o = (1.0 / p_dtype_max) / scale; - } - - scale_p = p_dtype_max; - - if constexpr(std::is_same_v) - { - float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - scale_o = scale_o * o_dtype_max / max_o; - } - } - - q_buf.ToDevice(q_host.data()); - k_buf.ToDevice(k_host.data()); - v_buf.ToDevice(v_host.data()); - knew_buf.ToDevice(knew_host.data()); - vnew_buf.ToDevice(vnew_host.data()); - bias_buf.ToDevice(bias_host.data()); - seqstart_q.ToDevice(seqstart_q_host.data()); - // Keep logical starts in seqstart_k; pass padded K via separate pointer - seqstart_k.ToDevice(seqstart_k_host.data()); - seqstart_q_padded_buf.ToDevice( - seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); - seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr - : seqstart_k_with_padding_host.data()); - cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); - cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); - seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] - ? seqlen_ks.data() - : nullptr); - cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); - rotary_cos_buf.ToDevice(rotary_cos_host.data()); - rotary_sin_buf.ToDevice(rotary_sin_host.data()); - drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); - drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); - alibi_slope_buf.ToDevice(alibi_slope_host.data()); - block_table_buf.ToDevice(block_table_host.data()); - cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); - - // clang-format off - auto layout_str = [&](bool permute){ - if(permute) return std::string("bhsd"); - else return std::string("bshd"); - }; - auto io_layout = [&](bool iperm_, bool operm_) { - if(iperm_ == operm_) return layout_str(iperm_); - else return layout_str(iperm_) + std::string("-") + layout_str(operm_); - }; - // clang-format on - - std::cout << "[" << data_type << "|" << mode << "|" << io_layout(i_perm, o_perm) - << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] - << "/" << seqlen_ks[0] - << (seqlen_kpads[0] < 0 ? "" - : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) - << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias - << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant - << ", mask:" << mask << ", v:" << (is_v_rowmajor ? "r" : "c"); -#if CK_TILE_FMHA_FWD_APPENDKV_API - if(0 < rotary_dim) - { - std::cout << ", rotary_dim:" << rotary_dim << "(" - << (is_rotary_interleaved ? "inter" : "half") << ")"; - } -#endif -#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API - if(1 < num_splits) - { - std::cout << ", num_splits:" << num_splits; - } - if(0 < page_block_size) - { - std::cout << ", page_block_size:" << page_block_size; - } - if(use_cache_batch_idx) - { - std::cout << ", cache_batch_idx:" << use_cache_batch_idx; - } -#endif - // Padding / effective length diagnostic logging - auto print_vec = [&](const char* label, const std::vector& v) { - if(v.empty()) - return; - std::cout << ", " << label << ":["; - for(std::size_t i = 0; i < v.size(); ++i) - { - if(i) - std::cout << ","; - std::cout << v[i]; - } - std::cout << "]"; - }; - - if(has_group_padding) - { - bool has_qpad = !seqstart_q_with_padding_host.empty(); - bool has_kpad = (seqlen_kpads[0] >= 0); - if(has_qpad) - { - print_vec("q_logical", seqlen_qs); - print_vec("q_padded", seqlen_qpads); - } - if(has_kpad) - { - print_vec("k_logical", seqlen_ks); - print_vec("k_padded", seqlen_kpads); - } - } - else if(has_batch_efflens) - { - // derive effective lengths from cumulative arrays if present - if(!cuq_cum.empty()) - { - std::vector eff_q(batch); - for(int b_i = 0; b_i < batch; ++b_i) - eff_q[b_i] = static_cast(cuq_cum[b_i + 1] - cuq_cum[b_i]); - print_vec("q_eff", eff_q); - } - if(!cukv_cum.empty()) - { - std::vector eff_kv(batch); - for(int b_i = 0; b_i < batch; ++b_i) - eff_kv[b_i] = static_cast(cukv_cum[b_i + 1] - cukv_cum[b_i]); - print_vec("kv_eff", eff_kv); - } - } - - std::cout << std::flush; - - const auto init_traits = [&](auto& traits) { - traits.hdim_q = hdim_q; - traits.hdim_v = hdim_v; - traits.data_type = data_type; - traits.is_v_rowmajor = is_v_rowmajor; - - if constexpr(std::is_same_v>) - { - traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved - : rope_enum::half_rotated) - : rope_enum::none); - } - else // fmha_fwd_traits or fmha_splitkv_traits - { - traits.is_group_mode = (mode == mode_enum::group); - traits.has_logits_soft_cap = 0.f < logits_soft_cap; - traits.mask_type = mask.type; - traits.bias_type = bias.type; - traits.has_lse = lse; - traits.do_fp8_static_quant = squant; - - if constexpr(std::is_same_v>) - { - traits.has_dropout = (p_drop > 0.0f); - } - else if constexpr(std::is_same_v>) - { - traits.use_pagedkv = (0 < page_block_size); - } - } - }; - - const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) { - /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, - /// seqlen_k] in this example, hence both the 'batch_stride_bias' & - /// 'nhead_stride_bias' are 0. - // setup stride_* arguments - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size) - : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); - }(); - const ck_tile::index_t stride_vnew = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return i_perm ? seqlen_knew : nhead_k * seqlen_knew; - }(); - const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); - const ck_tile::index_t stride_randval = (max_seqlen_k); - const ck_tile::index_t stride_o_acc = (hdim_v); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = - (0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q) - : (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); - const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_v = [&]() { - if(is_v_rowmajor) - return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v) - : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); - else - return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size) - : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); - }(); - const ck_tile::index_t nhead_stride_vnew = [&]() { - if(is_v_rowmajor) - return i_perm ? seqlen_knew * hdim_v : hdim_v; - else - return i_perm ? hdim_v * seqlen_knew : seqlen_knew; - }(); - const ck_tile::index_t nhead_stride_bias = - (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); - const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = - (0 < page_block_size ? (nhead_k * page_block_size * hdim_q) - : (nhead_k * shape_seqlen_k * hdim_q)); - const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); - const ck_tile::index_t batch_stride_v = - (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) - : (nhead_k * hdim_v * shape_seqlen_k)); - const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); - const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); - // setup split_stride_* arguments (only used in split-kv kernel) - const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); - const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); - - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.v_ptr = v_buf.GetDeviceBuffer(); - - args.batch = batch; - args.seqlen_q = shape_seqlen_q; // unused in group mode - args.hdim_q = hdim_q; - args.hdim_v = hdim_v; - args.nhead_q = nhead; - args.nhead_k = nhead_k; - - args.stride_q = stride_q; - args.stride_k = stride_k; - args.stride_v = stride_v; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.nhead_stride_v = nhead_stride_v; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.batch_stride_v = batch_stride_v; - - if constexpr(std::is_same_v>) - { - args.knew_ptr = knew_buf.GetDeviceBuffer(); - args.vnew_ptr = vnew_buf.GetDeviceBuffer(); - args.seqlen_knew = seqlen_knew; - - args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer(); - - args.rotary_cos_ptr = (0 < rotary_dim ? rotary_cos_buf.GetDeviceBuffer() : nullptr); - args.rotary_sin_ptr = (0 < rotary_dim ? rotary_sin_buf.GetDeviceBuffer() : nullptr); - args.rotary_dim = rotary_dim; - args.has_mask = (mask.type != mask_enum::no_mask); - - args.block_table_ptr = - (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); - args.batch_stride_block_table = batch_stride_block_table; - args.page_block_size = page_block_size; - - args.cache_batch_idx = - (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); - - args.stride_knew = stride_knew; - args.stride_vnew = stride_vnew; - args.nhead_stride_knew = nhead_stride_knew; - args.nhead_stride_vnew = nhead_stride_vnew; - args.batch_stride_knew = batch_stride_knew; - args.batch_stride_vnew = batch_stride_vnew; - } - else // fmha_fwd_args or fmha_fwd_splitkv_args - { - args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() - : bias_buf.GetDeviceBuffer(); - args.lse_ptr = lse_buf.GetDeviceBuffer(); - args.o_ptr = o_buf.GetDeviceBuffer(); - - args.seqstart_q_ptr = - (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); - args.seqstart_k_ptr = - (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); - args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] - ? seqlen_k_buf.GetDeviceBuffer() - : nullptr); - - args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) - args.max_seqlen_q = max_seqlen_q; - - args.scale_s = scale_s; - args.scale_p = scale_p; - args.scale_o = scale_o; - - args.logits_soft_cap = logits_soft_cap; - - args.stride_bias = - (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); - args.stride_o = stride_o; - args.nhead_stride_bias = nhead_stride_bias; - args.nhead_stride_lse = nhead_stride_lse; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_bias = batch_stride_bias; - args.batch_stride_lse = batch_stride_lse; - args.batch_stride_o = batch_stride_o; - - args.window_size_left = mask.left; - args.window_size_right = mask.right; - args.mask_type = static_cast(mask.type); - - if constexpr(std::is_same_v>) - { - args.rand_val_ptr = randval_buf.GetDeviceBuffer(); - - args.stride_randval = stride_randval; - args.nhead_stride_randval = nhead_stride_randval; - args.batch_stride_randval = batch_stride_randval; - - args.p_drop = p_drop; - args.s_randval = s_randval; - if(drop_prefs) - { - args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(), - drop_offset_buf.GetDeviceBuffer()); - } - else - { - args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); - } - - // Group-mode: optional physical padded starts for Q/K - if(mode == mode_enum::group) - { - args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() - ? nullptr - : seqstart_q_padded_buf.GetDeviceBuffer()); - args.seqstart_padded_k_ptr = - (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); - } - - // Batch-mode: optional cumulative effective seqlen overrides - if(mode == mode_enum::batch) - { - args.cu_seqlen_q_ptr = cuq_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_q_buf.GetDeviceBuffer()); - args.cu_seqlen_kv_ptr = cukv_cum.empty() - ? nullptr - : reinterpret_cast( - cu_seqlen_kv_buf.GetDeviceBuffer()); - } - } - else if constexpr(std::is_same_v>) - { - args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); - args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); - - args.block_table_ptr = - (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); - args.batch_stride_block_table = batch_stride_block_table; - args.page_block_size = page_block_size; - args.is_gappy = false; // use 'false' for flash-attention integration - - args.cache_batch_idx = - (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); - - args.num_splits = num_splits; - - args.stride_o_acc = stride_o_acc; - args.nhead_stride_lse_acc = nhead_stride_lse_acc; - args.nhead_stride_o_acc = nhead_stride_o_acc; - args.batch_stride_lse_acc = batch_stride_lse_acc; - args.batch_stride_o_acc = batch_stride_o_acc; - args.split_stride_lse_acc = split_stride_lse_acc; - args.split_stride_o_acc = split_stride_o_acc; - } - else if constexpr(std::is_same_v>) - { - args.block_table_ptr = - (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); - args.batch_stride_block_table = batch_stride_block_table; - args.page_block_size = page_block_size; - args.is_gappy = false; // use 'false' for flash-attention integration - - args.cache_batch_idx = - (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); - } - } - }; - - auto run_appendkv = [&](const ck_tile::stream_config& sc) { -#if CK_TILE_FMHA_FWD_APPENDKV_API - if(need_append_kvcache) - { - fmha_fwd_appendkv_traits fwd_appendkv_traits; - init_traits(fwd_appendkv_traits); - - fmha_fwd_appendkv_args fwd_appendkv_args; - init_args(fwd_appendkv_args); - - return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc); - } -#endif - return 0.0f; - }; - const float appendkv_ave_time = run_appendkv(stream_config); - if(appendkv_ave_time < 0.0f) - { - std::cout << ", not supported yet" << std::flush << std::endl; - return fwd_result::no_instance; - } - - auto run_fwd = [&](const ck_tile::stream_config& sc) { -#if CK_TILE_FMHA_FWD_PAGEDKV_API - if(1 == num_splits && use_kvcache) - { - fmha_fwd_pagedkv_traits fmha_pagedkv_traits; - init_traits(fmha_pagedkv_traits); - - fmha_fwd_pagedkv_args fmha_pagedkv_args; - init_args(fmha_pagedkv_args); - - const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc); -#if CK_TILE_FMHA_FWD_SPLITKV_API - // If there is no instance for these args, fallback to fmha_fwd_splitkv - if(ave_time >= 0.0f) - return ave_time; -#else - return ave_time; -#endif - } -#endif // CK_TILE_FMHA_FWD_PAGEDKV_API -#if CK_TILE_FMHA_FWD_SPLITKV_API - if(1 < num_splits || use_kvcache) - { - fmha_fwd_splitkv_traits fmha_splitkv_traits; - init_traits(fmha_splitkv_traits); - - fmha_fwd_splitkv_args fmha_splitkv_args; - init_args(fmha_splitkv_args); - - return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc); - } -#endif // CK_TILE_FMHA_FWD_SPLITKV_API - fmha_fwd_traits fmha_traits; - init_traits(fmha_traits); - - fmha_fwd_args fmha_args; - init_args(fmha_args); - - return fmha_fwd(fmha_traits, fmha_args, sc); - }; - const float fwd_ave_time = run_fwd(stream_config); - if(fwd_ave_time < 0.0f) - { - std::cout << ", not supported yet" << std::flush << std::endl; - return fwd_result::no_instance; - } - - const float ave_time = appendkv_ave_time + fwd_ave_time; - const float tflops = static_cast(flop) / 1.E9 / ave_time; - const float gb_per_sec = num_byte / 1.E6 / ave_time; - if(stream_config.time_kernel_) - { - std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " - << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) - << gb_per_sec << " GB/s" << std::flush; - } - - bool pass = true; - if(do_validation == 0) - { - std::cout << std::flush << std::endl; - } - else if(do_validation == 2) - { - // NOTE: use gpu to do validation - ck_tile::naive_attention_fwd_traits naive_t; - naive_t.q_type = data_type; - naive_t.k_type = data_type; - naive_t.v_type = data_type; - naive_t.o_type = data_type; - naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; - naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; - naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; - naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; - naive_t.variation = 0; // TODO? - naive_t.quant_algo = 0; - - ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes()); - - ck_tile::naive_attention_fwd_args naive_a; - naive_a.q_ptr = q_buf.GetDeviceBuffer(); - naive_a.k_ptr = k_buf.GetDeviceBuffer(); - naive_a.v_ptr = v_buf.GetDeviceBuffer(); - naive_a.o_ptr = o_naive_buf.GetDeviceBuffer(); - naive_a.scale_s = scale_s; - naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer - naive_a.page_table_ptr = - nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn) - naive_a.hdim = hdim_q; - naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different - naive_a.batch_q = batch; - naive_a.batch_kv = batch; - naive_a.batch_ratio_kv = 1; // batch_q / batch_kv - naive_a.seqlen_q = seqlen_qs[0]; - naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field - naive_a.nhead_q = nhead; - naive_a.nhead_kv = nhead_k; - naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv - naive_a.page_size = 0; // if paged, the seqlen-kv for each block - - ck_tile::stream_config naive_s{}; - - naive_attention_fwd(naive_t, naive_a, naive_s); - - auto o_naive_ref = o_naive_buf.ToHost(); - o_buf.FromDevice(o_host.data()); // TODO: ugly - - auto [rtol_, atol_] = get_elimit(init_method); - pass = ck_tile::check_err( - o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_); - std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; - } - else - { -#if CK_TILE_FMHA_FWD_APPENDKV_API - // When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times - // when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. - if(0 < rotary_dim && stream_config.time_kernel_) - { - const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0}; - q_buf.ToDevice(q_host.data()); - run_appendkv(stream_config2); - run_fwd(stream_config2); - } -#endif - o_buf.FromDevice(o_host.data()); - lse_buf.FromDevice(lse_host.data()); - randval_buf.FromDevice(randval_host.data()); - - constexpr bool supports_squant = std::is_same_v || - std::is_same_v || - std::is_same_v; - - auto p_compute_element_func = [&]() { - if constexpr(supports_squant) - return ck_tile::scales{scale_p}; - else - return ck_tile::identity{}; - }(); - - auto oacc_element_func = [&]() { - if constexpr(std::is_same_v && supports_squant) - return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o}); - else if constexpr(supports_squant) - return ck_tile::scales{scale_o}; - else - return ck_tile::identity{}; - }(); - - float p_undrop = 1.0 - p_drop; - uint8_t p_undrop_in_uint8_t = - uint8_t(std::floor(p_undrop * std::numeric_limits::max())); - float rp_undrop = 1.0 / p_undrop; - - for(ck_tile::index_t wb = 0; wb < batch; ++wb) - { - ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - if(mode == mode_enum::batch) - { - if(!cuq_cum.empty()) - { - real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; - } - if(!cukv_cum.empty()) - { - real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; - } - } - - // adjust matrix index according to the mode - const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); - const ck_tile::index_t cache_b_idx = - (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); - const ck_tile::index_t query_offset = - (mode == mode_enum::batch - ? 0 - : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] - : seqstart_q_with_padding_host[wb])); - const ck_tile::index_t key_offset = - (mode == mode_enum::batch - ? 0 - : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] - : seqstart_k_with_padding_host[wb])); - - ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); - ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); - ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); - ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); - - ck_tile::HostTensor s_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); - ck_tile::HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); - ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); - - ck_tile::index_t nr = nhead / nhead_k; - - // clang-format off - // permute - if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); }); - else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); }); - // clang-format on - -#if CK_TILE_FMHA_FWD_APPENDKV_API - // optionally apply RoPE to the q_host_ref - if(0 < rotary_dim) - { - decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); - - auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin( - rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); - - ck_tile::reference_batched_rotary_position_embedding( - q_host_ref, - rotary_cos_slice, - rotary_sin_slice, - is_rotary_interleaved, - q_host_ref_ro, - /*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask); - - q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); - } -#endif -#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API - if(0 < page_block_size) - { - // clang-format off - if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); }); - else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); }); - // clang-format on - } - else -#endif - { - // clang-format off - if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); - else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); - // clang-format on - } - -#if CK_TILE_FMHA_FWD_APPENDKV_API - // copy Knew to the end of K - if(0 < seqlen_knew) - { - ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); - // clang-format off - if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); }); - else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); }); - // clang-format on - - // optionally apply RoPE to the knew_host_ref - auto* real_knew_host_ref = &knew_host_ref; - std::optional knew_host_ref_ro; - if(0 < rotary_dim) - { - knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); - - auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin( - rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); - - ck_tile::reference_batched_rotary_position_embedding(knew_host_ref, - rotary_cos_slice, - rotary_sin_slice, - is_rotary_interleaved, - knew_host_ref_ro.value()); - - real_knew_host_ref = &knew_host_ref_ro.value(); - } - - (*real_knew_host_ref).ForEach([&](auto& self, auto i) { - k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i); - }); - } -#endif -#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API - if(0 < page_block_size) - { - if(is_v_rowmajor) - { - // clang-format off - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); }); - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); }); - // clang-format on - } - else - { - // clang-format off - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); }); - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); }); - // clang-format on - } - } - else -#endif - { - if(is_v_rowmajor) - { - // clang-format off - // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); - // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); - // clang-format on - } - else - { - // clang-format off - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); - // clang-format on - } - } - -#if CK_TILE_FMHA_FWD_APPENDKV_API - // copy Vnew to the end of V - if(0 < seqlen_knew) - { - ck_tile::HostTensor vnew_host_ref({nhead, hdim_v, seqlen_knew}); - if(is_v_rowmajor) - { - // clang-format off - if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); }); - else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); }); - // clang-format on - } - else - { - // clang-format off - if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); }); - else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); - // clang-format on - } - - vnew_host_ref.ForEach([&](auto& self, auto i) { - v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i); - }); - } -#endif - - // reference - ck_tile:: - reference_batched_gemm( - q_host_ref, - k_host_ref, - s_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale_s)); - - if(0.f < logits_soft_cap) - { - ck_tile::reference_unary_elementwise( - s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) { - return ck_tile::type_convert( - logits_soft_cap * - std::tanhf(ck_tile::type_convert(logits / logits_soft_cap))); - }); - } - - if(bias.type == bias_enum::elementwise_bias) - { - // elementwise bias - ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); - // clang-format off - if(i_perm) bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); - else bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); - // clang-format on - - // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, - // real_seqlen_k] - ck_tile::reference_batched_elementwise( - s_host_ref, bias_host_ref, s_host_ref); - } - else if(bias.type == bias_enum::alibi) - { - // alibi construct elementwise bias to verify - auto alibi_host = [&]() { - if(mask.type != mask_enum::no_mask) - { - return ck_tile::make_alibi_from_lr_mask( - 0, - mask.left, - mask.right, - real_seqlen_q, - real_seqlen_k, - static_cast(mask.type)); - } - else - { - return ck_tile::Alibi{ - 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; - } - }(); - - ck_tile::HostTensor alibi_bias_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); - auto i_b_slope = bias.rank_info == 0 ? 0 : wb; - for(auto i_h = 0; i_h < nhead; i_h++) - { - SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); - alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL - ? current_slope - : -current_slope; - for(auto i_r = 0; i_r < real_seqlen_q; i_r++) - { - for(auto i_c = 0; i_c < real_seqlen_k; i_c++) - { - SaccDataType pixel = 0; - alibi_host.update(pixel, i_r, i_c); - alibi_bias_host_ref(i_h, i_r, i_c) = pixel; - } - } - } - // [nhead, real_seqlen_q, real_seqlen_k] - ck_tile::reference_batched_elementwise( - s_host_ref, alibi_bias_host_ref, s_host_ref); - } - - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking( - s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, real_seqlen_q, real_seqlen_k)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - real_seqlen_q, - real_seqlen_k, - mask.type == mask_enum::mask_top_left)); - else - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - real_seqlen_q, - real_seqlen_k, - mask.type == mask_enum::mask_top_left)); - } - const ck_tile::HostTensor masked_s_host_ref = s_host_ref; - if(lse) - { - ck_tile:: - reference_batched_softmax( - s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); - } - else - { - ck_tile:: - reference_batched_softmax( - s_host_ref, p_host_ref, p_compute_element_func); - } - - if(p_drop > 0) - { - ck_tile::HostTensor randval_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); - ck_tile::reference_batched_dropout_randval( - randval_host_ref, wb, drop_seed, drop_offset); - ck_tile::reference_batched_dropout( - p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); - - ck_tile::HostTensor randval_host_result( - {nhead, real_seqlen_q, real_seqlen_k}); - randval_host_result.ForEach([&](auto& self, const auto& idx) { - self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); - }); - masked_s_host_ref.ForEach([&](const auto& self, const auto& idx) { - // Ignore all masked values in validation check - if(std::isinf(self(idx))) - { - randval_host_ref(idx) = 0; - randval_host_result(idx) = 0; - } - }); - bool cur_pass = ck_tile::check_err(randval_host_result, - randval_host_ref, - "DROPOUT RANDVAL Error: Incorrect results!"); - pass &= cur_pass; - if(!cur_pass) - { - break; - } - } - - ck_tile::reference_batched_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - oacc_element_func); - - ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); - // clang-format off - // permute - if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); - else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); - // clang-format on - - auto [rtol, atol] = get_elimit(init_method); - bool cur_pass = ck_tile::check_err(o_host_result, - o_host_ref, - std::string("OUT Error: Incorrect results!"), - rtol, - atol); - pass &= cur_pass; - if(!cur_pass) - { - std::cerr << "OUT mismatch found at batch: " << wb << std::endl - << "\tseqlen_q: " << real_seqlen_q << std::endl - << "\tseqlen_k: " << real_seqlen_k << std::endl - << "\tseqstart_q: " << seqstart_q_host << std::endl - << "\tseqstart_k: " << seqstart_k_host << std::endl; - - break; - } - - if(lse) - { - ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - const ck_tile::index_t query_offset_lse = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); - }); - - cur_pass = ck_tile::check_err(lse_host_result, - lse_host_ref, - "LSE Error: Incorrect results!", - rtol, - atol, - /* allow_infinity_ref = */ true); - - pass &= cur_pass; - if(!cur_pass) - { - std::cerr << "LSE mismatch found at batch: " << wb << std::endl - << "\tseqlen_q: " << real_seqlen_q << std::endl - << "\tseqlen_k: " << real_seqlen_k << std::endl - << "\tseqstart_q: " << seqstart_q_host << std::endl - << "\tseqstart_k: " << seqstart_k_host << std::endl; - - break; - } - } - } - - std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; - } - - if(json) - { - dump_fmha_fwd_json_results(*json, - data_type, - mode == mode_enum::batch ? "batch" : "group", - io_layout(i_perm, o_perm), - batch, - nhead, - nhead_k, - seqlen_qs[0], - seqlen_ks[0], - seqlen_kpads[0], - hdim_q, - hdim_v, - scale_s, - p_drop, - lse, - squant, - bias.type == bias_enum::elementwise_bias - ? "elementwise_bias" - : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - is_v_rowmajor ? "r" : "c", - pass, - ave_time, - tflops, - gb_per_sec); - } - - return pass ? fwd_result::success : fwd_result::failure; -} diff --git a/example/ck_tile/01_unified_attention/utils.hpp b/example/ck_tile/01_unified_attention/utils.hpp deleted file mode 100644 index 7f44d871804..00000000000 --- a/example/ck_tile/01_unified_attention/utils.hpp +++ /dev/null @@ -1,244 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core/container/span.hpp" - -enum class mode_enum -{ - batch = 0, - group -}; - -std::ostream& operator<<(std::ostream& stream, mode_enum mode) -{ - return stream << (mode == mode_enum::batch ? "batch" : "group"); -} - -template -std::ostream& operator<<(std::ostream& os, const std::vector& v) -{ - using size_type = typename std::vector::size_type; - - os << "["; - for(size_type idx = 0; idx < v.size(); ++idx) - { - if(0 < idx) - { - os << ", "; - } - os << v[idx]; - } - return os << "]"; -} - -std::vector to_seqstarts(ck_tile::span seqlens) -{ - std::vector seqstarts = {0}; - for(int32_t seqlen : seqlens) - { - seqstarts.push_back(seqstarts.back() + seqlen); - } - assert(seqstarts.size() == seqlens.size() + 1); - return seqstarts; -} - -template -std::vector generate_seqlens(mode_enum mode, - unsigned count, - int32_t seqlen_avg, - int32_t seqlen_min, // if not negative, clamp min - int32_t seqlen_max, // if not negative, clamp max - RandomEngine& random_engine) -{ - assert(0 < count); - - seqlen_min = (0 < seqlen_min ? seqlen_min : 1); - seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits::max()); - assert(seqlen_min <= seqlen_max); - - std::vector seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max)); - - if(mode == mode_enum::group && 1 < count) - { - using size_type = std::vector::size_type; - - std::uniform_int_distribution idx_dist(0, count - 1); - auto next_idx = std::bind(idx_dist, std::ref(random_engine)); - - std::uniform_int_distribution step_dist(1, count - 1); - auto next_step = std::bind(step_dist, std::ref(random_engine)); - - for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) - { - const size_type to_decrease = next_idx(); - // make sure each elements of seqlens is in range [seqlen_min, seqlen_max] - if(seqlens[to_decrease] == seqlen_min) - { - continue; - } - - const size_type to_increase = (to_decrease + next_step()) % count; - - if(seqlens[to_increase] >= seqlen_max) - { - continue; - } - - --seqlens[to_decrease]; - ++seqlens[to_increase]; - } - } - - return seqlens; -} - -// return random integer generated uniformly in range [low, high] -template -auto randint(Int low, - Int high, - RandomEngine& random_engine) -> std::enable_if_t, Int> -{ - std::uniform_int_distribution dist(low, high); - return dist(random_engine); -} - -// return random integers generated uniformly in range [low, high] -template -auto randints(ForwardIterator first, - ForwardIterator last, - Int low, - Int high, - RandomEngine& random_engine) -> std::enable_if_t> -{ - std::uniform_int_distribution dist(low, high); - - std::generate(first, last, [&] { return dist(random_engine); }); -} - -/* - * generate missing values in *_val randomly when the number of values is smaller than batch - * example (assume batch=3) - * q_val=1,2,3 k_val=4,5,6 -> OK - * q_val=1,2,3 -> OK, k same as q - * q_val=1,2 -> OK, q will rand remaining 1 element, k same as q - * q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element - * q_val=1,2,3,4 -> OK, but ignore exceed one - * - * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q - * q_val=1,2 k_val=4 -> not OK, k must have same splits with q - */ -template -std::tuple, - std::vector, - std::vector> -generate_missing_seqlens(mode_enum mode, - ck_tile::index_t batch, - const std::vector& q_val, - const std::vector& k_val, - const std::vector& k_pad_val, - ck_tile::index_t seqlen_k_min, - bool need_append_kvcache, - RandomEngine& random_engine) -{ - if(mode == mode_enum::batch) - { - ck_tile::index_t q = q_val[0]; - ck_tile::index_t k = k_val[0]; - - auto s_q = std::vector(batch, q); - auto s_k = [&] { - const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); - std::vector seqlen_ks(batch, seqlen_k_max); - - if(1 < batch && need_append_kvcache) - { - // to keep the original s_k value, we always use seqlen_k_max in first batch - randints(std::next(seqlen_ks.begin()), - seqlen_ks.end(), - seqlen_k_min, - seqlen_k_max, - random_engine); - return seqlen_ks; - } - - return seqlen_ks; - }(); - auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding - - // s_k should be greater than or equal to seqlen_k_min if provided - if(s_k.back() < seqlen_k_min) - { - std::ostringstream msg; - msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() - << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; - throw std::runtime_error(msg.str()); - } - - return std::make_tuple(s_q, s_k, s_kpad); - } - else - { - std::vector s_q; - std::vector s_k; - std::vector s_kpad; - ck_tile::index_t idx = 0; - for(; idx < std::min(static_cast(q_val.size()), batch); ++idx) - { - ck_tile::index_t q = q_val[idx]; - ck_tile::index_t k = - k_val[std::min(idx, static_cast(k_val.size()) - 1)]; - ck_tile::index_t kp = - k_pad_val.empty() - ? -1 - : k_pad_val[std::min(idx, static_cast(k_pad_val.size()) - 1)]; - - s_q.push_back(q); - s_k.push_back(k < 0 ? q : k); - s_kpad.push_back(kp); - - // s_k should be greater than or equal to seqlen_k_min - if(s_k.back() < seqlen_k_min) - { - std::ostringstream msg; - msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() - << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; - throw std::runtime_error(msg.str()); - } - } - if(idx < batch) - { - auto rem_q = - generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine); - auto rem_k = generate_seqlens( - mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine); - - s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); - s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); - s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); - } - return std::make_tuple(s_q, s_k, s_kpad); - } -} - -template -std::enable_if_t> iota_shuffle(RandomAccessIterator first, - RandomAccessIterator last, - Int value, - RandomEngine& random_engine) -{ - std::iota(first, last, value); - std::shuffle(first, last, random_engine); -} diff --git a/example/ck_tile/01_unified_attention/CMakeLists.txt b/example/ck_tile/42_unified_attention/CMakeLists.txt similarity index 100% rename from example/ck_tile/01_unified_attention/CMakeLists.txt rename to example/ck_tile/42_unified_attention/CMakeLists.txt diff --git a/example/ck_tile/01_unified_attention/README.md b/example/ck_tile/42_unified_attention/README.md similarity index 85% rename from example/ck_tile/01_unified_attention/README.md rename to example/ck_tile/42_unified_attention/README.md index 2b872cb9b51..d58eff09d02 100644 --- a/example/ck_tile/01_unified_attention/README.md +++ b/example/ck_tile/42_unified_attention/README.md @@ -1,6 +1,6 @@ # fused multi-head attention -This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. +This folder contains example for unified attention (fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. ## build ``` @@ -8,12 +8,12 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank ../script/cmake-ck-dev.sh ../ -make tile_example_fmha_fwd -j +make tile_example_unified_attention -j ``` -This will result in an executable `build/bin/tile_example_fmha_fwd` +This will result in an executable `build/bin/tile_example_unified_attention` ## kernel -The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. +The kernel template is `unified_attention.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. There are 2 template parameters for this kernel template. * `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). @@ -23,7 +23,7 @@ There are 2 template parameters for this kernel template. To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. ## executable -`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all the arguments. Below is an example of the output (may subject to change) +`tile_example_unified_attention` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_unified_attention -?` to list all the arguments. Below is an example of the output (may subject to change) ``` args: -v weather do CPU validation or not (default:1) @@ -88,14 +88,14 @@ args: -kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"") Comma-separated list of length 'b'. If empty, no override ``` -Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. -Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with +Example 1: `./bin/tile_example_unified_attention -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 2: `./bin/tile_example_unified_attention -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case ## Padding Examples -Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively. +Example 3 (Group mode with padding): `./bin/tile_example_unified_attention -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively. -Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively. +Example 4 (Batch mode with effective lengths): `./bin/tile_example_unified_attention -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively. ## support features Currently we are still in rapid development stage, so more features/optimizations will be coming soon. @@ -154,6 +154,6 @@ We support sequence padding and variable-length processing in both batch and gro Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. ## FP8 experimental support -As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_unified_attention`, on a gfx942 machine and ROCm 6.0+. Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/42_unified_attention/example_unified_attention.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/example_unified_attention.cpp rename to example/ck_tile/42_unified_attention/example_unified_attention.cpp diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_mask.cpp rename to example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_mask.cpp diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp rename to example/ck_tile/42_unified_attention/instances/unified_attention_d128_bf16_nmask.cpp diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_mask.cpp rename to example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_mask.cpp diff --git a/example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp b/example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp rename to example/ck_tile/42_unified_attention/instances/unified_attention_d128_fp16_nmask.cpp diff --git a/example/ck_tile/01_unified_attention/mask.hpp b/example/ck_tile/42_unified_attention/mask.hpp similarity index 100% rename from example/ck_tile/01_unified_attention/mask.hpp rename to example/ck_tile/42_unified_attention/mask.hpp diff --git a/example/ck_tile/01_unified_attention/misc/gamc.png b/example/ck_tile/42_unified_attention/misc/gamc.png similarity index 100% rename from example/ck_tile/01_unified_attention/misc/gamc.png rename to example/ck_tile/42_unified_attention/misc/gamc.png diff --git a/example/ck_tile/01_unified_attention/rotary.hpp b/example/ck_tile/42_unified_attention/rotary.hpp similarity index 100% rename from example/ck_tile/01_unified_attention/rotary.hpp rename to example/ck_tile/42_unified_attention/rotary.hpp diff --git a/example/ck_tile/01_unified_attention/script/benchmark_fwd.sh b/example/ck_tile/42_unified_attention/script/benchmark_fwd.sh similarity index 96% rename from example/ck_tile/01_unified_attention/script/benchmark_fwd.sh rename to example/ck_tile/42_unified_attention/script/benchmark_fwd.sh index 31ad8000392..3a3b9389002 100755 --- a/example/ck_tile/01_unified_attention/script/benchmark_fwd.sh +++ b/example/ck_tile/42_unified_attention/script/benchmark_fwd.sh @@ -1,6 +1,6 @@ #!/bin/sh # TODO: run this script from CK root or build directory -EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +EXE="$(find . -name tile_example_unified_attention -type f | head -n 1)" VALID=0 for prec in "fp16" "bf16" ; do diff --git a/example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx90a.txt similarity index 100% rename from example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx90a.txt rename to example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx90a.txt diff --git a/example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx942.txt similarity index 100% rename from example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx942.txt rename to example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx942.txt diff --git a/example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx950.txt similarity index 100% rename from example/ck_tile/01_unified_attention/script/fmha_bwd_known_fails_gfx950.txt rename to example/ck_tile/42_unified_attention/script/fmha_bwd_known_fails_gfx950.txt diff --git a/example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx90a.txt b/example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx90a.txt similarity index 100% rename from example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx90a.txt rename to example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx90a.txt diff --git a/example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx942.txt b/example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx942.txt similarity index 100% rename from example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx942.txt rename to example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx942.txt diff --git a/example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx950.txt similarity index 100% rename from example/ck_tile/01_unified_attention/script/fmha_fwd_known_fails_gfx950.txt rename to example/ck_tile/42_unified_attention/script/fmha_fwd_known_fails_gfx950.txt diff --git a/example/ck_tile/01_unified_attention/script/run_full_test.sh b/example/ck_tile/42_unified_attention/script/run_full_test.sh similarity index 100% rename from example/ck_tile/01_unified_attention/script/run_full_test.sh rename to example/ck_tile/42_unified_attention/script/run_full_test.sh diff --git a/example/ck_tile/01_unified_attention/script/smoke_test_bwd.sh b/example/ck_tile/42_unified_attention/script/smoke_test_bwd.sh similarity index 100% rename from example/ck_tile/01_unified_attention/script/smoke_test_bwd.sh rename to example/ck_tile/42_unified_attention/script/smoke_test_bwd.sh diff --git a/example/ck_tile/01_unified_attention/script/smoke_test_fwd.sh b/example/ck_tile/42_unified_attention/script/smoke_test_fwd.sh similarity index 100% rename from example/ck_tile/01_unified_attention/script/smoke_test_fwd.sh rename to example/ck_tile/42_unified_attention/script/smoke_test_fwd.sh diff --git a/example/ck_tile/01_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp similarity index 100% rename from example/ck_tile/01_unified_attention/unified_attention.cpp rename to example/ck_tile/42_unified_attention/unified_attention.cpp diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/42_unified_attention/unified_attention.hpp similarity index 100% rename from example/ck_tile/01_unified_attention/unified_attention.hpp rename to example/ck_tile/42_unified_attention/unified_attention.hpp diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp similarity index 100% rename from example/ck_tile/01_unified_attention/unified_attention_impl.hpp rename to example/ck_tile/42_unified_attention/unified_attention_impl.hpp diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 1cc44d3bc67..ca3fe67867e 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -4,7 +4,6 @@ include_directories(AFTER ${CMAKE_CURRENT_LIST_DIR} ) -add_subdirectory(01_unified_attention) add_subdirectory(01_fmha) add_subdirectory(02_layernorm2d) add_subdirectory(03_gemm) @@ -30,4 +29,4 @@ add_subdirectory(36_pooling) add_subdirectory(38_block_scale_gemm) add_subdirectory(40_streamk_gemm) add_subdirectory(41_batched_contraction) - +add_subdirectory(42_unified_attention) diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index f70819d9282..74693460ec1 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" #define ENABLE_ASM_MARKER 1 #if ENABLE_ASM_MARKER @@ -34,217 +35,6 @@ namespace ck_tile { -template -struct CoreLoopScheduler; - -template -struct CoreLoopScheduler -{ - template - CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, - ck_tile::number) - { - using namespace ck_tile; - - if constexpr(WaveGroup == 0) - { - if constexpr(Phase == 0) - { - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - }); - } - else if constexpr(Phase == 1) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 2) - { -#if !CK_TILE_DISABLE_PACKED_FP32 - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU -#endif - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU - }); - } - else if constexpr(Phase == 3) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - } - else - { - if constexpr(Phase == 0) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 1) - { - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - }); - } - else if constexpr(Phase == 2) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 3) - { -#if !CK_TILE_DISABLE_PACKED_FP32 - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU -#endif - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU - }); - } - } - } -}; - -template -struct CoreLoopScheduler -{ - template - CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, - ck_tile::number) - { - using namespace ck_tile; - - if constexpr(WaveGroup == 0) - { - if constexpr(Phase == 0) - { - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - }); - } - else if constexpr(Phase == 1) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 2) - { -#if !CK_TILE_DISABLE_PACKED_FP32 - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU -#endif - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU - }); - } - else if constexpr(Phase == 3) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - } - else - { - if constexpr(Phase == 0) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 1) - { - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - }); - } - else if constexpr(Phase == 2) - { - __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU - __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU - } - else if constexpr(Phase == 3) - { -#if !CK_TILE_DISABLE_PACKED_FP32 - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU -#endif - static_for<0, 8, 1>{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU - }); - } - } - } -}; - -namespace detail { -CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c) -{ -#if CK_TILE_DISABLE_PACKED_FP32 - return a * b + c; -#else - float result; - asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]" - : [result] "=v"(result) - : [a] "v"(a), [b] "s"(b), [c] "v"(c)); - return result; -#endif -} - -CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs) -{ - float result; - asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]" - : [result] "=v"(result) - : [lhs] "v"(lhs), [rhs] "v"(rhs)); - return result; -} - -CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs) -{ - float result; - asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]" - : [result] "=v"(result) - : [lhs] "v"(lhs), [rhs] "v"(rhs)); - return result; -} - -CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b) -{ - fp16x2_t result; - asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]" - : [result] "=v"(result) - : [a] "v"(a), [b] "v"(b)); - return result; -} - -CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b) -{ - bf16x2_t result; - asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]" - : [result] "=v"(result) - : [a] "v"(a), [b] "v"(b)); - return result; -} - -CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) -{ - fp32x2_t result; - asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]" - : [result] "=v"(result) - : [lhs] "v"(lhs), [rhs] "v"(rhs)); - return result; -} -} // namespace detail - template struct UnifiedAttentionPipeline { @@ -377,23 +167,24 @@ struct UnifiedAttentionPipeline typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction> - CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - [[maybe_unused]] const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - [[maybe_unused]] const VElementFunction& v_element_func, - const index_t num_blocks, - const index_t num_blocks_start, - const void* block_tables_ptr, - index_t block_table_offset, - const index_t kv_page_size_in_blocks, - [[maybe_unused]] const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, - FmhaMask mask, - float scale_s, - void* smem_ptr) const + CK_TILE_DEVICE auto operator()( + const QDramBlockWindowTmp& q_dram_block_window_tmp, // kBlockM * kHeadDimPadded tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // kPageBlockSize * kHeadDimPadded tile + [[maybe_unused]] const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // kHeadDimPadded * kPageBlockSize tile + [[maybe_unused]] const VElementFunction& v_element_func, + const index_t num_blocks, + const index_t num_blocks_start, + const void* block_tables_ptr, + index_t block_table_offset, + const index_t kv_page_size_in_blocks, + [[maybe_unused]] const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + float scale_s, + void* smem_ptr) const { using namespace ck_tile; static_assert( @@ -1224,17 +1015,18 @@ struct UnifiedAttentionPipeline template - CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const index_t num_blocks, - const index_t num_blocks_start, - const void* block_tables_ptr, - index_t block_table_offset, - const index_t kv_page_size_in_blocks, - FmhaMask mask, - float scale_s, - void* smem_ptr) const + CK_TILE_DEVICE auto operator()( + const QDramBlockWindowTmp& q_dram_block_window_tmp, // kBlockM * kHeadDimPadded tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // kPageBlockSize * kHeadDimPadded tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // kHeadDimPadded * kPageBlockSize tile + const index_t num_blocks, + const index_t num_blocks_start, + const void* block_tables_ptr, + index_t block_table_offset, + const index_t kv_page_size_in_blocks, + FmhaMask mask, + float scale_s, + void* smem_ptr) const { using namespace ck_tile; From 6a62216c2485ea79b4ad470afa8965b99c2849a1 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 2 Jan 2026 14:20:56 +0200 Subject: [PATCH 84/88] Update example/ck_tile/42_unified_attention/README.md Co-authored-by: spolifroni-amd --- example/ck_tile/42_unified_attention/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/42_unified_attention/README.md b/example/ck_tile/42_unified_attention/README.md index d58eff09d02..03674f4d8e6 100644 --- a/example/ck_tile/42_unified_attention/README.md +++ b/example/ck_tile/42_unified_attention/README.md @@ -1,6 +1,6 @@ # fused multi-head attention -This folder contains example for unified attention (fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. +This folder contains examples for unified attention (fused multi-head attention) using the ck_tile tile-programming implementation. The examples demonstrate the usage of the tile-programming API, as well as the new approach to constructing kernel templates and instantiating them. ## build ``` From 788890ac837c77295702c0362db23012dc06bf92 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 2 Jan 2026 14:23:15 +0200 Subject: [PATCH 85/88] Update example/ck_tile/42_unified_attention/README.md Co-authored-by: spolifroni-amd --- example/ck_tile/42_unified_attention/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/42_unified_attention/README.md b/example/ck_tile/42_unified_attention/README.md index 03674f4d8e6..27cfdea23b4 100644 --- a/example/ck_tile/42_unified_attention/README.md +++ b/example/ck_tile/42_unified_attention/README.md @@ -17,7 +17,7 @@ The kernel template is `unified_attention.hpp`, this is the grid-wise op in old There are 2 template parameters for this kernel template. * `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). -* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. +* `EpiloguePipeline` is the last stage of the pipeline. It modifies and stores the result. Post-fusion can be done at this stage though the example only returns the result. ## codegen To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. From efb7ada2a66099c384dd6080041be7ddfbe84ed6 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 2 Jan 2026 14:28:16 +0200 Subject: [PATCH 86/88] Update example/ck_tile/42_unified_attention/README.md --- example/ck_tile/42_unified_attention/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/42_unified_attention/README.md b/example/ck_tile/42_unified_attention/README.md index 27cfdea23b4..438ae1f9cea 100644 --- a/example/ck_tile/42_unified_attention/README.md +++ b/example/ck_tile/42_unified_attention/README.md @@ -3,7 +3,7 @@ This folder contains examples for unified attention (fused multi-head attention) using the ck_tile tile-programming implementation. The examples demonstrate the usage of the tile-programming API, as well as the new approach to constructing kernel templates and instantiating them. ## build -``` + # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank From a1973b472f50cce4483b97bea9d7dce2cbf70a21 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 2 Jan 2026 14:29:05 +0200 Subject: [PATCH 87/88] Update example/ck_tile/42_unified_attention/README.md --- example/ck_tile/42_unified_attention/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/example/ck_tile/42_unified_attention/README.md b/example/ck_tile/42_unified_attention/README.md index 438ae1f9cea..aa8a9968898 100644 --- a/example/ck_tile/42_unified_attention/README.md +++ b/example/ck_tile/42_unified_attention/README.md @@ -16,6 +16,7 @@ This will result in an executable `build/bin/tile_example_unified_attention` The kernel template is `unified_attention.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. There are 2 template parameters for this kernel template. + * `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). * `EpiloguePipeline` is the last stage of the pipeline. It modifies and stores the result. Post-fusion can be done at this stage though the example only returns the result. From 74297016ea557763aa4f6cb07737f0d644fa6088 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 2 Jan 2026 14:29:33 +0200 Subject: [PATCH 88/88] Update example/ck_tile/42_unified_attention/README.md --- example/ck_tile/42_unified_attention/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/example/ck_tile/42_unified_attention/README.md b/example/ck_tile/42_unified_attention/README.md index aa8a9968898..bcc10f7e657 100644 --- a/example/ck_tile/42_unified_attention/README.md +++ b/example/ck_tile/42_unified_attention/README.md @@ -13,6 +13,7 @@ make tile_example_unified_attention -j This will result in an executable `build/bin/tile_example_unified_attention` ## kernel + The kernel template is `unified_attention.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. There are 2 template parameters for this kernel template.