From 53eab181e5348754c33057bd33182c9bf425fbc4 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Wed, 1 Nov 2023 07:26:47 -0400 Subject: [PATCH 1/2] unify q persistent in register --- example/91_tile_program/fmha_fwd.cpp | 16 +- example/91_tile_program/fmha_fwd_kernel.hpp | 16 +- .../block_gemm_areg_bgmem_creg_v1.hpp | 14 +- .../block_gemm_areg_bsmem_creg_v1.hpp | 4 +- ..._gemm_areg_bsmem_creg_v1_custom_policy.hpp | 49 +++ .../block_gemm_asmem_bsmem_creg_v1.hpp | 4 +- ...gemm_asmem_bsmem_creg_v1_custom_policy.hpp | 58 +++ .../block_fmha_pipeline_qkvs.hpp | 3 +- .../block_fmha_pipeline_qr_ks_vs.hpp | 347 ++++++++++++++++++ ..._fmha_pipeline_qr_ks_vs_default_policy.hpp | 280 ++++++++++++++ .../ck/tile_program/tile/tile_fmha_shape.hpp | 30 +- .../ck/tile_program/warp_tile/warp_gemm.hpp | 37 ++ .../warp_tile/warp_gemm_attribute_mfma.hpp | 84 +++++ 13 files changed, 911 insertions(+), 31 deletions(-) create mode 100644 include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp create mode 100644 include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp create mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp create mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index 96aa97f99..3345ad01b 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -1,4 +1,5 @@ #include +#include #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -15,6 +16,8 @@ #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" #include "ck/tile_program/tile/tile_fmha_shape.hpp" @@ -33,10 +36,14 @@ using PDataType = ck::half_t; // data type for A matrix of second gemm using OaccDataType = float; // data type for second gemm accumulation using ODataType = ck::half_t; -// M0 N0 K0 N1 K1 +// M0 N0 K0 N1 K1 K0L // using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>; // using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>; -using FmhaShape = ck::tile_program::TileFmhaShape<128, 128, 32, 128, 32>; +using FmhaBlockTile = ck::Sequence<128, 128, 32, 128, 32, 128>; +using FmhaBlockWarps = ck::Sequence<4, 1, 1>; +using FmhaWarpTile = ck::Sequence<32, 32, 16>; +using FmhaShape = ck::tile_program:: + TileFmhaShape; using FmhaTilePartitioner = FmhaFwdTilePartitioner; using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem; -using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS; +// using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS; +using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVS; using FmhaEpilogue = FmhaFwdEpilogue>; using FmhaKernel = FmhaFwdKernel; @@ -134,7 +142,7 @@ int main(int argc, char* argv[]) << ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v << ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm << ", grid_size " << kGridSize.x << "x" << kGridSize.y << "x" << kGridSize.z - << std::endl; + << std::flush << std::endl; constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize; diff --git a/example/91_tile_program/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha_fwd_kernel.hpp index 58282d4b2..018248d9b 100644 --- a/example/91_tile_program/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha_fwd_kernel.hpp @@ -11,7 +11,7 @@ // P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] -#define C_LOG2E 1.44269504088896340736 // log2(e) +#define C_LOG2E 1.44269504088896340736 // log2(e) template struct FmhaFwdKernel @@ -148,10 +148,16 @@ struct FmhaFwdKernel Number<32>{}, Number<1>{}); - auto q_dram_window = - make_tile_window(q_dram, - make_tuple(Number{}, Number{}), - {i_m0, 0}); + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(Number{}, + Number{}); + else + return make_tuple(Number{}, Number{}); + }(), + {i_m0, 0}); auto k_dram_window = make_tile_window( k_dram, make_tuple(Number{}, Number{}), {0, 0}); diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp index 67ec722dc..926a80fae 100644 --- a/include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp @@ -26,9 +26,11 @@ namespace block { // This will: // 1. Load B from global memory into shared memory and then // 2. Call BlockGemmARegSGmemCRegV1 -template +template struct BlockGemmARegBGmemCRegV1 { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -37,13 +39,9 @@ struct BlockGemmARegBGmemCRegV1 static constexpr index_t kBlockSize = Problem::kBlockSize; // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation - using BlockGemmARegBSmemCRegImpl = - BlockGemmARegBSmemCRegV1, - BlockGemmARegBSmemCRegV1DefaultPolicy>; + using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< + BlockGemmARegBSmemCRegProblem, + BlockGemmARegBSmemCRegV1DefaultPolicy>; __host__ __device__ static constexpr ck::index_t GetStaticLdsSize() { diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp index 5778c2111..548906098 100644 --- a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp @@ -23,9 +23,11 @@ namespace block { // A is block distributed tensor // B is block window on shared memory // C is block distributed tensor -template +template struct BlockGemmARegBSmemCRegV1 { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp new file mode 100644 index 000000000..1f8dce998 --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +template +struct BlockGemmARegBSmemCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::At(Number<0>{}); + static constexpr index_t kNWarps = BlockWarps::At(Number<1>{}); + static constexpr index_t kKWarps = BlockWarps::At(Number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() + { + using namespace ck::tile_program::warp; + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp index 166f887db..ff3c44db7 100644 --- a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp @@ -24,9 +24,11 @@ namespace block { // A is block window on shared memory // B is block window on shared memory // C is block distributed tensor -template +template struct BlockGemmASmemBSmemCRegV1 { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; diff --git a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp new file mode 100644 index 000000000..42eec8e5c --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// Default policy for BlockGemmASmemBSmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +template +struct BlockGemmASmemBSmemCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + static constexpr index_t BlockMWarps = BlockWarps::At(Number<0>{}); + static constexpr index_t BlockNWarps = BlockWarps::At(Number<1>{}); + static constexpr index_t BlockKWarps = BlockWarps::At(Number<2>{}); + + static constexpr index_t MPerWarp = WarpTile::At(Number<0>{}); + static constexpr index_t NPerWarp = WarpTile::At(Number<1>{}); + static constexpr index_t KPerWarp = WarpTile::At(Number<2>{}); + + static constexpr bool TranposeC = TranposeC_; + + using WarpGemm = WarpGemmMfma; + + template + __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() + { + using namespace ck::tile_program::warp; + return make_tuple(WarpGemm{}, BlockMWarps, BlockNWarps); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp index ca54eea74..934c5c90f 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp @@ -35,7 +35,8 @@ struct BlockFmhaPipelineQKVS using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + static constexpr bool kQLoadOnce = false; // if q load whole block length (hdim) at once static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp new file mode 100644 index 000000000..0771af019 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -0,0 +1,347 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/load_tile.hpp" +#include "ck/tile_program/tile/store_tile.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/tile/slice_tile.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck/tile_program/block_tile/block_reduce.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQRKSVS +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + float scale, + index_t num_total_loop, + index_t /*num_sub_loop_qk*/, // in this pipeline, the 1st gemm loop must be static + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = + make_tile_window(v_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.GetBottomTensorView(), + q_dram_block_window_tmp.GetWindowLengths(), + q_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeQDramTileDistribution()); + + auto q = load_tile(q_dram_window); // persistent q register tile + + auto s_acc = decltype(gemm_0(get_slice_tile(tile_elementwise_in(q_element_func, q), + Sequence<0, 0>{}, + Sequence{}), + k_lds_window)){}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = + decltype(tile_elementwise_in(type_convert, s_acc)); + + using PBlockTileType = + decltype(tile_elementwise_in(type_convert, s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1( + get_slice_tile(PBlockTileType{}, Sequence<0, 0>{}, Sequence{}), + v_lds_window)); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); + tile_elementwise_inout([](auto& e) { e = NumericLimits::Lowest(); }, + m); + tile_elementwise_inout([](auto& e) { e = 0; }, l); + + auto k_dram_block_window = k_dram_block_window_tmp; + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), + v_dram_block_window_tmp.GetWindowLengths(), + v_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeVDramTileDistribution()); + + auto q_tile = tile_elementwise_in(q_element_func, q); + index_t i_total_loops = 0; + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + + tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); // Initialize C + + store_tile(k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write 0 + k_block_tile = load_tile(k_dram_window); // global read 1 + } + + // index_t i_k0_loops = num_sub_loop_qk - 2; + constexpr index_t k0_loops = kK0BlockLength / kK0; + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, i_k0 * kK0>{}, + Sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, (k0_loops - 2) * kK0>{}, + Sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, (k0_loops - 1) * kK0>{}, + Sequence{}), + k_lds_window); + } + + // STAGE 2, scale softmax + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + + const auto s = + tile_elementwise_in(type_convert, s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + Sequence<1>{}, + f_max, + NumericLimits::Lowest()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.GetTileDistribution()); // Pcompute{j} + + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); + sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // STAGE 3, KV gemm + constexpr index_t k1_loops = kN0 / kK1; + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, Sequence<0, i_k1 * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + + i_total_loops++; + } while(i_total_loops < num_total_loop); + + // finally, O + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = 1 / l[i_idx]; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + float scale, + index_t num_total_loop, + index_t num_sub_loop_qk, + void* smem_ptr) const + { + return operator()( + q_dram_block_window_tmp, + [](const QDataType& x) { return x; }, + k_dram_block_window_tmp, + [](const KDataType& x) { return x; }, + v_dram_block_window_tmp, + [](const VDataType& x) { return x; }, + scale, + num_total_loop, + num_sub_loop_qk, + smem_ptr); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp new file mode 100644 index 000000000..84b475587 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +struct BlockFmhaPipelineQRKSVSDefaultPolicy +{ + template + __host__ __device__ static constexpr auto MakeQRegBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto q_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + + // 3d + padding + template + __host__ __device__ static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number<8>{}), + make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return k_lds_block_desc; + } + + // 3d + padding + template + __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kPad = 1; + constexpr index_t kK1 = 8; + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number{}), + make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number{}, Number<1>{}), + Number{}, + Number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return v_lds_block_desc; + } + + template + __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() + { + return 0; + } + + template + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + constexpr index_t smem_size_gemm_0 = + GetSmemSizeQ() + sizeof(typename Problem::KDataType) * + MakeKLdsBlockDescriptor().GetElementSpaceSize(); + constexpr index_t smem_size_gemm_1 = + MakeVLdsBlockDescriptor().GetElementSpaceSize() * + sizeof(typename Problem::VDataType); + + // TODO: consider shuffle requirement + return math::max(smem_size_gemm_0, smem_size_gemm_1); + } + + template + __host__ __device__ static constexpr auto MakeQDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template At<1>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + + constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K0 = kKPerBlock / (K1 * K2); + + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1>>, + Tuple, Sequence<1, 2>>, + Sequence<2, 1, 2>, + Sequence<0, 0, 2>>{}); + } + + template + __host__ __device__ static constexpr auto MakeKDramTileDistribution() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t K1 = 16 / sizeof(KDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; +#if 1 // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); +#else // coalesce reading for each warps + constexpr index_t N0 = kBlockSize / get_warp_size(); + constexpr index_t N1 = kNPerBlock / (N2 * N0); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<1, 1>>{}); +#endif + } + + template + __device__ static constexpr auto MakeVDramTileDistribution() + { + using VDataType = remove_cvref_t; + ; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t K1 = 16 / sizeof(VDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + + template + __host__ __device__ static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + // using WarpGemm = ck::tile_program::warp::WarpGemmMfma{}), + // Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<1>{}), + // Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<2>{}), true>; + + using WarpGemm = + warp::WarpGemmImpl>; + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + + return BlockGemmARegBSmemCRegV1{}; + } + + template + __host__ __device__ static constexpr auto GetKVBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + // using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy; + using WarpGemm = ck::tile_program::warp::WarpGemmMfma< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + return BlockGemmARegBSmemCRegV1{}; + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/tile_fmha_shape.hpp b/include/ck/tile_program/tile/tile_fmha_shape.hpp index f26266e17..3d5450cd6 100644 --- a/include/ck/tile_program/tile/tile_fmha_shape.hpp +++ b/include/ck/tile_program/tile/tile_fmha_shape.hpp @@ -8,19 +8,27 @@ namespace ck { namespace tile_program { -template +template struct TileFmhaShape { - static constexpr index_t kM0 = kM0PerTile_; - static constexpr index_t kN0 = kN0PerTile_; - static constexpr index_t kK0 = kK0PerTile_; - static constexpr index_t kN1 = kN1PerTile_; - static constexpr index_t kK1 = kK1PerTile_; + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + + static constexpr index_t kM0 = BlockTile::At(Number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::At(Number<1>{}); // tile size along k seqlen + static constexpr index_t kK0 = BlockTile::At(Number<2>{}); // tile size along qk gemm unroll + static constexpr index_t kN1 = BlockTile::At(Number<3>{}); // tile size along v head_dim + static constexpr index_t kK1 = BlockTile::At(Number<4>{}); // tile size along kv gemm unroll + static constexpr index_t kK0BlockLength = + BlockTile::At(Number<5>{}); // total length of K0, used for pipeline that need load Q at + // once (or repeately load Q as a whole tile) }; } // namespace tile_program diff --git a/include/ck/tile_program/warp_tile/warp_gemm.hpp b/include/ck/tile_program/warp_tile/warp_gemm.hpp index 21ff0bfc7..2e938bfa2 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm.hpp @@ -22,9 +22,15 @@ using WarpGemmMfmaF16F16F32M16N16K16 = using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl>; +using WarpGemmMfmaF16F16F32M16N16K32 = + WarpGemmImpl>; + using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution>; +using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl>; +template +struct WarpGemmMfmaDispatcher; + +// clang-format off +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; +// clang-format on + +template +using WarpGemmMfma = + typename WarpGemmMfmaDispatcher:: + Type; + } // namespace warp } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp index f87f7c77f..ebbaee08e 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp @@ -287,6 +287,90 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution } }; +template +struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_V2 +{ + using Impl = remove_cvref_t; + + // swap A and B + using ADataType = typename Impl::BDataType; + using BDataType = typename Impl::ADataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename vector_type_maker::type::type; + using BVecType = typename vector_type_maker::type::type; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK * kKIter; + + using AWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<2>, + Sequence<1>>; + + using BWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2>, + Sequence<1>>; + + using CWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2, 2>, + Sequence<0, 2>>; + + // c_vec += a_vec * b_vec + __device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + const auto a_vector = typename vector_type_maker::type{a_vec}; + + const auto b_vector = typename vector_type_maker::type{b_vec}; + + // swap A and B, value and type + static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + b_vector.template AsType()[iKIter], + a_vector.template AsType()[iKIter]); + }); + } + + // c_vec = a_vec * b_vec + __device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + const auto a_vector = typename vector_type_maker::type{a_vec}; + const auto b_vector = typename vector_type_maker::type{b_vec}; + + constexpr auto I0 = Number<0>{}; + + // swap A and B, value and type + auto c_vec = Impl{}(b_vector.template AsType()[I0], + a_vector.template AsType()[I0]); + + static_for<1, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + b_vector.template AsType()[iKIter], + a_vector.template AsType()[iKIter]); + }); + + return c_vec; + } +}; + } // namespace warp } // namespace tile_program } // namespace ck From 5605132917758b9d3d381def422e9ea0219fe4fe Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 2 Nov 2023 01:47:32 -0400 Subject: [PATCH 2/2] add refactor warp_gemm dispatcher --- ...gemm_asmem_bsmem_creg_v1_custom_policy.hpp | 5 +- ..._fmha_pipeline_qr_ks_vs_default_policy.hpp | 12 ++--- .../ck/tile_program/warp_tile/warp_gemm.hpp | 31 ------------ .../warp_tile/warp_gemm_attribute_mfma.hpp | 2 +- .../warp_tile/warp_gemm_dispatcher.hpp | 47 +++++++++++++++++++ 5 files changed, 57 insertions(+), 40 deletions(-) create mode 100644 include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp diff --git a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index 42eec8e5c..d22d702cb 100644 --- a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -11,7 +11,7 @@ #include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/tile_elementwise.hpp" #include "ck/tile_program/tile/tile_gemm_shape.hpp" -#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp" namespace ck { namespace tile_program { @@ -43,7 +43,8 @@ struct BlockGemmASmemBSmemCRegV1CustomPolicy static constexpr bool TranposeC = TranposeC_; - using WarpGemm = WarpGemmMfma; + using WarpGemm = ck::tile_program::warp:: + WarpGemmMfmaDispatcher; template __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp index 84b475587..d1d70e78c 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -11,7 +11,7 @@ #include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/tile_elementwise.hpp" #include "ck/tile_program/tile/tile_gemm_shape.hpp" -#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp" #include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" @@ -224,14 +224,14 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy TileGemmShape>; - // using WarpGemm = ck::tile_program::warp::WarpGemmMfma{}), // Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<1>{}), // Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<2>{}), true>; - using WarpGemm = - warp::WarpGemmImpl>; @@ -257,7 +257,7 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy Problem::BlockFmhaShape::kN1, Problem::BlockFmhaShape::kK1>>; // using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy; - using WarpGemm = ck::tile_program::warp::WarpGemmMfma< + using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher< typename Problem::PDataType, typename Problem::VDataType, typename Problem::OaccDataType, diff --git a/include/ck/tile_program/warp_tile/warp_gemm.hpp b/include/ck/tile_program/warp_tile/warp_gemm.hpp index 2e938bfa2..50a80a07c 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm.hpp @@ -41,37 +41,6 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; -template -struct WarpGemmMfmaDispatcher; - -// clang-format off -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; -// clang-format on - -template -using WarpGemmMfma = - typename WarpGemmMfmaDispatcher:: - Type; - } // namespace warp } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp index ebbaee08e..cad9bac68 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp @@ -288,7 +288,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution }; template -struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_V2 +struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB { using Impl = remove_cvref_t; diff --git a/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp b/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp new file mode 100644 index 000000000..42751ce13 --- /dev/null +++ b/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace warp { + +namespace impl { +template +struct WarpGemmMfmaDispatcher; + +// clang-format off +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; +// clang-format on +} // namespace impl + +template +using WarpGemmMfmaDispatcher = typename impl:: + WarpGemmMfmaDispatcher::Type; + +} // namespace warp +} // namespace tile_program +} // namespace ck