diff --git a/example/91_tile_program/CMakeLists.txt b/example/91_tile_program/CMakeLists.txt index 27d3c67ad..7a352b482 100644 --- a/example/91_tile_program/CMakeLists.txt +++ b/example/91_tile_program/CMakeLists.txt @@ -6,3 +6,4 @@ add_example_executable(example_softmax softmax.cpp) add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp) add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp) add_example_executable(example_fmha_fwd fmha_fwd.cpp) +add_example_executable(example_flash_attention_fwd flash_attention_fwd.cpp) diff --git a/example/91_tile_program/flash_attention_fwd.cpp b/example/91_tile_program/flash_attention_fwd.cpp new file mode 100644 index 000000000..5d5e599b3 --- /dev/null +++ b/example/91_tile_program/flash_attention_fwd.cpp @@ -0,0 +1,225 @@ +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor/tensor_view.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +#include "reference_batched_gemm.hpp" +#include "reference_batched_softmax.hpp" +#include "flash_attention_fwd.hpp" + +int main(int argc, char* argv[]) +{ + using QDataType = ck::half_t; + using KDataType = ck::half_t; + using VDataType = ck::half_t; + using SaccDataType = float; + using SMPLComputeDataType = float; + using PDataType = ck::half_t; + using OaccDataType = float; + using ODataType = ck::half_t; + + ck::index_t Batch = 64; + ck::index_t M0 = 4096; + ck::index_t N0 = 4096; + ck::index_t K0 = 128; + ck::index_t N1 = 128; + ck::index_t init_method = 1; + ck::index_t time_kernel = 0; + + if(argc == 3) + { + init_method = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + + if(argc == 8) + { + init_method = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); + } + + std::array q_lengths{Batch, M0, K0}; + std::array q_strides{M0 * K0, K0, 1}; + + std::array k_lengths{Batch, N0, K0}; + std::array k_strides{N0 * K0, K0, 1}; + + std::array v_lengths{Batch, N1, N0}; + std::array v_strides{N1 * N0, N0, 1}; + + std::array s_lengths{Batch, M0, N0}; + std::array s_strides{M0 * N0, N0, 1}; + + std::array p_lengths{Batch, M0, N0}; + std::array p_strides{M0 * N0, N0, 1}; + + std::array o_lengths{Batch, M0, N1}; + std::array o_strides{M0 * N1, N1, 1}; + + // host verify + Tensor q_host(q_lengths, q_strides); + Tensor k_host(k_lengths, k_strides); + Tensor v_host(v_lengths, v_strides); + Tensor s_host_ref(s_lengths, s_strides); + Tensor p_host_ref(p_lengths, p_strides); + Tensor o_host_ref(o_lengths, o_strides); + Tensor o_host_dev(o_lengths, o_strides); + + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 2: + ck::utils::FillUniformDistribution{-3.f, 3.f}(q_host); + ck::utils::FillUniformDistribution{-3.f, 3.f}(k_host); + ck::utils::FillUniformDistribution{-3.f, 3.f}(v_host); + break; + case 3: + ck::utils::FillConstant{1.f}(q_host); + ck::utils::FillConstant{1.f}(k_host); + ck::utils::FillConstant{1.f}(v_host); + break; + case 4: + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck::utils::FillConstant{1.f}(k_host); + ck::utils::FillConstant{1.f}(v_host); + break; + case 5: + ck::utils::FillConstant{1.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck::utils::FillConstant{1.f}(v_host); + break; + case 6: + ck::utils::FillConstant{1.f}(q_host); + ck::utils::FillConstant{1.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 7: + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck::utils::FillConstant{1.f}(v_host); + break; + case 8: + ck::utils::FillConstant{1.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + case 9: + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(q_host); + ck::utils::FillConstant{1.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-3.f, 3.f}(v_host); + break; + default: + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); + } + + // reference + reference_batched_gemm( + q_host, k_host, s_host_ref); + reference_batched_softmax(s_host_ref, + p_host_ref); + reference_batched_gemm( + p_host_ref, v_host, o_host_ref); + + DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize()); + DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize()); + DeviceMem v_buf(sizeof(VDataType) * v_host.GetElementSpaceSize()); + DeviceMem o_buf(sizeof(ODataType) * o_host_ref.GetElementSpaceSize()); + + q_buf.ToDevice(q_host.mData.data()); + k_buf.ToDevice(k_host.mData.data()); + v_buf.ToDevice(v_host.mData.data()); + + constexpr ck::index_t kM0PerBlock = 128; + constexpr ck::index_t kN0PerBlock = 128; + constexpr ck::index_t kK0PerBlock = 32; + constexpr ck::index_t kN1PerBlock = 128; + constexpr ck::index_t kK1PerBlock = 32; + + constexpr ck::index_t kBlockSize = 256; + constexpr ck::index_t kHeadDim = 128; + + ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock); + + std::cout << "grid size " << kGridSize << std::endl; + + constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD + constexpr ck::index_t kWarpPerBlock = kBlockSize / warpSize; + constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + + float ave_time = launch_kernel( + StreamConfig{nullptr, static_cast(time_kernel)}, + FlashAttentionFwd{}, + kGridSize, + kBlockSize, + 0, + static_cast(q_buf.GetDeviceBuffer()), + static_cast(k_buf.GetDeviceBuffer()), + static_cast(v_buf.GetDeviceBuffer()), + static_cast(o_buf.GetDeviceBuffer()), + M0, + N0, + K0, + N1, + Batch, + K0, // StrideQ + K0, // StrideK + N0, // StrideV + N1, // StrideO + M0 * K0, // BatchStrideQ + N0 * K0, // BatchStrideK + N1 * N0, // BatchStrideV + M0 * N1); // BatchStrideO + + o_buf.FromDevice(o_host_dev.mData.data()); + + std::size_t flop = + std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0; + std::size_t num_btype = + sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 + + sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !ck::utils::check_err(o_host_dev, o_host_ref); +} diff --git a/example/91_tile_program/flash_attention_fwd.hpp b/example/91_tile_program/flash_attention_fwd.hpp new file mode 100644 index 000000000..5e0d36319 --- /dev/null +++ b/example/91_tile_program/flash_attention_fwd.hpp @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck/tile_program/block_tile/block_reduce.hpp" + +#include "flash_attention_fwd_impl.hpp" + +// S[M0, N0] = Q[M0, K0] * K[N0, K0] +// P[M0, N0] = Softmax(S[M0, N0]) +// O[M0, N1] = P[M0, N0] * V[N1, N0] +template +struct FlashAttentionFwd +{ + __device__ void operator()(const QDataType* q_ptr, + const KDataType* k_ptr, + const VDataType* v_ptr, + ODataType* o_ptr, + const ck::index_t M0, + const ck::index_t N0, + const ck::index_t K0, + const ck::index_t N1, + const ck::index_t /* Batch */, + const ck::index_t StrideQ, + const ck::index_t StrideK, + const ck::index_t StrideV, + const ck::index_t StrideO, + const ck::index_t BatchStrideQ, + const ck::index_t BatchStrideK, + const ck::index_t BatchStrideV, + const ck::index_t BatchStrideO) const + { + using namespace ck; + + // divide problem + const index_t num_tile_m0 = M0 / kM0PerBlock; + const index_t num_tile_n1 = N1 / kN1PerBlock; + + const index_t id_block = get_block_id(); + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + + return ck::make_tuple(quotient, modulus); + }; + + const auto [itmp, id_tile_n] = f(id_block, num_tile_n1); + const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0); + + const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch); + const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock); + const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock); + + const auto kernel_impl = FlashAttentionFwdImpl{}; + + kernel_impl(q_ptr + iBatch * BatchStrideQ, + k_ptr + iBatch * BatchStrideK, + v_ptr + iBatch * BatchStrideV, + o_ptr + iBatch * BatchStrideO, + M0, + N0, + K0, + N1, + StrideQ, + StrideK, + StrideV, + StrideO, + iM0, + iN1); + } +}; diff --git a/example/91_tile_program/flash_attention_fwd_impl.hpp b/example/91_tile_program/flash_attention_fwd_impl.hpp new file mode 100644 index 000000000..58bc26a27 --- /dev/null +++ b/example/91_tile_program/flash_attention_fwd_impl.hpp @@ -0,0 +1,373 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/tile/slice_tile.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck/tile_program/block_tile/block_reduce.hpp" + +// S[M0, N0] = Q[M0, K0] * K[N0, K0] +// P[M0, N0] = Softmax(S[M0, N0]) +// O[M0, N1] = P[M0, N0] * V[N1, N0] +template +struct FlashAttentionFwdImpl +{ + // block gemm0 pipeline + using BlockGemm0Problem = ck::tile_program::block::BlockGemmPipelineProblem< + QDataType, + KDataType, + SaccDataType, + kBlockSize, + ck::tile_program::TileGemmShape>; + + using BlockGemm0Policy = + ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy< + kHeadDim>; + + using BlockGemm0Pipeline = + ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2; + + // block gemm1 + using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1< + ck::tile_program::block::BlockGemmARegBSmemCRegProblem< + PDataType, + VDataType, + OaccDataType, + kBlockSize, + ck::tile_program::TileGemmShape>, + ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>; + + // 3d, with padding + __device__ static constexpr auto MakeVLdsBlockDescriptor() + { + using namespace ck; + + // using BDataType = B1DataType; + + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kK1PerBlock; + constexpr index_t kPad = 1; + // 2% faster than use kK1 = 8 + constexpr index_t kK1 = 4; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number{}), + make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number{}, Number<1>{}), + Number{}, + Number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return b_lds_block_desc; + } + + __device__ static constexpr auto MakeVDramTileDistribution() + { + using namespace ck; + using namespace ck::tile_program; + + using BDataType = VDataType; + + constexpr index_t kNPerBlock = kN1PerBlock; + constexpr index_t kKPerBlock = kK1PerBlock; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + + __device__ static constexpr ck::index_t GetStaticLdsSize() + { + using namespace ck; + + return math::max(BlockGemm0Pipeline::GetStaticLdsSize(), + static_cast(MakeVLdsBlockDescriptor().GetElementSpaceSize() * + sizeof(VDataType))); + } + + __device__ void operator()(const QDataType* q_ptr, + const KDataType* k_ptr, + const VDataType* v_ptr, + ODataType* o_ptr, + const ck::index_t M0, + const ck::index_t N0, + const ck::index_t K0, + const ck::index_t N1, + const ck::index_t StrideQ, + const ck::index_t StrideK, + const ck::index_t StrideV, + const ck::index_t StrideO, + const ck::index_t iM0, + const ck::index_t iN1) const + { + using namespace ck; + using namespace ck::tile_program; + using namespace ck::tile_program::block; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // allocate LDS + __shared__ char smem_ptr[GetStaticLdsSize()]; + + // Q/K/V DRAM and DRAM window + // FIXME: assume layout Q[M0, K0], K[N0, K0], V[N1, N0], O[M0, N1] + const auto q_dram = make_naive_tensor_view( + q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), Number<32>{}, Number<1>{}); + + const auto k_dram = make_naive_tensor_view( + k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), Number<32>{}, Number<1>{}); + + const auto v_dram = make_naive_tensor_view( + v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), Number<32>{}, Number<1>{}); + + auto q_dram_window = make_tile_window( + q_dram, make_tuple(Number{}, Number{}), {iM0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(Number{}, Number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(Number{}, Number{}), + {iN1, 0}, + MakeVDramTileDistribution()); + + // Q in Register + auto q_reg_tensor = make_static_distributed_tensor( + BlockGemm0Policy::template MakeARegBlockDescriptor()); + + // V LDS and LDS window + // V LDS occupies the same LDS allocation Q/K LDS + auto v_lds = make_tensor_view(reinterpret_cast(smem_ptr), + MakeVLdsBlockDescriptor()); + + auto v_lds_window = make_tile_window( + v_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // Block GEMM0 pipeline and Block GEMM1 + constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; + constexpr auto gemm1 = BlockGemm1{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SaccBlockTileType = + decltype(gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, nullptr)); + + using SBlockTileType = decltype(tile_elementwise_in( + type_convert, SaccBlockTileType{})); + + using PBlockTileType = decltype(tile_elementwise_in(type_convert, + SaccBlockTileType{})); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm1( + get_slice_tile( + PBlockTileType{}, Sequence<0, 0>{}, Sequence{}), + v_dram_window)); + + // init Sacc, Oacc, M, L + auto s_acc = SaccBlockTileType{}; + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); + tile_elementwise_inout([](auto& e) { e = NumericLimits::Lowest(); }, + m); + tile_elementwise_inout([](auto& e) { e = 0; }, l); + + // loop over Column of S (J loop) + index_t iN0 = 0; + + // Cold Q_Reg_Cache + s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr); + do + { + // Hot Q_Reg_Cache + if(iN0 > 0) + { + s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); + } + // S{j} + const auto s = + tile_elementwise_in(type_convert, s_acc); + + // prefetch load v tile + const auto v_prefetch = load_tile(v_dram_window); + + // m_local = rowmax(S{j}) + auto m_local = block_tile_reduce( + s, Sequence<1>{}, f_max, NumericLimits::Lowest()); + + block_tile_reduce_sync(m_local, f_max); + + // m{j-1} + const auto m_old = m; + + // m{j} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + // Pcompute{j} + auto p_compute = + make_static_distributed_tensor(s.GetTileDistribution()); + + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); + + sweep_tile_span(p_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + sweep_tile_span(p_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); + }); + }); + + // rowsum(Pcompute{j}) + auto rowsum_p = block_tile_reduce( + p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum); + + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + // l{j}, Oacc{j} + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); + + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + // FIXME: this use different equation from FA v2 paper, + // but produce correct result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + store_tile(v_lds_window, v_prefetch); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + + // type cast Pcompute{j} into P{j} + const auto p = + tile_elementwise_in(type_convert, p_compute); + + // Oacc{j} + constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock; + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + Sequence<0, i_k1 * kK1PerBlock>{}, + Sequence{}), + v_lds_window); + block_sync_lds(); + store_tile(v_lds_window, v); + move_tile_window(v_dram_window, {0, kK1PerBlock}); + }); + } + // tail + { + block_sync_lds(); + gemm1(o_acc, + get_slice_tile(p, + Sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + Sequence{}), + v_lds_window); + block_sync_lds(); + } + // move tile windows + move_tile_window(k_dram_window, {kN0PerBlock, 0}); + iN0 += kN0PerBlock; + } while(iN0 < N0); + + // Oacc + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + + sweep_tile_span(o_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + const auto tmp = 1 / l[i_idx]; + + sweep_tile_span(o_spans[I1], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + o_acc(i_j_idx) *= tmp; + }); + }); + + // type cast Oacc into O + const auto o = tile_elementwise_in(type_convert, o_acc); + + // O DRAM and O DRAM window + auto o_dram = make_naive_tensor_view( + o_ptr, make_tuple(M0, N1), make_tuple(StrideO, 1), Number<32>{}, Number<1>{}); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(Number{}, Number{}), + {iM0, iN1}, + o.GetTileDistribution()); + + // store O + store_tile(o_dram_window, o); + } +}; diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index d8a7a6a33..aa3e92f51 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -544,9 +544,9 @@ struct Merge_v2_magic_division : public BaseTransform using UpLengths = decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); - using LowLengthsMagicDivisor = decltype( - generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_divisor{}, - Number{})); + using LowLengthsMagicDivisor = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_divisor{}, + Number{})); LowLengths low_lengths_; LowLengthsMagicDivisor low_lengths_magic_divisor_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index 89fcbca1a..19f126e66 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -151,8 +151,9 @@ struct DeviceImageToColumnImpl remove_cvref_t; using OutputGridDesc = remove_cvref_t; - using Block2ETileMap = remove_cvref_t(OutputGridDesc{}))>; + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + OutputGridDesc{}))>; using GridwiseImageToColumnKernel = GridwiseImageToColumn, - BlockGemmARegBSmemCRegV1DefaultPolicy>; + using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< + BlockGemmARegBSmemCRegProblem, + BlockGemmARegBSmemCRegV1DefaultPolicy>; __host__ __device__ static constexpr ck::index_t GetStaticLdsSize() { diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp index 217928dfa..2b01568ed 100644 --- a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp @@ -26,7 +26,6 @@ struct BlockGemmARegBSmemCRegProblem static constexpr index_t kBlockSize = kBlockSize_; }; - } // namespace block } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp index 5778c2111..ffc37db87 100644 --- a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp @@ -15,6 +15,7 @@ #include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp" namespace ck { namespace tile_program { @@ -26,10 +27,11 @@ namespace block { template struct BlockGemmARegBSmemCRegV1 { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using BlockGemmPolicy = Policy; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -165,7 +167,6 @@ struct BlockGemmARegBSmemCRegV1 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp new file mode 100644 index 000000000..17590a90e --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +struct BlockGemmARegBSmemCRegV1K8Policy +{ + template + __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() + { + using namespace ck::tile_program::warp; + + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp new file mode 100644 index 000000000..a096bfc2c --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/load_tile.hpp" +#include "ck/tile_program/tile/store_tile.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCRegV2 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using Policy = BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + // Move this part into Policy? + __host__ __device__ static constexpr ck::index_t GetStaticLdsSize() + { + return sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().GetElementSpaceSize(); + } + + template + __host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + static_assert( + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store + // and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + a_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeADramTileDistribution()); + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + b_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(Number{}, Number{}), + {0, 0}, + b_copy_dram_window.GetTileDistribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(Number{}, Number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_copy_reg_tensor, b_lds_gemm_window)){}; + + // prefetch + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // block buffer write 0 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + // store_tile -> shuffle store tile + store_tile(a_copy_reg_tensor, a_block_tile_tmp); + // global read 1 + a_block_tile = load_tile(a_copy_dram_window); + + // LDS write 0 + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + // global read 1 + b_block_tile = load_tile(b_copy_dram_window); + } + + index_t iCounter = num_loop - 2; + + do + { + block_sync_lds(); + + // GEMM i + block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_reg_tensor, a_block_tile_tmp); + // global read i + 2 + a_block_tile = load_tile(a_copy_dram_window); + + // LDS write i + 1 + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + // global read i + 2 + b_block_tile = load_tile(b_copy_dram_window); + + iCounter--; + + } while(iCounter > 0); + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); + + block_sync_lds(); + + // LDS write num_loop - 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_reg_tensor, a_block_tile_tmp); + + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + + block_sync_lds(); + + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_copy_reg_tensor, b_lds_gemm_window); + } + + return c_block_tile; + } + + template + __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem); + } +}; + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCRegV2< + Problem, + BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy> +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using Policy = BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t k_loops = Policy::AKDim / kKPerBlock; + + // Move this part into Policy? + __host__ __device__ static constexpr ck::index_t GetStaticLdsSize() + { + return sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().GetElementSpaceSize(); + } + + // Cold A Register Cache + template + __host__ __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + static_assert( + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + ignore = a_element_func; + ignore = b_element_func; + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store + // and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + a_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeADramTileDistribution()); + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + b_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(Number{}, Number{}), + {0, 0}, + b_copy_dram_window.GetTileDistribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(Number{}, Number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm( + get_slice_tile(a_copy_reg_tensor, Sequence<0, 0>{}, Sequence{}), + b_lds_gemm_window)){}; + + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + { + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + Sequence<0, 0>{}, + Sequence{}); + a_block_tile = load_tile(a_copy_dram_window); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + } + if constexpr(k_loops > 2) + { + static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (i_k0)*kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + Sequence<0, (i_k0 + 1) * kKPerBlock>{}, + Sequence{}); + a_block_tile = load_tile(a_copy_dram_window); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + }); + } + + // tail + { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (k_loops - 2) * kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + set_slice_tile(a_copy_reg_tensor, + a_block_tile, + Sequence<0, (k_loops - 1) * kKPerBlock>{}, + Sequence{}); + + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (k_loops - 1) * kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + } + + store_tile(a_reg_block_tensor_tmp, a_copy_reg_tensor); + + return c_block_tile; + } + + // Hot A Register Cache + template + __host__ __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + static_assert(is_same_v>, + "wrong!"); + + static_assert(kNPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kKPerBlock == BDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + ignore = b_element_func; + + // A tile in Reg,blockTensor + // This tensor distribution used to construct both distributed tensor for local buffer store + // and read. without buffer address info + constexpr auto a_reg_block_dstr = Policy::template MakeARegBlockDescriptor(); + + // A Reg tensor for store, also used for block GEMM + auto a_copy_reg_tensor = make_static_distributed_tensor(a_reg_block_dstr); + store_tile(a_copy_reg_tensor, a_reg_block_tensor_tmp); + + // B tile in LDS, blockWindow + BDataType* p_b_lds = + static_cast(static_cast(static_cast(p_smem))); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + // This tensor view used to construct both tile window for lds store and read, with buffer + // address info + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + b_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(Number{}, Number{}), + {0, 0}, + b_copy_dram_window.GetTileDistribution()); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(Number{}, Number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm( + get_slice_tile(a_copy_reg_tensor, Sequence<0, 0>{}, Sequence{}), + b_lds_gemm_window)){}; + + auto b_block_tile = load_tile(b_copy_dram_window); + { + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + } + if constexpr(k_loops > 2) + { + static_for<0, k_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (i_k0)*kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + store_tile(b_copy_lds_window, b_block_tile); + b_block_tile = load_tile(b_copy_dram_window); + }); + } + + // tail + { + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (k_loops - 2) * kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + + block_sync_lds(); + + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + + block_gemm(c_block_tile, + get_slice_tile(a_copy_reg_tensor, + Sequence<0, (k_loops - 1) * kKPerBlock>{}, + Sequence{}), + b_copy_lds_window); + } + + return c_block_tile; + } + + template + __device__ auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + a_reg_block_tensor_tmp, + p_smem); + } + + template + __device__ auto operator()(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const ARegBlockTensorTmp& a_reg_block_tensor_tmp, + void* p_smem) const + { + return operator()( + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + a_reg_block_tensor_tmp, + p_smem); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp new file mode 100644 index 000000000..e54d3a4f5 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// NOTE: Assume A is K-Major +struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy +{ + template + __host__ __device__ static constexpr auto MakeARegBlockDescriptor() + { + constexpr auto blockgemm = GetBlockGemm(); + using BlockGemm = ck::remove_cvref_t; + + return policy_impl::make_a_reg_block_descriptor(); + } + + template + __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() + { + return policy_impl::make_b_lds_block_descriptor_3d_pad(); + } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + constexpr auto blockgemm = GetBlockGemm(); + using BlockGemm = ck::remove_cvref_t; + + return policy_impl::make_a_dram_tile_distribution_skip_lds(); + } + + template + __host__ __device__ static constexpr auto MakeBDramTileDistribution() + { + return policy_impl::make_b_dram_tile_distribution(); + } + + template + __host__ __device__ static constexpr auto GetBlockGemm() + { + using BlockGemmPolicy = BlockGemmARegBSmemCRegV1K8Policy; + + return BlockGemmARegBSmemCRegV1{}; + } +}; + +template +struct BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPersistentQRegCachePolicy + : BlockGemmPipelineAGmemBGmemCRegV2SkipALdsPolicy +{ + static constexpr index_t AKDim = AKDim_; + + template + __host__ __device__ static constexpr auto MakeARegBlockDescriptor() + { + using namespace ck; + + constexpr auto blockgemm = GetBlockGemm(); + using BlockGemm = ck::remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = AKDim; + + constexpr auto config = + BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + return a_block_dstr; + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp b/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp new file mode 100644 index 000000000..d57ad3e3a --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp @@ -0,0 +1,193 @@ +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { +namespace policy_impl { +// 3d + padding +template +__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad() +{ + using namespace ck; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number<8>{}), + make_tuple(Number<(kMPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto a_lds_block_desc = + transform_tensor_descriptor(a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return a_lds_block_desc; +} + +// 3d + padding +template +__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad() +{ + using namespace ck; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number<8>{}), + make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto b_lds_block_desc = + transform_tensor_descriptor(b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return b_lds_block_desc; +} + +template +__host__ __device__ static constexpr auto make_a_reg_block_descriptor() +{ + using namespace ck; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + return a_block_dstr; +} + +template +__host__ __device__ static constexpr auto make_a_dram_tile_distribution() +{ + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); +} + +template +__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds() +{ + constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K2 = + WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane; + // // 16 / sizeof(ADataType); + constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K0 = kKPerBlock / (K1 * K2); + + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1>>, + Tuple, Sequence<1, 2>>, + Sequence<2, 1, 2>, + Sequence<0, 0, 2>>{}); +} + +template +__host__ __device__ static constexpr auto make_b_dram_tile_distribution() +{ + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); +} + +template +__host__ __device__ static constexpr auto get_block_gemm() +{ + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; + + return BlockGemmASmemBSmemCRegV1{}; +} + +} // namespace policy_impl +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/slice_tile.hpp b/include/ck/tile_program/tile/slice_tile.hpp index e7999f26a..78aa1208a 100644 --- a/include/ck/tile_program/tile/slice_tile.hpp +++ b/include/ck/tile_program/tile/slice_tile.hpp @@ -49,17 +49,18 @@ __host__ __device__ constexpr auto set_slice_tile(DstStaticDistributedTensor_& d Sequence slice_ends) { using DstDistribution = decltype(DstStaticDistributedTensor_::GetTileDistribution()); + // using SrcDistribution = decltype(SrcStaticDistributedTensor_::GetTileDistribution()); constexpr auto sliced_dstr_yidx_ylen = detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); - constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); + // constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); - static_assert(is_same_v, "wrong!"); + // static_assert(is_same_v, "wrong!"); - dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer()); + dst_tile.SetYSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer()); } } // namespace tile_program diff --git a/include/ck/tile_program/tile/static_distributed_tensor.hpp b/include/ck/tile_program/tile/static_distributed_tensor.hpp index 57b3c418c..19f420aa4 100644 --- a/include/ck/tile_program/tile/static_distributed_tensor.hpp +++ b/include/ck/tile_program/tile/static_distributed_tensor.hpp @@ -59,7 +59,7 @@ struct StaticDistributedTensor template __host__ __device__ auto GetYSlicedThreadData(Sequence, - Sequence) const + Sequence) const { static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY && sizeof...(YSliceLengths) == StaticTileDistribution::NDimY, diff --git a/include/ck/tile_program/tile/store_tile.hpp b/include/ck/tile_program/tile/store_tile.hpp index 7aba746f7..974e8d082 100644 --- a/include/ck/tile_program/tile/store_tile.hpp +++ b/include/ck/tile_program/tile/store_tile.hpp @@ -12,3 +12,4 @@ #include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/store_tile_impl_static_distribution.hpp" #include "ck/tile_program/tile/store_tile_impl_static_lengths.hpp" +#include "ck/tile_program/tile/store_tile_impl_static_tensor.hpp" diff --git a/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp b/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp new file mode 100644 index 000000000..4407df6bd --- /dev/null +++ b/include/ck/tile_program/tile/store_tile_impl_static_tensor.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_window.hpp" + +namespace ck { +namespace tile_program { + +template +__device__ void +store_tile(StaticDistributedTensor& dst_dstr_tensor, + const StaticDistributedTensor& src_dstr_tensor) +{ + // static_assert(DstTileDistribution_==SrcTileDistribution_); + dst_dstr_tensor.GetThreadBuffer() = src_dstr_tensor.GetThreadBuffer(); +} + +} // namespace tile_program +} // namespace ck