From 14a7de08f6ad1f4b2360dc94c2da1f932e181869 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 18 Oct 2023 19:13:45 +0000 Subject: [PATCH 01/10] Temp save. basic programming scheme determined --- .../gemm_softmax_gemm_impl.hpp | 44 ++-- .../block_gemm_areg_bsmem_creg_v1.hpp | 2 + ...emm_areg_bsmem_creg_v1_iteratek_policy.hpp | 33 +++ ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 220 ++++++++++++++++++ ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 57 +++++ ..._pipeline_agmem_bgmem_creg_policy_impl.hpp | 192 +++++++++++++++ include/ck/tile_program/tile/store_tile.hpp | 1 + .../tile/store_tile_impl_static_tensor.hpp | 30 +++ 8 files changed, 566 insertions(+), 13 deletions(-) create mode 100644 include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp create mode 100644 include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp create mode 100644 include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp create mode 100644 include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp create mode 100644 include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp diff --git a/example/91_tile_program/gemm_softmax_gemm_impl.hpp b/example/91_tile_program/gemm_softmax_gemm_impl.hpp index ec664fafe..3dcced652 100644 --- a/example/91_tile_program/gemm_softmax_gemm_impl.hpp +++ b/example/91_tile_program/gemm_softmax_gemm_impl.hpp @@ -14,6 +14,7 @@ #include "ck/tile_program/tile/slice_tile.hpp" #include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp" #include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp" @@ -37,15 +38,15 @@ template struct GemmSoftmaxGemmImpl { - // block gemm0 pipeline - using BlockGemm0Pipeline = ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2< - ck::tile_program::block::BlockGemmPipelineProblem< - QDataType, - KDataType, - SaccDataType, - kBlockSize, - ck::tile_program::TileGemmShape>, - ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>; + // block gemm0 + using BlockGemm0Problem = ck::tile_program::block::BlockGemmPipelineProblem< + QDataType, + KDataType, + SaccDataType, + kBlockSize, + ck::tile_program::TileGemmShape>; + using BlockGemm0Policy = BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy; + using BlockGemm0 = decltype(BlockGemm0Policy::GetBlockGemm()); // block gemm1 using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1< @@ -162,7 +163,8 @@ struct GemmSoftmaxGemmImpl { using namespace ck; - return math::max(BlockGemm0Pipeline::GetStaticLdsSize(), + return math::max(sizeof(KDataType) * + Policy::template MakeBLdsBlockDescriptor().GetElementSpaceSize(), static_cast(MakeVLdsBlockDescriptor().GetElementSpaceSize() * sizeof(VDataType))); } @@ -204,10 +206,16 @@ struct GemmSoftmaxGemmImpl v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), Number<32>{}, Number<1>{}); auto q_dram_window = make_tile_window( - q_dram, make_tuple(Number{}, Number{}), {iM0, 0}); + q_dram, + make_tuple(Number{}, Number{}), + {iM0, 0}, + BlockGemm0Policy::MakeADramTileDistribution()); auto k_dram_window = make_tile_window( - k_dram, make_tuple(Number{}, Number{}), {0, 0}); + k_dram, + make_tuple(Number{}, Number{}), + {0, 0}, + BlockGemm0Policy::MakeBDramTileDistribution()); auto v_dram_window = make_tile_window(v_dram, @@ -224,7 +232,7 @@ struct GemmSoftmaxGemmImpl v_lds, make_tuple(Number{}, Number{}), {0, 0}); // Block GEMM0 pipeline and Block GEMM1 - constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm0 = BlockGemm0{}; constexpr auto gemm1 = BlockGemm1{}; // reduction function for softmax @@ -267,6 +275,16 @@ struct GemmSoftmaxGemmImpl // Sacc{j} = Q * K{j} const auto s_acc = gemm0_pipeline(q_dram_window, k_dram_window, K0 / kK0PerBlock, smem_ptr); + + // Expand gemm0 pipeline + if (iN0 == 0) + { + /* Cold Q_Reg_Cache */ + } + else + { + /* Hot Q_Reg_Cache */ + } // S{j} const auto s = diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp index 7f5cb3ff1..dd5decd74 100644 --- a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp @@ -14,6 +14,7 @@ #include "ck/tile_program/tile/tile_gemm_shape.hpp" #include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp" namespace ck { namespace tile_program { @@ -45,6 +46,7 @@ struct BlockGemmARegBSmemCRegV1 using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using BlockGemmPolicy= Policy; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp new file mode 100644 index 000000000..17590a90e --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +struct BlockGemmARegBSmemCRegV1K8Policy +{ + template + __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() + { + using namespace ck::tile_program::warp; + + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp new file mode 100644 index 000000000..3d51067a4 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/load_tile.hpp" +#include "ck/tile_program/tile/store_tile.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCRegV2 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using Policy = BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + // Move this part into Policy? + __host__ __device__ static constexpr ck::index_t GetStaticLdsSize() + { + return + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().GetElementSpaceSize(); + } + + template + __host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + static_assert( + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + a_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeADramTileDistribution()); + + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = + make_static_distributed_tensor(a_reg_block_dstr); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + b_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(Number{}, Number{}), + {0, 0}, + b_copy_dram_window.GetTileDistribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(Number{}, Number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_copy_reg_tensor, b_lds_gemm_window)){}; + + // prefetch + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // block buffer write 0 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + // store_tile -> shuffle store tile + store_tile(a_copy_reg_tensor, a_block_tile_tmp); + // global read 1 + a_block_tile = load_tile(a_copy_dram_window); + + // LDS write 0 + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + // global read 1 + b_block_tile = load_tile(b_copy_dram_window); + } + + index_t iCounter = num_loop - 2; + + do + { + block_sync_lds(); + + // GEMM i + block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_reg_tensor, a_block_tile_tmp); + // global read i + 2 + a_block_tile = load_tile(a_copy_dram_window); + + // LDS write i + 1 + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + // global read i + 2 + b_block_tile = load_tile(b_copy_dram_window); + + iCounter--; + + } while(iCounter > 0); + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); + + block_sync_lds(); + + // LDS write num_loop - 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_reg_tensor, a_block_tile_tmp); + + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + + block_sync_lds(); + + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); + } + + return c_block_tile; + } + + template + __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp new file mode 100644 index 000000000..12412e111 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// NOTE: Assume A is K-Major +struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy +{ + template + __host__ __device__ static constexpr auto MakeARegBlockDescriptor() + { + constexpr auto blockgemm = GetBlockGemm(); + using BlockGemm = ck::remove_cvref_t; + + return policy_impl::MakeARegBlockDescriptor(); + } + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + return policy_impl::MakeBLdsBlockDescriptor(); + } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + constexpr auto blockgemm = GetBlockGemm(); + using BlockGemm = ck::remove_cvref_t; + + return policy_impl::MakeADramTileDistribution_ASkipLDS(); + } + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + return policy_impl::MakeADramTileDistribution(); + } + + template + __host__ __device__ static constexpr auto GetBlockGemm() + { + using BlockGemmPolicy = BlockGemmARegBSmemCRegV1K8Policy; + + return BlockGemmARegBSmemCRegV1{}; + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp new file mode 100644 index 000000000..c7edda066 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp @@ -0,0 +1,192 @@ +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { +namespace policy_impl{ + // 3d + padding + template + __host__ __device__ static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number<8>{}), + make_tuple(Number<(kMPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return a_lds_block_desc; + } + + // 3d + padding + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + using namespace ck; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number<8>{}), + make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return b_lds_block_desc; + } + + template + __host__ __device__ static constexpr auto MakeARegBlockDescriptor() + { + using namespace ck; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + return a_block_dstr; + } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution_ASkipLDS() + { + constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K2 = WG::kK/ WG::WarpGemmAttribute::Impl::kABKLane; //WG::WarpGemmAttribute::Impl::kABKPerLane; // 16 / sizeof(ADataType); + constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K0 = kKPerBlock / (K1 * K2); + + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1>>, + Tuple, Sequence<1, 2>>, + Sequence<2, 1, 2>, + Sequence<0, 0, 2>>{}); + } + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + + template + __host__ __device__ static constexpr auto GetBlockGemm() + { + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; + + return BlockGemmASmemBSmemCRegV1{}; + } + +} // namespace policy_impl +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/store_tile.hpp b/include/ck/tile_program/tile/store_tile.hpp index 7aba746f7..974e8d082 100644 --- a/include/ck/tile_program/tile/store_tile.hpp +++ b/include/ck/tile_program/tile/store_tile.hpp @@ -12,3 +12,4 @@ #include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/store_tile_impl_static_distribution.hpp" #include "ck/tile_program/tile/store_tile_impl_static_lengths.hpp" +#include "ck/tile_program/tile/store_tile_impl_static_tensor.hpp" diff --git a/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp b/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp new file mode 100644 index 000000000..45b63b46e --- /dev/null +++ b/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_window.hpp" + +namespace ck { +namespace tile_program { + +template +__device__ void +store_tile(StaticDistributedTensor& dst_dstr_tensor, + const StaticDistributedTensor& src_dstr_tensor) +{ + // static_assert(DstTileDistribution_==SrcTileDistribution_); + dst_dstr_tensor.GetThreadBuffer() = src_dstr_tensor.GetThreadBuffer(); +} + +} // namespace tile_program +} // namespace ck From c3ebd87c82786b73d1062143bc089b3d8ab2ef8e Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 19 Oct 2023 19:04:42 +0000 Subject: [PATCH 02/10] temp save, debug correctness --- .../batched_gemm_softmax_gemm.cpp | 129 ++++++- .../batched_gemm_softmax_gemm.hpp | 2 + .../gemm_softmax_gemm_impl.hpp | 241 +++++++++++-- .../multi_index_transform.hpp | 6 +- .../impl/device_image_to_column_impl.hpp | 5 +- .../block_gemm_areg_bsmem_creg_v1.hpp | 10 +- ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 21 +- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 46 ++- ..._pipeline_agmem_bgmem_creg_policy_impl.hpp | 339 +++++++++--------- include/ck/tile_program/tile/slice_tile.hpp | 7 +- .../tile/static_distributed_tensor.hpp | 2 +- .../tile/store_tile_impl_static_tensor.hpp | 4 +- 12 files changed, 566 insertions(+), 246 deletions(-) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm.cpp b/example/91_tile_program/batched_gemm_softmax_gemm.cpp index a57ddf461..7cb4fd3ac 100644 --- a/example/91_tile_program/batched_gemm_softmax_gemm.cpp +++ b/example/91_tile_program/batched_gemm_softmax_gemm.cpp @@ -29,18 +29,28 @@ int main(int argc, char* argv[]) using ODataType = ck::half_t; ck::index_t Batch = 16; - ck::index_t M0 = 3328; + ck::index_t M0 = 4096; ck::index_t N0 = 4096; ck::index_t K0 = 128; ck::index_t N1 = 128; + ck::index_t init_method = 1; + ck::index_t time_kernel = 0; - if(argc == 6) + if(argc == 3) { - Batch = std::stoi(argv[1]); - M0 = std::stoi(argv[2]); - N0 = std::stoi(argv[3]); - K0 = std::stoi(argv[4]); - N1 = std::stoi(argv[5]); + init_method = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + + if(argc == 8) + { + init_method = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); } std::array q_lengths{Batch, M0, K0}; @@ -70,6 +80,105 @@ int main(int argc, char* argv[]) Tensor o_host_ref(o_lengths, o_strides); Tensor o_host_dev(o_lengths, o_strides); + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 2: + ck::utils::FillUniformDistribution{-3.f, 3.f}(q_host); + ck::utils::FillUniformDistribution{-3.f, 3.f}(k_host); + ck::utils::FillUniformDistribution{-3.f, 3.f}(v_host); + break; + case 3: + ck::utils::FillConstant{1.f}(q_host); + ck::utils::FillConstant{1.f}(k_host); + ck::utils::FillConstant{1.f}(v_host); + break; + case 4: + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck::utils::FillConstant{1.f}(k_host); + ck::utils::FillConstant{1.f}(v_host); + break; + case 5: + ck::utils::FillConstant{1.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck::utils::FillConstant{1.f}(v_host); + break; + case 6: + ck::utils::FillConstant{1.f}(q_host); + ck::utils::FillConstant{1.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 7: + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck::utils::FillConstant{1.f}(v_host); + break; + case 8: + ck::utils::FillConstant{1.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 9: + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck::utils::FillConstant{1.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + default: + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); + } +#if 0 + for (int im = 0; im < M0; im++) + { + for (int ik = 0; ik < K0; ik++) + { + printf("%04x ",*(reinterpret_cast(&(q_host(0, im, ik))))); + if (ik % 4 == 3) + { + printf("|"); + } + } + printf("\n"); + } +#endif + +#if 0 + for (int in = 0; in < N0; in++) + { + for (int ik = 0; ik < K0; ik++) + { + printf("%04x ",*(reinterpret_cast(&(k_host(0, in, ik))))); + if (ik % 8 == 7) + { + printf("|"); + } + } + printf("\n"); + } +#endif + +#if 1 + for (int in = 0; in < N1; in++) + { + for (int ik = 0; ik < N0; ik++) + { + printf("%04x ",*(reinterpret_cast(&(v_host(0, in, ik))))); + if (ik % 8 == 7) + { + printf("|"); + } + } + printf("\n"); + } +#endif + +/* #if 0 ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); @@ -79,6 +188,7 @@ int main(int argc, char* argv[]) ck::utils::FillUniformDistribution{-3.f, 3.f}(k_host); ck::utils::FillUniformDistribution{-3.f, 3.f}(v_host); #endif +*/ // reference reference_batched_gemm( @@ -104,6 +214,8 @@ int main(int argc, char* argv[]) constexpr ck::index_t kK1PerBlock = 32; constexpr ck::index_t kBlockSize = 256; + constexpr ck::index_t kHeadDim = 128; + ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock); std::cout << "grid size " << kGridSize << std::endl; @@ -113,7 +225,7 @@ int main(int argc, char* argv[]) constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; float ave_time = - launch_kernel(StreamConfig{nullptr, true}, + launch_kernel(StreamConfig{nullptr, static_cast(time_kernel)}, BatchedGemmSoftmaxGemm>; - using BlockGemm0Policy = BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy; - using BlockGemm0 = decltype(BlockGemm0Policy::GetBlockGemm()); + QDataType, + KDataType, + SaccDataType, + kBlockSize, + ck::tile_program::TileGemmShape>; + using BlockGemm0Policy = + ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy; + using BlockGemm0 = decltype(BlockGemm0Policy::GetBlockGemm()); // block gemm1 using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1< @@ -163,8 +165,10 @@ struct GemmSoftmaxGemmImpl { using namespace ck; - return math::max(sizeof(KDataType) * - Policy::template MakeBLdsBlockDescriptor().GetElementSpaceSize(), + return math::max(static_cast( + BlockGemm0Policy::template MakeBLdsBlockDescriptor() + .GetElementSpaceSize() * + sizeof(KDataType)), static_cast(MakeVLdsBlockDescriptor().GetElementSpaceSize() * sizeof(VDataType))); } @@ -205,17 +209,28 @@ struct GemmSoftmaxGemmImpl const auto v_dram = make_naive_tensor_view( v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), Number<32>{}, Number<1>{}); - auto q_dram_window = make_tile_window( - q_dram, - make_tuple(Number{}, Number{}), - {iM0, 0}, - BlockGemm0Policy::MakeADramTileDistribution()); - - auto k_dram_window = make_tile_window( - k_dram, - make_tuple(Number{}, Number{}), - {0, 0}, - BlockGemm0Policy::MakeBDramTileDistribution()); + auto q_dram_window = + make_tile_window(q_dram, + make_tuple(Number{}, Number{}), + {iM0, 0}, + BlockGemm0Policy::MakeADramTileDistribution()); + + // Q in Register + auto q_reg = make_static_distributed_tensor( + BlockGemm0Policy::template MakeARegBlockDescriptor()); + + auto k_dram_window = + make_tile_window(k_dram, + make_tuple(Number{}, Number{}), + {0, 0}, + BlockGemm0Policy::MakeBDramTileDistribution()); + + // K LDS and LDS window + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + BlockGemm0Policy::MakeBLdsBlockDescriptor()); + auto k_lds_window = make_tile_window( + k_lds, make_tuple(Number{}, Number{}), {0, 0}); auto v_dram_window = make_tile_window(v_dram, @@ -232,16 +247,17 @@ struct GemmSoftmaxGemmImpl v_lds, make_tuple(Number{}, Number{}), {0, 0}); // Block GEMM0 pipeline and Block GEMM1 - constexpr auto gemm0 = BlockGemm0{}; - constexpr auto gemm1 = BlockGemm1{}; + constexpr auto gemm0 = BlockGemm0{}; + constexpr auto gemm1 = BlockGemm1{}; // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; // infer Sacc, S, P, M, L, Oacc type - using SaccBlockTileType = - decltype(gemm0_pipeline(q_dram_window, k_dram_window, 0, nullptr)); + using SaccBlockTileType = decltype(gemm0( + get_slice_tile(q_reg, Sequence<0, 0>{}, Sequence{}), + k_lds_window)); using SBlockTileType = decltype(tile_elementwise_in( type_convert, SaccBlockTileType{})); @@ -257,7 +273,8 @@ struct GemmSoftmaxGemmImpl PBlockTileType{}, Sequence<0, 0>{}, Sequence{}), v_dram_window)); - // init Oacc, M, L + // init Sacc, Oacc, M, L + auto s_acc = SaccBlockTileType{}; auto o_acc = OaccBlockTileType{}; auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; @@ -268,24 +285,159 @@ struct GemmSoftmaxGemmImpl tile_elementwise_inout([](auto& e) { e = 0; }, l); // loop over Column of S (J loop) - index_t iN0 = 0; + index_t iN0 = 0; + constexpr index_t k0_loops = kHeadDim / kK0PerBlock; + + // Cold Q_Reg_Cache + auto q_block_tile = load_tile(q_dram_window); +#if 0 + printf("Blockid: %02d, Tid: %03d, q_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<3>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<4>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<5>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<6>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<7>{}]))) + ); +#endif + auto k_block_tile = load_tile(k_dram_window); +#if 0 + printf("Blockid: %02d, Tid: %03d, k_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) + ); +#endif + { + move_tile_window(q_dram_window, {0, kK0PerBlock}); + move_tile_window(k_dram_window, {0, kK0PerBlock}); + + tile_elementwise_inout([](auto& s) { s = 0; }, s_acc); + + set_slice_tile( + q_reg, q_block_tile, Sequence<0, 0>{}, Sequence{}); + q_block_tile = load_tile(q_dram_window); + + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); + } + + static_for<0, k0_loops - 3, 1>{}([&](auto i_k0) { + block_sync_lds(); + + gemm0(s_acc, + get_slice_tile(q_reg, + Sequence<0, (i_k0)*kK0PerBlock>{}, + Sequence{}), + k_lds_window); + + block_sync_lds(); + + move_tile_window(q_dram_window, {0, kK0PerBlock}); + move_tile_window(k_dram_window, {0, kK0PerBlock}); + + set_slice_tile(q_reg, + q_block_tile, + Sequence<0, (i_k0 + 1) * kK0PerBlock>{}, + Sequence{}); + q_block_tile = load_tile(q_dram_window); + + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); + }); + + // tail + { + block_sync_lds(); + + gemm0(s_acc, + get_slice_tile(q_reg, + Sequence<0, (k0_loops - 2) * kK0PerBlock>{}, + Sequence{}), + k_lds_window); + + block_sync_lds(); + + set_slice_tile(q_reg, + q_block_tile, + Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, + Sequence{}); + + store_tile(k_lds_window, k_block_tile); + + block_sync_lds(); + + gemm0(s_acc, + get_slice_tile(q_reg, + Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, + Sequence{}), + k_lds_window); + } do { - // Sacc{j} = Q * K{j} - const auto s_acc = - gemm0_pipeline(q_dram_window, k_dram_window, K0 / kK0PerBlock, smem_ptr); - - // Expand gemm0 pipeline - if (iN0 == 0) + // Hot Q_Reg_Cache + if(iN0 > 0) { - /* Cold Q_Reg_Cache */ - } - else - { - /* Hot Q_Reg_Cache */ - } + k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0PerBlock}); + + tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); + + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); + } + + static_for<0, k0_loops - 3, 1>{}([&](auto i_k0) { + block_sync_lds(); + + gemm0(s_acc, + get_slice_tile(q_reg, + Sequence<0, (i_k0)*kK0PerBlock>{}, + Sequence{}), + k_lds_window); + + block_sync_lds(); + + move_tile_window(k_dram_window, {0, kK0PerBlock}); + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); + }); + + // tail + { + block_sync_lds(); + + gemm0(s_acc, + get_slice_tile(q_reg, + Sequence<0, (k0_loops - 2) * kK0PerBlock>{}, + Sequence{}), + k_lds_window); + + block_sync_lds(); + + store_tile(k_lds_window, k_block_tile); + + block_sync_lds(); + + gemm0(s_acc, + get_slice_tile(q_reg, + Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, + Sequence{}), + k_lds_window); + } + } // S{j} const auto s = tile_elementwise_in(type_convert, s_acc); @@ -293,6 +445,18 @@ struct GemmSoftmaxGemmImpl // prefetch load v tile const auto v_prefetch = load_tile(v_dram_window); + printf("Blockid: %02d, Tid: %03d, v_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<3>{}]))), + *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<4>{}]))), + *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<5>{}]))), + *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<6>{}]))), + *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<7>{}]))) + ); + // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, Sequence<1>{}, f_max, NumericLimits::Lowest()); @@ -385,7 +549,6 @@ struct GemmSoftmaxGemmImpl // move tile windows move_tile_window(k_dram_window, {kN0PerBlock, 0}); iN0 += kN0PerBlock; - } while(iN0 < N0); // Oacc diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index d8a7a6a33..aa3e92f51 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -544,9 +544,9 @@ struct Merge_v2_magic_division : public BaseTransform using UpLengths = decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); - using LowLengthsMagicDivisor = decltype( - generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_divisor{}, - Number{})); + using LowLengthsMagicDivisor = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_divisor{}, + Number{})); LowLengths low_lengths_; LowLengthsMagicDivisor low_lengths_magic_divisor_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index 89fcbca1a..19f126e66 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -151,8 +151,9 @@ struct DeviceImageToColumnImpl remove_cvref_t; using OutputGridDesc = remove_cvref_t; - using Block2ETileMap = remove_cvref_t(OutputGridDesc{}))>; + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + OutputGridDesc{}))>; using GridwiseImageToColumnKernel = GridwiseImageToColumn struct BlockGemmARegBSmemCRegV1 { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using BlockGemmPolicy= Policy; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using BlockGemmPolicy = Policy; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 3d51067a4..45abd297a 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -30,7 +30,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using Policy = BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy; + using Policy = BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -41,9 +41,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV2().GetElementSpaceSize(); + return sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().GetElementSpaceSize(); } template (); // B tile in LDS, blockWindow - BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem))); + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - // This tensor view used to construct both tile window for lds store and read, with buffer address info + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); // A DRAM tile window for load @@ -87,10 +88,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV2()); - // A Reg tensor for store, also used for block GEMM - auto a_copy_reg_tensor = - make_static_distributed_tensor(a_reg_block_dstr); + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); // B DRAM tile window for load auto b_copy_dram_window = diff --git a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 12412e111..aeb8a3278 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -17,7 +17,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy __host__ __device__ static constexpr auto MakeARegBlockDescriptor() { constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = ck::remove_cvref_t; + using BlockGemm = ck::remove_cvref_t; return policy_impl::MakeARegBlockDescriptor(); } @@ -32,7 +32,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy __host__ __device__ static constexpr auto MakeADramTileDistribution() { constexpr auto blockgemm = GetBlockGemm(); - using BlockGemm = ck::remove_cvref_t; + using BlockGemm = ck::remove_cvref_t; return policy_impl::MakeADramTileDistribution_ASkipLDS(); } @@ -52,6 +52,48 @@ struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy } }; +struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy + : BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy +{ + template + __host__ __device__ static constexpr auto MakeARegBlockDescriptor() + { + using namespace ck; + + constexpr auto blockgemm = GetBlockGemm(); + using BlockGemm = ck::remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = kHeadDim; + + constexpr auto config = + BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + return a_block_dstr; + } +}; + } // namespace block } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp index c7edda066..cfc0b0138 100644 --- a/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ b/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp @@ -11,180 +11,181 @@ namespace ck { namespace tile_program { namespace block { -namespace policy_impl{ - // 3d + padding - template - __host__ __device__ static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number<8>{}), - make_tuple(Number<(kMPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), - Number<8>{}, - Number<1>{}); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return a_lds_block_desc; - } - - // 3d + padding - template - __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() - { - using namespace ck; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number<8>{}), - make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), - Number<8>{}, - Number<1>{}); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return b_lds_block_desc; - } - - template - __host__ __device__ static constexpr auto MakeARegBlockDescriptor() - { - using namespace ck; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template At<1>(); - constexpr index_t NWarp = config.template At<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding< - Sequence, - Tuple, Sequence>, - Tuple>, - Tuple>, - Sequence<1, 2>, - Sequence<0, 0>>{}; - - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); - - return a_block_dstr; - } - - template - __host__ __device__ static constexpr auto MakeADramTileDistribution() - { - using ADataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); - } - - template - __host__ __device__ static constexpr auto MakeADramTileDistribution_ASkipLDS() - { - constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template At<1>(); - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t K2 = WG::kK/ WG::WarpGemmAttribute::Impl::kABKLane; //WG::WarpGemmAttribute::Impl::kABKPerLane; // 16 / sizeof(ADataType); - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); +namespace policy_impl { +// 3d + padding +template +__host__ __device__ static constexpr auto MakeALdsBlockDescriptor() +{ + using namespace ck; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number<8>{}), + make_tuple(Number<(kMPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto a_lds_block_desc = + transform_tensor_descriptor(a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<2, 1>>, - Tuple, Sequence<1, 2>>, - Sequence<2, 1, 2>, - Sequence<0, 0, 2>>{}); - } + return a_lds_block_desc; +} + +// 3d + padding +template +__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() +{ + using namespace ck; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number<8>{}), + make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto b_lds_block_desc = + transform_tensor_descriptor(b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return b_lds_block_desc; +} + +template +__host__ __device__ static constexpr auto MakeARegBlockDescriptor() +{ + using namespace ck; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + return a_block_dstr; +} + +template +__host__ __device__ static constexpr auto MakeADramTileDistribution() +{ + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); +} + +template +__host__ __device__ static constexpr auto MakeADramTileDistribution_ASkipLDS() +{ + constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K2 = + WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane; + // // 16 / sizeof(ADataType); + constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K0 = kKPerBlock / (K1 * K2); + + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1>>, + Tuple, Sequence<1, 2>>, + Sequence<2, 1, 2>, + Sequence<0, 0, 2>>{}); +} + +template +__host__ __device__ static constexpr auto MakeBDramTileDistribution() +{ + using BDataType = remove_cvref_t; - template - __host__ __device__ static constexpr auto MakeBDramTileDistribution() - { - using BDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); - } - - template - __host__ __device__ static constexpr auto GetBlockGemm() - { - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; - - return BlockGemmASmemBSmemCRegV1{}; - } + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); +} + +template +__host__ __device__ static constexpr auto GetBlockGemm() +{ + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; + + return BlockGemmASmemBSmemCRegV1{}; +} } // namespace policy_impl } // namespace block diff --git a/include/ck/tile_program/tile/slice_tile.hpp b/include/ck/tile_program/tile/slice_tile.hpp index e7999f26a..78aa1208a 100644 --- a/include/ck/tile_program/tile/slice_tile.hpp +++ b/include/ck/tile_program/tile/slice_tile.hpp @@ -49,17 +49,18 @@ __host__ __device__ constexpr auto set_slice_tile(DstStaticDistributedTensor_& d Sequence slice_ends) { using DstDistribution = decltype(DstStaticDistributedTensor_::GetTileDistribution()); + // using SrcDistribution = decltype(SrcStaticDistributedTensor_::GetTileDistribution()); constexpr auto sliced_dstr_yidx_ylen = detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); - constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); + // constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); - static_assert(is_same_v, "wrong!"); + // static_assert(is_same_v, "wrong!"); - dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer()); + dst_tile.SetYSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer()); } } // namespace tile_program diff --git a/include/ck/tile_program/tile/static_distributed_tensor.hpp b/include/ck/tile_program/tile/static_distributed_tensor.hpp index 57b3c418c..19f420aa4 100644 --- a/include/ck/tile_program/tile/static_distributed_tensor.hpp +++ b/include/ck/tile_program/tile/static_distributed_tensor.hpp @@ -59,7 +59,7 @@ struct StaticDistributedTensor template __host__ __device__ auto GetYSlicedThreadData(Sequence, - Sequence) const + Sequence) const { static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY && sizeof...(YSliceLengths) == StaticTileDistribution::NDimY, diff --git a/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp b/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp index 45b63b46e..4407df6bd 100644 --- a/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp +++ b/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp @@ -15,9 +15,7 @@ namespace ck { namespace tile_program { -template +template __device__ void store_tile(StaticDistributedTensor& dst_dstr_tensor, const StaticDistributedTensor& src_dstr_tensor) From c96627dad04d4b1b2b397b59750ca089aceb10d9 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 20 Oct 2023 12:21:47 +0000 Subject: [PATCH 03/10] tempsave, debug print --- .../batched_gemm_softmax_gemm.cpp | 15 +-- .../gemm_softmax_gemm_impl.hpp | 94 +++++++++++++++---- 2 files changed, 87 insertions(+), 22 deletions(-) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm.cpp b/example/91_tile_program/batched_gemm_softmax_gemm.cpp index 7cb4fd3ac..5f1de81e8 100644 --- a/example/91_tile_program/batched_gemm_softmax_gemm.cpp +++ b/example/91_tile_program/batched_gemm_softmax_gemm.cpp @@ -133,13 +133,14 @@ int main(int argc, char* argv[]) ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); } -#if 0 +#if 1 + std::cout<<"Print Q matrix"<(&(q_host(0, im, ik))))); - if (ik % 4 == 3) + if (ik % 8 == 7) { printf("|"); } @@ -149,6 +150,7 @@ int main(int argc, char* argv[]) #endif #if 0 + std::cout<<"Print K matrix"<()); // Q in Register - auto q_reg = make_static_distributed_tensor( + auto q_reg_tensor = make_static_distributed_tensor( BlockGemm0Policy::template MakeARegBlockDescriptor()); auto k_dram_window = @@ -256,7 +256,7 @@ struct GemmSoftmaxGemmImpl // infer Sacc, S, P, M, L, Oacc type using SaccBlockTileType = decltype(gemm0( - get_slice_tile(q_reg, Sequence<0, 0>{}, Sequence{}), + get_slice_tile(q_reg_tensor, Sequence<0, 0>{}, Sequence{}), k_lds_window)); using SBlockTileType = decltype(tile_elementwise_in( @@ -291,7 +291,7 @@ struct GemmSoftmaxGemmImpl // Cold Q_Reg_Cache auto q_block_tile = load_tile(q_dram_window); #if 0 - printf("Blockid: %02d, Tid: %03d, q_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x|\n", + printf("Blockid: %02d, Tid: %03d, q_thread_buf(0-7): %04x %04x %04x %04x %04x %04x %04x %04x|\n", get_block_1d_id(), get_thread_local_1d_id(), *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<0>{}]))), *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<1>{}]))), @@ -302,6 +302,18 @@ struct GemmSoftmaxGemmImpl *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<6>{}]))), *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<7>{}]))) ); + + printf("Blockid: %02d, Tid: %03d, q_thread_buf(8-15): %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<8>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<9>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<10>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<11>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<12>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<13>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<14>{}]))), + *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<15>{}]))) + ); #endif auto k_block_tile = load_tile(k_dram_window); #if 0 @@ -324,18 +336,42 @@ struct GemmSoftmaxGemmImpl tile_elementwise_inout([](auto& s) { s = 0; }, s_acc); set_slice_tile( - q_reg, q_block_tile, Sequence<0, 0>{}, Sequence{}); + q_reg_tensor, q_block_tile, Sequence<0, 0>{}, Sequence{}); q_block_tile = load_tile(q_dram_window); +#if 1 + printf("Blockid: %02d, Tid: %03d, q_thread_buf(0-7): %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<3>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<4>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<5>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<6>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<7>{}]))) + ); + printf("Blockid: %02d, Tid: %03d, q_thread_buf(8-15): %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<8>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<9>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<10>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<11>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<12>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<13>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<14>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<15>{}]))) + ); +#endif store_tile(k_lds_window, k_block_tile); k_block_tile = load_tile(k_dram_window); } - + if constexpr(k0_loops > 2){ static_for<0, k0_loops - 3, 1>{}([&](auto i_k0) { block_sync_lds(); gemm0(s_acc, - get_slice_tile(q_reg, + get_slice_tile(q_reg_tensor, Sequence<0, (i_k0)*kK0PerBlock>{}, Sequence{}), k_lds_window); @@ -345,7 +381,7 @@ struct GemmSoftmaxGemmImpl move_tile_window(q_dram_window, {0, kK0PerBlock}); move_tile_window(k_dram_window, {0, kK0PerBlock}); - set_slice_tile(q_reg, + set_slice_tile(q_reg_tensor, q_block_tile, Sequence<0, (i_k0 + 1) * kK0PerBlock>{}, Sequence{}); @@ -354,30 +390,55 @@ struct GemmSoftmaxGemmImpl store_tile(k_lds_window, k_block_tile); k_block_tile = load_tile(k_dram_window); }); + } // tail { block_sync_lds(); gemm0(s_acc, - get_slice_tile(q_reg, + get_slice_tile(q_reg_tensor, Sequence<0, (k0_loops - 2) * kK0PerBlock>{}, Sequence{}), k_lds_window); block_sync_lds(); - set_slice_tile(q_reg, + set_slice_tile(q_reg_tensor, q_block_tile, Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, Sequence{}); +#if 1 + printf("Blockid: %02d, Tid: %03d, q_thread_buf(16-23): %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<16>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<17>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<18>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<19>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<20>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<21>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<22>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<23>{}]))) + ); + printf("Blockid: %02d, Tid: %03d, q_thread_buf(24-31): %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<24>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<25>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<26>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<27>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<28>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<29>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<30>{}]))), + *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<31>{}]))) + ); +#endif store_tile(k_lds_window, k_block_tile); block_sync_lds(); gemm0(s_acc, - get_slice_tile(q_reg, + get_slice_tile(q_reg_tensor, Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, Sequence{}), k_lds_window); @@ -397,12 +458,12 @@ struct GemmSoftmaxGemmImpl store_tile(k_lds_window, k_block_tile); k_block_tile = load_tile(k_dram_window); } - + if constexpr(k0_loops > 2){ static_for<0, k0_loops - 3, 1>{}([&](auto i_k0) { block_sync_lds(); gemm0(s_acc, - get_slice_tile(q_reg, + get_slice_tile(q_reg_tensor, Sequence<0, (i_k0)*kK0PerBlock>{}, Sequence{}), k_lds_window); @@ -414,13 +475,14 @@ struct GemmSoftmaxGemmImpl store_tile(k_lds_window, k_block_tile); k_block_tile = load_tile(k_dram_window); }); + } // tail { block_sync_lds(); gemm0(s_acc, - get_slice_tile(q_reg, + get_slice_tile(q_reg_tensor, Sequence<0, (k0_loops - 2) * kK0PerBlock>{}, Sequence{}), k_lds_window); @@ -432,7 +494,7 @@ struct GemmSoftmaxGemmImpl block_sync_lds(); gemm0(s_acc, - get_slice_tile(q_reg, + get_slice_tile(q_reg_tensor, Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, Sequence{}), k_lds_window); @@ -444,7 +506,7 @@ struct GemmSoftmaxGemmImpl // prefetch load v tile const auto v_prefetch = load_tile(v_dram_window); - +#if 0 printf("Blockid: %02d, Tid: %03d, v_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x|\n", get_block_1d_id(), get_thread_local_1d_id(), *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<0>{}]))), @@ -456,7 +518,7 @@ struct GemmSoftmaxGemmImpl *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<6>{}]))), *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<7>{}]))) ); - +#endif // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, Sequence<1>{}, f_max, NumericLimits::Lowest()); From 4892b128dc4e5a74f9ba2331be43a5a6af66b695 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 25 Oct 2023 15:10:01 +0000 Subject: [PATCH 04/10] Sanity pass --- .../batched_gemm_softmax_gemm.cpp | 174 +++++----- .../gemm_softmax_gemm_impl.hpp | 311 +++++++++++------- .../block_gemm_areg_bsmem_creg_v1.hpp | 21 +- ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 2 +- .../tile_window_impl_static_distribution.hpp | 4 + 5 files changed, 310 insertions(+), 202 deletions(-) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm.cpp b/example/91_tile_program/batched_gemm_softmax_gemm.cpp index 5f1de81e8..645d5aad0 100644 --- a/example/91_tile_program/batched_gemm_softmax_gemm.cpp +++ b/example/91_tile_program/batched_gemm_softmax_gemm.cpp @@ -28,11 +28,11 @@ int main(int argc, char* argv[]) using OaccDataType = float; using ODataType = ck::half_t; - ck::index_t Batch = 16; - ck::index_t M0 = 4096; - ck::index_t N0 = 4096; - ck::index_t K0 = 128; - ck::index_t N1 = 128; + ck::index_t Batch = 64; + ck::index_t M0 = 4096; + ck::index_t N0 = 4096; + ck::index_t K0 = 128; + ck::index_t N1 = 128; ck::index_t init_method = 1; ck::index_t time_kernel = 0; @@ -46,11 +46,11 @@ int main(int argc, char* argv[]) { init_method = std::stoi(argv[1]); time_kernel = std::stoi(argv[2]); - Batch = std::stoi(argv[3]); - M0 = std::stoi(argv[4]); - N0 = std::stoi(argv[5]); - K0 = std::stoi(argv[6]); - N1 = std::stoi(argv[7]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); } std::array q_lengths{Batch, M0, K0}; @@ -133,7 +133,7 @@ int main(int argc, char* argv[]) ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); } -#if 1 +#if 0 std::cout<<"Print Q matrix"<(&(k_host(0, in, ik))))); - if (ik % 8 == 7) - { - printf("|"); - } - } - printf("\n"); - } -#endif - #if 0 std::cout<<"Print V matrix"<{-3.f, 3.f}(q_host); - ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); - ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); -#else - ck::utils::FillUniformDistribution{-3.f, 3.f}(q_host); - ck::utils::FillUniformDistribution{-3.f, 3.f}(k_host); - ck::utils::FillUniformDistribution{-3.f, 3.f}(v_host); -#endif -*/ + /* + #if 0 + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + #else + ck::utils::FillUniformDistribution{-3.f, 3.f}(q_host); + ck::utils::FillUniformDistribution{-3.f, 3.f}(k_host); + ck::utils::FillUniformDistribution{-3.f, 3.f}(v_host); + #endif + */ // reference reference_batched_gemm( @@ -211,59 +195,91 @@ int main(int argc, char* argv[]) v_buf.ToDevice(v_host.mData.data()); constexpr ck::index_t kM0PerBlock = 128; - constexpr ck::index_t kN0PerBlock = 32; + constexpr ck::index_t kN0PerBlock = 128; constexpr ck::index_t kK0PerBlock = 32; - constexpr ck::index_t kN1PerBlock = 64; + constexpr ck::index_t kN1PerBlock = 128; constexpr ck::index_t kK1PerBlock = 32; constexpr ck::index_t kBlockSize = 256; - constexpr ck::index_t kHeadDim = 64; + constexpr ck::index_t kHeadDim = 128; - ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock); + ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock); +#if 0 + std::cout<<"Print K matrix"<(&(k_host(0, in, ik))))); + if (ik % 8 == 7) + { + printf("|"); + } + } + printf("\n"); + } +#endif + +#if 0 + std::cout << "Print S matrix" << std::endl; + for(int im = 0; im < M0; im++) + { + for(int in = 0; in < N0; in++) + { + printf("%.0lf ", s_host_ref(0, im, in)); + if(in % 8 == 7) + { + printf("|"); + } + } + printf("\n"); + } +#endif std::cout << "grid size " << kGridSize << std::endl; constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD constexpr ck::index_t kWarpPerBlock = kBlockSize / warpSize; constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; - float ave_time = - launch_kernel(StreamConfig{nullptr, static_cast(time_kernel)}, - BatchedGemmSoftmaxGemm{}, - kGridSize, - kBlockSize, - 0, - static_cast(q_buf.GetDeviceBuffer()), - static_cast(k_buf.GetDeviceBuffer()), - static_cast(v_buf.GetDeviceBuffer()), - static_cast(o_buf.GetDeviceBuffer()), - M0, - N0, - K0, - N1, - Batch, - K0, // StrideQ - K0, // StrideK - N0, // StrideV - N1, // StrideO - M0 * K0, // BatchStrideQ - N0 * K0, // BatchStrideK - N1 * N0, // BatchStrideV - M0 * N1); // BatchStrideO + float ave_time = launch_kernel( + StreamConfig{nullptr, static_cast(time_kernel)}, + BatchedGemmSoftmaxGemm{}, + kGridSize, + kBlockSize, + 0, + static_cast(q_buf.GetDeviceBuffer()), + static_cast(k_buf.GetDeviceBuffer()), + static_cast(v_buf.GetDeviceBuffer()), + static_cast(o_buf.GetDeviceBuffer()), + M0, + N0, + K0, + N1, + Batch, + K0, // StrideQ + K0, // StrideK + N0, // StrideV + N1, // StrideO + M0 * K0, // BatchStrideQ + N0 * K0, // BatchStrideK + N1 * N0, // BatchStrideV + M0 * N1); // BatchStrideO o_buf.FromDevice(o_host_dev.mData.data()); diff --git a/example/91_tile_program/gemm_softmax_gemm_impl.hpp b/example/91_tile_program/gemm_softmax_gemm_impl.hpp index e94b525a0..778cb46b1 100644 --- a/example/91_tile_program/gemm_softmax_gemm_impl.hpp +++ b/example/91_tile_program/gemm_softmax_gemm_impl.hpp @@ -290,34 +290,9 @@ struct GemmSoftmaxGemmImpl // Cold Q_Reg_Cache auto q_block_tile = load_tile(q_dram_window); -#if 0 - printf("Blockid: %02d, Tid: %03d, q_thread_buf(0-7): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<3>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<4>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<5>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<6>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<7>{}]))) - ); - - printf("Blockid: %02d, Tid: %03d, q_thread_buf(8-15): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<8>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<9>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<10>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<11>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<12>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<13>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<14>{}]))), - *(reinterpret_cast(&(q_block_tile.GetThreadBuffer()[Number<15>{}]))) - ); -#endif auto k_block_tile = load_tile(k_dram_window); #if 0 - printf("Blockid: %02d, Tid: %03d, k_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x|\n", + printf("Blockid: %02d, Tid: %03d, k_thread_buf(0-7): %04x %04x %04x %04x %04x %04x %04x %04x|\n", get_block_1d_id(), get_thread_local_1d_id(), *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), @@ -338,58 +313,61 @@ struct GemmSoftmaxGemmImpl set_slice_tile( q_reg_tensor, q_block_tile, Sequence<0, 0>{}, Sequence{}); q_block_tile = load_tile(q_dram_window); -#if 1 - printf("Blockid: %02d, Tid: %03d, q_thread_buf(0-7): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<3>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<4>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<5>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<6>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<7>{}]))) - ); - printf("Blockid: %02d, Tid: %03d, q_thread_buf(8-15): %04x %04x %04x %04x %04x %04x %04x %04x|\n", + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); +#if 0 + printf("Blockid: %02d, Tid: %03d, k_thread_buf(8-15): %04x %04x %04x %04x %04x %04x %04x %04x|\n", get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<8>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<9>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<10>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<11>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<12>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<13>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<14>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<15>{}]))) + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) ); #endif - store_tile(k_lds_window, k_block_tile); - k_block_tile = load_tile(k_dram_window); } - if constexpr(k0_loops > 2){ - static_for<0, k0_loops - 3, 1>{}([&](auto i_k0) { - block_sync_lds(); + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); - gemm0(s_acc, - get_slice_tile(q_reg_tensor, - Sequence<0, (i_k0)*kK0PerBlock>{}, - Sequence{}), - k_lds_window); + gemm0(s_acc, + get_slice_tile(q_reg_tensor, + Sequence<0, (i_k0)*kK0PerBlock>{}, + Sequence{}), + k_lds_window); - block_sync_lds(); + block_sync_lds(); - move_tile_window(q_dram_window, {0, kK0PerBlock}); - move_tile_window(k_dram_window, {0, kK0PerBlock}); + move_tile_window(q_dram_window, {0, kK0PerBlock}); + move_tile_window(k_dram_window, {0, kK0PerBlock}); - set_slice_tile(q_reg_tensor, - q_block_tile, - Sequence<0, (i_k0 + 1) * kK0PerBlock>{}, - Sequence{}); - q_block_tile = load_tile(q_dram_window); + set_slice_tile(q_reg_tensor, + q_block_tile, + Sequence<0, (i_k0 + 1) * kK0PerBlock>{}, + Sequence{}); + q_block_tile = load_tile(q_dram_window); - store_tile(k_lds_window, k_block_tile); - k_block_tile = load_tile(k_dram_window); - }); + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); +#if 0 + printf("Blockid: %02d, Tid: %03d, k_thread_buf(16-31): %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) + ); +#endif + }); } // tail @@ -408,31 +386,7 @@ struct GemmSoftmaxGemmImpl q_block_tile, Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, Sequence{}); -#if 1 - printf("Blockid: %02d, Tid: %03d, q_thread_buf(16-23): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<16>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<17>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<18>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<19>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<20>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<21>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<22>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<23>{}]))) - ); - printf("Blockid: %02d, Tid: %03d, q_thread_buf(24-31): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<24>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<25>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<26>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<27>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<28>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<29>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<30>{}]))), - *(reinterpret_cast(&(q_reg_tensor.GetThreadBuffer()[Number<31>{}]))) - ); -#endif store_tile(k_lds_window, k_block_tile); block_sync_lds(); @@ -442,6 +396,59 @@ struct GemmSoftmaxGemmImpl Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, Sequence{}), k_lds_window); +#if 0 + printf("gemm:01, Blockid: %02d, Tid: %03d, s(0-7): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf " + "%.0lf %.0lf|\n", + get_block_1d_id(), + get_thread_local_1d_id(), + s_acc.GetThreadBuffer()[Number<0>{}], + s_acc.GetThreadBuffer()[Number<1>{}], + s_acc.GetThreadBuffer()[Number<2>{}], + s_acc.GetThreadBuffer()[Number<3>{}], + s_acc.GetThreadBuffer()[Number<4>{}], + s_acc.GetThreadBuffer()[Number<5>{}], + s_acc.GetThreadBuffer()[Number<6>{}], + s_acc.GetThreadBuffer()[Number<7>{}]); + + printf("gemm:01, Blockid: %02d, Tid: %03d, s(8-15): %.0lf %.0lf %.0lf %.0lf %.0lf " + "%.0lf %.0lf %.0lf|\n", + get_block_1d_id(), + get_thread_local_1d_id(), + s_acc.GetThreadBuffer()[Number<8 + 0>{}], + s_acc.GetThreadBuffer()[Number<8 + 1>{}], + s_acc.GetThreadBuffer()[Number<8 + 2>{}], + s_acc.GetThreadBuffer()[Number<8 + 3>{}], + s_acc.GetThreadBuffer()[Number<8 + 4>{}], + s_acc.GetThreadBuffer()[Number<8 + 5>{}], + s_acc.GetThreadBuffer()[Number<8 + 6>{}], + s_acc.GetThreadBuffer()[Number<8 + 7>{}]); + + printf("gemm:01, Blockid: %02d, Tid: %03d, s(16-23): %.0lf %.0lf %.0lf %.0lf %.0lf " + "%.0lf %.0lf %.0lf|\n", + get_block_1d_id(), + get_thread_local_1d_id(), + s_acc.GetThreadBuffer()[Number<8 + 8 + 0>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 1>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 2>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 3>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 4>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 5>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 6>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 7>{}]); + + printf("gemm:01, Blockid: %02d, Tid: %03d, s(24-31): %.0lf %.0lf %.0lf %.0lf %.0lf " + "%.0lf %.0lf %.0lf|\n", + get_block_1d_id(), + get_thread_local_1d_id(), + s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 0>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 1>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 2>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 3>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 4>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 5>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 6>{}], + s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 7>{}]); +#endif } do @@ -450,6 +457,19 @@ struct GemmSoftmaxGemmImpl if(iN0 > 0) { k_block_tile = load_tile(k_dram_window); +#if 0 + printf("iN0==1, Blockid: %02d, Tid: %03d, k_block_tile(0-7): %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) + ); +#endif { move_tile_window(k_dram_window, {0, kK0PerBlock}); @@ -457,30 +477,43 @@ struct GemmSoftmaxGemmImpl store_tile(k_lds_window, k_block_tile); k_block_tile = load_tile(k_dram_window); +#if 0 + printf("iN0==1, Blockid: %02d, Tid: %03d, k_block_tile(8-15): %04x %04x %04x %04x %04x %04x %04x %04x|\n", + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), + *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) + ); +#endif } - if constexpr(k0_loops > 2){ - static_for<0, k0_loops - 3, 1>{}([&](auto i_k0) { - block_sync_lds(); + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); - gemm0(s_acc, - get_slice_tile(q_reg_tensor, - Sequence<0, (i_k0)*kK0PerBlock>{}, - Sequence{}), - k_lds_window); + gemm0(s_acc, + get_slice_tile(q_reg_tensor, + Sequence<0, (i_k0)*kK0PerBlock>{}, + Sequence{}), + k_lds_window); - block_sync_lds(); + block_sync_lds(); - move_tile_window(k_dram_window, {0, kK0PerBlock}); + move_tile_window(k_dram_window, {0, kK0PerBlock}); - store_tile(k_lds_window, k_block_tile); - k_block_tile = load_tile(k_dram_window); - }); + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); + }); } // tail { block_sync_lds(); - gemm0(s_acc, get_slice_tile(q_reg_tensor, Sequence<0, (k0_loops - 2) * kK0PerBlock>{}, @@ -499,26 +532,64 @@ struct GemmSoftmaxGemmImpl Sequence{}), k_lds_window); } + + // asm volatile("s_endpgm" ::); } // S{j} const auto s = tile_elementwise_in(type_convert, s_acc); - - // prefetch load v tile - const auto v_prefetch = load_tile(v_dram_window); #if 0 - printf("Blockid: %02d, Tid: %03d, v_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<3>{}]))), - *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<4>{}]))), - *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<5>{}]))), - *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<6>{}]))), - *(reinterpret_cast(&(v_prefetch.GetThreadBuffer()[Number<7>{}]))) + printf("Nloop:%02d, Blockid: %02d, Tid: %03d, s(0-7): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf|\n", + iN0, get_block_1d_id(), get_thread_local_1d_id(), + s.GetThreadBuffer()[Number<0>{}], + s.GetThreadBuffer()[Number<1>{}], + s.GetThreadBuffer()[Number<2>{}], + s.GetThreadBuffer()[Number<3>{}], + s.GetThreadBuffer()[Number<4>{}], + s.GetThreadBuffer()[Number<5>{}], + s.GetThreadBuffer()[Number<6>{}], + s.GetThreadBuffer()[Number<7>{}] + ); + + printf("Nloop:%02d, Blockid: %02d, Tid: %03d, s(8-15): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf|\n", + iN0, get_block_1d_id(), get_thread_local_1d_id(), + s.GetThreadBuffer()[Number<8+0>{}], + s.GetThreadBuffer()[Number<8+1>{}], + s.GetThreadBuffer()[Number<8+2>{}], + s.GetThreadBuffer()[Number<8+3>{}], + s.GetThreadBuffer()[Number<8+4>{}], + s.GetThreadBuffer()[Number<8+5>{}], + s.GetThreadBuffer()[Number<8+6>{}], + s.GetThreadBuffer()[Number<8+7>{}] + ); + + printf("Nloop:%02d, Blockid: %02d, Tid: %03d, s(16-23): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf|\n", + iN0, get_block_1d_id(), get_thread_local_1d_id(), + s.GetThreadBuffer()[Number<8+8+0>{}], + s.GetThreadBuffer()[Number<8+8+1>{}], + s.GetThreadBuffer()[Number<8+8+2>{}], + s.GetThreadBuffer()[Number<8+8+3>{}], + s.GetThreadBuffer()[Number<8+8+4>{}], + s.GetThreadBuffer()[Number<8+8+5>{}], + s.GetThreadBuffer()[Number<8+8+6>{}], + s.GetThreadBuffer()[Number<8+8+7>{}] + ); + + printf("Nloop:%02d, Blockid: %02d, Tid: %03d, s(24-31): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf|\n", + iN0, get_block_1d_id(), get_thread_local_1d_id(), + s.GetThreadBuffer()[Number<8+8+8+0>{}], + s.GetThreadBuffer()[Number<8+8+8+1>{}], + s.GetThreadBuffer()[Number<8+8+8+2>{}], + s.GetThreadBuffer()[Number<8+8+8+3>{}], + s.GetThreadBuffer()[Number<8+8+8+4>{}], + s.GetThreadBuffer()[Number<8+8+8+5>{}], + s.GetThreadBuffer()[Number<8+8+8+6>{}], + s.GetThreadBuffer()[Number<8+8+8+7>{}] ); #endif + // prefetch load v tile + const auto v_prefetch = load_tile(v_dram_window); + // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, Sequence<1>{}, f_max, NumericLimits::Lowest()); @@ -609,7 +680,7 @@ struct GemmSoftmaxGemmImpl block_sync_lds(); } // move tile windows - move_tile_window(k_dram_window, {kN0PerBlock, 0}); + move_tile_window(k_dram_window, {kN0PerBlock, -(k0_loops - 1) * kK0PerBlock}); iN0 += kN0PerBlock; } while(iN0 < N0); diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp index cc9637d28..79c694863 100644 --- a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp @@ -178,11 +178,28 @@ struct BlockGemmARegBSmemCRegV1 a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetYSlicedThreadData( merge_sequences(Sequence{}, a_warp_y_index_zeros), merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths)); - +#if 0 + printf("(m: %02d, k:%02d) Blockid: %02d, Tid: %03d, a_thread_buf: %04x %04x %04x %04x|\n", + mIter.value, kIter.value, + get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(a_warp_tensor.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(a_warp_tensor.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(a_warp_tensor.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(a_warp_tensor.GetThreadBuffer()[Number<3>{}]))) + ); +#endif static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - +#if 0 + printf("(m: %02d, n: %02d, k:%02d) Blockid: %02d, Tid: %03d, b_thread_buf: %04x %04x %04x %04x|\n", + mIter.value, nIter.value, kIter.value, get_block_1d_id(), get_thread_local_1d_id(), + *(reinterpret_cast(&(b_warp_tensor.GetThreadBuffer()[Number<0>{}]))), + *(reinterpret_cast(&(b_warp_tensor.GetThreadBuffer()[Number<1>{}]))), + *(reinterpret_cast(&(b_warp_tensor.GetThreadBuffer()[Number<2>{}]))), + *(reinterpret_cast(&(b_warp_tensor.GetThreadBuffer()[Number<3>{}]))) + ); +#endif // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; diff --git a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index aeb8a3278..846aed738 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -40,7 +40,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy template __host__ __device__ static constexpr auto MakeBDramTileDistribution() { - return policy_impl::MakeADramTileDistribution(); + return policy_impl::MakeBDramTileDistribution(); } template diff --git a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp index 053550692..690e92764 100644 --- a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp +++ b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp @@ -287,6 +287,8 @@ struct TileWindowWithStaticDistribution const vector_t vec_value = GetBottomTensorView().template GetVectorizedElements( bottom_tensor_thread_coord); + // printf("Blockid: %02d, Tid: %03d, K read to lds: %05d\n", get_block_1d_id(), + // get_thread_local_1d_id(), bottom_tensor_thread_coord.GetOffset()); const vector_type_t vec{vec_value}; @@ -366,6 +368,8 @@ struct TileWindowWithStaticDistribution // write into bottom tensor GetBottomTensorView().template SetVectorizedElements( bottom_tensor_thread_coord, vec_value); + // printf("Blockid: %02d, Tid: %03d, K write to lds: %05d\n", get_block_1d_id(), + // get_thread_local_1d_id(), bottom_tensor_thread_coord.GetOffset()); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) From 00cbe505d372a04f2e2e6c703de26378b5434d66 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 25 Oct 2023 15:34:28 +0000 Subject: [PATCH 05/10] Clean debug code --- .../gemm_softmax_gemm_impl.hpp | 168 +----------------- .../block_gemm_areg_bsmem_creg_v1.hpp | 20 +-- script/clang-format-overwrite.sh | 2 +- 3 files changed, 3 insertions(+), 187 deletions(-) diff --git a/example/91_tile_program/gemm_softmax_gemm_impl.hpp b/example/91_tile_program/gemm_softmax_gemm_impl.hpp index 778cb46b1..69cf42a35 100644 --- a/example/91_tile_program/gemm_softmax_gemm_impl.hpp +++ b/example/91_tile_program/gemm_softmax_gemm_impl.hpp @@ -291,19 +291,6 @@ struct GemmSoftmaxGemmImpl // Cold Q_Reg_Cache auto q_block_tile = load_tile(q_dram_window); auto k_block_tile = load_tile(k_dram_window); -#if 0 - printf("Blockid: %02d, Tid: %03d, k_thread_buf(0-7): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) - ); -#endif { move_tile_window(q_dram_window, {0, kK0PerBlock}); move_tile_window(k_dram_window, {0, kK0PerBlock}); @@ -316,19 +303,6 @@ struct GemmSoftmaxGemmImpl store_tile(k_lds_window, k_block_tile); k_block_tile = load_tile(k_dram_window); -#if 0 - printf("Blockid: %02d, Tid: %03d, k_thread_buf(8-15): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) - ); -#endif } if constexpr(k0_loops > 2) { @@ -354,19 +328,6 @@ struct GemmSoftmaxGemmImpl store_tile(k_lds_window, k_block_tile); k_block_tile = load_tile(k_dram_window); -#if 0 - printf("Blockid: %02d, Tid: %03d, k_thread_buf(16-31): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) - ); -#endif }); } @@ -396,59 +357,6 @@ struct GemmSoftmaxGemmImpl Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, Sequence{}), k_lds_window); -#if 0 - printf("gemm:01, Blockid: %02d, Tid: %03d, s(0-7): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf " - "%.0lf %.0lf|\n", - get_block_1d_id(), - get_thread_local_1d_id(), - s_acc.GetThreadBuffer()[Number<0>{}], - s_acc.GetThreadBuffer()[Number<1>{}], - s_acc.GetThreadBuffer()[Number<2>{}], - s_acc.GetThreadBuffer()[Number<3>{}], - s_acc.GetThreadBuffer()[Number<4>{}], - s_acc.GetThreadBuffer()[Number<5>{}], - s_acc.GetThreadBuffer()[Number<6>{}], - s_acc.GetThreadBuffer()[Number<7>{}]); - - printf("gemm:01, Blockid: %02d, Tid: %03d, s(8-15): %.0lf %.0lf %.0lf %.0lf %.0lf " - "%.0lf %.0lf %.0lf|\n", - get_block_1d_id(), - get_thread_local_1d_id(), - s_acc.GetThreadBuffer()[Number<8 + 0>{}], - s_acc.GetThreadBuffer()[Number<8 + 1>{}], - s_acc.GetThreadBuffer()[Number<8 + 2>{}], - s_acc.GetThreadBuffer()[Number<8 + 3>{}], - s_acc.GetThreadBuffer()[Number<8 + 4>{}], - s_acc.GetThreadBuffer()[Number<8 + 5>{}], - s_acc.GetThreadBuffer()[Number<8 + 6>{}], - s_acc.GetThreadBuffer()[Number<8 + 7>{}]); - - printf("gemm:01, Blockid: %02d, Tid: %03d, s(16-23): %.0lf %.0lf %.0lf %.0lf %.0lf " - "%.0lf %.0lf %.0lf|\n", - get_block_1d_id(), - get_thread_local_1d_id(), - s_acc.GetThreadBuffer()[Number<8 + 8 + 0>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 1>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 2>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 3>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 4>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 5>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 6>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 7>{}]); - - printf("gemm:01, Blockid: %02d, Tid: %03d, s(24-31): %.0lf %.0lf %.0lf %.0lf %.0lf " - "%.0lf %.0lf %.0lf|\n", - get_block_1d_id(), - get_thread_local_1d_id(), - s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 0>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 1>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 2>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 3>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 4>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 5>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 6>{}], - s_acc.GetThreadBuffer()[Number<8 + 8 + 8 + 7>{}]); -#endif } do @@ -457,19 +365,6 @@ struct GemmSoftmaxGemmImpl if(iN0 > 0) { k_block_tile = load_tile(k_dram_window); -#if 0 - printf("iN0==1, Blockid: %02d, Tid: %03d, k_block_tile(0-7): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) - ); -#endif { move_tile_window(k_dram_window, {0, kK0PerBlock}); @@ -477,19 +372,6 @@ struct GemmSoftmaxGemmImpl store_tile(k_lds_window, k_block_tile); k_block_tile = load_tile(k_dram_window); -#if 0 - printf("iN0==1, Blockid: %02d, Tid: %03d, k_block_tile(8-15): %04x %04x %04x %04x %04x %04x %04x %04x|\n", - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<3>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<4>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<5>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<6>{}]))), - *(reinterpret_cast(&(k_block_tile.GetThreadBuffer()[Number<7>{}]))) - ); -#endif } if constexpr(k0_loops > 2) { @@ -538,55 +420,7 @@ struct GemmSoftmaxGemmImpl // S{j} const auto s = tile_elementwise_in(type_convert, s_acc); -#if 0 - printf("Nloop:%02d, Blockid: %02d, Tid: %03d, s(0-7): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf|\n", - iN0, get_block_1d_id(), get_thread_local_1d_id(), - s.GetThreadBuffer()[Number<0>{}], - s.GetThreadBuffer()[Number<1>{}], - s.GetThreadBuffer()[Number<2>{}], - s.GetThreadBuffer()[Number<3>{}], - s.GetThreadBuffer()[Number<4>{}], - s.GetThreadBuffer()[Number<5>{}], - s.GetThreadBuffer()[Number<6>{}], - s.GetThreadBuffer()[Number<7>{}] - ); - - printf("Nloop:%02d, Blockid: %02d, Tid: %03d, s(8-15): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf|\n", - iN0, get_block_1d_id(), get_thread_local_1d_id(), - s.GetThreadBuffer()[Number<8+0>{}], - s.GetThreadBuffer()[Number<8+1>{}], - s.GetThreadBuffer()[Number<8+2>{}], - s.GetThreadBuffer()[Number<8+3>{}], - s.GetThreadBuffer()[Number<8+4>{}], - s.GetThreadBuffer()[Number<8+5>{}], - s.GetThreadBuffer()[Number<8+6>{}], - s.GetThreadBuffer()[Number<8+7>{}] - ); - - printf("Nloop:%02d, Blockid: %02d, Tid: %03d, s(16-23): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf|\n", - iN0, get_block_1d_id(), get_thread_local_1d_id(), - s.GetThreadBuffer()[Number<8+8+0>{}], - s.GetThreadBuffer()[Number<8+8+1>{}], - s.GetThreadBuffer()[Number<8+8+2>{}], - s.GetThreadBuffer()[Number<8+8+3>{}], - s.GetThreadBuffer()[Number<8+8+4>{}], - s.GetThreadBuffer()[Number<8+8+5>{}], - s.GetThreadBuffer()[Number<8+8+6>{}], - s.GetThreadBuffer()[Number<8+8+7>{}] - ); - - printf("Nloop:%02d, Blockid: %02d, Tid: %03d, s(24-31): %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf %.0lf|\n", - iN0, get_block_1d_id(), get_thread_local_1d_id(), - s.GetThreadBuffer()[Number<8+8+8+0>{}], - s.GetThreadBuffer()[Number<8+8+8+1>{}], - s.GetThreadBuffer()[Number<8+8+8+2>{}], - s.GetThreadBuffer()[Number<8+8+8+3>{}], - s.GetThreadBuffer()[Number<8+8+8+4>{}], - s.GetThreadBuffer()[Number<8+8+8+5>{}], - s.GetThreadBuffer()[Number<8+8+8+6>{}], - s.GetThreadBuffer()[Number<8+8+8+7>{}] - ); -#endif + // prefetch load v tile const auto v_prefetch = load_tile(v_dram_window); diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp index 79c694863..dba2f2c9f 100644 --- a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp @@ -178,28 +178,10 @@ struct BlockGemmARegBSmemCRegV1 a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetYSlicedThreadData( merge_sequences(Sequence{}, a_warp_y_index_zeros), merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths)); -#if 0 - printf("(m: %02d, k:%02d) Blockid: %02d, Tid: %03d, a_thread_buf: %04x %04x %04x %04x|\n", - mIter.value, kIter.value, - get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(a_warp_tensor.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(a_warp_tensor.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(a_warp_tensor.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(a_warp_tensor.GetThreadBuffer()[Number<3>{}]))) - ); -#endif + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); -#if 0 - printf("(m: %02d, n: %02d, k:%02d) Blockid: %02d, Tid: %03d, b_thread_buf: %04x %04x %04x %04x|\n", - mIter.value, nIter.value, kIter.value, get_block_1d_id(), get_thread_local_1d_id(), - *(reinterpret_cast(&(b_warp_tensor.GetThreadBuffer()[Number<0>{}]))), - *(reinterpret_cast(&(b_warp_tensor.GetThreadBuffer()[Number<1>{}]))), - *(reinterpret_cast(&(b_warp_tensor.GetThreadBuffer()[Number<2>{}]))), - *(reinterpret_cast(&(b_warp_tensor.GetThreadBuffer()[Number<3>{}]))) - ); -#endif // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index da83254f0..f8f28499e 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' +# find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' From f1a92556d26a85a8530ff778313f57b1869eaefb Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 25 Oct 2023 15:38:11 +0000 Subject: [PATCH 06/10] recover unnecessary change --- script/clang-format-overwrite.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index f8f28499e..da83254f0 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -# find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' +#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' From 70b5e6359bbb70bdb61a791bef1431bdd4248652 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 26 Oct 2023 16:58:25 +0000 Subject: [PATCH 07/10] temp save, refactor --- .../batched_gemm_softmax_gemm.hpp | 3 +- .../flash_attention_fwd_impl.hpp | 423 ++++++++++++++++++ .../gemm_softmax_gemm_impl.hpp | 407 +++-------------- ..._pipeline_agmem_bgmem_creg_v2_askiplds.hpp | 328 ++++++++++++++ ...ne_agmem_bgmem_creg_v2_askiplds_policy.hpp | 15 +- ..._pipeline_agmem_bgmem_creg_policy_impl.hpp | 14 +- .../tile_window_impl_static_distribution.hpp | 4 - 7 files changed, 828 insertions(+), 366 deletions(-) create mode 100644 example/91_tile_program/flash_attention_fwd_impl.hpp diff --git a/example/91_tile_program/batched_gemm_softmax_gemm.hpp b/example/91_tile_program/batched_gemm_softmax_gemm.hpp index 8dfe0bf95..2b485d7ac 100644 --- a/example/91_tile_program/batched_gemm_softmax_gemm.hpp +++ b/example/91_tile_program/batched_gemm_softmax_gemm.hpp @@ -18,6 +18,7 @@ #include "ck/tile_program/block_tile/block_reduce.hpp" #include "gemm_softmax_gemm_impl.hpp" +#include "flash_attention_fwd_impl.hpp" // S[M0, N0] = Q[M0, K0] * K[N0, K0] // P[M0, N0] = Softmax(S[M0, N0]) @@ -79,7 +80,7 @@ struct BatchedGemmSoftmaxGemm const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); - const auto kernel_impl = GemmSoftmaxGemmImpl +struct FlashAttentionFwdImpl +{ + // block gemm0 pipeline + using BlockGemm0Problem = ck::tile_program::block::BlockGemmPipelineProblem< + QDataType, + KDataType, + SaccDataType, + kBlockSize, + ck::tile_program::TileGemmShape>; + + using BlockGemm0Policy = ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy; + + using BlockGemm0Pipeline = ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2< + BlockGemm0Problem, + BlockGemm0Policy>; + + // block gemm1 + using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1< + ck::tile_program::block::BlockGemmARegBSmemCRegV1Problem< + PDataType, + VDataType, + OaccDataType, + kBlockSize, + ck::tile_program::TileGemmShape>, + ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>; + +#if 0 + // 2d + __device__ static constexpr auto MakeVLdsBlockDescriptor() + { + using namespace ck; + + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kN0PerBlock; + + constexpr auto b_lds_desc = + make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{}); + + return b_lds_desc; + } +#elif 0 + // fake XOR + __device__ static constexpr auto MakeVLdsBlockDescriptor() + { + using namespace ck; + + using BDataType = VDataType; + + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kN0PerBlock; + + constexpr auto b_lds_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( + make_tuple(kNPerBlock / 2, 2, kKPerBlock), Number{}); + + constexpr index_t kK1 = 16 / sizeof(BDataType); + + constexpr auto b_lds_desc_d4_d5_d6 = transform_tensor_descriptor( + b_lds_desc_d1_d2_d3, + make_tuple(make_xor_transform(make_tuple(kNPerBlock / 2, kKPerBlock), kK1), + make_pass_through_transform(2)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + constexpr auto b_lds_desc_n_k = transform_tensor_descriptor( + b_lds_desc_d4_d5_d6, + make_tuple(make_merge_transform(make_tuple(kNPerBlock / 2, 2)), + make_pass_through_transform(kKPerBlock)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return b_lds_desc_n_k; + } +#else + // 3d, with padding + __device__ static constexpr auto MakeVLdsBlockDescriptor() + { + using namespace ck; + + // using BDataType = B1DataType; + + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kK1PerBlock; + constexpr index_t kPad = 1; + constexpr index_t kK1 = 8; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number{}), + make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number{}, Number<1>{}), + Number{}, + Number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return b_lds_block_desc; + } +#endif + + __device__ static constexpr auto MakeVDramTileDistribution() + { + using namespace ck; + using namespace ck::tile_program; + + using BDataType = VDataType; + + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kK1PerBlock; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + + __device__ static constexpr ck::index_t GetStaticLdsSize() + { + using namespace ck; + + return math::max(BlockGemm0Pipeline::GetStaticLdsSize(), + static_cast(MakeVLdsBlockDescriptor().GetElementSpaceSize() * + sizeof(VDataType))); + } + + __device__ void operator()(const QDataType* q_ptr, + const KDataType* k_ptr, + const VDataType* v_ptr, + ODataType* o_ptr, + const ck::index_t M0, + const ck::index_t N0, + const ck::index_t K0, + const ck::index_t N1, + const ck::index_t StrideQ, + const ck::index_t StrideK, + const ck::index_t StrideV, + const ck::index_t StrideO, + const ck::index_t iM0, + const ck::index_t iN1) const + { + using namespace ck; + using namespace ck::tile_program; + using namespace ck::tile_program::block; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // allocate LDS + __shared__ char smem_ptr[GetStaticLdsSize()]; + + // Q/K/V DRAM and DRAM window + // FIXME: assume layout Q[M0, K0], K[N0, K0], V[N1, N0], O[M0, N1] + const auto q_dram = make_naive_tensor_view( + q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), Number<32>{}, Number<1>{}); + + const auto k_dram = make_naive_tensor_view( + k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), Number<32>{}, Number<1>{}); + + const auto v_dram = make_naive_tensor_view( + v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), Number<32>{}, Number<1>{}); + + auto q_dram_window = + make_tile_window(q_dram, + make_tuple(Number{}, Number{}), + {iM0, 0}); + + auto k_dram_window = + make_tile_window(k_dram, + make_tuple(Number{}, Number{}), + {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(Number{}, Number{}), + {iN1, 0}, + MakeVDramTileDistribution()); + + // Q in Register + auto q_reg_tensor = make_static_distributed_tensor( + BlockGemm0Policy::template MakeARegBlockDescriptor()); + + // V LDS and LDS window + // V LDS occupies the same LDS allocation Q/K LDS + auto v_lds = make_tensor_view(reinterpret_cast(smem_ptr), + MakeVLdsBlockDescriptor()); + + auto v_lds_window = make_tile_window( + v_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SaccBlockTileType = decltype(gemm0_pipeline( + q_dram_window, k_dram_window, q_reg_tensor, nullptr)); + + using SBlockTileType = decltype(tile_elementwise_in( + type_convert, SaccBlockTileType{})); + + using PBlockTileType = decltype(tile_elementwise_in(type_convert, + SaccBlockTileType{})); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm1( + get_slice_tile( + PBlockTileType{}, Sequence<0, 0>{}, Sequence{}), + v_dram_window)); + + // init Sacc, Oacc, M, L + auto s_acc = SaccBlockTileType{}; + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); + tile_elementwise_inout([](auto& e) { e = NumericLimits::Lowest(); }, + m); + tile_elementwise_inout([](auto& e) { e = 0; }, l); + + // loop over Column of S (J loop) + index_t iN0 = 0; + + // Cold Q_Reg_Cache + s_acc = gemm0_pipeline( + q_dram_window, k_dram_window, q_reg_tensor, smem_ptr); + do + { + // Hot Q_Reg_Cache + if(iN0 > 0) + { + s_acc = gemm0_pipeline( + k_dram_window, q_reg_tensor, smem_ptr); + } + // S{j} + const auto s = + tile_elementwise_in(type_convert, s_acc); + + // prefetch load v tile + const auto v_prefetch = load_tile(v_dram_window); + + // m_local = rowmax(S{j}) + auto m_local = block_tile_reduce( + s, Sequence<1>{}, f_max, NumericLimits::Lowest()); + + block_tile_reduce_sync(m_local, f_max); + + // m{j-1} + const auto m_old = m; + + // m{j} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + // Pcompute{j} + auto p_compute = + make_static_distributed_tensor(s.GetTileDistribution()); + + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); + + sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + sweep_tile_span(p_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); + }); + }); + + // rowsum(Pcompute{j}) + auto rowsum_p = block_tile_reduce( + p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum); + + // l{j}, Oacc{j} + sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); + + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + + sweep_tile_span(p_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + // FIXME: this use different equation from FA v2 paper, + // but produce correct result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + store_tile(v_lds_window, v_prefetch); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + Sequence<0, i_k1 * kK1PerBlock>{}, + Sequence{}), + v_lds_window); + block_sync_lds(); + store_tile(v_lds_window, v); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + }); + } + // tail + { + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + Sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + Sequence{}), + v_lds_window); + block_sync_lds(); + } + // move tile windows + move_tile_window(k_dram_window, {kN0PerBlock, 0}); + iN0 += kN0PerBlock; + } while(iN0 < N0); + + // Oacc + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto tmp = 1 / l[i_idx]; + + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + + // type cast Oacc into O + const auto o = tile_elementwise_in(type_convert, o_acc); + + // O DRAM and O DRAM window + auto o_dram = make_naive_tensor_view( + o_ptr, make_tuple(M0, N1), make_tuple(StrideO, 1), Number<32>{}, Number<1>{}); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(Number{}, Number{}), + {iM0, iN1}, + o.GetTileDistribution()); + + // store O + store_tile(o_dram_window, o); + } +}; diff --git a/example/91_tile_program/gemm_softmax_gemm_impl.hpp b/example/91_tile_program/gemm_softmax_gemm_impl.hpp index 69cf42a35..5c263732b 100644 --- a/example/91_tile_program/gemm_softmax_gemm_impl.hpp +++ b/example/91_tile_program/gemm_softmax_gemm_impl.hpp @@ -9,14 +9,14 @@ #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tile_program/tile/tile_distribution.hpp" -#include "ck/tile_program/tile/tile_elementwise.hpp" #include "ck/tile_program/tile/tile_gemm_shape.hpp" -#include "ck/tile_program/tile/slice_tile.hpp" -#include "ck/tile_program/warp_tile/warp_gemm.hpp" -#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" -#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" #include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" -#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_problem.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp" // S[M0, N0] = Q[M0, K0] * K[N0, K0] @@ -31,146 +31,36 @@ template + ck::index_t kN1PerBlock> struct GemmSoftmaxGemmImpl { - // block gemm0 - using BlockGemm0Problem = ck::tile_program::block::BlockGemmPipelineProblem< - QDataType, - KDataType, - SaccDataType, - kBlockSize, - ck::tile_program::TileGemmShape>; - using BlockGemm0Policy = - ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy; - using BlockGemm0 = decltype(BlockGemm0Policy::GetBlockGemm()); + // block gemm0 pipeline + using BlockGemm0Pipeline = ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2< + ck::tile_program::block::BlockGemmPipelineProblem< + QDataType, + KDataType, + SaccDataType, + kBlockSize, + ck::tile_program::TileGemmShape>, + ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>; // block gemm1 - using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1< - ck::tile_program::block::BlockGemmARegBSmemCRegV1Problem< + using BlockGemm1 = ck::tile_program::block::BlockGemmARegBGmemCRegV1< + ck::tile_program::block::BlockGemmARegBGmemCRegProblem< PDataType, VDataType, OaccDataType, kBlockSize, - ck::tile_program::TileGemmShape>, - ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>; - -#if 0 - // 2d - __device__ static constexpr auto MakeVLdsBlockDescriptor() - { - using namespace ck; - - constexpr index_t kNPerBlock = kN1PerBlock; - constexpr index_t kKPerBlock = kN0PerBlock; - - constexpr auto b_lds_desc = - make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{}); - - return b_lds_desc; - } -#elif 0 - // fake XOR - __device__ static constexpr auto MakeVLdsBlockDescriptor() - { - using namespace ck; - - using BDataType = VDataType; - - constexpr index_t kNPerBlock = kN1PerBlock; - constexpr index_t kKPerBlock = kN0PerBlock; - - constexpr auto b_lds_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( - make_tuple(kNPerBlock / 2, 2, kKPerBlock), Number{}); - - constexpr index_t kK1 = 16 / sizeof(BDataType); - - constexpr auto b_lds_desc_d4_d5_d6 = transform_tensor_descriptor( - b_lds_desc_d1_d2_d3, - make_tuple(make_xor_transform(make_tuple(kNPerBlock / 2, kKPerBlock), kK1), - make_pass_through_transform(2)), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - constexpr auto b_lds_desc_n_k = transform_tensor_descriptor( - b_lds_desc_d4_d5_d6, - make_tuple(make_merge_transform(make_tuple(kNPerBlock / 2, 2)), - make_pass_through_transform(kKPerBlock)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return b_lds_desc_n_k; - } -#else - // 3d, with padding - __device__ static constexpr auto MakeVLdsBlockDescriptor() - { - using namespace ck; - - // using BDataType = B1DataType; - - constexpr index_t kNPerBlock = kN1PerBlock; - constexpr index_t kKPerBlock = kK1PerBlock; - constexpr index_t kPad = 1; - constexpr index_t kK1 = 8; - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number{}), - make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number{}, Number<1>{}), - Number{}, - Number<1>{}); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return b_lds_block_desc; - } -#endif - - __device__ static constexpr auto MakeVDramTileDistribution() - { - using namespace ck; - using namespace ck::tile_program; - - using BDataType = VDataType; - - constexpr index_t kNPerBlock = kN1PerBlock; - constexpr index_t kKPerBlock = kK1PerBlock; - - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); - } + ck::tile_program::TileGemmShape>, + ck::tile_program::block::BlockGemmARegBGmemCRegV1DefaultPolicy>; __device__ static constexpr ck::index_t GetStaticLdsSize() { - using namespace ck; - - return math::max(static_cast( - BlockGemm0Policy::template MakeBLdsBlockDescriptor() - .GetElementSpaceSize() * - sizeof(KDataType)), - static_cast(MakeVLdsBlockDescriptor().GetElementSpaceSize() * - sizeof(VDataType))); + return ck::math::max(BlockGemm0Pipeline::GetStaticLdsSize(), + BlockGemm1::GetStaticLdsSize()); } __device__ void operator()(const QDataType* q_ptr, @@ -198,7 +88,7 @@ struct GemmSoftmaxGemmImpl // allocate LDS __shared__ char smem_ptr[GetStaticLdsSize()]; - // Q/K/V DRAM and DRAM window + // Q/K/V DRAM // FIXME: assume layout Q[M0, K0], K[N0, K0], V[N1, N0], O[M0, N1] const auto q_dram = make_naive_tensor_view( q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), Number<32>{}, Number<1>{}); @@ -209,55 +99,27 @@ struct GemmSoftmaxGemmImpl const auto v_dram = make_naive_tensor_view( v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), Number<32>{}, Number<1>{}); - auto q_dram_window = - make_tile_window(q_dram, - make_tuple(Number{}, Number{}), - {iM0, 0}, - BlockGemm0Policy::MakeADramTileDistribution()); - - // Q in Register - auto q_reg_tensor = make_static_distributed_tensor( - BlockGemm0Policy::template MakeARegBlockDescriptor()); - - auto k_dram_window = - make_tile_window(k_dram, - make_tuple(Number{}, Number{}), - {0, 0}, - BlockGemm0Policy::MakeBDramTileDistribution()); - - // K LDS and LDS window - auto k_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - BlockGemm0Policy::MakeBLdsBlockDescriptor()); - auto k_lds_window = make_tile_window( - k_lds, make_tuple(Number{}, Number{}), {0, 0}); - - auto v_dram_window = - make_tile_window(v_dram, - make_tuple(Number{}, Number{}), - {iN1, 0}, - MakeVDramTileDistribution()); - - // V LDS and LDS window - // V LDS occupies the same LDS allocation Q/K LDS - auto v_lds = make_tensor_view(reinterpret_cast(smem_ptr), - MakeVLdsBlockDescriptor()); - - auto v_lds_window = make_tile_window( - v_lds, make_tuple(Number{}, Number{}), {0, 0}); + // Q/K/V DRAM window + auto q_dram_window = make_tile_window( + q_dram, make_tuple(Number{}, Number{}), {iM0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(Number{}, Number{}), {0, 0}); + + auto v_dram_window = make_tile_window( + v_dram, make_tuple(Number{}, Number{}), {iN1, 0}); // Block GEMM0 pipeline and Block GEMM1 - constexpr auto gemm0 = BlockGemm0{}; - constexpr auto gemm1 = BlockGemm1{}; + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; // infer Sacc, S, P, M, L, Oacc type - using SaccBlockTileType = decltype(gemm0( - get_slice_tile(q_reg_tensor, Sequence<0, 0>{}, Sequence{}), - k_lds_window)); + using SaccBlockTileType = + decltype(gemm0_pipeline(q_dram_window, k_dram_window, 0, nullptr)); using SBlockTileType = decltype(tile_elementwise_in( type_convert, SaccBlockTileType{})); @@ -268,13 +130,9 @@ struct GemmSoftmaxGemmImpl using MLBlockTileType = decltype(block_tile_reduce( SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); - using OaccBlockTileType = decltype(gemm1( - get_slice_tile( - PBlockTileType{}, Sequence<0, 0>{}, Sequence{}), - v_dram_window)); + using OaccBlockTileType = decltype(gemm1(PBlockTileType{}, v_dram_window, smem_ptr)); - // init Sacc, Oacc, M, L - auto s_acc = SaccBlockTileType{}; + // init Oacc, M, L auto o_acc = OaccBlockTileType{}; auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; @@ -285,145 +143,18 @@ struct GemmSoftmaxGemmImpl tile_elementwise_inout([](auto& e) { e = 0; }, l); // loop over Column of S (J loop) - index_t iN0 = 0; - constexpr index_t k0_loops = kHeadDim / kK0PerBlock; - - // Cold Q_Reg_Cache - auto q_block_tile = load_tile(q_dram_window); - auto k_block_tile = load_tile(k_dram_window); - { - move_tile_window(q_dram_window, {0, kK0PerBlock}); - move_tile_window(k_dram_window, {0, kK0PerBlock}); - - tile_elementwise_inout([](auto& s) { s = 0; }, s_acc); - - set_slice_tile( - q_reg_tensor, q_block_tile, Sequence<0, 0>{}, Sequence{}); - q_block_tile = load_tile(q_dram_window); - - store_tile(k_lds_window, k_block_tile); - k_block_tile = load_tile(k_dram_window); - } - if constexpr(k0_loops > 2) - { - static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { - block_sync_lds(); - - gemm0(s_acc, - get_slice_tile(q_reg_tensor, - Sequence<0, (i_k0)*kK0PerBlock>{}, - Sequence{}), - k_lds_window); - - block_sync_lds(); - - move_tile_window(q_dram_window, {0, kK0PerBlock}); - move_tile_window(k_dram_window, {0, kK0PerBlock}); - - set_slice_tile(q_reg_tensor, - q_block_tile, - Sequence<0, (i_k0 + 1) * kK0PerBlock>{}, - Sequence{}); - q_block_tile = load_tile(q_dram_window); - - store_tile(k_lds_window, k_block_tile); - k_block_tile = load_tile(k_dram_window); - }); - } - - // tail - { - block_sync_lds(); - - gemm0(s_acc, - get_slice_tile(q_reg_tensor, - Sequence<0, (k0_loops - 2) * kK0PerBlock>{}, - Sequence{}), - k_lds_window); - - block_sync_lds(); - - set_slice_tile(q_reg_tensor, - q_block_tile, - Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, - Sequence{}); - - store_tile(k_lds_window, k_block_tile); - - block_sync_lds(); - - gemm0(s_acc, - get_slice_tile(q_reg_tensor, - Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, - Sequence{}), - k_lds_window); - } + index_t iN0 = 0; do { - // Hot Q_Reg_Cache - if(iN0 > 0) - { - k_block_tile = load_tile(k_dram_window); - { - move_tile_window(k_dram_window, {0, kK0PerBlock}); - - tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); - - store_tile(k_lds_window, k_block_tile); - k_block_tile = load_tile(k_dram_window); - } - if constexpr(k0_loops > 2) - { - static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { - block_sync_lds(); - - gemm0(s_acc, - get_slice_tile(q_reg_tensor, - Sequence<0, (i_k0)*kK0PerBlock>{}, - Sequence{}), - k_lds_window); - - block_sync_lds(); - - move_tile_window(k_dram_window, {0, kK0PerBlock}); - - store_tile(k_lds_window, k_block_tile); - k_block_tile = load_tile(k_dram_window); - }); - } - - // tail - { - block_sync_lds(); - gemm0(s_acc, - get_slice_tile(q_reg_tensor, - Sequence<0, (k0_loops - 2) * kK0PerBlock>{}, - Sequence{}), - k_lds_window); - - block_sync_lds(); - - store_tile(k_lds_window, k_block_tile); - - block_sync_lds(); - - gemm0(s_acc, - get_slice_tile(q_reg_tensor, - Sequence<0, (k0_loops - 1) * kK0PerBlock>{}, - Sequence{}), - k_lds_window); - } - - // asm volatile("s_endpgm" ::); - } + // Sacc{j} = Q * K{j} + const auto s_acc = + gemm0_pipeline(q_dram_window, k_dram_window, K0 / kK0PerBlock, smem_ptr); + // S{j} const auto s = tile_elementwise_in(type_convert, s_acc); - // prefetch load v tile - const auto v_prefetch = load_tile(v_dram_window); - // m_local = rowmax(S{j}) auto m_local = block_tile_reduce( s, Sequence<1>{}, f_max, NumericLimits::Lowest()); @@ -477,48 +208,28 @@ struct GemmSoftmaxGemmImpl }); }); - block_sync_lds(); - store_tile(v_lds_window, v_prefetch); - move_tile_window(v_dram_window, {0, kK1PerBlock}); - // type cast Pcompute{j} into P{j} const auto p = tile_elementwise_in(type_convert, p_compute); - // Oacc{j} - constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; - - if constexpr(k1_loops > 1) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v - block_sync_lds(); - gemm1(o_acc, - get_slice_tile(p, - Sequence<0, i_k1 * kK1PerBlock>{}, - Sequence{}), - v_lds_window); - block_sync_lds(); - store_tile(v_lds_window, v); - move_tile_window(v_dram_window, {0, kK1PerBlock}); - }); - } - // tail - { - block_sync_lds(); - gemm1(o_acc, - get_slice_tile(p, - Sequence<0, (k1_loops - 1) * kK1PerBlock>{}, - Sequence{}), - v_lds_window); - block_sync_lds(); - } - // move tile windows - move_tile_window(k_dram_window, {kN0PerBlock, -(k0_loops - 1) * kK0PerBlock}); + // wait for gemm0 pipeline to finish reading Lds + block_sync_lds(); + + // Block GEMM1: Oacc{j} += P{j} * V{j} + gemm1(o_acc, p, v_dram_window, smem_ptr); + + // move K/V tile windows for next iteration (J loop) + move_tile_window(k_dram_window, {kN0PerBlock, 0}); + move_tile_window(v_dram_window, {0, kN0PerBlock}); + + // wait for gemm1 to finish reading Lds, before next iteration (J loop) + block_sync_lds(); + iN0 += kN0PerBlock; + } while(iN0 < N0); - // Oacc + // O constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); sweep_tile_span(o_spans[I0], [&](auto idx0) { @@ -549,4 +260,4 @@ struct GemmSoftmaxGemmImpl // store O store_tile(o_dram_window, o); } -}; +}; \ No newline at end of file diff --git a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 45abd297a..c49ca5285 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -214,6 +214,334 @@ struct BlockGemmPipelineAGmemBGmemCRegV2 +struct BlockGemmPipelineAGmemBGmemCRegV2> +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using Policy = BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t k_loops = Policy::AKDim / kKPerBlock; + + // Move this part into Policy? + __host__ __device__ static constexpr ck::index_t GetStaticLdsSize() + { + return sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().GetElementSpaceSize(); + } + + // Cold A Register Cache + template + __host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + static_assert( + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + ignore = a_element_func; + ignore = b_element_func; + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store + // and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + a_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeADramTileDistribution()); + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + b_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(Number{}, Number{}), + {0, 0}, + b_copy_dram_window.GetTileDistribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(Number{}, Number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(get_slice_tile(a_copy_reg_tensor, + Sequence<0, 0>{}, + Sequence{}), b_lds_gemm_window)){}; + + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + { + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + set_slice_tile( + a_copy_reg_tensor, a_block_tile, Sequence<0, 0>{}, Sequence{}); + a_block_tile = load_tile(a_copy_dram_window); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + } + if constexpr(k_loops > 2) + { + static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (i_k0)*kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + Sequence<0, (i_k0 + 1) * kKPerBlock>{}, + Sequence{}); + a_block_tile = load_tile(a_copy_dram_window); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + }); + } + + // tail + { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (k_loops - 2) * kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + Sequence<0, (k_loops - 1) * kKPerBlock>{}, + Sequence{}); + + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (k_loops - 1) * kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + } + + store_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor); + + return c_block_tile; + } + + // Hot A Register Cache + template + __host__ __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + static_assert( + is_same_v>, + "wrong!"); + + static_assert( + kNPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kKPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + ignore = b_element_func; + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store + // and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); + store_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + b_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(Number{}, Number{}), + {0, 0}, + b_copy_dram_window.GetTileDistribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(Number{}, Number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(get_slice_tile(a_copy_reg_tensor, + Sequence<0, 0>{}, + Sequence{}), b_lds_gemm_window)){}; + + auto b_block_tile = load_tile(b_copy_dram_window); + { + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + } + if constexpr(k_loops > 2) + { + static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (i_k0)*kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + }); + } + + // tail + { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (k_loops - 2) * kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (k_loops - 1) * kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + } + + return c_block_tile; + } + + template + __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + a_reg_block_tensor_tmp, + p_smem); + } + + template + __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + return operator()( + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + a_reg_block_tensor_tmp, + p_smem); + } +}; + + } // namespace block } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 846aed738..ecc0df679 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -19,13 +19,13 @@ struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy constexpr auto blockgemm = GetBlockGemm(); using BlockGemm = ck::remove_cvref_t; - return policy_impl::MakeARegBlockDescriptor(); + return policy_impl::make_a_reg_block_descriptor(); } template __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() { - return policy_impl::MakeBLdsBlockDescriptor(); + return policy_impl::make_b_lds_block_descriptor(); } template @@ -34,13 +34,13 @@ struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy constexpr auto blockgemm = GetBlockGemm(); using BlockGemm = ck::remove_cvref_t; - return policy_impl::MakeADramTileDistribution_ASkipLDS(); + return policy_impl::make_a_dram_tile_distribution_skip_lds(); } template __host__ __device__ static constexpr auto MakeBDramTileDistribution() { - return policy_impl::MakeBDramTileDistribution(); + return policy_impl::make_b_dram_tile_distribution(); } template @@ -52,10 +52,13 @@ struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy } }; +template struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy : BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy { - template + static constexpr index_t AKDim = AKDim_; + + template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() { using namespace ck; @@ -64,7 +67,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy using BlockGemm = ck::remove_cvref_t; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = kHeadDim; + constexpr index_t kKPerBlock = AKDim; constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); diff --git a/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp index cfc0b0138..8d2ee8654 100644 --- a/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp +++ b/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp @@ -14,7 +14,7 @@ namespace block { namespace policy_impl { // 3d + padding template -__host__ __device__ static constexpr auto MakeALdsBlockDescriptor() +__host__ __device__ static constexpr auto make_a_lds_block_descriptor() { using namespace ck; @@ -39,7 +39,7 @@ __host__ __device__ static constexpr auto MakeALdsBlockDescriptor() // 3d + padding template -__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() +__host__ __device__ static constexpr auto make_b_lds_block_descriptor() { using namespace ck; @@ -63,7 +63,7 @@ __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() } template -__host__ __device__ static constexpr auto MakeARegBlockDescriptor() +__host__ __device__ static constexpr auto make_a_reg_block_descriptor() { using namespace ck; @@ -97,7 +97,7 @@ __host__ __device__ static constexpr auto MakeARegBlockDescriptor() } template -__host__ __device__ static constexpr auto MakeADramTileDistribution() +__host__ __device__ static constexpr auto make_a_dram_tile_distribution() { using ADataType = remove_cvref_t; @@ -123,7 +123,7 @@ __host__ __device__ static constexpr auto MakeADramTileDistribution() } template -__host__ __device__ static constexpr auto MakeADramTileDistribution_ASkipLDS() +__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds() { constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); @@ -154,7 +154,7 @@ __host__ __device__ static constexpr auto MakeADramTileDistribution_ASkipLDS() } template -__host__ __device__ static constexpr auto MakeBDramTileDistribution() +__host__ __device__ static constexpr auto make_b_dram_tile_distribution() { using BDataType = remove_cvref_t; @@ -180,7 +180,7 @@ __host__ __device__ static constexpr auto MakeBDramTileDistribution() } template -__host__ __device__ static constexpr auto GetBlockGemm() +__host__ __device__ static constexpr auto get_block_gemm() { using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; diff --git a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp index 690e92764..053550692 100644 --- a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp +++ b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp @@ -287,8 +287,6 @@ struct TileWindowWithStaticDistribution const vector_t vec_value = GetBottomTensorView().template GetVectorizedElements( bottom_tensor_thread_coord); - // printf("Blockid: %02d, Tid: %03d, K read to lds: %05d\n", get_block_1d_id(), - // get_thread_local_1d_id(), bottom_tensor_thread_coord.GetOffset()); const vector_type_t vec{vec_value}; @@ -368,8 +366,6 @@ struct TileWindowWithStaticDistribution // write into bottom tensor GetBottomTensorView().template SetVectorizedElements( bottom_tensor_thread_coord, vec_value); - // printf("Blockid: %02d, Tid: %03d, K write to lds: %05d\n", get_block_1d_id(), - // get_thread_local_1d_id(), bottom_tensor_thread_coord.GetOffset()); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) From 9901df5954a8f703ae35d85fcedd29cd5086ac3a Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 26 Oct 2023 17:32:03 +0000 Subject: [PATCH 08/10] fix a typo --- example/91_tile_program/flash_attention_fwd_impl.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/example/91_tile_program/flash_attention_fwd_impl.hpp b/example/91_tile_program/flash_attention_fwd_impl.hpp index 9fc89d0c7..d305f78e4 100644 --- a/example/91_tile_program/flash_attention_fwd_impl.hpp +++ b/example/91_tile_program/flash_attention_fwd_impl.hpp @@ -92,7 +92,6 @@ struct FlashAttentionFwdImpl return b_lds_block_desc; } -#endif __device__ static constexpr auto MakeVDramTileDistribution() { From 46f639a3812239a50586de7a7eb342c2facf1d29 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 1 Nov 2023 14:11:51 +0000 Subject: [PATCH 09/10] Change VLDS layout, ~2% performance gain --- example/91_tile_program/flash_attention_fwd_impl.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/example/91_tile_program/flash_attention_fwd_impl.hpp b/example/91_tile_program/flash_attention_fwd_impl.hpp index d305f78e4..70ab4c340 100644 --- a/example/91_tile_program/flash_attention_fwd_impl.hpp +++ b/example/91_tile_program/flash_attention_fwd_impl.hpp @@ -75,7 +75,8 @@ struct FlashAttentionFwdImpl constexpr index_t kNPerBlock = kN1PerBlock; constexpr index_t kKPerBlock = kK1PerBlock; constexpr index_t kPad = 1; - constexpr index_t kK1 = 8; + // 2% faster than use kK1 = 8 + constexpr index_t kK1 = 4; constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(Number{}, Number{}, Number{}), From 6194f0d387dd7719fcbdeabe5ec9013856684a44 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 6 Nov 2023 14:51:35 +0000 Subject: [PATCH 10/10] fix o_span --- example/91_tile_program/flash_attention_fwd_impl.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/example/91_tile_program/flash_attention_fwd_impl.hpp b/example/91_tile_program/flash_attention_fwd_impl.hpp index 70ab4c340..58bc26a27 100644 --- a/example/91_tile_program/flash_attention_fwd_impl.hpp +++ b/example/91_tile_program/flash_attention_fwd_impl.hpp @@ -279,15 +279,16 @@ struct FlashAttentionFwdImpl block_tile_reduce_sync(rowsum_p, f_sum); + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); // l{j}, Oacc{j} - sweep_tile_span(p_spans[I0], [&](auto idx0) { + sweep_tile_span(o_spans[I0], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(p_spans[I1], [&](auto idx1) { + sweep_tile_span(o_spans[I1], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); // FIXME: this use different equation from FA v2 paper,