Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions example/91_tile_program/fmha_fwd.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <cstring>
#include <ostream>

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
Expand All @@ -15,6 +16,8 @@

#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp"
#include "ck/tile_program/tile/tile_fmha_shape.hpp"

Expand All @@ -33,10 +36,14 @@ using PDataType = ck::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck::half_t;

// M0 N0 K0 N1 K1
// M0 N0 K0 N1 K1 K0L
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>;
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>;
using FmhaShape = ck::tile_program::TileFmhaShape<128, 128, 32, 128, 32>;
using FmhaBlockTile = ck::Sequence<128, 128, 32, 128, 32, 128>;
using FmhaBlockWarps = ck::Sequence<4, 1, 1>;
using FmhaWarpTile = ck::Sequence<32, 32, 16>;
using FmhaShape = ck::tile_program::
TileFmhaShape<FmhaBlockTile, FmhaBlockWarps, FmhaWarpTile, FmhaBlockWarps, FmhaWarpTile>;

using FmhaTilePartitioner = FmhaFwdTilePartitioner<FmhaShape>;
using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QDataType,
Expand All @@ -49,7 +56,8 @@ using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QD
ODataType,
256, // BlockSize
FmhaShape>;
using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>;
// using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>;
using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVS<FmhaPipelineProblem>;

using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>;
using FmhaKernel = FmhaFwdKernel<FmhaTilePartitioner, FmhaPipeline, FmhaEpilogue>;
Expand Down Expand Up @@ -134,7 +142,7 @@ int main(int argc, char* argv[])
<< ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v
<< ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm
<< ", grid_size " << kGridSize.x << "x" << kGridSize.y << "x" << kGridSize.z
<< std::endl;
<< std::flush << std::endl;

constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize;
Expand Down
16 changes: 11 additions & 5 deletions example/91_tile_program/fmha_fwd_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k]

#define C_LOG2E 1.44269504088896340736 // log2(e)
#define C_LOG2E 1.44269504088896340736 // log2(e)

template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel
Expand Down Expand Up @@ -148,10 +148,16 @@ struct FmhaFwdKernel
Number<32>{},
Number<1>{});

auto q_dram_window =
make_tile_window(q_dram,
make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kK0>{}),
{i_m0, 0});
auto q_dram_window = make_tile_window(
q_dram,
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(Number<FmhaPipeline::kM0>{},
Number<FmhaPipeline::kK0BlockLength>{});
else
return make_tuple(Number<FmhaPipeline::kM0>{}, Number<FmhaPipeline::kK0>{});
}(),
{i_m0, 0});

auto k_dram_window = make_tile_window(
k_dram, make_tuple(Number<FmhaPipeline::kN0>{}, Number<FmhaPipeline::kK0>{}), {0, 0});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ namespace block {
// This will:
// 1. Load B from global memory into shared memory and then
// 2. Call BlockGemmARegSGmemCRegV1
template <typename Problem, typename Policy = BlockGemmARegBGmemCRegV1DefaultPolicy>
template <typename Problem_, typename Policy_ = BlockGemmARegBGmemCRegV1DefaultPolicy>
struct BlockGemmARegBGmemCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
Expand All @@ -37,13 +39,9 @@ struct BlockGemmARegBGmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize;

// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using BlockGemmARegBSmemCRegImpl =
BlockGemmARegBSmemCRegV1<BlockGemmARegBSmemCRegProblem<ADataType,
BDataType,
CDataType,
kBlockSize,
BlockGemmShape>,
BlockGemmARegBSmemCRegV1DefaultPolicy>;
using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1<
BlockGemmARegBSmemCRegProblem<ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape>,
BlockGemmARegBSmemCRegV1DefaultPolicy>;

__host__ __device__ static constexpr ck::index_t GetStaticLdsSize()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ namespace block {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmARegBSmemCRegV1DefaultPolicy>
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV1DefaultPolicy>
struct BlockGemmARegBSmemCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"

#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"

namespace ck {
namespace tile_program {
namespace block {

template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpGemm_>
struct BlockGemmARegBSmemCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;

using BlockWarps = remove_cvref_t<BlockWarps_>;

static constexpr index_t kMWarps = BlockWarps::At(Number<0>{});
static constexpr index_t kNWarps = BlockWarps::At(Number<1>{});
static constexpr index_t kKWarps = BlockWarps::At(Number<2>{});

using WarpGemm = remove_cvref_t<WarpGemm_>;

template <typename Problem>
__host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp()
{
using namespace ck::tile_program::warp;
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
}
};

} // namespace block
} // namespace tile_program
} // namespace ck
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ namespace block {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem, typename Policy = BlockGemmASmemBSmemCRegV1DefaultPolicy>
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
struct BlockGemmASmemBSmemCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"

#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp"

namespace ck {
namespace tile_program {
namespace block {

// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpTile_,
bool TranposeC_>
struct BlockGemmASmemBSmemCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;

using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t BlockMWarps = BlockWarps::At(Number<0>{});
static constexpr index_t BlockNWarps = BlockWarps::At(Number<1>{});
static constexpr index_t BlockKWarps = BlockWarps::At(Number<2>{});

static constexpr index_t MPerWarp = WarpTile::At(Number<0>{});
static constexpr index_t NPerWarp = WarpTile::At(Number<1>{});
static constexpr index_t KPerWarp = WarpTile::At(Number<2>{});

static constexpr bool TranposeC = TranposeC_;

using WarpGemm = ck::tile_program::warp::
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWarp, NPerWarp, KPerWarp, TranposeC>;

template <typename Problem>
__host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp()
{
using namespace ck::tile_program::warp;
return make_tuple(WarpGemm{}, BlockMWarps, BlockNWarps);
}
};

} // namespace block
} // namespace tile_program
} // namespace ck
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ struct BlockFmhaPipelineQKVS
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;

using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr bool kQLoadOnce = false; // if q load whole block length (hdim) at once

static constexpr index_t kBlockSize = Problem::kBlockSize;

Expand Down
Loading