From 441de82c4bbc6790a4e6534ce00fcd90e65872e2 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 1 Dec 2023 01:13:17 -0600 Subject: [PATCH 01/45] async copy code --- example/91_tile_program/CMakeLists.txt | 13 + example/91_tile_program/fmha_fwd.cpp | 6 +- example/91_tile_program/fmha_fwd_kernel.hpp | 26 +- .../fmha_fwd_tile_partitioner.hpp | 9 +- include/ck/tensor/tensor_view.hpp | 28 +- .../multi_index_transform.hpp | 85 ++++ .../multi_index_transform_helper.hpp | 7 + .../tensor_descriptor_helper.hpp | 100 ++++ .../block_gemm_areg_bsmem_creg_v2.hpp | 236 ++++++++++ ..._gemm_areg_bsmem_creg_v2_custom_policy.hpp | 49 ++ ...gemm_areg_bsmem_creg_v2_default_policy.hpp | 58 +++ .../block_fmha_pipeline_qr_ks_vs.hpp | 288 ++++++++---- ..._fmha_pipeline_qr_ks_vs_default_policy.hpp | 437 +++++++++++++++--- include/ck/tile_program/tile/load_tile.hpp | 40 +- include/ck/tile_program/tile/slice_tile.hpp | 63 +-- .../slice_tile_impl_distributed_tensor.hpp | 73 +++ .../tile/slice_tile_impl_static_lengths.hpp | 42 ++ .../tile/static_distributed_tensor.hpp | 31 +- .../ck/tile_program/tile/tile_fmha_shape.hpp | 12 +- .../tile_window_impl_static_distribution.hpp | 130 +++++- include/ck/utility/amd_buffer_addressing.hpp | 263 ++++++++++- include/ck/utility/amd_inline_asm.hpp | 12 + .../ck/utility/buffer_view_impl_generic.hpp | 4 +- .../ck/utility/buffer_view_impl_global.hpp | 50 +- include/ck/utility/buffer_view_impl_lds.hpp | 4 +- include/ck/utility/buffer_view_impl_vgpr.hpp | 4 +- include/ck/utility/data_type.hpp | 18 +- include/ck/utility/integral_constant.hpp | 3 + include/ck/utility/static_buffer.hpp | 22 + 29 files changed, 1831 insertions(+), 282 deletions(-) create mode 100644 include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp create mode 100644 include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp create mode 100644 include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_default_policy.hpp create mode 100644 include/ck/tile_program/tile/slice_tile_impl_distributed_tensor.hpp create mode 100644 include/ck/tile_program/tile/slice_tile_impl_static_lengths.hpp diff --git a/example/91_tile_program/CMakeLists.txt b/example/91_tile_program/CMakeLists.txt index 27d3c67ad..9bc86eccb 100644 --- a/example/91_tile_program/CMakeLists.txt +++ b/example/91_tile_program/CMakeLists.txt @@ -6,3 +6,16 @@ add_example_executable(example_softmax softmax.cpp) add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp) add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp) add_example_executable(example_fmha_fwd fmha_fwd.cpp) + +# NOTE: this is dangerous since will change the whole kernel to flush denormals +# WIP with compiler team for an exp2 intrinsic..., then remove this +if(NOT DEFINED FMHA_FWD_FAST_EXP2) + set(FMHA_FWD_FAST_EXP2 true) +endif() + +if(FMHA_FWD_FAST_EXP2) +set_source_files_properties(fmha_fwd.cpp PROPERTIES COMPILE_OPTIONS "-DCK_FMHA_FWD_FAST_EXP2=1;-fgpu-flush-denormals-to-zero") +else() +set_source_files_properties(fmha_fwd.cpp PROPERTIES COMPILE_OPTIONS "-DCK_FMHA_FWD_FAST_EXP2=0") +endif() + diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index e34dd3f1f..f7b394ebc 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -51,12 +51,14 @@ using FmhaShapeHDim64 = ck::tile_program::TileFmhaShape; using FmhaShapeHDim128 = ck::tile_program::TileFmhaShape; using FmhaTilePartitionerHDim64 = FmhaFwdTilePartitioner; @@ -112,9 +114,7 @@ float invoker_fmha_kernel(const void* q_ptr, dim3 kGridSize = FmhaKernel::GridSize(batch, nhead, seqlen_q, hdim_v); constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); - constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD - constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize; - constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + constexpr ck::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; constexpr bool is_v_rowmajor = ck::is_same_v; diff --git a/example/91_tile_program/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha_fwd_kernel.hpp index b447de1db..a8cd3fe61 100644 --- a/example/91_tile_program/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha_fwd_kernel.hpp @@ -16,10 +16,11 @@ template struct FmhaFwdKernel { - using TilePartitioner = ck::remove_cvref_t; - using FmhaPipeline = ck::remove_cvref_t; - using EpiloguePipeline = ck::remove_cvref_t; - static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize; + using TilePartitioner = ck::remove_cvref_t; + using FmhaPipeline = ck::remove_cvref_t; + using EpiloguePipeline = ck::remove_cvref_t; + static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; using QDataType = ck::remove_cvref_t; using KDataType = ck::remove_cvref_t; @@ -79,11 +80,18 @@ struct FmhaFwdKernel ck::index_t batch_stride_v, ck::index_t batch_stride_o) { - return Kargs{q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, - seqlen_k, hdim_q, hdim_v, scale, stride_q, - stride_k, stride_v, stride_o, nhead_stride_q, nhead_stride_k, - nhead_stride_v, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, - batch_stride_o}; + return Kargs + { + q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, seqlen_k, hdim_q, hdim_v, +#if CK_FMHA_FWD_FAST_EXP2 + static_cast(scale * C_LOG2E), +#else + scale, +#endif + stride_q, stride_k, stride_v, stride_o, nhead_stride_q, nhead_stride_k, + nhead_stride_v, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, + batch_stride_o + }; } __host__ static constexpr auto GridSize(ck::index_t batch_size_, diff --git a/example/91_tile_program/fmha_fwd_tile_partitioner.hpp b/example/91_tile_program/fmha_fwd_tile_partitioner.hpp index c01cadc4e..d60dce24a 100644 --- a/example/91_tile_program/fmha_fwd_tile_partitioner.hpp +++ b/example/91_tile_program/fmha_fwd_tile_partitioner.hpp @@ -24,7 +24,10 @@ struct FmhaFwdTilePartitioner ck::index_t hdim_v_) { // TODO: this may need tuning - return dim3((seqlen_q_ / kM0) * (hdim_v_ / kN1), batch_size_, nhead_); + return dim3(ck::math::integer_divide_ceil(seqlen_q_, kM0) * + ck::math::integer_divide_ceil(hdim_v_, kN1), + nhead_, + batch_size_); } __device__ auto operator()(ck::index_t /*seqlen_q*/, ck::index_t hdim_v) @@ -35,8 +38,8 @@ struct FmhaFwdTilePartitioner const index_t num_tile_n1 = hdim_v / kN1; const index_t i_block = blockIdx.x; - const index_t i_batch = blockIdx.y; - const index_t i_nhead = blockIdx.z; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; const auto f = [](index_t dividend, index_t divisor) { index_t quotient = dividend / divisor; diff --git a/include/ck/tensor/tensor_view.hpp b/include/ck/tensor/tensor_view.hpp index 0ecfcfa0a..da1c544a9 100644 --- a/include/ck/tensor/tensor_view.hpp +++ b/include/ck/tensor/tensor_view.hpp @@ -53,15 +53,39 @@ struct TensorView // X is vector of DataType. // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::type, typename scalar_type>::type>, bool>::type = false> __host__ __device__ constexpr remove_cvref_t - GetVectorizedElements(const TensorCoord& coord) const + GetVectorizedElements(const TensorCoord& coord, bool_constant = {}) const { return buf_.template Get( coord.GetOffset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); + } + + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::type, + typename scalar_type>::type>, + bool>::type = false> + __host__ __device__ void GetVectorizedElementsRaw(remove_cvref_t& dst, + const TensorCoord& coord) const + { + return buf_.template GetRaw(dst, coord.GetOffset()); + } + + template >::type, + typename scalar_type>::type>, + bool>::type = false> + __host__ __device__ constexpr void AsyncGetVectorizedElements(remove_cvref_t* smem, + const TensorCoord& coord) const + { + return buf_.template AsyncGet(smem, coord.GetOffset(), true /*not used*/); } // X is vector of DataType. diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 5c166b9b6..445004dc6 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -18,6 +18,7 @@ enum struct IndexTransformEnum UnMerge, Replicate, Xor, + Offset, }; template @@ -1401,4 +1402,88 @@ struct Xor : public BaseTransform<2, 2> } }; +template +struct Offset : public BaseTransform<1, 1> +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + OffsetLength offset_length_; + + __host__ __device__ constexpr Offset() = default; + + __host__ __device__ constexpr Offset(const LowLength& low_length, + const OffsetLength& offset_length) + : up_lengths_{make_tuple(low_length)}, offset_length_{offset_length} + { + } + + __host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::Offset; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] + offset_length_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Offset{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("offset_length_: "); + print(offset_length_); + + printf("}"); + } +}; + } // namespace ck diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index 649a36a04..7ce003670 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -119,4 +119,11 @@ __host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_leng return Xor{low_lengths, right_shift}; } +template +__host__ __device__ constexpr auto make_offset_transform(const LowLength& low_length, + const OffsetLength& offset_length) +{ + return Offset{low_length, offset_length}; +} + } // namespace ck diff --git a/include/ck/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp index c49cc91c0..54f9f80b4 100644 --- a/include/ck/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -83,6 +83,59 @@ make_naive_tensor_descriptor(const Tuple& lengths, GuaranteedVectorStrides>{transforms, element_space_size}; } +// tensor descriptor with offset, the offset will not be added into element space size +// only have an information of the starting offset, and will impact on offset calculation +template ::type = false> +__host__ __device__ constexpr auto +make_naive_tensor_descriptor_with_offset(const Tuple& lengths, + const Tuple& strides, + const Offset& offset, + Number = Number<-1>{}, + Number = Number<-1>{}) +{ + const auto desc_0 = [&]() { + const auto element_space_size = detail::calculate_element_space_size_impl( + lengths, strides, Number<0>{}, LongNumber<1>{}); + + const auto transforms = make_tuple(make_offset_transform(element_space_size, offset)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = make_tuple(Sequence<1>{}); + + constexpr auto visible_dim_hidden_ids = Sequence<1>{}; + + using GuaranteedVectorLengths = + typename sequence_merge::type, + Sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, + Sequence>::type; + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; + }(); + + constexpr index_t N = sizeof...(Lengths); + + return transform_tensor_descriptor( + desc_0, + make_tuple(make_embed_transform(lengths, strides)), + make_tuple(Sequence<0>{}), + make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{})); +} + // Lengths... could be: // 1) index_t, which is known at run-time, or // 2) Number<>, which is known at compile-time @@ -123,6 +176,53 @@ make_naive_tensor_descriptor_packed(const Tuple& lengths, GuaranteedVectorStrides>{transforms, element_space_size}; } +template ::type = false> +__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed_with_offset( + const Tuple& lengths, + const Offset& offset, + Number = Number<-1>{}) +{ + const auto desc_0 = [&]() { + const auto element_space_size = + container_reduce(lengths, math::multiplies{}, LongNumber<1>{}); + + const auto transforms = make_tuple(make_offset_transform(element_space_size, offset)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = make_tuple(Sequence<1>{}); + + constexpr auto visible_dim_hidden_ids = Sequence<1>{}; + + using GuaranteedVectorLengths = + typename sequence_merge::type, + Sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, Sequence<1>>::type; + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; + }(); + + constexpr index_t N = sizeof...(Lengths); + + return transform_tensor_descriptor( + desc_0, + make_tuple(make_unmerge_transform(lengths)), + make_tuple(Sequence<0>{}), + make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{})); +} + // Lengths... could be: // 1) index_t, which is known at run-time, or // 2) Number<>, which is known at compile-time diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp new file mode 100644 index 000000000..ad6193cc9 --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV2 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + __device__ void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.GetLengths()[Number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.GetLengths()[Number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = + make_static_distributed_tensor(a_block_dstr); + + a_block_tensor.GetThreadBuffer() = a_block_tensor_tmp.GetThreadBuffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.GetBottomTensorView(), + make_tuple(Number{}, Number{}), + b_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using Array will cause register spill + Array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + StaticallyIndexedArray, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // check C-block-distribution + static_assert(is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths()); + constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetYSlicedThreadData( + merge_sequences(Sequence{}, a_warp_y_index_zeros), + merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetYSlicedThreadData( + merge_sequences(Sequence{}, c_warp_y_index_zeros), + merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.SetYSlicedThreadData( + merge_sequences(Sequence{}, c_warp_y_index_zeros), + merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.GetThreadBuffer()); + }); + }); + }); + } + + __device__ constexpr auto MakeCBlockTile() const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + __device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp new file mode 100644 index 000000000..842b0ce38 --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +template +struct BlockGemmARegBSmemCRegV2CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::At(Number<0>{}); + static constexpr index_t kNWarps = BlockWarps::At(Number<1>{}); + static constexpr index_t kKWarps = BlockWarps::At(Number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() + { + using namespace ck::tile_program::warp; + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_default_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_default_policy.hpp new file mode 100644 index 000000000..f7306e67a --- /dev/null +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_default_policy.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// Default policy for BlockGemmARegBSmemCRegV2 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmARegBSmemCRegV2DefaultPolicy +{ + template + __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() + { + using namespace ck::tile_program::warp; + +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + // FIXME + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#endif + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 9224469e6..752b0cce1 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -40,7 +40,8 @@ struct BlockFmhaPipelineQRKSVS using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once - static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kBlockPerCu = BlockFmhaShape::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kM0 = BlockFmhaShape::kM0; static constexpr index_t kN0 = BlockFmhaShape::kN0; @@ -64,7 +65,7 @@ struct BlockFmhaPipelineQRKSVS operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const KElementFunction& k_element_func, + const KElementFunction& /*k_element_func*/, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, float scale, @@ -85,20 +86,46 @@ struct BlockFmhaPipelineQRKSVS kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], "wrong!"); + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + // K tile in LDS - KDataType* k_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); - auto k_lds = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_window = - make_tile_window(k_lds, make_tuple(Number{}, Number{}), {0, 0}); + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).GetLengths(), + {0, 0, 0}); + }, + Number{}); + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + auto k_lds_load = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), + Policy::template MakeKLdsLoadBlockDescriptor(i_buf).GetLengths(), + {0, 0}); + }, + Number{}); +#else + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().GetLengths(), + {0, 0}); +#endif // V tile in LDS auto v_lds = make_tensor_view( reinterpret_cast(smem_ptr), Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_window = - make_tile_window(v_lds, make_tuple(Number{}, Number{}), {0, 0}); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().GetLengths(), {0, 0}); // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); @@ -110,12 +137,13 @@ struct BlockFmhaPipelineQRKSVS q_dram_block_window_tmp.GetWindowOrigin(), Policy::template MakeQDramTileDistribution()); - auto q = load_tile(q_dram_window); // persistent q register tile + auto q = + load_tile(q_dram_window, integral_constant{}); // persistent q register tile + + __builtin_amdgcn_sched_barrier(0); - auto s_acc = decltype(gemm_0(get_slice_tile(tile_elementwise_in(q_element_func, q), - Sequence<0, 0>{}, - Sequence{}), - k_lds_window)){}; + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; // reduction function for softmax const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; @@ -125,15 +153,13 @@ struct BlockFmhaPipelineQRKSVS using SBlockTileType = decltype(tile_elementwise_in(type_convert, s_acc)); - using PBlockTileType = - decltype(tile_elementwise_in(type_convert, s_acc)); + // using PBlockTileType = + // decltype(tile_elementwise_in(type_convert, s_acc)); using MLBlockTileType = decltype(block_tile_reduce( SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); - using OaccBlockTileType = decltype(gemm_1( - get_slice_tile(PBlockTileType{}, Sequence<0, 0>{}, Sequence{}), - v_lds_window)); + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); // init Oacc, M, L auto o_acc = OaccBlockTileType{}; @@ -152,73 +178,85 @@ struct BlockFmhaPipelineQRKSVS v_dram_block_window_tmp.GetWindowOrigin(), Policy::template MakeVDramTileDistribution()); - auto q_tile = tile_elementwise_in(q_element_func, q); - index_t i_total_loops = 0; + __builtin_amdgcn_sched_barrier(0); + auto k_dram_window = make_tile_window( + k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + // prefetch K tile + async_load_tile(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.GetNumAccess()); + auto q_tile = tile_elementwise_in(q_element_func, q); + __builtin_amdgcn_sched_barrier(0); + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; do { // STAGE 1, QK gemm - auto k_dram_window = make_tile_window( - k_dram_block_window.GetBottomTensorView(), - k_dram_block_window.GetWindowLengths(), - k_dram_block_window.GetWindowOrigin(), - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load - - auto k_block_tile = load_tile(k_dram_window); + tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); // Initialize C + if constexpr(k0_loops > 1) { - move_tile_window(k_dram_window, {0, kK0}); - - tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); // Initialize C - - store_tile(k_lds_window, - tile_elementwise_in(k_element_func, k_block_tile)); // LDS write 0 - k_block_tile = load_tile(k_dram_window); // global read 1 - } - - // index_t i_k0_loops = num_sub_loop_qk - 2; - constexpr index_t k0_loops = kK0BlockLength / kK0; - - if constexpr(k0_loops > 2) - { - static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { - block_sync_lds(); + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile(k_lds_store(Number{})>{}), + k_dram_window); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.GetNumAccess()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); gemm_0(s_acc, get_slice_tile(q_tile, Sequence<0, i_k0 * kK0>{}, Sequence{}), - k_lds_window); - block_sync_lds(); - move_tile_window(k_dram_window, {0, kK0}); - - store_tile( - k_lds_window, - tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 - k_block_tile = load_tile(k_dram_window); // global read i + 2 +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[Number{})>{}]); + +#else + get_slice_tile(k_lds_load, + Sequence<(LdsSeq.At(Number{})) * kN0, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN0, kK0>{})); +#endif }); } - const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile - { // tail - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - Sequence<0, (k0_loops - 2) * kK0>{}, - Sequence{}), - k_lds_window); - block_sync_lds(); + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); - store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); - block_sync_lds(); + async_load_fence(); + __builtin_amdgcn_s_barrier(); + auto v_buf = load_tile(v_dram_window); + __builtin_amdgcn_sched_barrier(0); + { // tail gemm_0(s_acc, get_slice_tile(q_tile, Sequence<0, (k0_loops - 1) * kK0>{}, Sequence{}), - k_lds_window); +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[Number{})>{}]); + +#else + get_slice_tile( + k_lds_load, + Sequence<(LdsSeq.At(Number{})) * kN0, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN0, kK0>{})); +#endif } + __builtin_amdgcn_sched_barrier(1); // STAGE 2, scale softmax +#if !CK_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif const auto s = tile_elementwise_in(type_convert, s_acc); // S{j} @@ -236,12 +274,51 @@ struct BlockFmhaPipelineQRKSVS auto p_compute = make_static_distributed_tensor( s.GetTileDistribution()); // Pcompute{j} + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile(v_dram_window); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * m[i_idx]; +#endif sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); +#else p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); +#endif }); }); @@ -253,8 +330,13 @@ struct BlockFmhaPipelineQRKSVS constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * m[i_idx]; + const auto tmp = math::exp2(scale * m_old[i_idx] - row_max); +#else const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); // FIXME: this use different equation from FA v2 paper, @@ -264,65 +346,75 @@ struct BlockFmhaPipelineQRKSVS }); }); - block_sync_lds(); - if constexpr(ck::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_distributed_tensor(v_shuffle_tmp, v_prefetch); - store_tile( - v_lds_window, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch - } - move_tile_window(v_dram_window, {0, kK1}); - const auto p = tile_elementwise_in(type_convert, p_compute); // STAGE 3, KV gemm - constexpr index_t k1_loops = kN0 / kK1; if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile(v_dram_window); // load next v_buf + } block_sync_lds(); gemm_1(o_acc, get_slice_tile( p, Sequence<0, i_k1 * kK1>{}, Sequence{}), - v_lds_window); - block_sync_lds(); + get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{})); + if constexpr(ck::is_same_v) { auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_distributed_tensor(v_shuffle_tmp, v); - store_tile(v_lds_window, + shuffle_distributed_tensor(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch } else { store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v)); // store next v + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf } - move_tile_window(v_dram_window, {0, kK1}); + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); }); } - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window = + make_tile_window(k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); + + if constexpr(k1_loops >= 2 && + LdsSeq.At(Number<0>{}) == LdsSeq.At(Number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + } // tail { block_sync_lds(); - gemm_1(o_acc, - get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), - v_lds_window); - block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), + get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{})); } } while(i_total_loops < num_total_loop); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp index 401fb770d..6c125899c 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -15,8 +15,13 @@ #include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +// TODO: remove this +#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0 + namespace ck { namespace tile_program { namespace block { @@ -24,6 +29,58 @@ namespace block { // This pipeline is qkv all located in LDS struct BlockFmhaPipelineQRKSVSDefaultPolicy { + static constexpr index_t KLdsBuffers = 3; + static constexpr index_t VLdsBuffers = 3; + + template + struct LdsBufferSequence + { + static constexpr auto Make() + { + return transform_sequences( + [&](auto i) { + if(i < k_loops_) + return i % k_bufs_; + return (i - k_loops_) % v_bufs_; + }, + typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{}); + }; + + using type = remove_cvref_t; + }; + // clang-format off + template<> struct + LdsBufferSequence<3, 3, 4, 4> { using type = Sequence<1, 2, 0, 1, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 4, 2> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 4> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 3, 3> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 2> { using type = Sequence<1, 2, 1, 0>;}; + // clang-format on + + template + __host__ __device__ static constexpr auto GetLdsBufferSequence() + { + using BlockFmhaShape = remove_cvref_t; + + constexpr index_t kN0 = BlockFmhaShape::kN0; + constexpr index_t kK0 = BlockFmhaShape::kK0; + constexpr index_t kK1 = BlockFmhaShape::kK1; + constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + return typename LdsBufferSequence::type{}; + } + template __host__ __device__ static constexpr auto GetSmemKPackK() { @@ -40,9 +97,66 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy return 16 / sizeof(VDataType); } template - __host__ __device__ static constexpr auto GetTransposedVectorloadV() + __host__ __device__ static constexpr auto GetVectorloadV() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + __host__ __device__ static constexpr auto GetSingleSmemElementSpaceSize() { - return 4; // TODO: fix me + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = warpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (warpSize * KVector + kPad); + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return math::max(SingleKSize, SingleVSize); + } + + template + __host__ __device__ static constexpr auto GetVectorloadK() + { + using KDataType = remove_cvref_t; + return 4 / sizeof(KDataType); // TODO: this is for async copy } template @@ -77,7 +191,7 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy return q_block_dstr; } - // 3d + padding +#if 0 template __host__ __device__ static constexpr auto MakeKLdsBlockDescriptor() { @@ -100,32 +214,170 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy return k_lds_block_desc; } +#endif - // 3d + padding - template - __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() + template + __host__ __device__ static constexpr auto + MakeKLdsStoreBlockDescriptor(Number = Number<0>{}) { -#if 0 - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kPad = 1; - constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = + KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + warpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(Number{}, // n0 + Number{}, // n1 + Number{}, // n2 + Number{}, // k0 + Number{}), // k1 + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{}), + Number()>{}, + Number{}, + Number<1>{}); - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number{}), - make_tuple(Number<(kNPerBlock + kPad) * kKPack>{}, Number{}, Number<1>{}), - Number{}, + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + template + __host__ __device__ static constexpr auto + MakeKLdsLoadBlockDescriptor(Number = Number<0>{}) + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; // for async-copy, this pad is between warps + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(Number{}, // n0 + Number{}, // n2 + Number{}, // n1 + Number{}, // k0 + Number{}), // k1 + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{}), + Number()>{}, + Number{}, Number<1>{}); - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0, 2, 1>{}, Sequence<3, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - return v_lds_block_desc; + return k_lds_block_desc; + } #else + template + __host__ __device__ static constexpr auto MakeKLdsLoadBlockDescriptor() + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; // for async-copy, this pad is between warps + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + // constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad); + // constexpr index_t SingleVSize = MakeVLdsBlockDescriptor().GetElementSpaceSize(); + constexpr index_t BufferSize = + GetSingleSmemElementSpaceSize(); // math::max(SingleKSize, SingleVSize); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(Number{}, // num_buffers + Number{}, // n0 + Number{}, // n2 + Number{}, // n1 + Number{}, // k0 + Number{}), // k1 + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{}), + Number{}, + Number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0, 1, 3, 2>{}, Sequence<4, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return k_lds_block_desc; + } +#endif + + // 3d + padding + template + __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() + { using VDataType = remove_cvref_t; constexpr index_t Banks = 32; // TODO: need change based on arch constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); @@ -138,11 +390,13 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy static_assert(kKPerBlock % kKPack == 0); constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, + make_tuple(Number{}, + Number{}, Number{}, Number{}, Number{}), - make_tuple(Number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + make_tuple(Number()>{}, + Number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, Number{}, Number{}, Number<1>{}), @@ -152,13 +406,13 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy constexpr auto v_lds_block_desc = transform_tensor_descriptor( v_lds_block_desc_0, make_tuple( - make_merge_transform(make_tuple(Number{}, Number{})), + make_merge_transform(make_tuple( + Number{}, Number{}, Number{})), make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<1, 2>{}, Sequence<0, 3>{}), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); return v_lds_block_desc; -#endif } template @@ -170,15 +424,11 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy template __host__ __device__ static constexpr ck::index_t GetSmemSize() { - constexpr index_t smem_size_gemm_0 = - GetSmemSizeQ() + sizeof(typename Problem::KDataType) * - MakeKLdsBlockDescriptor().GetElementSpaceSize(); - constexpr index_t smem_size_gemm_1 = - MakeVLdsBlockDescriptor().GetElementSpaceSize() * - sizeof(typename Problem::VDataType); - - // TODO: consider shuffle requirement - return math::max(smem_size_gemm_0, smem_size_gemm_1); + // TODO: assume Q is in register + constexpr index_t single_smem_size = + GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + + return single_smem_size * KLdsBuffers; } template @@ -204,24 +454,23 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy Tuple, Sequence>, Tuple, Sequence<2, 1>>, Tuple, Sequence<1, 2>>, - Sequence<2, 1, 2>, + Sequence<1, 2, 2>, Sequence<0, 0, 2>>{}); } template __host__ __device__ static constexpr auto MakeKDramTileDistribution() { +#if 0 // coalesce reading for each blocks using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t K1 = 16 / sizeof(KDataType); constexpr index_t K0 = kKPerBlock / K1; constexpr index_t N2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kNPerBlock / (N2 * N1); @@ -232,17 +481,34 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy Tuple, Sequence<2, 0>>, Sequence<1, 2>, Sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t N0 = kBlockSize / get_warp_size(); - constexpr index_t N1 = kNPerBlock / (N2 * N0); +#else + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KVector = GetVectorloadK(); // this is for global load + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; return make_static_tile_distribution( StaticTileDistributionEncoding, Tuple, Sequence>, Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, + Tuple, Sequence<1, 0>>, Sequence<1, 2>, - Sequence<1, 1>>{}); + Sequence<0, 1>>{}); #endif } @@ -258,7 +524,7 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy if constexpr(ck::is_same_v) { - constexpr index_t N1 = GetTransposedVectorloadV(); + constexpr index_t N1 = GetVectorloadV(); constexpr index_t N0 = kNPerBlock / N1; // P constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -267,17 +533,35 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy constexpr index_t kKPack = GetSmemKPackV(); static_assert(kKPack % K3 == 0); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<2, 1, 2>>, - Tuple, Sequence<1, 0, 2>>, - Sequence<2, 1>, - Sequence<3, 1>>{}); + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + StaticTileDistributionEncoding< + Sequence<1>, + Tuple, Sequence>, + Tuple, Sequence<2, 1, 2>>, + Tuple, Sequence<1, 0, 2>>, + Sequence<2, 1>, + Sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + StaticTileDistributionEncoding< + Sequence<1>, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<0, 2>>, + Sequence<2, 1>, + Sequence<3, 1>>{}); + } } else { @@ -307,7 +591,7 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t N1 = GetTransposedVectorloadV(); + constexpr index_t N1 = GetVectorloadV(); constexpr index_t N0 = kNPerBlock / N1; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; static_assert(total_pixels % N1 == 0); // TODO: this is not always true? @@ -315,16 +599,33 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy constexpr index_t kKPack = GetSmemKPackV(); static_assert(kKPack % K3 == 0); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<2, 1, 2>>, - Tuple, Sequence<1, 0, 2>>, - Sequence<1, 2>, - Sequence<1, 3>>{}); + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1, 2>>, + Tuple, Sequence<1, 0, 2>>, + Sequence<1, 2>, + Sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<0, 2>>, + Sequence<1, 2>, + Sequence<1, 3>>{}); + } } template @@ -345,13 +646,13 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy 2>>; using BlockGemmPolicy = - BlockGemmARegBSmemCRegV1CustomPolicy; - return BlockGemmARegBSmemCRegV1{}; + return BlockGemmARegBSmemCRegV2{}; } template @@ -375,12 +676,12 @@ struct BlockFmhaPipelineQRKSVSDefaultPolicy Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<2>{}), true>; using BlockGemmPolicy = - BlockGemmARegBSmemCRegV1CustomPolicy; - return BlockGemmARegBSmemCRegV1{}; + return BlockGemmARegBSmemCRegV2{}; } }; diff --git a/include/ck/tile_program/tile/load_tile.hpp b/include/ck/tile_program/tile/load_tile.hpp index b2ea61afa..987619043 100644 --- a/include/ck/tile_program/tile/load_tile.hpp +++ b/include/ck/tile_program/tile/load_tile.hpp @@ -19,13 +19,49 @@ namespace tile_program { template + index_t NumCoord, + bool use_inline_asm = false> __device__ auto load_tile(const TileWindowWithStaticDistribution& tile_window, + bool_constant = {}) +{ + return tile_window.Load(bool_constant{}); +} + +template +__device__ void load_tile(StaticDistributedTensor& dst_tensor, + const TileWindowWithStaticDistribution& tile_window) { - return tile_window.Load(); + tile_window.LoadRaw(dst_tensor); +} + +template +__device__ auto async_load_tile(LdsTileWindow_&& lds_tile, + const TileWindowWithStaticDistribution& tile_window) +{ + return tile_window.AsyncLoad(lds_tile); +} + +__device__ auto async_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } } // namespace tile_program diff --git a/include/ck/tile_program/tile/slice_tile.hpp b/include/ck/tile_program/tile/slice_tile.hpp index e7999f26a..7de77db49 100644 --- a/include/ck/tile_program/tile/slice_tile.hpp +++ b/include/ck/tile_program/tile/slice_tile.hpp @@ -3,64 +3,5 @@ #pragma once -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" -#include "ck/tensor_description/tensor_space_filling_curve.hpp" - -#include "ck/tile_program/tile/tile_distribution.hpp" -#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" -#include "ck/tile_program/tile/tile_window.hpp" -#include "ck/tile_program/tile/static_distributed_tensor.hpp" - -namespace ck { -namespace tile_program { - -template -__host__ __device__ constexpr auto get_slice_tile(const StaticDistributedTensor_& tile, - Sequence slice_begins, - Sequence slice_ends) -{ - using Distribution = decltype(StaticDistributedTensor_::GetTileDistribution()); - using DataType = typename StaticDistributedTensor_::DataType; - - constexpr auto sliced_dstr_yidx_ylen = - detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends); - - constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); - constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); - constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); - - auto sliced_tensor = make_static_distributed_tensor(sliced_dstr); - - sliced_tensor.GetThreadBuffer() = tile.GetYSlicedThreadData(sliced_y_origins, sliced_y_lengths); - - return sliced_tensor; -} - -template -__host__ __device__ constexpr auto set_slice_tile(DstStaticDistributedTensor_& dst_tile, - const SrcStaticDistributedTensor_& src_tile, - Sequence slice_begins, - Sequence slice_ends) -{ - using DstDistribution = decltype(DstStaticDistributedTensor_::GetTileDistribution()); - - constexpr auto sliced_dstr_yidx_ylen = - detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); - - constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); - constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); - constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); - - static_assert(is_same_v, "wrong!"); - - dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer()); -} - -} // namespace tile_program -} // namespace ck +#include "ck/tile_program/tile/slice_tile_impl_distributed_tensor.hpp" +#include "ck/tile_program/tile/slice_tile_impl_static_lengths.hpp" diff --git a/include/ck/tile_program/tile/slice_tile_impl_distributed_tensor.hpp b/include/ck/tile_program/tile/slice_tile_impl_distributed_tensor.hpp new file mode 100644 index 000000000..2b31b449b --- /dev/null +++ b/include/ck/tile_program/tile/slice_tile_impl_distributed_tensor.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" +#include "ck/tile_program/tile/tile_window.hpp" +#include "ck/tile_program/tile/static_distributed_tensor.hpp" + +namespace ck { +namespace tile_program { + +template +__host__ __device__ constexpr auto +get_slice_tile(const StaticDistributedTensor& tile, + Sequence slice_begins, + Sequence slice_ends) +{ + using DataType = remove_cvref_t; + using Distribution = remove_cvref_t; + + constexpr auto sliced_dstr_yidx_ylen = + detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends); + + constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); + constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); + constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); + + auto sliced_tensor = make_static_distributed_tensor(sliced_dstr); + + sliced_tensor.GetThreadBuffer() = tile.GetYSlicedThreadData(sliced_y_origins, sliced_y_lengths); + + return sliced_tensor; +} + +template +__host__ __device__ constexpr auto +set_slice_tile(StaticDistributedTensor& dst_tile, + const StaticDistributedTensor& src_tile, + Sequence slice_begins, + Sequence slice_ends) +{ + using DstDistribution = remove_cvref_t; + + constexpr auto sliced_dstr_yidx_ylen = + detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); + + constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>(); + constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>(); + constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>(); + + static_assert(is_same_v, "wrong!"); + + dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer()); +} + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/slice_tile_impl_static_lengths.hpp b/include/ck/tile_program/tile/slice_tile_impl_static_lengths.hpp new file mode 100644 index 000000000..de7aa8f03 --- /dev/null +++ b/include/ck/tile_program/tile/slice_tile_impl_static_lengths.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" +#include "ck/tile_program/tile/tile_window.hpp" +#include "ck/tile_program/tile/tile_window_impl_static_lengths.hpp" + +namespace ck { +namespace tile_program { + +template +__host__ __device__ constexpr auto +get_slice_tile(const TileWindowWithStaticLengths& tile, + Sequence slice_begins, + Sequence slice_ends) +{ + using TileWindow = TileWindowWithStaticLengths; + // NOTE: This API will override the origin of the tile window! + static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds)); + static_assert(sizeof...(SliceBegins) == TileWindow::GetNumOfDimension()); + + constexpr auto slice_lengths = slice_ends - slice_begins; + + return make_tile_window(tile.GetBottomTensorView(), + sequence_to_tuple_of_number(slice_lengths), + to_multi_index(slice_begins)); +} + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/static_distributed_tensor.hpp b/include/ck/tile_program/tile/static_distributed_tensor.hpp index 57b3c418c..2e2f34124 100644 --- a/include/ck/tile_program/tile/static_distributed_tensor.hpp +++ b/include/ck/tile_program/tile/static_distributed_tensor.hpp @@ -46,11 +46,15 @@ struct StaticDistributedTensor return StaticTileDistribution::GetDistributedSpans(); } - __host__ __device__ void Initialize(const DataType& x) { thread_buf_.Initialize(x); } + __host__ __device__ void Initialize(const DataType& x) { thread_buf_.arr.Initialize(x); } - __host__ __device__ constexpr const auto& GetThreadBuffer() const { return thread_buf_; } + __host__ __device__ constexpr const auto& GetThreadBuffer() const { return thread_buf_.arr; } - __host__ __device__ constexpr auto& GetThreadBuffer() { return thread_buf_; } + __host__ __device__ constexpr auto& GetThreadBuffer() { return thread_buf_.arr; } + + __host__ __device__ constexpr const auto& GetThreadBufferRaw() const { return thread_buf_.vec; } + + __host__ __device__ constexpr auto& GetThreadBufferRaw() { return thread_buf_.vec; } __host__ __device__ static constexpr index_t GetThreadBufferSize() { @@ -59,7 +63,7 @@ struct StaticDistributedTensor template __host__ __device__ auto GetYSlicedThreadData(Sequence, - Sequence) const + Sequence) const { static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY && sizeof...(YSliceLengths) == StaticTileDistribution::NDimY, @@ -78,7 +82,7 @@ struct StaticDistributedTensor constexpr auto idx_ys = idx + Sequence{}; sliced_thread_data(Number{}) = - thread_buf_[Number{}]; + thread_buf_.arr[Number{}]; }); return sliced_thread_data; @@ -100,7 +104,7 @@ struct StaticDistributedTensor static_ford>{}([&](auto idx) { constexpr auto idx_ys = idx + Sequence{}; - thread_buf_(Number{}) = + thread_buf_.arr(Number{}) = sliced_thread_data[Number{}]; }); } @@ -114,7 +118,7 @@ struct StaticDistributedTensor constexpr auto y_idx = GetTileDistribution().GetYIndicesFromDistributedIndices(TileDistributedIndices{}); - return thread_buf_[Number{}]; + return thread_buf_.arr[Number{}]; } template @@ -126,20 +130,20 @@ struct StaticDistributedTensor constexpr auto y_idx = GetTileDistribution().GetYIndicesFromDistributedIndices(TileDistributedIndices{}); - return thread_buf_(Number{}); + return thread_buf_.arr(Number{}); } #if 0 template __host__ __device__ auto GetElementFromYsIndex(Sequence idx_ys) const { - return thread_buf_[Number{}]; + return thread_buf_.arr[Number{}]; } template __host__ __device__ void SetElementFromYsIndex(Sequence idx_ys, const DataType& v) { - thread_buf_(Number{}) = v; + thread_buf_.arr(Number{}) = v; } template __host__ __device__ auto GetElementFromTileDistributedIndices(TileDistributedIndices) const @@ -166,7 +170,12 @@ struct StaticDistributedTensor #endif // - StaticBuffer thread_buf_; + union _U + { + StaticBuffer arr; + vector_type vec{}; + } thread_buf_; + static_assert(sizeof(thread_buf_) == sizeof(DataType) * kThreadElementSpaceSize); }; template diff --git a/include/ck/tile_program/tile/tile_fmha_shape.hpp b/include/ck/tile_program/tile/tile_fmha_shape.hpp index c577a3943..adbc96d1c 100644 --- a/include/ck/tile_program/tile/tile_fmha_shape.hpp +++ b/include/ck/tile_program/tile/tile_fmha_shape.hpp @@ -5,6 +5,8 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/math.hpp" namespace ck { namespace tile_program { @@ -14,7 +16,8 @@ template + index_t kBlockPerCu_ = 2, // hint to occupancy + typename VLayout_ = ck::tensor_layout::gemm::RowMajor> struct TileFmhaShape { using BlockTile = remove_cvref_t; @@ -23,6 +26,12 @@ struct TileFmhaShape using Gemm1BlockWarps = remove_cvref_t; using Gemm1WarpTile = remove_cvref_t; + static constexpr index_t NumWarps = + reduce_on_sequence(Gemm0BlockWarps{}, math::multiplies{}, Number<1>{}); + + static_assert(NumWarps == + reduce_on_sequence(Gemm1BlockWarps{}, math::multiplies{}, Number<1>{})); + static constexpr index_t kM0 = BlockTile::At(Number<0>{}); // tile size along q seqlen static constexpr index_t kN0 = BlockTile::At(Number<1>{}); // tile size along k seqlen static constexpr index_t kK0 = BlockTile::At(Number<2>{}); // tile size along qk gemm unroll @@ -32,6 +41,7 @@ struct TileFmhaShape BlockTile::At(Number<5>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) + static constexpr index_t kBlockPerCu = kBlockPerCu_; using VLayout = remove_cvref_t; // rowmajor : seqlen*hdim, colmajor : hdim*seqlen }; diff --git a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp index f67fa6a28..aaedad879 100644 --- a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp +++ b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp @@ -261,7 +261,64 @@ struct TileWindowWithStaticDistribution get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); } - __device__ auto Load() const + __device__ auto MakeLoadBuffer() const + { + return make_static_distributed_tensor(TileDstr{}); + } + + template + __device__ void LoadRaw(T& buf) const + { + using Traits = LoadStoreTraits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = Number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess); + + constexpr auto idx_ys = + generate_array([&](auto jj) { return idx_ys_start[jj]; }, Number{}); + + constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys); + // if constexpr(iCoordAccess == 1) + // Number{}.foo(); + GetBottomTensorView().template GetVectorizedElementsRaw( + buf.GetThreadBufferRaw().template AsType()( + Number{}), + bottom_tensor_thread_coord); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(Array{0}, idx_diff_ys); + + MoveWindowAdaptorAndBottomTensorThreadCoordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + __device__ constexpr auto GetNumAccess() const { return LoadStoreTraits::NumAccess; } + + template + __device__ auto Load(bool_constant = {}) const { using Traits = LoadStoreTraits; @@ -288,7 +345,7 @@ struct TileWindowWithStaticDistribution // read from bottom tensor const vector_t vec_value = GetBottomTensorView().template GetVectorizedElements( - bottom_tensor_thread_coord); + bottom_tensor_thread_coord, bool_constant{}); const vector_type_t vec{vec_value}; @@ -324,6 +381,75 @@ struct TileWindowWithStaticDistribution return dst_tensor; } + template + __device__ auto AsyncLoad(LdsTileWindow_&& lds_tile) const + { + using LdsTileWindow = remove_cvref_t; + // using LdsTensorView = typename LdsTileWindow::BottomTensorView; + using LdsDataType = typename LdsTileWindow::DataType; + // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc; + + // issues * warps * lanes + static_assert(LdsTileWindow::GetNumOfDimension() == 3); // TODO: hard coded + + const index_t size_per_buf = + lds_tile.GetBottomTensorView().GetTensorDescriptor().CalculateOffset( + make_tuple(Number<0>{}, Number<0>{}, Number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.GetBottomTensorView().GetTensorDescriptor().CalculateOffset( + make_tuple(Number<0>{}, Number<1>{}, Number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.GetBottomTensorView().GetTensorDescriptor().CalculateOffset( + make_tuple(Number<1>{}, Number<0>{}, Number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + m0_set_with_memory(m0_init_value); // This should be wave independent + + using Traits = LoadStoreTraits; + + using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename vector_type_t::type; + using SFC_Ys = typename Traits::SFC_Ys; + + LdsDataType* smem = lds_tile.GetBottomTensorView().GetBufferView().p_data_; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + // TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = Number{}; + + // read from bottom tensor + GetBottomTensorView().template AsyncGetVectorizedElements( + smem, bottom_tensor_thread_coord); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(Array{0}, idx_diff_ys); + + MoveWindowAdaptorAndBottomTensorThreadCoordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + m0_inc_with_memory(size_per_issue); + } + }); + }); + } + __device__ void Store(const StaticDistributedTensor& dstr_tensor) const { using Traits = LoadStoreTraits; diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index acdb35fcc..22044fa78 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -49,6 +49,127 @@ __device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) return wave_buffer_resource.content; } +// TODO: glc/slc/... +template +struct buffer_load; + +template <> +struct buffer_load<16> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 16); + asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<8> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 8); + asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<4> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 4); + asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<2> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 2); + asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<1> +{ + template + __device__ void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 1); + asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" + : "+v"(value) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template +__device__ void buffer_load_fence(T& target, index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); + auto& buf = target.GetThreadBuffer(); + constexpr index_t buf_size = buf.Size(); + static_for<0, buf_size, 1>{}([&buf](auto i) { asm volatile("" : "+v"(buf(i)) : : "memory"); }); + // using type = typename remove_cvref_t::type; + // asm volatile("" : "+X"(target.GetThreadBufferRaw().template AsType()(Number<0>{})) : : + // "memory"); + + // asm volatile("s_waitcnt vmcnt(%1)" : "+X"(target) : "n"(cnt) : "memory"); + + // asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); + // asm volatile("" : "=X"(target)); +} + +__device__ void buffer_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + // buffer load i8 __device__ int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, @@ -286,6 +407,24 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); +__device__ void async_buffer_load_fp32(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0) +{ + asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) + : "memory"); +} + +__device__ void async_buffer_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + // memory coherency bit for buffer store/load instruction // check ISA manual for each GFX target // e.g. for @@ -402,9 +541,14 @@ amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, } } +#ifndef BUFFER_LOAD_USE_INLINEASM +#define BUFFER_LOAD_USE_INLINEASM 0 +#endif + template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + bool use_inline_asm = false> __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) @@ -420,7 +564,15 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) // fp32 + if constexpr(use_inline_asm) + { + using type = typename vector_type::type; + type tmp; + buffer_load{}( + tmp, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return tmp; + } + else if constexpr(is_same::value) // fp32 { if constexpr(N == 1) { @@ -540,6 +692,58 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w } } +template +__device__ void amd_buffer_load_raw_impl(typename vector_type::type& dst, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); +#if BUFFER_LOAD_USE_INLINEASM + using type = typename vector_type::type; + buffer_load{}( + dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + (void)dst; + (void)src_wave_buffer_resource; + (void)src_thread_addr_offset; + (void)src_wave_addr_offset; +#endif +} + +template +__device__ void amd_async_buffer_load_impl(T* smem, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset, + index_t src_immediate_addr_offset = 0) +{ + static_assert( + (is_same::value && (N == 1)) || (is_same::value && (N == 2)) || + (is_same::value && (N == 2)) || (is_same::value && (N == 1)) || + (is_same::value && (N == 4)), + "wrong! not implemented"); + if constexpr(sizeof(T) * N == 4) + { + async_buffer_load_fp32(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset); + } +} + template __device__ void amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, @@ -1031,7 +1235,8 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type::typ // It is user's responsibility to make sure that is true. template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + bool use_inline_asm = false> __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, @@ -1050,12 +1255,12 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; - return amd_buffer_load_impl( + return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(0); #endif @@ -1067,7 +1272,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, // It is user's responsibility to make sure that is true. template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + bool use_inline_asm> __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, @@ -1085,12 +1291,55 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(customized_value); } +template +__device__ void amd_buffer_load_raw(typename vector_type_maker::type::type& dst, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + + amd_buffer_load_raw_impl( + dst, src_wave_buffer_resource, src_thread_addr_offset, 0); +} + +// unfortunately async copy can not make sure invalid data is zero inside LDS +// ... unless people manually write zero to LDS at the proper address. +// so not support invalid_element check for now. +// buffer_load OOB still working. +template +__device__ void amd_async_buffer_load_with_oob(T* smem, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_async_buffer_load_impl( + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); +} + // buffer_store requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 43baa817d..2a43e2b57 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -367,5 +367,17 @@ __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, f #endif } +// TODO: we have "memory" clobber here because this inline asm is used for async copy +__device__ void m0_set_with_memory(index_t v) +{ + asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory"); +} + +// NOTE: this is an immediate value +__device__ void m0_inc_with_memory(index_t v) +{ + asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory"); +} + } // namespace ck #endif diff --git a/include/ck/utility/buffer_view_impl_generic.hpp b/include/ck/utility/buffer_view_impl_generic.hpp index 78c7b8e9a..1b88bf5c4 100644 --- a/include/ck/utility/buffer_view_impl_generic.hpp +++ b/include/ck/utility/buffer_view_impl_generic.hpp @@ -60,10 +60,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> - __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __device__ constexpr auto + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; diff --git a/include/ck/utility/buffer_view_impl_global.hpp b/include/ck/utility/buffer_view_impl_global.hpp index f8d716ca5..621509407 100644 --- a/include/ck/utility/buffer_view_impl_global.hpp +++ b/include/ck/utility/buffer_view_impl_global.hpp @@ -63,10 +63,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> - __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __device__ constexpr auto + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -90,14 +92,16 @@ struct BufferView, t_per_x, - Coherence>( + Coherence, + use_inline_asm>( p_data_, i, is_valid_element, buffer_size_); } else { return amd_buffer_load_invalid_element_return_customized_value, t_per_x, - Coherence>( + Coherence, + use_inline_asm>( p_data_, i, is_valid_element, buffer_size_, invalid_element_value_); } } @@ -129,6 +133,46 @@ struct BufferView>::type, + typename scalar_type>::type>::value, + bool>::type = false> + __device__ constexpr auto GetRaw(remove_cvref_t& dst, index_t i) const + { + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_load_raw, t_per_x, Coherence>(dst, p_data_, i, buffer_size_); + } + + // i is offset of T, not X. i should be aligned to X + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __device__ constexpr auto + AsyncGet(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const + { + // X is vector of T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_async_buffer_load_with_oob, t_per_x, Coherence>( + smem, p_data_, i, buffer_size_); + } + // i is offset of T, not X. i should be aligned to X template >::type, typename scalar_type>::type>::value, bool>::type = false> - __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __device__ constexpr auto + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; diff --git a/include/ck/utility/buffer_view_impl_vgpr.hpp b/include/ck/utility/buffer_view_impl_vgpr.hpp index 4c3e94884..15bdf1354 100644 --- a/include/ck/utility/buffer_view_impl_vgpr.hpp +++ b/include/ck/utility/buffer_view_impl_vgpr.hpp @@ -60,10 +60,12 @@ struct BufferView>::type, typename scalar_type>::type>::value, bool>::type = false> - __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __device__ constexpr auto + Get(index_t i, bool is_valid_element, bool_constant = {}) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index dc1b59ac9..0bc50cda7 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -178,7 +178,7 @@ struct vector_type StaticallyIndexedArray d1x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -214,7 +214,7 @@ struct vector_type StaticallyIndexedArray d2x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -266,7 +266,7 @@ struct vector_type StaticallyIndexedArray d4x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -330,7 +330,7 @@ struct vector_type StaticallyIndexedArray d8x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -406,7 +406,7 @@ struct vector_type StaticallyIndexedArray d16x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -494,7 +494,7 @@ struct vector_type StaticallyIndexedArray d32x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -592,7 +592,7 @@ struct vector_type StaticallyIndexedArray d64x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -702,7 +702,7 @@ struct vector_type StaticallyIndexedArray d128x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -822,7 +822,7 @@ struct vector_type StaticallyIndexedArray d256x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __host__ __device__ constexpr vector_type() : data_{} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 04c07bc4c..9e8d3771b 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -50,4 +50,7 @@ __host__ __device__ constexpr auto operator%(integral_constant, integral_ return integral_constant{}; } +template +using bool_constant = integral_constant; + } // namespace ck diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index 0ccebd476..85d69d1e1 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -73,6 +73,28 @@ struct StaticBuffer : public StaticallyIndexedArray, N> return vx.template AsType().template At<0>(); } + // Get a vector (type X) + // "is" is offset of S, not X. + // "is" should be aligned to X + template ::value, bool>::type = false> + __host__ __device__ constexpr remove_reference_t& GetAsTypeRaw(Number is) + { + using X = remove_cvref_t; + + constexpr index_t kSPerX = scalar_type::vector_size; + static_assert(Is % kSPerX == 0, "wrong! \"Is\" should be aligned to X"); + + using new_type = + StaticBuffer; + static_assert(sizeof(new_type) == sizeof(*this)); + + // auto & new_this = __builtin_bit_cast(new_type, *this); + + return __builtin_bit_cast(new_type, *this).operator()(Number{}); + } + // Set a vector (type X) // "is" is offset of S, not X. // "is" should be aligned to X From f90c80a929bd88001bd1bbe773daf5e87479a35a Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 1 Dec 2023 02:07:26 -0600 Subject: [PATCH 02/45] modify stream config --- example/91_tile_program/fmha_fwd.cpp | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index f7b394ebc..339356419 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -109,7 +109,8 @@ float invoker_fmha_kernel(const void* q_ptr, ck::index_t hdim_v, float scale, bool i_perm, - bool o_perm) + bool o_perm, + StreamConfig stream_config) { dim3 kGridSize = FmhaKernel::GridSize(batch, nhead, seqlen_q, hdim_v); constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); @@ -153,7 +154,7 @@ float invoker_fmha_kernel(const void* q_ptr, nhead * hdim_v * seqlen_k, // batch_stride_v nhead * seqlen_q * hdim_v); // batch_stride_o - float ave_time = launch_kernel(StreamConfig{nullptr, true}, + float ave_time = launch_kernel(stream_config, FmhaKernel{}, kGridSize, kBlockSize, @@ -162,6 +163,15 @@ float invoker_fmha_kernel(const void* q_ptr, return ave_time; } +static inline int env_get_int(const char* var_name, int default_int) +{ + char* v = getenv(var_name); + int r = default_int; + if(v) + r = atoi(v); + return r; +} + int main(int argc, char* argv[]) { int do_validation = 1; @@ -199,6 +209,11 @@ int main(int argc, char* argv[]) if(scale == .0f) scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + int stream_warmup = env_get_int("CK_WARMUP", 5); + int stream_repeat = env_get_int("CK_REPEAT", 20); + + StreamConfig stream_config {nullptr, true, 0, stream_warmup, stream_repeat}; + auto get_lengths = [&](bool permute, ck::index_t b /*batch*/, ck::index_t h /*nhead*/, @@ -258,7 +273,8 @@ int main(int argc, char* argv[]) hdim_v, scale, i_perm, - o_perm); + o_perm, + stream_config); else if(hdim_q == hdim_v && hdim_q == 128) ave_time = invoker_fmha_kernel(q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), @@ -272,7 +288,8 @@ int main(int argc, char* argv[]) hdim_v, scale, i_perm, - o_perm); + o_perm, + stream_config); else { std::cout << "not support hdim, will not run" << std::endl; From 94e772308037df06988563103b5beb02885f4304 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 1 Dec 2023 04:14:21 -0600 Subject: [PATCH 03/45] mofidy some internal API --- .../block_fmha_pipeline_qr_ks_vs.hpp | 3 +- include/ck/tile_program/tile/load_tile.hpp | 24 ++++----- .../tile/static_distributed_tensor.hpp | 29 ++++------- .../tile_window_impl_static_distribution.hpp | 49 ------------------- include/ck/utility/amd_buffer_addressing.hpp | 39 +++++---------- 5 files changed, 32 insertions(+), 112 deletions(-) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 752b0cce1..4102d30ba 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -137,8 +137,7 @@ struct BlockFmhaPipelineQRKSVS q_dram_block_window_tmp.GetWindowOrigin(), Policy::template MakeQDramTileDistribution()); - auto q = - load_tile(q_dram_window, integral_constant{}); // persistent q register tile + auto q = load_tile_raw(q_dram_window); // persistent q register tile __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tile_program/tile/load_tile.hpp b/include/ck/tile_program/tile/load_tile.hpp index 987619043..431fbd18c 100644 --- a/include/ck/tile_program/tile/load_tile.hpp +++ b/include/ck/tile_program/tile/load_tile.hpp @@ -19,30 +19,26 @@ namespace tile_program { template + index_t NumCoord> __device__ auto load_tile(const TileWindowWithStaticDistribution& tile_window, - bool_constant = {}) + NumCoord>& tile_window) { - return tile_window.Load(bool_constant{}); + return tile_window.Load(); } -template -__device__ void load_tile(StaticDistributedTensor& dst_tensor, - const TileWindowWithStaticDistribution& tile_window) +__device__ auto load_tile_raw(const TileWindowWithStaticDistribution& tile_window) { - tile_window.LoadRaw(dst_tensor); + return tile_window.Load(bool_constant{}); } template {}; sliced_thread_data(Number{}) = - thread_buf_.arr[Number{}]; + thread_buf_[Number{}]; }); return sliced_thread_data; @@ -104,7 +100,7 @@ struct StaticDistributedTensor static_ford>{}([&](auto idx) { constexpr auto idx_ys = idx + Sequence{}; - thread_buf_.arr(Number{}) = + thread_buf_(Number{}) = sliced_thread_data[Number{}]; }); } @@ -118,7 +114,7 @@ struct StaticDistributedTensor constexpr auto y_idx = GetTileDistribution().GetYIndicesFromDistributedIndices(TileDistributedIndices{}); - return thread_buf_.arr[Number{}]; + return thread_buf_[Number{}]; } template @@ -130,20 +126,20 @@ struct StaticDistributedTensor constexpr auto y_idx = GetTileDistribution().GetYIndicesFromDistributedIndices(TileDistributedIndices{}); - return thread_buf_.arr(Number{}); + return thread_buf_(Number{}); } #if 0 template __host__ __device__ auto GetElementFromYsIndex(Sequence idx_ys) const { - return thread_buf_.arr[Number{}]; + return thread_buf_[Number{}]; } template __host__ __device__ void SetElementFromYsIndex(Sequence idx_ys, const DataType& v) { - thread_buf_.arr(Number{}) = v; + thread_buf_(Number{}) = v; } template __host__ __device__ auto GetElementFromTileDistributedIndices(TileDistributedIndices) const @@ -170,12 +166,7 @@ struct StaticDistributedTensor #endif // - union _U - { - StaticBuffer arr; - vector_type vec{}; - } thread_buf_; - static_assert(sizeof(thread_buf_) == sizeof(DataType) * kThreadElementSpaceSize); + StaticBuffer thread_buf_; }; template diff --git a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp index aaedad879..6a778a207 100644 --- a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp +++ b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp @@ -266,55 +266,6 @@ struct TileWindowWithStaticDistribution return make_static_distributed_tensor(TileDstr{}); } - template - __device__ void LoadRaw(T& buf) const - { - using Traits = LoadStoreTraits; - - using vector_type_t = typename Traits::vector_type_t; - using vector_t = typename vector_type_t::type; - using SFC_Ys = typename Traits::SFC_Ys; - - constexpr auto tile_dstr = TileDstr{}; - - // loop over thread tensor space [y0, y1, ...] - static_for<0, NumCoord, 1>{}([&](auto iCoord) { - /// TODO: use structure binding (to be captured later) if compiled in C++20 - auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; - auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; - - static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = Number{}; - - // data index [y0, y1, ...] - constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess); - - constexpr auto idx_ys = - generate_array([&](auto jj) { return idx_ys_start[jj]; }, Number{}); - - constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys); - // if constexpr(iCoordAccess == 1) - // Number{}.foo(); - GetBottomTensorView().template GetVectorizedElementsRaw( - buf.GetThreadBufferRaw().template AsType()( - Number{}), - bottom_tensor_thread_coord); - - // move thread coordinate - if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) - { - constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess); - - constexpr auto idx_diff_ps_ys = - container_concat(Array{0}, idx_diff_ys); - - MoveWindowAdaptorAndBottomTensorThreadCoordinate( - window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - } - }); - }); - } - __device__ constexpr auto GetNumAccess() const { return LoadStoreTraits::NumAccess; } template diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 22044fa78..3e2c01455 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -148,23 +148,6 @@ struct buffer_load<1> } }; -template -__device__ void buffer_load_fence(T& target, index_t cnt = 0) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); - auto& buf = target.GetThreadBuffer(); - constexpr index_t buf_size = buf.Size(); - static_for<0, buf_size, 1>{}([&buf](auto i) { asm volatile("" : "+v"(buf(i)) : : "memory"); }); - // using type = typename remove_cvref_t::type; - // asm volatile("" : "+X"(target.GetThreadBufferRaw().template AsType()(Number<0>{})) : : - // "memory"); - - // asm volatile("s_waitcnt vmcnt(%1)" : "+X"(target) : "n"(cnt) : "memory"); - - // asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); - // asm volatile("" : "=X"(target)); -} - __device__ void buffer_load_fence(index_t cnt = 0) { asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); @@ -407,12 +390,12 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); -__device__ void async_buffer_load_fp32(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0) +__device__ void async_buffer_load_dword(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0) { asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" : "=r"(smem) /*dummy dependency for smem*/ @@ -736,11 +719,11 @@ __device__ void amd_async_buffer_load_impl(T* smem, "wrong! not implemented"); if constexpr(sizeof(T) * N == 4) { - async_buffer_load_fp32(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset); + async_buffer_load_dword(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset); } } From 3ab658e240f0f9661dbccd679e08dbc8b1dbca63 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 1 Dec 2023 04:37:46 -0600 Subject: [PATCH 04/45] remove some useleff code --- .../tile_window_impl_static_distribution.hpp | 5 ----- include/ck/utility/data_type.hpp | 18 +++++++-------- include/ck/utility/static_buffer.hpp | 22 ------------------- 3 files changed, 9 insertions(+), 36 deletions(-) diff --git a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp index 6a778a207..523877dc0 100644 --- a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp +++ b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp @@ -261,11 +261,6 @@ struct TileWindowWithStaticDistribution get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); } - __device__ auto MakeLoadBuffer() const - { - return make_static_distributed_tensor(TileDstr{}); - } - __device__ constexpr auto GetNumAccess() const { return LoadStoreTraits::NumAccess; } template diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 0bc50cda7..dc1b59ac9 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -178,7 +178,7 @@ struct vector_type StaticallyIndexedArray d1x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{} {} + __host__ __device__ constexpr vector_type() : data_{type{0}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -214,7 +214,7 @@ struct vector_type StaticallyIndexedArray d2x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{} {} + __host__ __device__ constexpr vector_type() : data_{type{0}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -266,7 +266,7 @@ struct vector_type StaticallyIndexedArray d4x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{} {} + __host__ __device__ constexpr vector_type() : data_{type{0}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -330,7 +330,7 @@ struct vector_type StaticallyIndexedArray d8x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{} {} + __host__ __device__ constexpr vector_type() : data_{type{0}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -406,7 +406,7 @@ struct vector_type StaticallyIndexedArray d16x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{} {} + __host__ __device__ constexpr vector_type() : data_{type{0}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -494,7 +494,7 @@ struct vector_type StaticallyIndexedArray d32x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{} {} + __host__ __device__ constexpr vector_type() : data_{type{0}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -592,7 +592,7 @@ struct vector_type StaticallyIndexedArray d64x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{} {} + __host__ __device__ constexpr vector_type() : data_{type{0}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -702,7 +702,7 @@ struct vector_type StaticallyIndexedArray d128x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{} {} + __host__ __device__ constexpr vector_type() : data_{type{0}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} @@ -822,7 +822,7 @@ struct vector_type StaticallyIndexedArray d256x1_; } data_; - __host__ __device__ constexpr vector_type() : data_{} {} + __host__ __device__ constexpr vector_type() : data_{type{0}} {} __host__ __device__ constexpr vector_type(type v) : data_{v} {} diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index 85d69d1e1..0ccebd476 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -73,28 +73,6 @@ struct StaticBuffer : public StaticallyIndexedArray, N> return vx.template AsType().template At<0>(); } - // Get a vector (type X) - // "is" is offset of S, not X. - // "is" should be aligned to X - template ::value, bool>::type = false> - __host__ __device__ constexpr remove_reference_t& GetAsTypeRaw(Number is) - { - using X = remove_cvref_t; - - constexpr index_t kSPerX = scalar_type::vector_size; - static_assert(Is % kSPerX == 0, "wrong! \"Is\" should be aligned to X"); - - using new_type = - StaticBuffer; - static_assert(sizeof(new_type) == sizeof(*this)); - - // auto & new_this = __builtin_bit_cast(new_type, *this); - - return __builtin_bit_cast(new_type, *this).operator()(Number{}); - } - // Set a vector (type X) // "is" is offset of S, not X. // "is" should be aligned to X From 3380eeb819d601372b3eb45920851136552ddea2 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 4 Dec 2023 02:38:23 -0600 Subject: [PATCH 05/45] support MQA/GQA --- example/91_tile_program/arg_parser.hpp | 171 ++++++++++++++++++ example/91_tile_program/fmha_fwd.cpp | 183 +++++++++++++------- example/91_tile_program/fmha_fwd_kernel.hpp | 12 +- library/include/ck/library/utility/fill.hpp | 41 +++++ 4 files changed, 340 insertions(+), 67 deletions(-) create mode 100644 example/91_tile_program/arg_parser.hpp diff --git a/example/91_tile_program/arg_parser.hpp b/example/91_tile_program/arg_parser.hpp new file mode 100644 index 000000000..501b92788 --- /dev/null +++ b/example/91_tile_program/arg_parser.hpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +/* + * arg parser for + * -[key0]=[value0] -[key1]=[value1] ... + */ +class ArgParser +{ + public: + class Arg + { + public: + std::string name; + std::string value; + std::string help_text; + }; + + ArgParser() {} + ArgParser& insert(const std::string& _name, + const std::string& _default_value, + const std::string& _help_text) + { + Arg in; + in.name = _name; + in.value = _default_value; + in.help_text = _help_text; + + if(input_map.count(_name) != 0) + { + printf("arg:%s already exist\n", _name.c_str()); + } + else + { + input_map[_name] = in; + keys.push_back(_name); + } + return *this; + } + void print() + { + printf("args:\n"); + for(auto& key : keys) + { + auto value = input_map[key]; + std::vector help_text_lines; + size_t pos = 0; + for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;) + { + help_text_lines.push_back(std::string(value.help_text.begin() + pos, + value.help_text.begin() + next_pos++)); + pos = next_pos; + next_pos = value.help_text.find('\n', pos); + } + help_text_lines.push_back( + std::string(value.help_text.begin() + pos, value.help_text.end())); + + std::string default_value = std::string("(default:") + value.value + std::string(")"); + + std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key + << std::setw(4) << " " << help_text_lines[0] << " " << default_value + << std::endl; + + for(auto help_next_line = std::next(help_text_lines.begin()); + help_next_line != help_text_lines.end(); + ++help_next_line) + { + std::cout << std::setw(17) << " " << *help_next_line << std::endl; + } + } + } + bool parse(int argc, char* argv[], int start_index = 1) + { + if(argc < start_index) + { + printf("not enough args\n"); + return false; + } + for(int i = start_index; i < argc; i++) + { + char* cur_arg = argv[i]; + if(cur_arg[0] != '-') + { + printf("illegal input\n"); + print(); + return false; + } + else + { + std::string text(cur_arg + 1); + if(text == "?") + { + print(); + return false; + } + auto pos = text.find('='); + if(pos == std::string::npos) + { + printf("arg should be [key]=[value] pair, here:%s\n", text.c_str()); + return false; + } + if(pos >= (text.size() - 1)) + { + printf("cant find value after \"=\", here:%s\n", text.c_str()); + return false; + } + auto key = text.substr(0, pos); + auto value = text.substr(pos + 1); + if(input_map.count(key) == 0) + { + printf("no such arg:%s\n", key.c_str()); + return false; + } + input_map[key].value = value; + } + } + return true; + } + + std::string get_str(const std::string& name) const + { + std::string value = input_map.at(name).value; + return value; + } + + int get_int(const std::string& name) const + { + int value = atoi(input_map.at(name).value.c_str()); + return value; + } + + uint32_t get_uint32(const std::string& name) const + { + uint32_t value = strtoul(input_map.at(name).value.c_str(), nullptr, 10); + return value; + } + + uint64_t get_uint64(const std::string& name) const + { + uint64_t value = strtoull(input_map.at(name).value.c_str(), nullptr, 10); + return value; + } + + float get_float(const std::string& name) const + { + double value = atof(input_map.at(name).value.c_str()); + return static_cast(value); + } + + double get_double(const std::string& name) const + { + double value = atof(input_map.at(name).value.c_str()); + return value; + } + + private: + std::unordered_map input_map; + std::vector keys; +}; diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index 339356419..159ea27e4 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -26,6 +26,8 @@ #include "fmha_fwd_kernel.hpp" #include "fmha_fwd_tile_partitioner.hpp" #include "fmha_fwd_epilogue.hpp" +#include "arg_parser.hpp" +#include using QDataType = ck::half_t; using KDataType = ck::half_t; @@ -103,6 +105,7 @@ float invoker_fmha_kernel(const void* q_ptr, void* o_ptr, ck::index_t batch, ck::index_t nhead, + ck::index_t nhead_k, ck::index_t seqlen_q, ck::index_t seqlen_k, ck::index_t hdim_q, @@ -120,6 +123,7 @@ float invoker_fmha_kernel(const void* q_ptr, constexpr bool is_v_rowmajor = ck::is_same_v; + assert(nhead % nhead_k == 0); // batch * nhead * seqlen * hdim or batch * seqlen * nhead * hdim auto kargs = FmhaKernel::MakeKargs( q_ptr, @@ -130,14 +134,15 @@ float invoker_fmha_kernel(const void* q_ptr, seqlen_k, // seqlen_k hdim_q, // hdim_q hdim_v, // hdim_v + nhead / nhead_k, scale, - i_perm ? hdim_q : nhead * hdim_q, // stride_q - i_perm ? hdim_q : nhead * hdim_q, // stride_k + i_perm ? hdim_q : nhead * hdim_q, // stride_q + i_perm ? hdim_q : nhead_k * hdim_q, // stride_k [&]() { if constexpr(is_v_rowmajor) - return i_perm ? hdim_v : nhead * hdim_v; + return i_perm ? hdim_v : nhead_k * hdim_v; else - return i_perm ? seqlen_k : nhead * seqlen_k; + return i_perm ? seqlen_k : nhead_k * seqlen_k; }(), // stride_v o_perm ? hdim_v : nhead * hdim_v, // stride_o i_perm ? seqlen_q * hdim_q : hdim_q, // nhead_stride_q @@ -150,8 +155,8 @@ float invoker_fmha_kernel(const void* q_ptr, }(), // nhead_stride_v o_perm ? seqlen_q * hdim_v : hdim_v, // nhead_stride_o nhead * seqlen_q * hdim_q, // batch_stride_q - nhead * seqlen_k * hdim_q, // batch_stride_k - nhead * hdim_v * seqlen_k, // batch_stride_v + nhead_k * seqlen_k * hdim_q, // batch_stride_k + nhead_k * hdim_v * seqlen_k, // batch_stride_v nhead * seqlen_q * hdim_v); // batch_stride_o float ave_time = launch_kernel(stream_config, @@ -172,49 +177,75 @@ static inline int env_get_int(const char* var_name, int default_int) return r; } -int main(int argc, char* argv[]) +auto create_args(int argc, char* argv[]) { - int do_validation = 1; - ck::index_t batch = 2; - ck::index_t nhead = 8; - ck::index_t seqlen_q = 3328; - ck::index_t seqlen_k = 4096; - ck::index_t hdim_q = 128; - ck::index_t hdim_v = 128; - - float scale = .0f; + ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do cpu validation or not") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "0", + "num of head, for k/v, 0 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", "3328", "seqlen_q") + .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(seqlen)") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} - bool i_perm = true; // if true, will be batch * nhead * seqlen * hdim - bool o_perm = true; // if false, will be batch * seqlen * nhead * hdim +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; - if(argc >= 2) - do_validation = std::stoi(argv[1]); + int do_validation = arg_parser.get_int("v"); + ck::index_t batch = arg_parser.get_int("b"); + ck::index_t nhead = arg_parser.get_int("h"); + ck::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k == 0) + nhead_k = nhead; - if(argc >= 8) + if(nhead % nhead_k != 0) { - batch = std::stoi(argv[2]); - nhead = std::stoi(argv[3]); - seqlen_q = std::stoi(argv[4]); - seqlen_k = std::stoi(argv[5]); - hdim_q = std::stoi(argv[6]); - hdim_v = std::stoi(argv[7]); + std::cout << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return -1; } - if(argc >= 9) - scale = std::stof(argv[8]); - if(argc >= 10) - i_perm = static_cast(std::stoi(argv[9])); - if(argc >= 11) - o_perm = static_cast(std::stoi(argv[10])); + ck::index_t seqlen_q = arg_parser.get_int("s"); + ck::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k == 0) + seqlen_k = seqlen_q; + ck::index_t hdim_q = arg_parser.get_int("d"); + ck::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v == 0) + hdim_v = hdim_q; + + int i_perm = arg_parser.get_int("iperm"); // if true, will be batch * nhead * seqlen * hdim + int o_perm = arg_parser.get_int("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale = arg_parser.get_float("scale"); if(scale == .0f) scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + int init_method = arg_parser.get_int("init"); + int stream_warmup = env_get_int("CK_WARMUP", 5); int stream_repeat = env_get_int("CK_REPEAT", 20); - StreamConfig stream_config {nullptr, true, 0, stream_warmup, stream_repeat}; + StreamConfig stream_config{nullptr, true, 0, stream_warmup, stream_repeat}; - auto get_lengths = [&](bool permute, + auto get_lengths = [&](int permute, ck::index_t b /*batch*/, ck::index_t h /*nhead*/, ck::index_t s /*seqlen*/, @@ -230,20 +261,29 @@ int main(int argc, char* argv[]) // host verify Tensor q_host(get_lengths(i_perm, batch, nhead, seqlen_q, hdim_q)); - Tensor k_host(get_lengths(i_perm, batch, nhead, seqlen_k, hdim_q)); - Tensor v_host(is_v_rowmajor ? get_lengths(i_perm, batch, nhead, seqlen_k, hdim_v) - : get_lengths(i_perm, batch, nhead, hdim_v, seqlen_k)); + Tensor k_host(get_lengths(i_perm, batch, nhead_k, seqlen_k, hdim_q)); + Tensor v_host(is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_k, hdim_v) + : get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_k)); Tensor o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v)); -#if 0 - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); -#else - ck::utils::FillUniformDistribution{0.f, 1.f}(q_host); - ck::utils::FillUniformDistribution{0.f, 1.f}(k_host); - ck::utils::FillUniformDistribution{-.5f, .5f}(v_host); -#endif + if(init_method == 0) + { + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); + } + else if(init_method == 1) + { + ck::utils::FillUniformDistribution{0.f, 1.f}(q_host); + ck::utils::FillUniformDistribution{0.f, 1.f}(k_host); + ck::utils::FillUniformDistribution{-.5f, .5f}(v_host); + } + else if(init_method == 2) + { + ck::utils::FillTrigValue{}(q_host); + ck::utils::FillTrigValue{}(k_host); + ck::utils::FillTrigValue{}(v_host); + } DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize()); DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize()); @@ -254,10 +294,17 @@ int main(int argc, char* argv[]) k_buf.ToDevice(k_host.mData.data()); v_buf.ToDevice(v_host.mData.data()); - std::cout << "batch:" << batch << ", nhead:" << nhead << ", seqlen_q:" << seqlen_q - << ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v - << ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm - << ", v:" << std::string(FmhaKernelHDim64::VLayout::name) << std::flush << std::endl; + // clang-format off + auto layout_str = [&](int permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + // clang-format on + + std::cout << "b:" << batch << ", h:" << nhead << ", h_k:" << nhead_k << ", s:" << seqlen_q + << ", s_k:" << seqlen_k << ", d:" << hdim_q << ", d_v:" << hdim_v + << ", scale:" << scale << ", i:" << layout_str(i_perm) << ", o:" << layout_str(o_perm) + << ", v:" << std::string(FmhaKernelHDim64::VLayout::name)[0] << std::flush; float ave_time = 0; if(hdim_q == hdim_v && hdim_q == 64) @@ -267,6 +314,7 @@ int main(int argc, char* argv[]) o_buf.GetDeviceBuffer(), batch, nhead, + nhead_k, seqlen_q, seqlen_k, hdim_q, @@ -282,6 +330,7 @@ int main(int argc, char* argv[]) o_buf.GetDeviceBuffer(), batch, nhead, + nhead_k, seqlen_q, seqlen_k, hdim_q, @@ -299,6 +348,7 @@ int main(int argc, char* argv[]) std::size_t flop = std::size_t(2) * batch * nhead * seqlen_q * seqlen_k * hdim_q + std::size_t(2) * batch * nhead * seqlen_q * hdim_v * seqlen_k; + // TODO: MQA/GQA case nhead is smaller, do we need to change this formular? std::size_t num_btype = sizeof(QDataType) * batch * nhead * seqlen_q * hdim_q + sizeof(KDataType) * batch * nhead * seqlen_k * hdim_q + sizeof(VDataType) * batch * nhead * hdim_v * seqlen_k + @@ -308,13 +358,14 @@ int main(int argc, char* argv[]) float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; + std::cout << ", " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::flush << std::endl; if(do_validation) { Tensor q_host_ref({batch * nhead, seqlen_q, hdim_q}); - Tensor k_host_ref({batch * nhead, seqlen_k, hdim_q}); + Tensor k_host_ref( + {batch * nhead, seqlen_k, hdim_q}); // NOTE: expand nhead the same as q const auto v_lengths = std::array{batch * nhead, hdim_v, seqlen_k}; const auto v_strides = is_v_rowmajor ? std::array{hdim_v * seqlen_k, 1, hdim_v} @@ -326,24 +377,28 @@ int main(int argc, char* argv[]) Tensor s_host_ref({batch * nhead, seqlen_q, seqlen_k}); Tensor p_host_ref({batch * nhead, seqlen_q, seqlen_k}); + ck::index_t nr = nhead / nhead_k; + +#define EACH_R for(ck::index_t r = 0; r < nr; r++) // clang-format off // permute - if(i_perm) q_host.ForEach([&](auto& self, auto idx) { q_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); }); - else q_host.ForEach([&](auto& self, auto idx) { q_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); }); + if(i_perm) q_host.ForEach([&](auto& self, auto i) { q_host_ref(i[0] * nhead + i[1], i[2], i[3]) = self(i); }); + else q_host.ForEach([&](auto& self, auto i) { q_host_ref(i[0] * nhead + i[2], i[1], i[3]) = self(i); }); - if(i_perm) k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); }); - else k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); }); + if(i_perm) k_host.ForEach([&](auto& self, auto i) { EACH_R k_host_ref(i[0] * nhead + i[1] * nr + r, i[2], i[3]) = self(i); }); + else k_host.ForEach([&](auto& self, auto i) { EACH_R k_host_ref(i[0] * nhead + i[2] * nr + r, i[1], i[3]) = self(i); }); if constexpr (is_v_rowmajor) { // v_host :b, h, s, d, v_host_ref : batch*hdim*seq - if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[3], idx[2]) = self(idx); }); + if(i_perm) v_host.ForEach([&](auto& self, auto i) { EACH_R v_host_ref(i[0] * nhead + i[1] * nr + r, i[3], i[2]) = self(i); }); // v_host : b, s, h, d, v_host_ref : batch*hdim*seq - else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[3], idx[1]) = self(idx); }); + else v_host.ForEach([&](auto& self, auto i) { EACH_R v_host_ref(i[0] * nhead + i[2] * nr + r, i[3], i[1]) = self(i); }); } else { - if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); }); - else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); }); + if(i_perm) v_host.ForEach([&](auto& self, auto i) { EACH_R v_host_ref(i[0] * nhead + i[1] * nr + r, i[2], i[3]) = self(i); }); + else v_host.ForEach([&](auto& self, auto i) { EACH_R v_host_ref(i[0] * nhead + i[2] * nr + r, i[1], i[3]) = self(i); }); } +#undef EACH_R // reference reference_batched_gemm( @@ -357,8 +412,8 @@ int main(int argc, char* argv[]) p_host_ref, v_host_ref, o_host_ref); // permute - if(o_perm) o_host_result_ref.ForEach([&](auto& self, auto idx) { self(idx) = o_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]); }); - else o_host_result_ref.ForEach([&](auto& self, auto idx) { self(idx) = o_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]); }); + if(o_perm) o_host_result_ref.ForEach([&](auto& self, auto i) { self(i) = o_host_ref(i[0] * nhead + i[1], i[2], i[3]); }); + else o_host_result_ref.ForEach([&](auto& self, auto i) { self(i) = o_host_ref(i[0] * nhead + i[2], i[1], i[3]); }); // clang-format on o_buf.FromDevice(o_host.mData.data()); diff --git a/example/91_tile_program/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha_fwd_kernel.hpp index a8cd3fe61..01d8c5696 100644 --- a/example/91_tile_program/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha_fwd_kernel.hpp @@ -40,6 +40,9 @@ struct FmhaFwdKernel ck::index_t hdim_q; ck::index_t hdim_v; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck::index_t nhead_radio_qk; float scale; ck::index_t stride_q; @@ -66,6 +69,7 @@ struct FmhaFwdKernel ck::index_t seqlen_k, ck::index_t hdim_q, ck::index_t hdim_v, + ck::index_t nhead_radio_qk, float scale, ck::index_t stride_q, ck::index_t stride_k, @@ -82,7 +86,7 @@ struct FmhaFwdKernel { return Kargs { - q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, seqlen_k, hdim_q, hdim_v, + q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, seqlen_k, hdim_q, hdim_v, nhead_radio_qk, #if CK_FMHA_FWD_FAST_EXP2 static_cast(scale * C_LOG2E), #else @@ -129,9 +133,11 @@ struct FmhaFwdKernel const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + i_nhead * kargs.nhead_stride_q + i_batch * kargs.batch_stride_q; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - i_nhead * kargs.nhead_stride_k + i_batch * kargs.batch_stride_k; + (i_nhead / kargs.nhead_radio_qk) * kargs.nhead_stride_k + + i_batch * kargs.batch_stride_k; const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - i_nhead * kargs.nhead_stride_v + i_batch * kargs.batch_stride_v; + (i_nhead / kargs.nhead_radio_qk) * kargs.nhead_stride_v + + i_batch * kargs.batch_stride_v; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o; diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index 4e075df43..fef974810 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -133,5 +133,46 @@ struct FillConstant } }; +template +struct FillTrigValue +{ + template + struct LinearTrigGen + { + int i{0}; + auto operator()() + { + float v = 0; + if constexpr(UseCos_) + { + v = cos(i); + } + else + { + v = sin(i); + } + if constexpr(UseAbs_) + v = abs(v); + i++; + return static_cast(v); + } + }; + template + void operator()(ForwardIter first, ForwardIter last) const + { + LinearTrigGen gen; + std::generate(first, last, gen); + } + + template + auto operator()(ForwardRange&& range) const -> std::void_t< + decltype(std::declval()(std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + } // namespace utils } // namespace ck From f810f415c2d05f5773b772084424740bdd374d81 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 5 Dec 2023 21:44:23 -0600 Subject: [PATCH 06/45] rename some code --- .../block_fmha_pipeline_qr_ks_vs.hpp | 20 +- ...k_fmha_pipeline_qr_ks_vs_custom_policy.hpp | 694 ++++++++++++++++++ ..._fmha_pipeline_qr_ks_vs_default_policy.hpp | 677 +---------------- include/ck/tile_program/tile/load_tile.hpp | 10 +- .../tile_window_impl_static_distribution.hpp | 5 +- 5 files changed, 713 insertions(+), 693 deletions(-) create mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 4102d30ba..3da246e3a 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -98,7 +98,7 @@ struct BlockFmhaPipelineQRKSVS Policy::template MakeKLdsStoreBlockDescriptor(i_buf).GetLengths(), {0, 0, 0}); }, - Number{}); + Number{}); #if K_LDS_LOAD_USE_OFFSET_TRANSFORM auto k_lds_load = generate_tuple( @@ -109,7 +109,7 @@ struct BlockFmhaPipelineQRKSVS Policy::template MakeKLdsLoadBlockDescriptor(i_buf).GetLengths(), {0, 0}); }, - Number{}); + Number{}); #else auto k_lds_Load_view = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); @@ -137,8 +137,9 @@ struct BlockFmhaPipelineQRKSVS q_dram_block_window_tmp.GetWindowOrigin(), Policy::template MakeQDramTileDistribution()); - auto q = load_tile_raw(q_dram_window); // persistent q register tile - + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = load_tile_raw(q_dram_window); __builtin_amdgcn_sched_barrier(0); using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); @@ -152,9 +153,6 @@ struct BlockFmhaPipelineQRKSVS using SBlockTileType = decltype(tile_elementwise_in(type_convert, s_acc)); - // using PBlockTileType = - // decltype(tile_elementwise_in(type_convert, s_acc)); - using MLBlockTileType = decltype(block_tile_reduce( SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); @@ -185,7 +183,7 @@ struct BlockFmhaPipelineQRKSVS Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load // prefetch K tile - async_load_tile(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); @@ -202,8 +200,8 @@ struct BlockFmhaPipelineQRKSVS if constexpr(k0_loops > 1) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - async_load_tile(k_lds_store(Number{})>{}), - k_dram_window); + async_load_tile_raw(k_lds_store(Number{})>{}), + k_dram_window); if constexpr(i_k0 < k0_loops - 1) move_tile_window(k_dram_window, {0, kK0}); @@ -401,7 +399,7 @@ struct BlockFmhaPipelineQRKSVS if constexpr(k1_loops >= 2 && LdsSeq.At(Number<0>{}) == LdsSeq.At(Number{})) __builtin_amdgcn_s_barrier(); - async_load_tile(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); move_tile_window(k_dram_window, {0, kK0}); } // tail diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp new file mode 100644 index 000000000..c937b6e9e --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp @@ -0,0 +1,694 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp" +#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp" +#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +// TODO: remove this +#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0 + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQRKSVSCustomPolicy +{ + static constexpr index_t AsyncCopyK = AsyncCopyK_; + static constexpr index_t AsyncCopyV = AsyncCopyV_; // TODO: this not supported yet + + static constexpr index_t NumPrefetchK = NumPrefetchK_; + static constexpr index_t NumPrefetchV = NumPrefetchK_; + + template + struct LdsBufferSequence + { + static constexpr auto Make() + { + return transform_sequences( + [&](auto i) { + if(i < k_loops_) + return i % k_prefetches_; + return (i - k_loops_) % v_prefetches_; + }, + typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{}); + }; + + using type = remove_cvref_t; + }; + // clang-format off + template<> struct + LdsBufferSequence<3, 3, 4, 4> { using type = Sequence<1, 2, 0, 1, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 4, 2> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 4> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 3, 3> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; + + template<> struct + LdsBufferSequence<3, 3, 2, 2> { using type = Sequence<1, 2, 1, 0>;}; + // clang-format on + + template + __host__ __device__ static constexpr auto GetLdsBufferSequence() + { + using BlockFmhaShape = remove_cvref_t; + + constexpr index_t kN0 = BlockFmhaShape::kN0; + constexpr index_t kK0 = BlockFmhaShape::kK0; + constexpr index_t kK1 = BlockFmhaShape::kK1; + constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + return typename LdsBufferSequence::type{}; + } + + template + __host__ __device__ static constexpr auto GetSmemKPackK() + { + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + __host__ __device__ static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + template + __host__ __device__ static constexpr auto GetVectorloadV() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + __host__ __device__ static constexpr auto GetSingleSmemElementSpaceSize() + { + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = warpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (warpSize * KVector + kPad); + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return math::max(SingleKSize, SingleVSize); + } + + template + __host__ __device__ static constexpr auto GetVectorloadK() + { + using KDataType = remove_cvref_t; + return 4 / sizeof(KDataType); // TODO: this is for async copy + } + + template + __host__ __device__ static constexpr auto MakeQRegBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto q_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); + + return q_block_dstr; + } + +#if 0 + template + __host__ __device__ static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPack = GetSmemKPackV(); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number{}), + make_tuple(Number<(kNPerBlock + 1) * kKPack>{}, Number{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return k_lds_block_desc; + } +#endif + + template + __host__ __device__ static constexpr auto + MakeKLdsStoreBlockDescriptor(Number = Number<0>{}) + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = + KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + warpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(Number{}, // n0 + Number{}, // n1 + Number{}, // n2 + Number{}, // k0 + Number{}), // k1 + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{}), + Number()>{}, + Number{}, + Number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + template + __host__ __device__ static constexpr auto + MakeKLdsLoadBlockDescriptor(Number = Number<0>{}) + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; // for async-copy, this pad is between warps + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(Number{}, // n0 + Number{}, // n2 + Number{}, // n1 + Number{}, // k0 + Number{}), // k1 + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{}), + Number()>{}, + Number{}, + Number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0, 2, 1>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return k_lds_block_desc; + } +#else + template + __host__ __device__ static constexpr auto MakeKLdsLoadBlockDescriptor() + { + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; // for async-copy, this pad is between warps + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + // constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad); + // constexpr index_t SingleVSize = MakeVLdsBlockDescriptor().GetElementSpaceSize(); + constexpr index_t BufferSize = + GetSingleSmemElementSpaceSize(); // math::max(SingleKSize, SingleVSize); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(Number{}, // num_buffers + Number{}, // n0 + Number{}, // n2 + Number{}, // n1 + Number{}, // k0 + Number{}), // k1 + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{}), + Number{}, + Number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0, 1, 3, 2>{}, Sequence<4, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return k_lds_block_desc; + } +#endif + + // 3d + padding + template + __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() + { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}), + make_tuple(Number()>{}, + Number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + Number{}, + Number{}, + Number<1>{}), + Number{}, + Number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_merge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return v_lds_block_desc; + } + + template + __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() + { + return 0; + } + + template + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + // TODO: assume Q is in register + constexpr index_t single_smem_size = + GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + + return single_smem_size * NumPrefetchK; + } + + template + __host__ __device__ static constexpr auto MakeQDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template At<1>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + + constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K0 = kKPerBlock / (K1 * K2); + + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1>>, + Tuple, Sequence<1, 2>>, + Sequence<1, 2, 2>, + Sequence<0, 0, 2>>{}); + } + + template + __host__ __device__ static constexpr auto MakeKDramTileDistribution() + { +#if 0 // coalesce reading for each blocks + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t K1 = 16 / sizeof(KDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); +#else + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KVector = GetVectorloadK(); // this is for global load + + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<1, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); +#endif + } + + template + __device__ static constexpr auto MakeVDramTileDistribution() + { + using VDataType = remove_cvref_t; + using VLayout = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(ck::is_same_v) + { + constexpr index_t N1 = GetVectorloadV(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + StaticTileDistributionEncoding< + Sequence<1>, + Tuple, Sequence>, + Tuple, Sequence<2, 1, 2>>, + Tuple, Sequence<1, 0, 2>>, + Sequence<2, 1>, + Sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + StaticTileDistributionEncoding< + Sequence<1>, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<0, 2>>, + Sequence<2, 1>, + Sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = 16 / sizeof(VDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + } + + template + __host__ __device__ static constexpr auto MakeShuffledVRegBlockDescriptor() + { + // This descriptor only used when V layout is seqlen * hdim + using VLayout = remove_cvref_t; + static_assert(ck::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t N1 = GetVectorloadV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1, 2>>, + Tuple, Sequence<1, 0, 2>>, + Sequence<1, 2>, + Sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<0, 2>>, + Sequence<1, 2>, + Sequence<1, 3>>{}); + } + } + + template + __host__ __device__ static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< + warp::WarpGemmAttributeMfmaImplF16F16F32M32N32K8, + 2>>; + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + return BlockGemmARegBSmemCRegV2{}; + } + + template + __host__ __device__ static constexpr auto GetKVBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + return BlockGemmARegBSmemCRegV2{}; + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp index 6c125899c..87c0dcee0 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -3,687 +3,14 @@ #pragma once -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" - -#include "ck/tile_program/tile/tile_distribution.hpp" -#include "ck/tile_program/tile/tile_elementwise.hpp" -#include "ck/tile_program/tile/tile_gemm_shape.hpp" -#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp" -#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" -#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" -#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" -#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp" -#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" - -// TODO: remove this -#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0 +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp" namespace ck { namespace tile_program { namespace block { // This pipeline is qkv all located in LDS -struct BlockFmhaPipelineQRKSVSDefaultPolicy -{ - static constexpr index_t KLdsBuffers = 3; - static constexpr index_t VLdsBuffers = 3; - - template - struct LdsBufferSequence - { - static constexpr auto Make() - { - return transform_sequences( - [&](auto i) { - if(i < k_loops_) - return i % k_bufs_; - return (i - k_loops_) % v_bufs_; - }, - typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{}); - }; - - using type = remove_cvref_t; - }; - // clang-format off - template<> struct - LdsBufferSequence<3, 3, 4, 4> { using type = Sequence<1, 2, 0, 1, 0, 1, 2, 0>; }; - - template<> struct - LdsBufferSequence<3, 3, 4, 2> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; - - template<> struct - LdsBufferSequence<3, 3, 2, 4> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; - - template<> struct - LdsBufferSequence<3, 3, 3, 3> { using type = Sequence<1, 2, 0, 1, 2, 0>; }; - - template<> struct - LdsBufferSequence<3, 3, 2, 2> { using type = Sequence<1, 2, 1, 0>;}; - // clang-format on - - template - __host__ __device__ static constexpr auto GetLdsBufferSequence() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t kN0 = BlockFmhaShape::kN0; - constexpr index_t kK0 = BlockFmhaShape::kK0; - constexpr index_t kK1 = BlockFmhaShape::kK1; - constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; - - constexpr index_t k0_loops = kK0BlockLength / kK0; - constexpr index_t k1_loops = kN0 / kK1; - - return typename LdsBufferSequence::type{}; - } - - template - __host__ __device__ static constexpr auto GetSmemKPackK() - { - // TODO: this is for 3d layout - using KDataType = remove_cvref_t; - return 16 / sizeof(KDataType); - } - - template - __host__ __device__ static constexpr auto GetSmemKPackV() - { - // TODO: this is for 3d layout - using VDataType = remove_cvref_t; - return 16 / sizeof(VDataType); - } - template - __host__ __device__ static constexpr auto GetVectorloadV() - { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - - // TODO: not correct! - if constexpr(total_pixels > 4) - return 4; - else - return 2; - } - - template - __host__ __device__ static constexpr auto GetSingleSmemElementSpaceSize() - { - // this function assume K/V can share smem - constexpr index_t SingleKSize = [&]() { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; - constexpr index_t warpSize = ck::get_warp_size(); - - constexpr index_t KPack = GetSmemKPackV(); // this is for lds - constexpr index_t KVector = GetVectorloadK(); // this is for global load - constexpr index_t kPad = KPack; - - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); - constexpr index_t LanesPerK = kKPerBlock / KVector; - constexpr index_t LaneGroups = warpSize / LanesPerK; - constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - - return NumIssues * NumWarps * (warpSize * KVector + kPad); - }(); - - constexpr index_t SingleVSize = [&]() { - using VDataType = remove_cvref_t; - constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(PixelsPerRow % kKPack == 0); - constexpr index_t NPerRow = PixelsPerRow / kKPack; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - static_assert(kNPerBlock % NPerRow == 0); - static_assert(kKPerBlock % kKPack == 0); - - return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); - }(); - - return math::max(SingleKSize, SingleVSize); - } - - template - __host__ __device__ static constexpr auto GetVectorloadK() - { - using KDataType = remove_cvref_t; - return 4 / sizeof(KDataType); // TODO: this is for async copy - } - - template - __host__ __device__ static constexpr auto MakeQRegBlockDescriptor() - { - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; - - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template At<1>(); - constexpr index_t NWarp = config.template At<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr auto q_block_outer_dstr_encoding = StaticTileDistributionEncoding< - Sequence, - Tuple, Sequence>, - Tuple>, - Tuple>, - Sequence<1, 2>, - Sequence<0, 0>>{}; - - constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - - constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); - - return q_block_dstr; - } - -#if 0 - template - __host__ __device__ static constexpr auto MakeKLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kKPack = GetSmemKPackV(); - - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number{}), - make_tuple(Number<(kNPerBlock + 1) * kKPack>{}, Number{}, Number<1>{}), - Number<8>{}, - Number<1>{}); - - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return k_lds_block_desc; - } -#endif - - template - __host__ __device__ static constexpr auto - MakeKLdsStoreBlockDescriptor(Number = Number<0>{}) - { - // K is always k-major, we use async-copy to load into LDS - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; - constexpr index_t warpSize = ck::get_warp_size(); - - constexpr index_t KPack = GetSmemKPackV(); // this is for lds - constexpr index_t KVector = GetVectorloadK(); // this is for global load - constexpr index_t kPad = - KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed - - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); - constexpr index_t LanesPerK = - kKPerBlock / KVector; // how many lane (within a wave) to load K - constexpr index_t LaneGroups = - warpSize / - LanesPerK; // how many groups (within a wave), they may load different N, but same K - constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( - make_tuple(Number{}, // n0 - Number{}, // n1 - Number{}, // n2 - Number{}, // k0 - Number{}), // k1 - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number<1>{}), - Number()>{}, - Number{}, - Number<1>{}); - - // TODO this layout is hard coded, and will be used in async copy buffer view load - // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) - constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple(make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_merge_transform(make_tuple( - Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1, 3, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return k_lds_block_desc_issues_warps_lanes; - } - -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - template - __host__ __device__ static constexpr auto - MakeKLdsLoadBlockDescriptor(Number = Number<0>{}) - { - // K is always k-major, we use async-copy to load into LDS - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; - constexpr index_t warpSize = ck::get_warp_size(); - - constexpr index_t KPack = GetSmemKPackV(); // this is for lds - constexpr index_t KVector = GetVectorloadK(); // this is for global load - constexpr index_t kPad = KPack; // for async-copy, this pad is between warps - - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); - constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave - constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave - constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( - make_tuple(Number{}, // n0 - Number{}, // n2 - Number{}, // n1 - Number{}, // k0 - Number{}), // k1 - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number<1>{}), - Number()>{}, - Number{}, - Number<1>{}); - - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple( - make_merge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<0, 2, 1>{}, Sequence<3, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return k_lds_block_desc; - } -#else - template - __host__ __device__ static constexpr auto MakeKLdsLoadBlockDescriptor() - { - // K is always k-major, we use async-copy to load into LDS - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; - constexpr index_t warpSize = ck::get_warp_size(); - - constexpr index_t KPack = GetSmemKPackV(); // this is for lds - constexpr index_t KVector = GetVectorloadK(); // this is for global load - constexpr index_t kPad = KPack; // for async-copy, this pad is between warps - - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); - constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave - constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave - constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - // constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad); - // constexpr index_t SingleVSize = MakeVLdsBlockDescriptor().GetElementSpaceSize(); - constexpr index_t BufferSize = - GetSingleSmemElementSpaceSize(); // math::max(SingleKSize, SingleVSize); - - constexpr auto k_lds_block_desc_0 = - make_naive_tensor_descriptor(make_tuple(Number{}, // num_buffers - Number{}, // n0 - Number{}, // n2 - Number{}, // n1 - Number{}, // k0 - Number{}), // k1 - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number<1>{}), - Number{}, - Number<1>{}); - - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<0, 1, 3, 2>{}, Sequence<4, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return k_lds_block_desc; - } -#endif - - // 3d + padding - template - __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() - { - using VDataType = remove_cvref_t; - constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(PixelsPerRow % kKPack == 0); - constexpr index_t NPerRow = PixelsPerRow / kKPack; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - static_assert(kNPerBlock % NPerRow == 0); - static_assert(kKPerBlock % kKPack == 0); - - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}), - make_tuple(Number()>{}, - Number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, - Number{}, - Number{}, - Number<1>{}), - Number{}, - Number<1>{}); - - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple( - Number{}, Number{}, Number{})), - make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return v_lds_block_desc; - } - - template - __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() - { - return 0; - } - - template - __host__ __device__ static constexpr ck::index_t GetSmemSize() - { - // TODO: assume Q is in register - constexpr index_t single_smem_size = - GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); - - return single_smem_size * KLdsBuffers; - } - - template - __host__ __device__ static constexpr auto MakeQDramTileDistribution() - { - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template At<1>(); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; - - constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<2, 1>>, - Tuple, Sequence<1, 2>>, - Sequence<1, 2, 2>, - Sequence<0, 0, 2>>{}); - } - - template - __host__ __device__ static constexpr auto MakeKDramTileDistribution() - { -#if 0 // coalesce reading for each blocks - using KDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t K1 = 16 / sizeof(KDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); -#else - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; - constexpr index_t warpSize = ck::get_warp_size(); - - constexpr index_t KVector = GetVectorloadK(); // this is for global load - - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); - constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave - constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave - constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - - constexpr index_t N0 = NumIssues; - constexpr index_t N1 = LaneGroups; - constexpr index_t N2 = NumWarps; - constexpr index_t K0 = LanesPerK; - constexpr index_t K1 = KVector; - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<1, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); -#endif - } - - template - __device__ static constexpr auto MakeVDramTileDistribution() - { - using VDataType = remove_cvref_t; - using VLayout = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - if constexpr(ck::is_same_v) - { - constexpr index_t N1 = GetVectorloadV(); - constexpr index_t N0 = kNPerBlock / N1; // P - - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - if constexpr(get_warp_size() % (K2 * N0) == 0) - { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); - return make_static_tile_distribution( - StaticTileDistributionEncoding< - Sequence<1>, - Tuple, Sequence>, - Tuple, Sequence<2, 1, 2>>, - Tuple, Sequence<1, 0, 2>>, - Sequence<2, 1>, - Sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = kBlockSize / get_warp_size() / K1; - static_assert(kKPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - StaticTileDistributionEncoding< - Sequence<1>, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<0, 2>>, - Sequence<2, 1>, - Sequence<3, 1>>{}); - } - } - else - { - constexpr index_t K1 = 16 / sizeof(VDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); - } - } - - template - __host__ __device__ static constexpr auto MakeShuffledVRegBlockDescriptor() - { - // This descriptor only used when V layout is seqlen * hdim - using VLayout = remove_cvref_t; - static_assert(ck::is_same_v); - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr index_t N1 = GetVectorloadV(); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - if constexpr(get_warp_size() % (K2 * N0) == 0) - { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<2, 1, 2>>, - Tuple, Sequence<1, 0, 2>>, - Sequence<1, 2>, - Sequence<1, 3>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = kBlockSize / get_warp_size() / K1; - static_assert(kKPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<0, 2>>, - Sequence<1, 2>, - Sequence<1, 3>>{}); - } - } - - template - __host__ __device__ static constexpr auto GetQKBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; - - using WarpGemm = warp::WarpGemmImpl< - warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< - warp::WarpGemmAttributeMfmaImplF16F16F32M32N32K8, - 2>>; - - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV2CustomPolicy; - - return BlockGemmARegBSmemCRegV2{}; - } - - template - __host__ __device__ static constexpr auto GetKVBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; - - using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<0>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<2>{}), - true>; - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV2CustomPolicy; - return BlockGemmARegBSmemCRegV2{}; - } -}; +using BlockFmhaPipelineQRKSVSDefaultPolicy = BlockFmhaPipelineQRKSVSCustomPolicy; } // namespace block } // namespace tile_program diff --git a/include/ck/tile_program/tile/load_tile.hpp b/include/ck/tile_program/tile/load_tile.hpp index 431fbd18c..855800ae1 100644 --- a/include/ck/tile_program/tile/load_tile.hpp +++ b/include/ck/tile_program/tile/load_tile.hpp @@ -46,11 +46,11 @@ template -__device__ auto async_load_tile(LdsTileWindow_&& lds_tile, - const TileWindowWithStaticDistribution& tile_window) +__device__ auto async_load_tile_raw(LdsTileWindow_&& lds_tile, + const TileWindowWithStaticDistribution& tile_window) { return tile_window.AsyncLoad(lds_tile); } diff --git a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp index 523877dc0..5e67dd4ff 100644 --- a/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp +++ b/include/ck/tile_program/tile/tile_window_impl_static_distribution.hpp @@ -327,8 +327,9 @@ struct TileWindowWithStaticDistribution return dst_tensor; } - template - __device__ auto AsyncLoad(LdsTileWindow_&& lds_tile) const + // TODO: currently async load only implemented in inline asm + template + __device__ auto AsyncLoad(LdsTileWindow_&& lds_tile, bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; // using LdsTensorView = typename LdsTileWindow::BottomTensorView; From 149a242c0e60ef055fc088f82d970158e50218dd Mon Sep 17 00:00:00 2001 From: carlushuang Date: Wed, 6 Dec 2023 02:30:19 -0600 Subject: [PATCH 07/45] seperate async to different pipeline --- example/91_tile_program/fmha_fwd.cpp | 12 +- .../block_fmha_pipeline_qr_ks_vs.hpp | 255 ++++------ .../block_fmha_pipeline_qr_ks_vs_async.hpp | 457 ++++++++++++++++++ ...pipeline_qr_ks_vs_async_default_policy.hpp | 18 + ...k_fmha_pipeline_qr_ks_vs_custom_policy.hpp | 131 ++--- ..._fmha_pipeline_qr_ks_vs_default_policy.hpp | 3 +- .../ck/tile_program/tile/tile_elementwise.hpp | 23 + 7 files changed, 663 insertions(+), 236 deletions(-) create mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp create mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index 159ea27e4..f9c01911c 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -16,6 +16,7 @@ #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" @@ -87,11 +88,16 @@ using FmhaPipelineProblemHDim128 = ODataType, 256, // BlockSize FmhaShapeHDim128>; -// using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS; + using FmhaPipelineHDim64 = - ck::tile_program::block::BlockFmhaPipelineQRKSVS; + ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync; using FmhaPipelineHDim128 = - ck::tile_program::block::BlockFmhaPipelineQRKSVS; + ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync; + +// using FmhaPipelineHDim64 = +// ck::tile_program::block::BlockFmhaPipelineQRKSVS; +// using FmhaPipelineHDim128 = +// ck::tile_program::block::BlockFmhaPipelineQRKSVS; using FmhaEpilogue = FmhaFwdEpilogue>; using FmhaKernelHDim64 = FmhaFwdKernel; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 3da246e3a..e468f79b2 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -38,7 +38,7 @@ struct BlockFmhaPipelineQRKSVS using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; - static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static constexpr index_t kBlockPerCu = BlockFmhaShape::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -65,7 +65,7 @@ struct BlockFmhaPipelineQRKSVS operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const KElementFunction& /*k_element_func*/, + const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, float scale, @@ -86,39 +86,13 @@ struct BlockFmhaPipelineQRKSVS kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], "wrong!"); - constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); - // K tile in LDS - auto k_lds_ptr = reinterpret_cast(smem_ptr); - auto k_lds_store = generate_tuple( - [&](auto i_buf) { - return make_tile_window( - make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), - Policy::template MakeKLdsStoreBlockDescriptor(i_buf).GetLengths(), - {0, 0, 0}); - }, - Number{}); - -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - auto k_lds_load = generate_tuple( - [&](auto i_buf) { - return make_tile_window( - make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), - Policy::template MakeKLdsLoadBlockDescriptor(i_buf).GetLengths(), - {0, 0}); - }, - Number{}); -#else - auto k_lds_Load_view = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); - - auto k_lds_load = - make_tile_window(k_lds_Load_view, - Policy::template MakeKLdsLoadBlockDescriptor().GetLengths(), - {0, 0}); -#endif + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(Number{}, Number{}), {0, 0}); // V tile in LDS auto v_lds = make_tensor_view( @@ -137,10 +111,7 @@ struct BlockFmhaPipelineQRKSVS q_dram_block_window_tmp.GetWindowOrigin(), Policy::template MakeQDramTileDistribution()); - // TODO: we use async Copy for K, which is inline asm - // a side effect is we have to use inline asm for q as well - auto q = load_tile_raw(q_dram_window); - __builtin_amdgcn_sched_barrier(0); + auto q = load_tile(q_dram_window); using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); auto s_acc = SaccBlockTileType{}; @@ -150,8 +121,7 @@ struct BlockFmhaPipelineQRKSVS const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = - decltype(tile_elementwise_in(type_convert, s_acc)); + using SBlockTileType = decltype(cast_tile(s_acc)); using MLBlockTileType = decltype(block_tile_reduce( SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); @@ -163,10 +133,9 @@ struct BlockFmhaPipelineQRKSVS auto m = MLBlockTileType{}; auto l = MLBlockTileType{}; - tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); - tile_elementwise_inout([](auto& e) { e = NumericLimits::Lowest(); }, - m); - tile_elementwise_inout([](auto& e) { e = 0; }, l); + clear_tile(o_acc); + set_tile(m, NumericLimits::Lowest()); + clear_tile(l); auto k_dram_block_window = k_dram_block_window_tmp; auto v_dram_window = @@ -175,88 +144,74 @@ struct BlockFmhaPipelineQRKSVS v_dram_block_window_tmp.GetWindowOrigin(), Policy::template MakeVDramTileDistribution()); - __builtin_amdgcn_sched_barrier(0); - auto k_dram_window = make_tile_window( - k_dram_block_window.GetBottomTensorView(), - k_dram_block_window.GetWindowLengths(), - k_dram_block_window.GetWindowOrigin(), - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load - // prefetch K tile - async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - __builtin_amdgcn_sched_barrier(0); - - buffer_load_fence(k_dram_window.GetNumAccess()); auto q_tile = tile_elementwise_in(q_element_func, q); - __builtin_amdgcn_sched_barrier(0); + + // prefetch K tile index_t i_total_loops = 0; constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k1_loops = kN0 / kK1; do { // STAGE 1, QK gemm - tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); // Initialize C - if constexpr(k0_loops > 1) + auto k_dram_window = make_tile_window( + k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - async_load_tile_raw(k_lds_store(Number{})>{}), - k_dram_window); - if constexpr(i_k0 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); - - async_load_fence(k_dram_window.GetNumAccess()); - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // Initialize C + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); gemm_0(s_acc, get_slice_tile(q_tile, Sequence<0, i_k0 * kK0>{}, Sequence{}), -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - k_lds_load[Number{})>{}]); + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); -#else - get_slice_tile(k_lds_load, - Sequence<(LdsSeq.At(Number{})) * kN0, 0>{}, - Sequence<(LdsSeq.At(Number{}) + 1) * kN0, kK0>{})); -#endif + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 }); } - // TODO: this to fix a bug when loop smaller than 2, - // the following fence/barrier will be scheduled inside 1st loop - if constexpr(k0_loops <= 2) - __builtin_amdgcn_sched_barrier(0); + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, (k0_loops - 2) * kK0>{}, + Sequence{}), + k_lds_window); + block_sync_lds(); - async_load_fence(); - __builtin_amdgcn_s_barrier(); + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); - auto v_buf = load_tile(v_dram_window); - __builtin_amdgcn_sched_barrier(0); - { // tail gemm_0(s_acc, get_slice_tile(q_tile, Sequence<0, (k0_loops - 1) * kK0>{}, Sequence{}), -#if K_LDS_LOAD_USE_OFFSET_TRANSFORM - k_lds_load[Number{})>{}]); - -#else - get_slice_tile( - k_lds_load, - Sequence<(LdsSeq.At(Number{})) * kN0, 0>{}, - Sequence<(LdsSeq.At(Number{}) + 1) * kN0, kK0>{})); -#endif + k_lds_window); } - __builtin_amdgcn_sched_barrier(1); // STAGE 2, scale softmax #if !CK_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); #endif - const auto s = - tile_elementwise_in(type_convert, s_acc); // S{j} + const auto s = cast_tile(s_acc); // S{j} auto m_local = block_tile_reduce( s, Sequence<1>{}, @@ -271,38 +226,6 @@ struct BlockFmhaPipelineQRKSVS auto p_compute = make_static_distributed_tensor( s.GetTileDistribution()); // Pcompute{j} - __builtin_amdgcn_sched_barrier(0x7F); - // store & prefetch next v, after the max reduction - if constexpr(ck::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_distributed_tensor(v_shuffle_tmp, v_buf); - - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, - Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); - - store_tile( - v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v_buf)); // store the prefetch - } - - if constexpr(k1_loops > 1) - { - move_tile_window( - v_dram_window, - {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile(v_dram_window); // load next v_buf - } - __builtin_amdgcn_sched_barrier(0); - constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); @@ -343,75 +266,63 @@ struct BlockFmhaPipelineQRKSVS }); }); - const auto p = - tile_elementwise_in(type_convert, p_compute); + block_sync_lds(); + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = cast_tile(p_compute); // STAGE 3, KV gemm if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) - { - v_buf = load_tile(v_dram_window); // load next v_buf - } + const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); gemm_1(o_acc, get_slice_tile( p, Sequence<0, i_k1 * kK1>{}, Sequence{}), - get_slice_tile( - v_lds_window, - Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, - Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{})); - + v_lds_window); + block_sync_lds(); if constexpr(ck::is_same_v) { auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_distributed_tensor(v_shuffle_tmp, v_buf); - auto v_lds_window_tmp = get_slice_tile( - v_lds_window, - Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, - Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, + shuffle_distributed_tensor(v_shuffle_tmp, v); + store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch } else { store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + tile_elementwise_in(v_element_func, v)); // store next v } - if constexpr(i_k1 < k1_loops - 1) - move_tile_window(v_dram_window, {0, kK1}); + move_tile_window(v_dram_window, {0, kK1}); }); } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); i_total_loops++; - if(i_total_loops < num_total_loop) - { - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window = - make_tile_window(k_dram_block_window.GetBottomTensorView(), - k_dram_block_window.GetWindowLengths(), - k_dram_block_window.GetWindowOrigin(), - Policy::template MakeKDramTileDistribution()); - - if constexpr(k1_loops >= 2 && - LdsSeq.At(Number<0>{}) == LdsSeq.At(Number{})) - __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - } // tail { block_sync_lds(); - gemm_1( - o_acc, - get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), - get_slice_tile( - v_lds_window, - Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, - Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{})); + gemm_1(o_acc, + get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); } } while(i_total_loops < num_total_loop); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp new file mode 100644 index 000000000..af9aab48a --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -0,0 +1,457 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/load_tile.hpp" +#include "ck/tile_program/tile/store_tile.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/tile/slice_tile.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck/tile_program/block_tile/block_reduce.hpp" +#include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaPipelineQRKSVSAsync +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once + + static constexpr index_t kBlockPerCu = BlockFmhaShape::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + float scale, + index_t num_total_loop, + index_t /*num_sub_loop_qk*/, // in this pipeline, the 1st gemm loop must be static + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).GetLengths(), + {0, 0, 0}); + }, + Number{}); + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + auto k_lds_load = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), + Policy::template MakeKLdsLoadBlockDescriptor(i_buf).GetLengths(), + {0, 0}); + }, + Number{}); +#else + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().GetLengths(), + {0, 0}); +#endif + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().GetLengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.GetBottomTensorView(), + q_dram_block_window_tmp.GetWindowLengths(), + q_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeQDramTileDistribution()); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = load_tile_raw(q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, NumericLimits::Lowest()); + clear_tile(l); + + auto k_dram_block_window = k_dram_block_window_tmp; + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), + v_dram_block_window_tmp.GetWindowLengths(), + v_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeVDramTileDistribution()); + + __builtin_amdgcn_sched_barrier(0); + auto k_dram_window = make_tile_window( + k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + // prefetch K tile + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.GetNumAccess()); + auto q_tile = tile_elementwise_in(q_element_func, q); + __builtin_amdgcn_sched_barrier(0); + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // Initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(Number{})>{}), + k_dram_window); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.GetNumAccess()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, i_k0 * kK0>{}, + Sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[Number{})>{}]); + +#else + get_slice_tile(k_lds_load, + Sequence<(LdsSeq.At(Number{})) * kN0, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN0, kK0>{})); +#endif + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + auto v_buf = load_tile(v_dram_window); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, (k0_loops - 1) * kK0>{}, + Sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[Number{})>{}]); + +#else + get_slice_tile( + k_lds_load, + Sequence<(LdsSeq.At(Number{})) * kN0, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN0, kK0>{})); +#endif + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale softmax +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + Sequence<1>{}, + f_max, + NumericLimits::Lowest()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.GetTileDistribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile(v_dram_window); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); + sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * m[i_idx]; +#endif + sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * m[i_idx]; + const auto tmp = math::exp2(scale * m_old[i_idx] - row_max); +#else + const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile(v_dram_window); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, Sequence<0, i_k1 * kK1>{}, Sequence{}), + get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{})); + + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window = + make_tile_window(k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); + + if constexpr(k1_loops >= 2 && + LdsSeq.At(Number<0>{}) == LdsSeq.At(Number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), + get_slice_tile( + v_lds_window, + Sequence<(LdsSeq.At(Number{})) * kN1, 0>{}, + Sequence<(LdsSeq.At(Number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // finally, O + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = 1 / l[i_idx]; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + float scale, + index_t num_total_loop, + index_t num_sub_loop_qk, + void* smem_ptr) const + { + return operator()( + q_dram_block_window_tmp, + [](const QDataType& x) { return x; }, + k_dram_block_window_tmp, + [](const KDataType& x) { return x; }, + v_dram_block_window_tmp, + [](const VDataType& x) { return x; }, + scale, + num_total_loop, + num_sub_loop_qk, + smem_ptr); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp new file mode 100644 index 000000000..a22f16253 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +using BlockFmhaPipelineQRKSVSAsyncDefaultPolicy = + BlockFmhaPipelineQRKSVSCustomPolicy; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp index c937b6e9e..711280588 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp @@ -121,21 +121,29 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy { // this function assume K/V can share smem constexpr index_t SingleKSize = [&]() { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; - constexpr index_t warpSize = ck::get_warp_size(); - - constexpr index_t KPack = GetSmemKPackV(); // this is for lds - constexpr index_t KVector = GetVectorloadK(); // this is for global load - constexpr index_t kPad = KPack; - - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); - constexpr index_t LanesPerK = kKPerBlock / KVector; - constexpr index_t LaneGroups = warpSize / LanesPerK; - constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - - return NumIssues * NumWarps * (warpSize * KVector + kPad); + if constexpr(!AsyncCopyK) + { + return MakeKLdsBlockDescriptor().GetElementSpaceSize(); + } + else + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackV(); // this is for lds + constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(warpSize * KVector >= kKPerBlock && + warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = warpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (warpSize * KVector + kPad); + } }(); constexpr index_t SingleVSize = [&]() { @@ -195,7 +203,7 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy return q_block_dstr; } -#if 0 + // TODO: this is used for non async copy desc. unify in the future template __host__ __device__ static constexpr auto MakeKLdsBlockDescriptor() { @@ -218,7 +226,6 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy return k_lds_block_desc; } -#endif template __host__ __device__ static constexpr auto @@ -429,10 +436,11 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy __host__ __device__ static constexpr ck::index_t GetSmemSize() { // TODO: assume Q is in register + // TODO: assume K/V has same data type constexpr index_t single_smem_size = GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); - return single_smem_size * NumPrefetchK; + return single_smem_size * math::max(NumPrefetchK, NumPrefetchV); } template @@ -465,55 +473,58 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy template __host__ __device__ static constexpr auto MakeKDramTileDistribution() { -#if 0 // coalesce reading for each blocks - using KDataType = remove_cvref_t; + if constexpr(!AsyncCopyK) + { + using KDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - constexpr index_t K1 = 16 / sizeof(KDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); + constexpr index_t K1 = 16 / sizeof(KDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); -#else - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; - constexpr index_t warpSize = ck::get_warp_size(); + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + else + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t warpSize = ck::get_warp_size(); - constexpr index_t KVector = GetVectorloadK(); // this is for global load + constexpr index_t KVector = GetVectorloadK(); // this is for global load - static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); - constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave - constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave - constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); - static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); - constexpr index_t N0 = NumIssues; - constexpr index_t N1 = LaneGroups; - constexpr index_t N2 = NumWarps; - constexpr index_t K0 = LanesPerK; - constexpr index_t K1 = KVector; + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<1, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); -#endif + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<1, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } } template diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp index 87c0dcee0..df11a8e87 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -10,7 +10,8 @@ namespace tile_program { namespace block { // This pipeline is qkv all located in LDS -using BlockFmhaPipelineQRKSVSDefaultPolicy = BlockFmhaPipelineQRKSVSCustomPolicy; +using BlockFmhaPipelineQRKSVSDefaultPolicy = + BlockFmhaPipelineQRKSVSCustomPolicy; } // namespace block } // namespace tile_program diff --git a/include/ck/tile_program/tile/tile_elementwise.hpp b/include/ck/tile_program/tile/tile_elementwise.hpp index dedfc5961..a5d27ba7c 100644 --- a/include/ck/tile_program/tile/tile_elementwise.hpp +++ b/include/ck/tile_program/tile/tile_elementwise.hpp @@ -52,5 +52,28 @@ __device__ auto tile_elementwise_in(const InElementFunc& in_element_func, return out_dstr_tensor; } +template +__device__ void set_tile(DstrTensors& dstr_tensor, const T& value) +{ + tile_elementwise_inout( + [&value](auto& x) { + x = type_convert>(value); + }, + dstr_tensor); +} + +template +__device__ void clear_tile(DstrTensors& dstr_tensor) +{ + set_tile(dstr_tensor, 0); +} + +template +__device__ auto cast_tile(const SrcDstrTensors& src_tensor) +{ + return tile_elementwise_in(type_convert, + src_tensor); +} + } // namespace tile_program } // namespace ck From 72b5a47e4229bd9ea3be7a7afbf8908f3f504099 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 7 Dec 2023 00:04:53 -0600 Subject: [PATCH 08/45] rename radio->ratio --- example/91_tile_program/fmha_fwd_kernel.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/example/91_tile_program/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha_fwd_kernel.hpp index 01d8c5696..87faeeae4 100644 --- a/example/91_tile_program/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha_fwd_kernel.hpp @@ -42,7 +42,7 @@ struct FmhaFwdKernel // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case - ck::index_t nhead_radio_qk; + ck::index_t nhead_ratio_qk; float scale; ck::index_t stride_q; @@ -69,7 +69,7 @@ struct FmhaFwdKernel ck::index_t seqlen_k, ck::index_t hdim_q, ck::index_t hdim_v, - ck::index_t nhead_radio_qk, + ck::index_t nhead_ratio_qk, float scale, ck::index_t stride_q, ck::index_t stride_k, @@ -86,7 +86,7 @@ struct FmhaFwdKernel { return Kargs { - q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, seqlen_k, hdim_q, hdim_v, nhead_radio_qk, + q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, seqlen_k, hdim_q, hdim_v, nhead_ratio_qk, #if CK_FMHA_FWD_FAST_EXP2 static_cast(scale * C_LOG2E), #else @@ -133,10 +133,10 @@ struct FmhaFwdKernel const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + i_nhead * kargs.nhead_stride_q + i_batch * kargs.batch_stride_q; const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - (i_nhead / kargs.nhead_radio_qk) * kargs.nhead_stride_k + + (i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + i_batch * kargs.batch_stride_k; const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - (i_nhead / kargs.nhead_radio_qk) * kargs.nhead_stride_v + + (i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + i_batch * kargs.batch_stride_v; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o; From 8f3a9adef1dae4bee5e37a98e99a2cf356b53ccd Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 9 Dec 2023 20:56:20 -0600 Subject: [PATCH 09/45] merge feature/fmha-pad-support aece827 --- example/91_tile_program/fmha_fwd.cpp | 623 ++++++++++++------ example/91_tile_program/fmha_fwd_kernel.hpp | 563 +++++++++++++--- example/91_tile_program/fmha_utils.hpp | 70 ++ .../reference_batched_elementwise.hpp | 60 ++ .../reference_batched_gemm.hpp | 26 +- .../reference_batched_masking.hpp | 47 ++ example/91_tile_program/reference_gemm.hpp | 19 +- include/ck/host_utility/hip_check_error.hpp | 2 + include/ck/host_utility/io.hpp | 26 +- include/ck/host_utility/kernel_launch.hpp | 28 +- include/ck/tensor/tensor_view.hpp | 48 ++ .../block_masking_specialization.hpp | 104 +++ .../block_fmha_pipeline_problem.hpp | 16 +- .../block_fmha_pipeline_qr_ks_vs.hpp | 138 +++- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 138 +++- ...k_fmha_pipeline_qr_ks_vs_custom_policy.hpp | 32 + include/ck/tile_program/tile/load_tile.hpp | 7 + include/ck/tile_program/tile/null_tensor.hpp | 14 + .../ck/tile_program/tile/null_tile_window.hpp | 83 +++ .../tile/static_distributed_tensor.hpp | 29 + .../ck/tile_program/tile/tile_elementwise.hpp | 30 +- .../ck/tile_program/tile/tile_fmha_shape.hpp | 4 +- .../ck/tile_program/tile/tile_fmha_traits.hpp | 24 + include/ck/tile_program/tile/tile_window.hpp | 1 + include/ck/utility/common_header.hpp | 1 + include/ck/utility/functional.hpp | 11 + include/ck/utility/static_switch.hpp | 63 ++ library/src/utility/device_memory.cpp | 27 +- 28 files changed, 1824 insertions(+), 410 deletions(-) create mode 100644 example/91_tile_program/fmha_utils.hpp create mode 100644 example/91_tile_program/reference_batched_elementwise.hpp create mode 100644 example/91_tile_program/reference_batched_masking.hpp create mode 100644 include/ck/tile_program/block_tile/block_masking_specialization.hpp create mode 100644 include/ck/tile_program/tile/null_tensor.hpp create mode 100644 include/ck/tile_program/tile/null_tile_window.hpp create mode 100644 include/ck/tile_program/tile/tile_fmha_traits.hpp create mode 100644 include/ck/utility/static_switch.hpp diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index f9c01911c..1678a9905 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -1,5 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include #include +#include #include +#include #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -13,6 +19,7 @@ #include "ck/library/utility/fill.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp" @@ -20,19 +27,24 @@ #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" +#include "ck/tile_program/block_tile/block_masking_specialization.hpp" #include "ck/tile_program/tile/tile_fmha_shape.hpp" +#include "ck/tile_program/tile/tile_fmha_traits.hpp" +#include "reference_batched_elementwise.hpp" #include "reference_batched_gemm.hpp" +#include "reference_batched_masking.hpp" #include "reference_batched_softmax.hpp" #include "fmha_fwd_kernel.hpp" #include "fmha_fwd_tile_partitioner.hpp" #include "fmha_fwd_epilogue.hpp" +#include "fmha_utils.hpp" #include "arg_parser.hpp" -#include using QDataType = ck::half_t; using KDataType = ck::half_t; using VDataType = ck::half_t; +using BiasDataType = ck::half_t; using SaccDataType = float; // data type for first gemm accumulation using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck::half_t; // data type for A matrix of second gemm @@ -45,70 +57,95 @@ using ODataType = ck::half_t; using VLayout = ck::tensor_layout::gemm::RowMajor; // (bs, nhead) seqlen * hdim // using VLayout = ck::tensor_layout::gemm::ColumnMajor; // (bs, nhead) hdim * seqlen -using FmhaBlockTileHdim64 = ck::Sequence<128, 64, 32, 64, 32, 64>; -using FmhaBlockTileHdim128 = ck::Sequence<128, 128, 32, 128, 32, 128>; -using FmhaBlockWarps = ck::Sequence<4, 1, 1>; -using FmhaWarpTile = ck::Sequence<32, 32, 16>; -using FmhaShapeHDim64 = ck::tile_program::TileFmhaShape; -using FmhaShapeHDim128 = ck::tile_program::TileFmhaShape; - -using FmhaTilePartitionerHDim64 = FmhaFwdTilePartitioner; -using FmhaTilePartitionerHDim128 = FmhaFwdTilePartitioner; -using FmhaPipelineProblemHDim64 = - ck::tile_program::block::BlockFmhaPipelineProblem; -using FmhaPipelineProblemHDim128 = +template +struct FmhaBlockTile; + +template <> +struct FmhaBlockTile : ck::Sequence<128, 64, 32, 64, 32, 64> +{ +}; +template <> +struct FmhaBlockTile : ck::Sequence<128, 128, 32, 128, 32, 128> +{ +}; +using FmhaBlockWarps = ck::Sequence<4, 1, 1>; +using FmhaWarpTile = ck::Sequence<32, 32, 16>; + +template +struct FmhaShape; + +template <> +struct FmhaShape : ck::tile_program::TileFmhaShape, + FmhaBlockWarps, + FmhaWarpTile, + FmhaBlockWarps, + FmhaWarpTile, + VLayout> +{ +}; +template <> +struct FmhaShape + : ck::tile_program::TileFmhaShape, + FmhaBlockWarps, + FmhaWarpTile, + FmhaBlockWarps, + FmhaWarpTile, + VLayout> +{ +}; + +// using FmhaMask = ck::tile_program::block::MaskUpperTriangleFromTopLeftPredicate; +// using FmhaMask = ck::tile_program::block::MaskUpperTriangleFromBottomRightPredicate; +using FmhaMask = ck::tile_program::block::MaskDisabledPredicate; + +inline constexpr bool kM0NeedPadding = false; +inline constexpr bool kN0K1NeedPadding = false; +template +using FmhaTraits = ck::tile_program::TileFmhaTraits; + +template +using FmhaTilePartitioner = FmhaFwdTilePartitioner>; + +template +using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem; + /* BlockSize = */ 256, + FmhaShape, + kIsGroupMode, + FmhaMask, + FmhaTraits>; -using FmhaPipelineHDim64 = - ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync; -using FmhaPipelineHDim128 = - ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync; +template +using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync< + FmhaPipelineProblem>; -// using FmhaPipelineHDim64 = -// ck::tile_program::block::BlockFmhaPipelineQRKSVS; -// using FmhaPipelineHDim128 = -// ck::tile_program::block::BlockFmhaPipelineQRKSVS; +using FmhaEpilogue = FmhaFwdEpilogue>; -using FmhaEpilogue = FmhaFwdEpilogue>; -using FmhaKernelHDim64 = FmhaFwdKernel; -using FmhaKernelHDim128 = - FmhaFwdKernel; +template +using FmhaKernel = FmhaFwdKernel, + FmhaPipeline, + FmhaEpilogue>; -template +template float invoker_fmha_kernel(const void* q_ptr, const void* k_ptr, const void* v_ptr, + const void* bias_ptr, void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, ck::index_t batch, ck::index_t nhead, ck::index_t nhead_k, @@ -116,62 +153,117 @@ float invoker_fmha_kernel(const void* q_ptr, ck::index_t seqlen_k, ck::index_t hdim_q, ck::index_t hdim_v, + ck::index_t max_seqlen_q, float scale, bool i_perm, bool o_perm, StreamConfig stream_config) { - dim3 kGridSize = FmhaKernel::GridSize(batch, nhead, seqlen_q, hdim_v); - constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); - - constexpr ck::index_t kBlockPerCu = FmhaKernel::kBlockPerCu; - constexpr bool is_v_rowmajor = - ck::is_same_v; + ck::is_same_v; assert(nhead % nhead_k == 0); - // batch * nhead * seqlen * hdim or batch * seqlen * nhead * hdim - auto kargs = FmhaKernel::MakeKargs( - q_ptr, - k_ptr, - v_ptr, - o_ptr, - seqlen_q, // seqlen_q - seqlen_k, // seqlen_k - hdim_q, // hdim_q - hdim_v, // hdim_v - nhead / nhead_k, - scale, - i_perm ? hdim_q : nhead * hdim_q, // stride_q - i_perm ? hdim_q : nhead_k * hdim_q, // stride_k - [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return i_perm ? seqlen_k : nhead_k * seqlen_k; - }(), // stride_v - o_perm ? hdim_v : nhead * hdim_v, // stride_o - i_perm ? seqlen_q * hdim_q : hdim_q, // nhead_stride_q - i_perm ? seqlen_k * hdim_q : hdim_q, // nhead_stride_k - [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * seqlen_k : seqlen_k; - }(), // nhead_stride_v - o_perm ? seqlen_q * hdim_v : hdim_v, // nhead_stride_o - nhead * seqlen_q * hdim_q, // batch_stride_q - nhead_k * seqlen_k * hdim_q, // batch_stride_k - nhead_k * hdim_v * seqlen_k, // batch_stride_v - nhead * seqlen_q * hdim_v); // batch_stride_o - - float ave_time = launch_kernel(stream_config, - FmhaKernel{}, - kGridSize, - kBlockSize, - 0, - kargs); // BatchStrideO - return ave_time; + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias' + /// are 0. + // setup stride_* arguments + const ck::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck::index_t stride_v = [&]() { + if constexpr(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? seqlen_k : nhead_k * seqlen_k; + }(); + const ck::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k); + const ck::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q); + const ck::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q); + const ck::index_t nhead_stride_v = [&]() { + if constexpr(is_v_rowmajor) + return i_perm ? seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_k : seqlen_k; + }(); + const ck::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k); + const ck::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck::index_t batch_stride_q = (nhead * seqlen_q * hdim_q); + const ck::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q); + const ck::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k); + const ck::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k); + const ck::index_t batch_stride_o = (nhead * seqlen_q * hdim_v); + + const auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel_::kIsGroupMode) + { + return FmhaKernel_::MakeKargs(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + nhead / nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o); + } + else + { // create batch mode kernel arguments + return FmhaKernel_::MakeKargs(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + nhead / nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_o); + } + }(); + + const dim3 kGridSize = FmhaKernel_::GridSize(batch, nhead, max_seqlen_q, hdim_v); + constexpr dim3 kBlockSize = FmhaKernel_::BlockSize(); + + constexpr ck::index_t kBlockPerCu = FmhaKernel_::kBlockPerCu; + + return launch_kernel(stream_config, + FmhaKernel_{}, + kGridSize, + kBlockSize, + 0, + kargs); // BatchStrideO } static inline int env_get_int(const char* var_name, int default_int) @@ -187,6 +279,7 @@ auto create_args(int argc, char* argv[]) { ArgParser arg_parser; arg_parser.insert("v", "1", "weather do cpu validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("b", "2", "batch size") .insert("h", "8", "num of head, for q") .insert("h_k", @@ -203,6 +296,7 @@ auto create_args(int argc, char* argv[]) "permute input\n" "if true, will be b*h*s*d, else b*s*h*d") .insert("operm", "1", "permute output") + .insert("bias", "0", "add bias or not") .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float"); bool result = arg_parser.parse(argc, argv); @@ -216,6 +310,7 @@ int main(int argc, char* argv[]) return -1; int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); ck::index_t batch = arg_parser.get_int("b"); ck::index_t nhead = arg_parser.get_int("h"); ck::index_t nhead_k = arg_parser.get_int("h_k"); @@ -244,6 +339,8 @@ int main(int argc, char* argv[]) if(scale == .0f) scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + bool use_bias = arg_parser.get_uint32("bias"); + int init_method = arg_parser.get_int("init"); int stream_warmup = env_get_int("CK_WARMUP", 5); @@ -251,6 +348,36 @@ int main(int argc, char* argv[]) StreamConfig stream_config{nullptr, true, 0, stream_warmup, stream_repeat}; + const std::vector seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const std::vector seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + using namespace ck::literals; + + flop += nhead * (2_uz * real_seqlen_q * real_seqlen_k * hdim_q + + 2_uz * real_seqlen_q * hdim_v * real_seqlen_k); + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k + + sizeof(ODataType) * real_seqlen_q * hdim_v); + } + } + auto get_lengths = [&](int permute, ck::index_t b /*batch*/, ck::index_t h /*nhead*/, @@ -262,43 +389,61 @@ int main(int argc, char* argv[]) return std::array{b, s, h, d}; }; - constexpr bool is_v_rowmajor = - ck::is_same_v; - - // host verify - Tensor q_host(get_lengths(i_perm, batch, nhead, seqlen_q, hdim_q)); - Tensor k_host(get_lengths(i_perm, batch, nhead_k, seqlen_k, hdim_q)); - Tensor v_host(is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_k, hdim_v) - : get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_k)); - Tensor o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v)); + constexpr bool is_v_rowmajor = ck::is_same_v; + + // host memory for storing all the tensor elements + const ck::index_t shape_batch = (mode == Mode::Batch ? batch : 1); + const ck::index_t shape_seqlen_q = (mode == Mode::Batch ? seqlen_q : seqstart_q_host.back()); + const ck::index_t shape_seqlen_k = (mode == Mode::Batch ? seqlen_k : seqstart_k_host.back()); + + Tensor q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + Tensor k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + Tensor v_host( + is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); + // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host + // will not be used for verification at all (but will be copied to device anyway). + Tensor bias_host( + use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + Tensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); if(init_method == 0) { ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(bias_host); } else if(init_method == 1) { ck::utils::FillUniformDistribution{0.f, 1.f}(q_host); ck::utils::FillUniformDistribution{0.f, 1.f}(k_host); ck::utils::FillUniformDistribution{-.5f, .5f}(v_host); + ck::utils::FillUniformDistribution{0.f, 1.f}(bias_host); } else if(init_method == 2) { ck::utils::FillTrigValue{}(q_host); ck::utils::FillTrigValue{}(k_host); ck::utils::FillTrigValue{}(v_host); + ck::utils::FillTrigValue{}(bias_host); } - DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize()); - DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize()); - DeviceMem v_buf(sizeof(VDataType) * v_host.GetElementSpaceSize()); - DeviceMem o_buf(sizeof(ODataType) * o_host.GetElementSpaceSize()); - - q_buf.ToDevice(q_host.mData.data()); - k_buf.ToDevice(k_host.mData.data()); - v_buf.ToDevice(v_host.mData.data()); + DeviceMem q_buf(q_host.GetElementSpaceSizeInBytes()); + DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes()); + DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes()); + DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes()); + DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes()); + DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); // clang-format off auto layout_str = [&](int permute){ @@ -307,123 +452,171 @@ int main(int argc, char* argv[]) }; // clang-format on - std::cout << "b:" << batch << ", h:" << nhead << ", h_k:" << nhead_k << ", s:" << seqlen_q - << ", s_k:" << seqlen_k << ", d:" << hdim_q << ", d_v:" << hdim_v - << ", scale:" << scale << ", i:" << layout_str(i_perm) << ", o:" << layout_str(o_perm) - << ", v:" << std::string(FmhaKernelHDim64::VLayout::name)[0] << std::flush; + std::cout << "[" << mode << "] b:" << batch << ", h:" << nhead << ", h_k:" << nhead_k + << ", s:" << seqlen_q << ", s_k:" << seqlen_k << ", d:" << hdim_q + << ", d_v:" << hdim_v << ", scale:" << scale << ", i:" << layout_str(i_perm) + << ", o:" << layout_str(o_perm) << ", bias:" << use_bias + << ", v:" << std::string(VLayout::name)[0] << std::flush; float ave_time = 0; if(hdim_q == hdim_v && hdim_q == 64) - ave_time = invoker_fmha_kernel(q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - scale, - i_perm, - o_perm, - stream_config); + { + BOOL_SWITCH_2(mode == Mode::Group, kIsGroupMode, use_bias, kHasBias, [&] { + using Kernel = FmhaKernel; + + ave_time = invoker_fmha_kernel(q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + batch, + nhead, + nhead_k, + shape_seqlen_q, + shape_seqlen_k, + hdim_q, + hdim_v, + max_seqlen_q, + scale, + i_perm, + o_perm, + stream_config); + }); + } else if(hdim_q == hdim_v && hdim_q == 128) - ave_time = invoker_fmha_kernel(q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - scale, - i_perm, - o_perm, - stream_config); + { + BOOL_SWITCH_2(mode == Mode::Group, kIsGroupMode, use_bias, kHasBias, [&] { + using Kernel = FmhaKernel; + + ave_time = invoker_fmha_kernel(q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + batch, + nhead, + nhead_k, + shape_seqlen_q, + shape_seqlen_k, + hdim_q, + hdim_v, + max_seqlen_q, + scale, + i_perm, + o_perm, + stream_config); + }); + } else { - std::cout << "not support hdim, will not run" << std::endl; + std::cerr << "not support hdim, will not run" << std::endl; return -1; } - std::size_t flop = std::size_t(2) * batch * nhead * seqlen_q * seqlen_k * hdim_q + - std::size_t(2) * batch * nhead * seqlen_q * hdim_v * seqlen_k; - - // TODO: MQA/GQA case nhead is smaller, do we need to change this formular? - std::size_t num_btype = sizeof(QDataType) * batch * nhead * seqlen_q * hdim_q + - sizeof(KDataType) * batch * nhead * seqlen_k * hdim_q + - sizeof(VDataType) * batch * nhead * hdim_v * seqlen_k + - sizeof(ODataType) * batch * nhead * seqlen_q * hdim_v; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << ", " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::flush << std::endl; + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush << std::endl; if(do_validation) { - Tensor q_host_ref({batch * nhead, seqlen_q, hdim_q}); - Tensor k_host_ref( - {batch * nhead, seqlen_k, hdim_q}); // NOTE: expand nhead the same as q - const auto v_lengths = std::array{batch * nhead, hdim_v, seqlen_k}; - const auto v_strides = is_v_rowmajor - ? std::array{hdim_v * seqlen_k, 1, hdim_v} - : std::array{hdim_v * seqlen_k, seqlen_k, 1}; - Tensor v_host_ref(v_lengths, v_strides); - Tensor o_host_ref({batch * nhead, seqlen_q, hdim_v}); - Tensor o_host_result_ref(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v)); - - Tensor s_host_ref({batch * nhead, seqlen_q, seqlen_k}); - Tensor p_host_ref({batch * nhead, seqlen_q, seqlen_k}); - - ck::index_t nr = nhead / nhead_k; - -#define EACH_R for(ck::index_t r = 0; r < nr; r++) - // clang-format off - // permute - if(i_perm) q_host.ForEach([&](auto& self, auto i) { q_host_ref(i[0] * nhead + i[1], i[2], i[3]) = self(i); }); - else q_host.ForEach([&](auto& self, auto i) { q_host_ref(i[0] * nhead + i[2], i[1], i[3]) = self(i); }); - - if(i_perm) k_host.ForEach([&](auto& self, auto i) { EACH_R k_host_ref(i[0] * nhead + i[1] * nr + r, i[2], i[3]) = self(i); }); - else k_host.ForEach([&](auto& self, auto i) { EACH_R k_host_ref(i[0] * nhead + i[2] * nr + r, i[1], i[3]) = self(i); }); - - if constexpr (is_v_rowmajor) { - // v_host :b, h, s, d, v_host_ref : batch*hdim*seq - if(i_perm) v_host.ForEach([&](auto& self, auto i) { EACH_R v_host_ref(i[0] * nhead + i[1] * nr + r, i[3], i[2]) = self(i); }); - // v_host : b, s, h, d, v_host_ref : batch*hdim*seq - else v_host.ForEach([&](auto& self, auto i) { EACH_R v_host_ref(i[0] * nhead + i[2] * nr + r, i[3], i[1]) = self(i); }); - } - else { - if(i_perm) v_host.ForEach([&](auto& self, auto i) { EACH_R v_host_ref(i[0] * nhead + i[1] * nr + r, i[2], i[3]) = self(i); }); - else v_host.ForEach([&](auto& self, auto i) { EACH_R v_host_ref(i[0] * nhead + i[2] * nr + r, i[1], i[3]) = self(i); }); + o_buf.FromDevice(o_host.data()); + + for(ck::index_t wb = 0; wb < batch; ++wb) + { + const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck::index_t b = (mode == Mode::Batch ? wb : 0); + const ck::index_t query_offset = (mode == Mode::Batch ? 0 : seqstart_q_host[wb]); + const ck::index_t key_offset = (mode == Mode::Batch ? 0 : seqstart_k_host[wb]); + + const auto v_host_ref_lengths = + std::array{nhead, hdim_v, real_seqlen_k}; + const auto v_host_ref_strides = + is_v_rowmajor + ? std::array{hdim_v * real_seqlen_k, 1, hdim_v} + : std::array{hdim_v * real_seqlen_k, real_seqlen_k, 1}; + + Tensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + Tensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + Tensor v_host_ref(v_host_ref_lengths, v_host_ref_strides); + Tensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + Tensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + Tensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + + ck::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + if constexpr (is_v_rowmajor) { + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + } + else { + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); + } + + // reference + reference_batched_gemm( + q_host_ref, k_host_ref, s_host_ref, + ck::identity{}, ck::identity{}, + [&](SaccDataType x) { return scale * x; }); + + if(use_bias) + { + Tensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, real_seqlen_k] + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + + reference_batched_masking(s_host_ref); + reference_batched_softmax(s_host_ref, p_host_ref); + reference_batched_gemm(p_host_ref, v_host_ref, o_host_ref); + + Tensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + // clang-format on + + if(!ck::utils::check_err(o_host_result, o_host_ref)) + { + std::cerr << "mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + return -1; + } } -#undef EACH_R - - // reference - reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, - [](const QDataType& x) { return x; }, - [](const KDataType& x) { return x; }, - [&scale](const SaccDataType& x) { return scale * x; }); - reference_batched_softmax(s_host_ref, - p_host_ref); - reference_batched_gemm( - p_host_ref, v_host_ref, o_host_ref); - - // permute - if(o_perm) o_host_result_ref.ForEach([&](auto& self, auto i) { self(i) = o_host_ref(i[0] * nhead + i[1], i[2], i[3]); }); - else o_host_result_ref.ForEach([&](auto& self, auto i) { self(i) = o_host_ref(i[0] * nhead + i[2], i[1], i[3]); }); - // clang-format on - - o_buf.FromDevice(o_host.mData.data()); - return !ck::utils::check_err(o_host, o_host_result_ref); } else { diff --git a/example/91_tile_program/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha_fwd_kernel.hpp index 87faeeae4..ca8d0930b 100644 --- a/example/91_tile_program/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha_fwd_kernel.hpp @@ -3,15 +3,21 @@ #pragma once +#include + #include "ck/utility/common_header.hpp" #include "ck/tensor/tensor_view.hpp" #include "ck/tile_program/tile/tile_window.hpp" // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] // P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] +#ifndef C_LOG2E #define C_LOG2E 1.44269504088896340736 // log2(e) +#endif template struct FmhaFwdKernel @@ -22,19 +28,77 @@ struct FmhaFwdKernel static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; - using QDataType = ck::remove_cvref_t; - using KDataType = ck::remove_cvref_t; - using VDataType = ck::remove_cvref_t; - using ODataType = ck::remove_cvref_t; + using QDataType = ck::remove_cvref_t; + using KDataType = ck::remove_cvref_t; + using VDataType = ck::remove_cvref_t; + using BiasDataType = ck::remove_cvref_t; + using ODataType = ck::remove_cvref_t; using VLayout = ck::remove_cvref_t; - struct Kargs + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kM0NeedPadding = FmhaPipeline::kM0NeedPadding; + static constexpr bool kN0K1NeedPadding = FmhaPipeline::kN0K1NeedPadding; + static constexpr bool kHasBias = FmhaPipeline::kHasBias; + + using C0MatrixMask = ck::tile_program::block::C0MatrixMask_impl< + ck::remove_cvref_t>; + + private: + struct EmptyKargs + { + }; + + struct CommonKargs { - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - void* o_ptr; + __host__ constexpr CommonKargs(const void* q_ptr_, + const void* k_ptr_, + const void* v_ptr_, + void* o_ptr_, + ck::index_t seqlen_q_, + ck::index_t seqlen_k_, + ck::index_t hdim_q_, + ck::index_t hdim_v_, + ck::index_t nhead_ratio_qk_, + float scale_, + ck::index_t stride_q_, + ck::index_t stride_k_, + ck::index_t stride_v_, + ck::index_t stride_o_, + ck::index_t nhead_stride_q_, + ck::index_t nhead_stride_k_, + ck::index_t nhead_stride_v_, + ck::index_t nhead_stride_o_) + : q_ptr{reinterpret_cast(q_ptr_)}, + k_ptr{reinterpret_cast(k_ptr_)}, + v_ptr{reinterpret_cast(v_ptr_)}, + o_ptr{reinterpret_cast(o_ptr_)}, + seqlen_q{seqlen_q_}, + seqlen_k{seqlen_k_}, + hdim_q{hdim_q_}, + hdim_v{hdim_v_}, + nhead_ratio_qk{nhead_ratio_qk_}, +#if CK_FMHA_FWD_FAST_EXP2 + scale{static_cast(scale_ * C_LOG2E)}, +#else + scale{scale_}, +#endif + stride_q{stride_q_}, + stride_k{stride_k_}, + stride_v{stride_v_}, + stride_o{stride_o_}, + nhead_stride_q{nhead_stride_q_}, + nhead_stride_k{nhead_stride_k_}, + nhead_stride_v{nhead_stride_v_}, + nhead_stride_o{nhead_stride_o_} + { + } + + const QDataType* q_ptr; + const KDataType* k_ptr; + const VDataType* v_ptr; + ODataType* o_ptr; + ck::index_t seqlen_q; ck::index_t seqlen_k; ck::index_t hdim_q; @@ -54,6 +118,69 @@ struct FmhaFwdKernel ck::index_t nhead_stride_k; ck::index_t nhead_stride_v; ck::index_t nhead_stride_o; + }; + + struct CommonBiasKargs + { + const BiasDataType* bias_ptr = nullptr; + ck::index_t stride_bias = 0; + ck::index_t nhead_stride_bias = 0; + }; + + struct BatchModeBiasKargs : CommonBiasKargs + { + ck::index_t batch_stride_bias = 0; + }; + + struct BatchModeKargs : CommonKargs, + std::conditional_t + { + __host__ constexpr BatchModeKargs(const void* q_ptr_, + const void* k_ptr_, + const void* v_ptr_, + void* o_ptr_, + ck::index_t seqlen_q_, + ck::index_t seqlen_k_, + ck::index_t hdim_q_, + ck::index_t hdim_v_, + ck::index_t nhead_ratio_qk_, + float scale_, + ck::index_t stride_q_, + ck::index_t stride_k_, + ck::index_t stride_v_, + ck::index_t stride_o_, + ck::index_t nhead_stride_q_, + ck::index_t nhead_stride_k_, + ck::index_t nhead_stride_v_, + ck::index_t nhead_stride_o_, + ck::index_t batch_stride_q_, + ck::index_t batch_stride_k_, + ck::index_t batch_stride_v_, + ck::index_t batch_stride_o_) + : CommonKargs{q_ptr_, + k_ptr_, + v_ptr_, + o_ptr_, + seqlen_q_, + seqlen_k_, + hdim_q_, + hdim_v_, + nhead_ratio_qk_, + scale_, + stride_q_, + stride_k_, + stride_v_, + stride_o_, + nhead_stride_q_, + nhead_stride_k_, + nhead_stride_v_, + nhead_stride_o_}, + batch_stride_q{batch_stride_q_}, + batch_stride_k{batch_stride_k_}, + batch_stride_v{batch_stride_v_}, + batch_stride_o{batch_stride_o_} + { + } ck::index_t batch_stride_q; ck::index_t batch_stride_k; @@ -61,41 +188,156 @@ struct FmhaFwdKernel ck::index_t batch_stride_o; }; - __host__ static constexpr Kargs MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - void* o_ptr, - ck::index_t seqlen_q, - ck::index_t seqlen_k, - ck::index_t hdim_q, - ck::index_t hdim_v, - ck::index_t nhead_ratio_qk, - float scale, - ck::index_t stride_q, - ck::index_t stride_k, - ck::index_t stride_v, - ck::index_t stride_o, - ck::index_t nhead_stride_q, - ck::index_t nhead_stride_k, - ck::index_t nhead_stride_v, - ck::index_t nhead_stride_o, - ck::index_t batch_stride_q, - ck::index_t batch_stride_k, - ck::index_t batch_stride_v, - ck::index_t batch_stride_o) + struct GroupModeKargs : CommonKargs, std::conditional_t { - return Kargs + __host__ constexpr GroupModeKargs(const void* q_ptr_, + const void* k_ptr_, + const void* v_ptr_, + void* o_ptr_, + const void* seqstart_q_ptr_, + const void* seqstart_k_ptr_, + const void* seqlen_k_ptr_, + ck::index_t hdim_q_, + ck::index_t hdim_v_, + ck::index_t nhead_ratio_qk_, + float scale_, + ck::index_t stride_q_, + ck::index_t stride_k_, + ck::index_t stride_v_, + ck::index_t stride_o_, + ck::index_t nhead_stride_q_, + ck::index_t nhead_stride_k_, + ck::index_t nhead_stride_v_, + ck::index_t nhead_stride_o_) + : CommonKargs{q_ptr_, + k_ptr_, + v_ptr_, + o_ptr_, + -1 /* will be updated inside the kernel */, + -1 /* will be updated inside the kernel */, + hdim_q_, + hdim_v_, + nhead_ratio_qk_, + scale_, + stride_q_, + stride_k_, + stride_v_, + stride_o_, + nhead_stride_q_, + nhead_stride_k_, + nhead_stride_v_, + nhead_stride_o_}, + seqstart_q_ptr{reinterpret_cast(seqstart_q_ptr_)}, + seqstart_k_ptr{reinterpret_cast(seqstart_k_ptr_)}, + seqlen_k_ptr{reinterpret_cast(seqlen_k_ptr_)} { - q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, seqlen_k, hdim_q, hdim_v, nhead_ratio_qk, -#if CK_FMHA_FWD_FAST_EXP2 - static_cast(scale * C_LOG2E), -#else - scale, -#endif - stride_q, stride_k, stride_v, stride_o, nhead_stride_q, nhead_stride_k, - nhead_stride_v, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, - batch_stride_o - }; + } + + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + public: + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_ptr, + ck::index_t seqlen_q, + ck::index_t seqlen_k, + ck::index_t hdim_q, + ck::index_t hdim_v, + ck::index_t nhead_ratio_qk, + float scale, + ck::index_t stride_q, + ck::index_t stride_k, + ck::index_t stride_v, + ck::index_t stride_bias, + ck::index_t stride_o, + ck::index_t nhead_stride_q, + ck::index_t nhead_stride_k, + ck::index_t nhead_stride_v, + ck::index_t nhead_stride_bias, + ck::index_t nhead_stride_o, + ck::index_t batch_stride_q, + ck::index_t batch_stride_k, + ck::index_t batch_stride_v, + ck::index_t batch_stride_bias, + ck::index_t batch_stride_o) + { + Kargs kargs{q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, + seqlen_k, hdim_q, hdim_v, nhead_ratio_qk, scale, + stride_q, stride_k, stride_v, stride_o, nhead_stride_q, + nhead_stride_k, nhead_stride_v, nhead_stride_o, batch_stride_q, batch_stride_k, + batch_stride_v, batch_stride_o}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = reinterpret_cast(bias_ptr); + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck::index_t hdim_q, + ck::index_t hdim_v, + ck::index_t nhead_ratio_qk, + float scale, + ck::index_t stride_q, + ck::index_t stride_k, + ck::index_t stride_v, + ck::index_t stride_bias, + ck::index_t stride_o, + ck::index_t nhead_stride_q, + ck::index_t nhead_stride_k, + ck::index_t nhead_stride_v, + ck::index_t nhead_stride_bias, + ck::index_t nhead_stride_o) + { + Kargs kargs{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + nhead_ratio_qk, + scale, + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = reinterpret_cast(bias_ptr); + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + + return kargs; } __host__ static constexpr auto GridSize(ck::index_t batch_size_, @@ -129,57 +371,163 @@ struct FmhaFwdKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(ck::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(kHasBias) + { + batch_offset_bias = query_start * kargs.stride_bias + key_start; + } + else + { + batch_offset_bias = key_start; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(kHasBias) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + - i_nhead * kargs.nhead_stride_q + i_batch * kargs.batch_stride_q; - const KDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + - (i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + - i_batch * kargs.batch_stride_k; - const VDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + - (i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + - i_batch * kargs.batch_stride_v; - ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + - i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o; + const QDataType* q_ptr = kargs.q_ptr + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + kargs.k_ptr + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + kargs.v_ptr + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = kargs.o_ptr + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; // Q/K/V DRAM and DRAM window - const auto q_dram = make_naive_tensor_view( - q_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_q), - make_tuple(kargs.stride_q, 1), - Number<32>{}, - Number<1>{}); - - const auto k_dram = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.seqlen_k, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - Number<32>{}, - Number<1>{}); + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + Number<32>{}, + Number<1>{}); + + return pad_tensor_view(q_dram_naive, + make_tuple(Number{}, Number<1>{}), + Sequence{}); + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + Number<32>{}, + Number<1>{}); + return pad_tensor_view(k_dram_naive, + make_tuple(Number{}, Number<1>{}), + Sequence{}); + }(); const auto v_dram = [&]() { if constexpr(ck::is_same_v) { - const auto v_dram_tmp = make_naive_tensor_view( + const auto v_dram_naive = make_naive_tensor_view( v_ptr, make_tuple(kargs.seqlen_k, kargs.hdim_v), make_tuple(kargs.stride_v, 1), Number<32>{}, Number<1>{}); - return transform_tensor_view( - v_dram_tmp, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.seqlen_k), + make_pass_through_transform(kargs.hdim_v)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + /// FIXME: The return value of v_dram_naive.GetTensorDescriptor().GetLength() is + /// same as + /// v_dram_transposed.GetTensorDescriptor().GetLength(). Replace following + /// if-clause by pad_tensor_view() call after fixing this issue. + if constexpr(kN0K1NeedPadding) + { + const index_t pad_length = + FmhaPipeline::kK1 * + ck::math::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kK1) - + kargs.seqlen_k; + + return transform_tensor_view( + v_dram_transposed, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_right_pad_transform(kargs.seqlen_k, pad_length)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + return v_dram_transposed; + } } else { - return make_naive_tensor_view( + const auto v_dram_naive = make_naive_tensor_view( v_ptr, make_tuple(kargs.hdim_v, kargs.seqlen_k), make_tuple(kargs.stride_v, 1), Number<32>{}, Number<1>{}); + + return pad_tensor_view(v_dram_naive, + make_tuple(Number<1>{}, Number{}), + Sequence{}); } }(); @@ -201,22 +549,65 @@ struct FmhaFwdKernel make_tile_window(v_dram, make_tuple(Number{}, Number{}), {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables is not supported. Remove + /// following copy capture of the 'i_nhead' + /// if compiled in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(Number{}, Number{}); + if constexpr(kHasBias) + { + const BiasDataType* bias_ptr = + kargs.bias_ptr + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; - auto o_acc_tile = FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - kargs.scale, - kargs.seqlen_k / FmhaPipeline::kN0, - kargs.hdim_q / FmhaPipeline::kK0, - smem_ptr); + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + Number<32>{}, + Number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + Sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + C0MatrixMask casual_mask{kargs.seqlen_q, kargs.seqlen_k}; + + auto o_acc_tile = + FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + casual_mask, + kargs.scale, + ck::math::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0), + ck::math::integer_divide_ceil(kargs.hdim_q, FmhaPipeline::kK0), + smem_ptr); // O DRAM and O DRAM window - auto o_dram = make_naive_tensor_view( - o_ptr, - make_tuple(kargs.seqlen_q, kargs.hdim_v), - make_tuple(kargs.stride_o, 1), - Number<32>{}, - Number<1>{}); + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + Number<32>{}, + Number<1>{}); + + return pad_tensor_view(o_dram_naive, + make_tuple(Number{}, Number<1>{}), + Sequence{}); + }(); auto o_dram_window = make_tile_window(o_dram, diff --git a/example/91_tile_program/fmha_utils.hpp b/example/91_tile_program/fmha_utils.hpp new file mode 100644 index 000000000..e6bdc4a49 --- /dev/null +++ b/example/91_tile_program/fmha_utils.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#pragma once + +enum class Mode : unsigned +{ + Batch, + Group +}; + +inline std::ostream& operator<<(std::ostream& stream, Mode mode) +{ + return stream << (mode == Mode::Batch ? "batch" : "group"); +} + +/// TODO: make sure result is valid for MaskUpperTriangleFromBottomRightPredicate +std::vector generate_seqstarts(Mode mode, + unsigned count, + int32_t seqlens_sum, + std::optional seed = std::nullopt) +{ + assert(0 < count); + + const std::vector seqlens = [&]() { + std::vector original_seqlens(count, seqlens_sum); + + if(mode == Mode::Group && 1 < count) + { + using size_type = std::vector::size_type; + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); + + for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + if(original_seqlens[to_decrease] == 1) + { + continue; + } + + const size_type to_increase = (to_decrease + next_step()) % count; + + --original_seqlens[to_decrease]; + ++original_seqlens[to_increase]; + } + } + + return original_seqlens; + }(); + + std::vector seqstarts = {0}; + for(int32_t seqlen : seqlens) + { + seqstarts.push_back(seqstarts.back() + seqlen); + } + return seqstarts; +} diff --git a/example/91_tile_program/reference_batched_elementwise.hpp b/example/91_tile_program/reference_batched_elementwise.hpp new file mode 100644 index 000000000..cf5beec2d --- /dev/null +++ b/example/91_tile_program/reference_batched_elementwise.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template > +void reference_batched_elementwise(const Tensor& a_b_m_n, + const Tensor& b_b_m_n, + Tensor& c_b_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const BinaryElementOp& binary_element_op = {}) +{ + const ck::index_t N = c_b_m_n.mDesc.GetLengths()[2]; + + const bool broadcast_a_dim_b = (a_b_m_n.GetLengths()[0] == 1); + const bool broadcast_a_dim_m = (a_b_m_n.GetLengths()[1] == 1); + const bool broadcast_a_dim_n = (a_b_m_n.GetLengths()[2] == 1); + + const bool broadcast_b_dim_b = (b_b_m_n.GetLengths()[0] == 1); + const bool broadcast_b_dim_m = (b_b_m_n.GetLengths()[1] == 1); + const bool broadcast_b_dim_n = (b_b_m_n.GetLengths()[2] == 1); + + auto f = [&](auto batch, auto m) { + for(ck::index_t n = 0; n < N; ++n) + { + AccDataType v_a{}; + { + ck::index_t i_b = (broadcast_a_dim_b ? 0 : batch); + ck::index_t i_m = (broadcast_a_dim_m ? 0 : m); + ck::index_t i_n = (broadcast_a_dim_n ? 0 : n); + + v_a = ck::type_convert(a_element_op(a_b_m_n(i_b, i_m, i_n))); + } + + AccDataType v_b{}; + { + ck::index_t i_b = (broadcast_b_dim_b ? 0 : batch); + ck::index_t i_m = (broadcast_b_dim_m ? 0 : m); + ck::index_t i_n = (broadcast_b_dim_n ? 0 : n); + + v_b = ck::type_convert(b_element_op(b_b_m_n(i_b, i_m, i_n))); + } + + c_b_m_n(batch, m, n) = ck::type_convert(binary_element_op(v_a, v_b)); + } + }; + + make_ParallelTensorFunctor(f, c_b_m_n.mDesc.GetLengths()[0], c_b_m_n.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); +} diff --git a/example/91_tile_program/reference_batched_gemm.hpp b/example/91_tile_program/reference_batched_gemm.hpp index a29af3e30..f4e03fcc9 100644 --- a/example/91_tile_program/reference_batched_gemm.hpp +++ b/example/91_tile_program/reference_batched_gemm.hpp @@ -10,15 +10,15 @@ template + typename AElementOp = ck::identity, + typename BElementOp = ck::identity, + typename ACCElementOp = ck::identity> void reference_batched_gemm(const Tensor& a_b_m_k, const Tensor& b_b_n_k, Tensor& c_b_m_n, - const AElementOp& a_element_op, - const BElementOp& b_element_op, - const ACCElementOp& acc_element_op) + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const int N = b_b_n_k.mDesc.GetLengths()[1]; const int K = b_b_n_k.mDesc.GetLengths()[2]; @@ -43,17 +43,3 @@ void reference_batched_gemm(const Tensor& a_b_m_k, make_ParallelTensorFunctor(f, c_b_m_n.mDesc.GetLengths()[0], c_b_m_n.mDesc.GetLengths()[1])( std::thread::hardware_concurrency()); } - -template -void reference_batched_gemm(const Tensor& a_b_m_k, - const Tensor& b_b_n_k, - Tensor& c_b_m_n) -{ - reference_batched_gemm( - a_b_m_k, - b_b_n_k, - c_b_m_n, - [](const ADataType& x) { return x; }, - [](const BDataType& x) { return x; }, - [](const AccDataType& x) { return x; }); -} diff --git a/example/91_tile_program/reference_batched_masking.hpp b/example/91_tile_program/reference_batched_masking.hpp new file mode 100644 index 000000000..3351dcd4f --- /dev/null +++ b/example/91_tile_program/reference_batched_masking.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/tile_program/block_tile/block_masking_specialization.hpp" + +template +void reference_batched_masking(Tensor& c_b_m_n) +{ + const int M = c_b_m_n.mDesc.GetLengths()[1]; + const int N = c_b_m_n.mDesc.GetLengths()[2]; + + const int MNDiff = M - N; + + auto f = [&](auto batch) { + for(int n = 0; n < N; ++n) + { + for(int m = 0; m < M; ++m) + { + if constexpr(std::is_same_v< + MaskingType, + ck::tile_program::block::MaskUpperTriangleFromTopLeftPredicate>) + { + if(n > m) + { + c_b_m_n(batch, m, n) = -ck::NumericLimits::Infinity(); + } + } + else if constexpr(std::is_same_v) + { + if(n > m - MNDiff) + { + c_b_m_n(batch, m, n) = -ck::NumericLimits::Infinity(); + } + } + } + } + }; + + make_ParallelTensorFunctor(f, + c_b_m_n.mDesc.GetLengths()[0])(std::thread::hardware_concurrency()); +} diff --git a/example/91_tile_program/reference_gemm.hpp b/example/91_tile_program/reference_gemm.hpp index a558e5719..1972214b9 100644 --- a/example/91_tile_program/reference_gemm.hpp +++ b/example/91_tile_program/reference_gemm.hpp @@ -6,10 +6,19 @@ #include "ck/utility/common_header.hpp" #include "ck/library/utility/host_tensor.hpp" -template +template void reference_gemm(const Tensor& a_m_k, const Tensor& b_n_k, - Tensor& c_m_n) + Tensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const int N = b_n_k.mDesc.GetLengths()[0]; const int K = b_n_k.mDesc.GetLengths()[1]; @@ -21,13 +30,13 @@ void reference_gemm(const Tensor& a_m_k, for(int k = 0; k < K; ++k) { - ADataType v_a = a_m_k(m, k); - BDataType v_b = b_n_k(n, k); + ADataType v_a = a_element_op(a_m_k(m, k)); + BDataType v_b = b_element_op(b_n_k(n, k)); v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); } - c_m_n(m, n) = ck::type_convert(v_acc); + c_m_n(m, n) = ck::type_convert(acc_element_op(v_acc)); } }; diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index 3e44faecb..ca1f91ad8 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -4,6 +4,8 @@ #pragma once #include +#include + #include // To be removed, which really does not tell the location of failed HIP functional call diff --git a/include/ck/host_utility/io.hpp b/include/ck/host_utility/io.hpp index 55734bab2..7578537be 100644 --- a/include/ck/host_utility/io.hpp +++ b/include/ck/host_utility/io.hpp @@ -13,15 +13,33 @@ template std::ostream& operator<<(std::ostream& os, const std::vector& v) { - std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); - return os; + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; } template std::ostream& operator<<(std::ostream& os, const std::array& v) { - std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); - return os; + os << "["; + for(std::size_t idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; } template diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index e0d6e32ec..a58eca557 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -54,11 +54,11 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #endif hipEvent_t start, stop; - hip_check_error(hipEventCreate(&start)); - hip_check_error(hipEventCreate(&stop)); + HIP_CHECK_ERROR(hipEventCreate(&start)); + HIP_CHECK_ERROR(hipEventCreate(&stop)); - hip_check_error(hipDeviceSynchronize()); - hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipEventRecord(start, stream_config.stream_id_)); for(int i = 0; i < nrepeat; ++i) { @@ -66,12 +66,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config, hip_check_error(hipGetLastError()); } - hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); - hip_check_error(hipEventSynchronize(stop)); + HIP_CHECK_ERROR(hipEventRecord(stop, stream_config.stream_id_)); + HIP_CHECK_ERROR(hipEventSynchronize(stop)); float total_time = 0; - hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); return total_time / nrepeat; } @@ -125,11 +125,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #endif hipEvent_t start, stop; - hip_check_error(hipEventCreate(&start)); - hip_check_error(hipEventCreate(&stop)); + HIP_CHECK_ERROR(hipEventCreate(&start)); + HIP_CHECK_ERROR(hipEventCreate(&stop)); - hip_check_error(hipDeviceSynchronize()); - hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipEventRecord(start, stream_config.stream_id_)); for(int i = 0; i < nrepeat; ++i) { @@ -138,12 +138,12 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, hip_check_error(hipGetLastError()); } - hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); - hip_check_error(hipEventSynchronize(stop)); + HIP_CHECK_ERROR(hipEventRecord(stop, stream_config.stream_id_)); + HIP_CHECK_ERROR(hipEventSynchronize(stop)); float total_time = 0; - hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); return total_time / nrepeat; } diff --git a/include/ck/tensor/tensor_view.hpp b/include/ck/tensor/tensor_view.hpp index da1c544a9..03a8fcadf 100644 --- a/include/ck/tensor/tensor_view.hpp +++ b/include/ck/tensor/tensor_view.hpp @@ -122,6 +122,11 @@ struct TensorView TensorDesc desc_; }; +// placeholder type if we want to opt-out a tile view parameter +struct NullTensorView +{ +}; + template @@ -192,4 +197,47 @@ __host__ __device__ constexpr auto transform_tensor_view(const OldTensorView& ol old_tensor_view.buf_, new_desc}; } +template + typename DoPads> // Sequence +__host__ __device__ constexpr auto +pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads) +{ + constexpr index_t num_dim = DoPads::Size(); + + static_assert(num_dim == TileLengths::Size() && num_dim == TensorView::GetNumOfDimension(), + "wrong! inconsistent # of dimensions"); + + // transforms + const auto transforms = generate_tuple( + [&](auto idim) { + const auto old_length = tensor_view.GetTensorDescriptor().GetLength(idim); + + const auto tile_length = tile_lengths[idim]; + + const auto new_length = + math::integer_divide_ceil(old_length, tile_length) * tile_length; + + const auto pad_length = new_length - old_length; + + constexpr bool DoPad = DoPads::At(idim); + + const auto transform = + conditional_expr(make_right_pad_transform(old_length, pad_length), + make_pass_through_transform(old_length)); + + return transform; + }, + Number{}); + + // lower dimension Id + const auto lower_dimss = + generate_tuple([&](auto idim) { return Sequence{}; }, Number{}); + + // upper dimension Id + const auto upper_dimss = lower_dimss; + + return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss); +} + } // namespace ck diff --git a/include/ck/tile_program/block_tile/block_masking_specialization.hpp b/include/ck/tile_program/block_tile/block_masking_specialization.hpp new file mode 100644 index 000000000..e9f067733 --- /dev/null +++ b/include/ck/tile_program/block_tile/block_masking_specialization.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tile_program { +namespace block { + +struct MaskDisabledPredicate +{ + __host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const + { + return false; + }; + + __host__ __device__ constexpr bool + IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const + { + return false; + } +}; + +struct MaskUpperTriangleFromTopLeftPredicate +{ + __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; } + + __host__ __device__ constexpr bool + IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const + { + return operator()(m + m_tile - 1, n); + } +}; + +// eg: m = 3, n = 5 => offset = 2 +// so matrix(n > m + offset) = 0 +// 1 2 3 4 5 +// 1 * * * 0 0 +// 2 * * * * 0 +// 3 * * * * * +struct MaskUpperTriangleFromBottomRightPredicate +{ + __host__ __device__ void SetDiagonalOffset(const index_t diagonal_offset) + { + diagonal_offset_ = diagonal_offset; + } + __host__ __device__ constexpr bool operator()(index_t m, index_t n) const + { + return n > (m - diagonal_offset_); + } + + __host__ __device__ constexpr bool IsTileSkippable(index_t m_tile_orig, + index_t n_tile_orig, + index_t m_tile_size, + index_t /*n_tile_size*/) const + { + return operator()(m_tile_orig + m_tile_size - 1, n_tile_orig); + } + + private: + index_t diagonal_offset_; +}; + +// to track the points which need to be set to -inf on C0 +// Note: no need to reset M padding value, because they will not be stored out. +template +struct C0MatrixMask_impl +{ + using MaskOutPredicate = MaskOutPredicate_; + + __host__ __device__ C0MatrixMask_impl(index_t MRaw, index_t NRaw) + : NRaw_(NRaw), predicate_(MaskOutPredicate{}) + { + if constexpr(std::is_same_v) + { + predicate_.SetDiagonalOffset(MRaw - NRaw); + } + } + + __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const + { + return n >= NRaw_; + } + + __host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const + { + return predicate_(m, n) || IsNOutOfBound(n); + } + + __host__ __device__ constexpr bool + IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const + { + return predicate_.IsTileSkippable(m, n, m_tile, n_tile); + } + + private: + // index_t MRaw_; + index_t NRaw_; + MaskOutPredicate predicate_; +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp index 7804ff259..e0c19af03 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp @@ -15,11 +15,15 @@ template + typename BlockFmhaShape_, + bool kIsGroupMode_, + typename BlockFmhaMask_, + typename Traits_> struct BlockFmhaPipelineProblem { using QDataType = remove_cvref_t; @@ -27,12 +31,22 @@ struct BlockFmhaPipelineProblem using VDataType = remove_cvref_t; using SaccDataType = remove_cvref_t; using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; + using BlockFmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; static constexpr index_t kBlockSize = kBlockSize_; + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kM0NeedPadding = Traits::kM0NeedPadding; + static constexpr bool kN0K1NeedPadding = Traits::kN0K1NeedPadding; + static constexpr bool kHasBias = Traits::kHasBias; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; } // namespace block diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index e468f79b2..5dedf714b 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -16,9 +16,14 @@ #include "ck/tile_program/tile/slice_tile.hpp" #include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck/tile_program/block_tile/block_masking_specialization.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp" #include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" +#ifndef C_LOG2E +#define C_LOG2E 1.44269504088896340736 // log2(e) +#endif + namespace ck { namespace tile_program { namespace block { @@ -32,15 +37,17 @@ struct BlockFmhaPipelineQRKSVS using VDataType = remove_cvref_t; using SaccDataType = remove_cvref_t; using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using BlockFmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once - static constexpr index_t kBlockPerCu = BlockFmhaShape::kBlockPerCu; + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kM0 = BlockFmhaShape::kM0; @@ -50,6 +57,11 @@ struct BlockFmhaPipelineQRKSVS static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; + static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; + static constexpr bool kHasBias = Problem::kHasBias; + __host__ __device__ static constexpr ck::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -58,9 +70,12 @@ struct BlockFmhaPipelineQRKSVS template + typename VElementFunction, + typename BiasElementFunction, + typename CausalMask> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -68,6 +83,9 @@ struct BlockFmhaPipelineQRKSVS const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + CausalMask causal_mask, float scale, index_t num_total_loop, index_t /*num_sub_loop_qk*/, // in this pipeline, the 1st gemm loop must be static @@ -83,7 +101,9 @@ struct BlockFmhaPipelineQRKSVS kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && - kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], "wrong!"); // K tile in LDS @@ -144,7 +164,14 @@ struct BlockFmhaPipelineQRKSVS v_dram_block_window_tmp.GetWindowOrigin(), Policy::template MakeVDramTileDistribution()); - auto q_tile = tile_elementwise_in(q_element_func, q); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.GetBottomTensorView(), + bias_dram_block_window_tmp.GetWindowLengths(), + bias_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeBiasDramTileDistribution()); + + const auto q_origin = q_dram_window.GetWindowOrigin(); + auto q_tile = tile_elementwise_in(q_element_func, q); // prefetch K tile index_t i_total_loops = 0; @@ -152,6 +179,13 @@ struct BlockFmhaPipelineQRKSVS constexpr index_t k1_loops = kN0 / kK1; do { + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + if(causal_mask.IsTileSkippable( + q_origin.At(Number<0>{}), k_origin.At(Number<0>{}), kM0, kN0)) + { + continue; + } + // STAGE 1, QK gemm auto k_dram_window = make_tile_window( k_dram_block_window.GetBottomTensorView(), @@ -186,8 +220,9 @@ struct BlockFmhaPipelineQRKSVS }); } - const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile - { // tail + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail block_sync_lds(); gemm_0(s_acc, get_slice_tile(q_tile, @@ -206,10 +241,39 @@ struct BlockFmhaPipelineQRKSVS k_lds_window); } - // STAGE 2, scale softmax + // STAGE 2, scale, add bias, mask, softmax + if constexpr(is_null_tile_window(bias_dram_window)) + { #if !CK_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); #endif + } + else + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert(bias_element_func(y)); +#else + x = scale * x + + C_LOG2E * type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kN0K1NeedPadding || + !is_same_v) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + + return causal_mask.IsMaskedElement(row, col); + }); + } const auto s = cast_tile(s_acc); // S{j} auto m_local = block_tile_reduce( @@ -235,7 +299,14 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_FMHA_FWD_FAST_EXP2 - p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + if constexpr(is_null_tile_window(bias_dram_window)) + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } + else + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - m[i_idx]); + } #else p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); #endif @@ -251,8 +322,17 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 - auto row_max = scale * m[i_idx]; - const auto tmp = math::exp2(scale * m_old[i_idx] - row_max); + const auto tmp = [&]() { + if constexpr(is_null_tile_window(bias_dram_window)) + { + auto row_max = scale * m[i_idx]; + return math::exp2(scale * m_old[i_idx] - row_max); + } + else + { + return math::exp2(m_old[i_idx] - m[i_idx]); + } + }(); #else const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); #endif @@ -343,27 +423,33 @@ struct BlockFmhaPipelineQRKSVS template + typename VDramBlockWindowTmp, + typename BiasDramBlockWindowTmp, + typename CausalMask> __host__ __device__ auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + CausalMask causal_mask, float scale, index_t num_total_loop, index_t num_sub_loop_qk, void* smem_ptr) const { - return operator()( - q_dram_block_window_tmp, - [](const QDataType& x) { return x; }, - k_dram_block_window_tmp, - [](const KDataType& x) { return x; }, - v_dram_block_window_tmp, - [](const VDataType& x) { return x; }, - scale, - num_total_loop, - num_sub_loop_qk, - smem_ptr); + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + causal_mask, + scale, + num_total_loop, + num_sub_loop_qk, + smem_ptr); } }; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index af9aab48a..88a7864bf 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -16,9 +16,14 @@ #include "ck/tile_program/tile/slice_tile.hpp" #include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck/tile_program/block_tile/block_masking_specialization.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp" #include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" +#ifndef C_LOG2E +#define C_LOG2E 1.44269504088896340736 // log2(e) +#endif + namespace ck { namespace tile_program { namespace block { @@ -32,15 +37,17 @@ struct BlockFmhaPipelineQRKSVSAsync using VDataType = remove_cvref_t; using SaccDataType = remove_cvref_t; using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using BlockFmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; - static constexpr bool kQLoadOnce = true; // if q load whole block length (hdim) at once + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once - static constexpr index_t kBlockPerCu = BlockFmhaShape::kBlockPerCu; + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kM0 = BlockFmhaShape::kM0; @@ -50,6 +57,11 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; + static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; + static constexpr bool kHasBias = Problem::kHasBias; + __host__ __device__ static constexpr ck::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -58,9 +70,12 @@ struct BlockFmhaPipelineQRKSVSAsync template + typename VElementFunction, + typename BiasElementFunction, + typename CausalMask> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -68,6 +83,9 @@ struct BlockFmhaPipelineQRKSVSAsync const KElementFunction& /*k_element_func*/, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + CausalMask causal_mask, float scale, index_t num_total_loop, index_t /*num_sub_loop_qk*/, // in this pipeline, the 1st gemm loop must be static @@ -83,7 +101,9 @@ struct BlockFmhaPipelineQRKSVSAsync kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && - kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], "wrong!"); constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); @@ -180,19 +200,35 @@ struct BlockFmhaPipelineQRKSVSAsync k_dram_block_window.GetWindowOrigin(), Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load + + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.GetBottomTensorView(), + bias_dram_block_window_tmp.GetWindowLengths(), + bias_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeBiasDramTileDistribution()); + // prefetch K tile async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.GetWindowOrigin(); buffer_load_fence(k_dram_window.GetNumAccess()); auto q_tile = tile_elementwise_in(q_element_func, q); __builtin_amdgcn_sched_barrier(0); + index_t i_total_loops = 0; constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k1_loops = kN0 / kK1; do { + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + if(causal_mask.IsTileSkippable( + q_origin.At(Number<0>{}), k_origin.At(Number<0>{}), kM0, kN0)) + { + continue; + } + // STAGE 1, QK gemm clear_tile(s_acc); // Initialize C if constexpr(k0_loops > 1) @@ -229,7 +265,8 @@ struct BlockFmhaPipelineQRKSVSAsync async_load_fence(); __builtin_amdgcn_s_barrier(); - auto v_buf = load_tile(v_dram_window); + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window); __builtin_amdgcn_sched_barrier(0); { // tail gemm_0(s_acc, @@ -248,10 +285,39 @@ struct BlockFmhaPipelineQRKSVSAsync } __builtin_amdgcn_sched_barrier(1); - // STAGE 2, scale softmax + // STAGE 2, scale, add bias, mask, softmax + if constexpr(is_null_tile_window(bias_dram_window)) + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + else + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { #if !CK_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); + x = scale * x + type_convert(bias_element_func(y)); +#else + x = scale * x + + C_LOG2E * type_convert(bias_element_func(y)); #endif + }, + s_acc, + bias_tile); + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kN0K1NeedPadding || + !is_same_v) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + + return causal_mask.IsMaskedElement(row, col); + }); + } const auto s = cast_tile(s_acc); // S{j} auto m_local = block_tile_reduce( @@ -309,7 +375,14 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_FMHA_FWD_FAST_EXP2 - p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + if constexpr(is_null_tile_window(bias_dram_window)) + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } + else + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - m[i_idx]); + } #else p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); #endif @@ -325,8 +398,17 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 - auto row_max = scale * m[i_idx]; - const auto tmp = math::exp2(scale * m_old[i_idx] - row_max); + const auto tmp = [&]() { + if constexpr(is_null_tile_window(bias_dram_window)) + { + auto row_max = scale * m[i_idx]; + return math::exp2(scale * m_old[i_idx] - row_max); + } + else + { + return math::exp2(m_old[i_idx] - m[i_idx]); + } + }(); #else const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); #endif @@ -428,27 +510,33 @@ struct BlockFmhaPipelineQRKSVSAsync template + typename VDramBlockWindowTmp, + typename BiasDramBlockWindowTmp, + typename CausalMask> __host__ __device__ auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + CausalMask causal_mask, float scale, index_t num_total_loop, index_t num_sub_loop_qk, void* smem_ptr) const { - return operator()( - q_dram_block_window_tmp, - [](const QDataType& x) { return x; }, - k_dram_block_window_tmp, - [](const KDataType& x) { return x; }, - v_dram_block_window_tmp, - [](const VDataType& x) { return x; }, - scale, - num_total_loop, - num_sub_loop_qk, - smem_ptr); + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + causal_mask, + scale, + num_total_loop, + num_sub_loop_qk, + smem_ptr); } }; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp index 711280588..b00028f8e 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp @@ -596,6 +596,38 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy } } + template + __host__ __device__ static constexpr auto MakeBiasDramTileDistribution() + { + constexpr index_t MPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t NPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template At<1>(); + constexpr index_t NWarp = config.template At<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<1, 2>, + Sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + return c_block_dstr; + } + template __host__ __device__ static constexpr auto MakeShuffledVRegBlockDescriptor() { diff --git a/include/ck/tile_program/tile/load_tile.hpp b/include/ck/tile_program/tile/load_tile.hpp index 855800ae1..1aae8ad3a 100644 --- a/include/ck/tile_program/tile/load_tile.hpp +++ b/include/ck/tile_program/tile/load_tile.hpp @@ -11,6 +11,7 @@ #include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/tile_window.hpp" +#include "ck/tile_program/tile/null_tensor.hpp" #include "ck/tile_program/tile/static_distributed_tensor.hpp" namespace ck { @@ -60,5 +61,11 @@ __device__ auto async_load_fence(index_t cnt = 0) asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } +template +__device__ auto load_tile(const NullTileWindow&) +{ + return NullTensor{}; +} + } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/tile/null_tensor.hpp b/include/ck/tile_program/tile/null_tensor.hpp new file mode 100644 index 000000000..50f1efa17 --- /dev/null +++ b/include/ck/tile_program/tile/null_tensor.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tile_program { + +struct NullTensor +{ +}; + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/null_tile_window.hpp b/include/ck/tile_program/tile/null_tile_window.hpp new file mode 100644 index 000000000..d959d7154 --- /dev/null +++ b/include/ck/tile_program/tile/null_tile_window.hpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor/tensor_view.hpp" +#include "ck/utility/common_header.hpp" +#include + +namespace ck { +namespace tile_program { + +// placeholder type if we want to opt-out a tile window parameter +template +struct NullTileWindow +{ + using BottomTensorView = NullTensorView; + using WindowLengths = remove_cvref_t; + + using BottomTensorIndex = Array; + + __device__ constexpr NullTileWindow() = default; + + __device__ constexpr NullTileWindow(const WindowLengths& window_lengths) + : window_lengths_{window_lengths} + { + } + + __device__ constexpr auto GetWindowLengths() const { return window_lengths_; } + + __device__ constexpr auto GetBottomTensorView() const { return NullTensorView{}; } + + __device__ constexpr auto GetWindowOrigin() const { return BottomTensorIndex{}; } + + WindowLengths window_lengths_; +}; + +// utility to check if this is a Null Tile Window +namespace impl { +template +struct IsNullTileWindow : public std::false_type +{ +}; + +template +struct IsNullTileWindow> : public std::true_type +{ +}; +} // namespace impl + +template +__device__ constexpr auto is_null_tile_window(const T&) +{ + return impl::IsNullTileWindow>::value; +} + +template +__device__ constexpr auto make_null_tile_window(const WindowLengths& window_lengths) +{ + static_assert(is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return NullTileWindow>{window_lengths}; +} + +template +__device__ constexpr auto +make_tile_window(NullTensorView, const WindowLengths& window_lengths, Ts&&...) +{ + static_assert(is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return NullTileWindow>{window_lengths}; +} + +template +__device__ void move_tile_window(NullTileWindow&, + const typename NullTileWindow::BottomTensorIndex&) +{ +} + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/static_distributed_tensor.hpp b/include/ck/tile_program/tile/static_distributed_tensor.hpp index 19f420aa4..3e675cb5f 100644 --- a/include/ck/tile_program/tile/static_distributed_tensor.hpp +++ b/include/ck/tile_program/tile/static_distributed_tensor.hpp @@ -176,5 +176,34 @@ __host__ __device__ constexpr auto make_static_distributed_tensor(const StaticTi remove_cvref_t>{}; } +template +__host__ __device__ void +set_tile_if(StaticDistributedTensor& out_tensor, + DataType value, + XIndicesPredicate predicate) +{ + + StaticTileDistribution tile_distribution; + const auto partition_index = detail::get_partition_index(tile_distribution); + + constexpr auto out_spans = + StaticDistributedTensor::GetDistributedSpans(); + sweep_tile_span(out_spans[Number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto y_idx = tile_distribution.GetYIndicesFromDistributedIndices(i_j_idx); + + const auto coord = make_tensor_adaptor_coordinate( + tile_distribution.GetPsYs2XsAdaptor(), + container_concat(partition_index, to_array(y_idx))); + + if(predicate(coord.GetBottomIndex())) + { + out_tensor(i_j_idx) = value; + } + }); + }); +} + } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/tile/tile_elementwise.hpp b/include/ck/tile_program/tile/tile_elementwise.hpp index a5d27ba7c..69357a382 100644 --- a/include/ck/tile_program/tile/tile_elementwise.hpp +++ b/include/ck/tile_program/tile/tile_elementwise.hpp @@ -9,13 +9,17 @@ #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/null_tensor.hpp" #include "ck/tile_program/tile/static_distributed_tensor.hpp" namespace ck { namespace tile_program { // TODO: support tensors with different distribution -template +template , NullTensor>>...>>> __device__ void tile_elementwise_inout(const InOutElementFunc& inout_element_func, InOutDstrTensors&... inout_dstr_tensors) { @@ -29,7 +33,10 @@ __device__ void tile_elementwise_inout(const InOutElementFunc& inout_element_fun [&](auto i) { inout_element_func(inout_dstr_tensors.GetThreadBuffer().At(i)...); }); } -template +template >...>>> __device__ auto tile_elementwise_in(const InElementFunc& in_element_func, const InDstrTensors&... in_dstr_tensors) { @@ -75,5 +82,24 @@ __device__ auto cast_tile(const SrcDstrTensors& src_tensor) src_tensor); } +// no-op function for NullTensor arguments +template , NullTensor>...>>> +__device__ void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...) +{ +} + +// no-op function for NullTensor arguments +template , NullTensor>...>>> +__device__ auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...) +{ + return NullTensor{}; +} + } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/tile/tile_fmha_shape.hpp b/include/ck/tile_program/tile/tile_fmha_shape.hpp index adbc96d1c..c5ab151e6 100644 --- a/include/ck/tile_program/tile/tile_fmha_shape.hpp +++ b/include/ck/tile_program/tile/tile_fmha_shape.hpp @@ -16,8 +16,7 @@ template + typename VLayout_ = ck::tensor_layout::gemm::RowMajor> struct TileFmhaShape { using BlockTile = remove_cvref_t; @@ -41,7 +40,6 @@ struct TileFmhaShape BlockTile::At(Number<5>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) - static constexpr index_t kBlockPerCu = kBlockPerCu_; using VLayout = remove_cvref_t; // rowmajor : seqlen*hdim, colmajor : hdim*seqlen }; diff --git a/include/ck/tile_program/tile/tile_fmha_traits.hpp b/include/ck/tile_program/tile/tile_fmha_traits.hpp new file mode 100644 index 000000000..ab52929ba --- /dev/null +++ b/include/ck/tile_program/tile/tile_fmha_traits.hpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { +namespace tile_program { + +template +struct TileFmhaTraits +{ + static constexpr bool kM0NeedPadding = kM0NeedPadding_; + static constexpr bool kN0K1NeedPadding = kN0K1NeedPadding_; + static constexpr bool kHasBias = kHasBias_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/tile/tile_window.hpp b/include/ck/tile_program/tile/tile_window.hpp index a07b52800..d3df06d24 100644 --- a/include/ck/tile_program/tile/tile_window.hpp +++ b/include/ck/tile_program/tile/tile_window.hpp @@ -7,6 +7,7 @@ #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor_coordinate.hpp" +#include "ck/tile_program/tile/null_tile_window.hpp" #include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/tile_window_impl_static_distribution.hpp" #include "ck/tile_program/tile/tile_window_impl_static_lengths.hpp" diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 58ebac9d8..31e3f5140 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" #include "ck/utility/static_assert.hpp" +#include "ck/utility/static_switch.hpp" #include "ck/utility/remove_cvref.hpp" #include "ck/utility/is_static.hpp" #include "ck/utility/bit_cast.hpp" diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index 91797d240..e1bab6f59 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -3,6 +3,8 @@ #pragma once +#include + #include "ck/utility/integral_constant.hpp" #include "ck/utility/type.hpp" @@ -128,4 +130,13 @@ constexpr auto conditional_expr(X&& x, Y&& y) } } +struct identity +{ + template + __host__ __device__ constexpr T&& operator()(T&& arg) const noexcept + { + return std::forward(arg); + } +}; + } // namespace ck diff --git a/include/ck/utility/static_switch.hpp b/include/ck/utility/static_switch.hpp new file mode 100644 index 000000000..293487f72 --- /dev/null +++ b/include/ck/utility/static_switch.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#define BOOL_SWITCH_1(COND1, CONST_NAME1, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + __VA_ARGS__(); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + __VA_ARGS__(); \ + } \ + }() + +#define BOOL_SWITCH_2(COND1, CONST_NAME1, COND2, CONST_NAME2, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH_1(COND2, CONST_NAME2, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH_1(COND2, CONST_NAME2, ##__VA_ARGS__); \ + } \ + }() + +#define BOOL_SWITCH_3(COND1, CONST_NAME1, COND2, CONST_NAME2, COND3, CONST_NAME3, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \ + } \ + }() + +#define BOOL_SWITCH_4( \ + COND1, CONST_NAME1, COND2, CONST_NAME2, COND3, CONST_NAME3, COND4, CONST_NAME4, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH_3( \ + COND2, CONST_NAME2, COND3, CONST_NAME3, COND4, CONST_NAME4, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH_3( \ + COND2, CONST_NAME2, COND3, CONST_NAME3, COND4, CONST_NAME4, ##__VA_ARGS__); \ + } \ + }() diff --git a/library/src/utility/device_memory.cpp b/library/src/utility/device_memory.cpp index 61b6326b5..c95619ecd 100644 --- a/library/src/utility/device_memory.cpp +++ b/library/src/utility/device_memory.cpp @@ -1,23 +1,25 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include + #include "ck/host_utility/hip_check_error.hpp" #include "ck/library/utility/device_memory.hpp" DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) { - hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); } void DeviceMem::Realloc(std::size_t mem_size) { if(mpDeviceBuf) { - hip_check_error(hipFree(mpDeviceBuf)); + HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); } mMemSize = mem_size; - hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); } void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; } @@ -28,7 +30,7 @@ void DeviceMem::ToDevice(const void* p) const { if(mpDeviceBuf) { - hip_check_error( + HIP_CHECK_ERROR( hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); } else @@ -39,14 +41,14 @@ void DeviceMem::ToDevice(const void* p) const void DeviceMem::ToDevice(const void* p, const std::size_t cpySize) const { - hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); } void DeviceMem::FromDevice(void* p) const { if(mpDeviceBuf) { - hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); } else { @@ -56,14 +58,14 @@ void DeviceMem::FromDevice(void* p) const void DeviceMem::FromDevice(void* p, const std::size_t cpySize) const { - hip_check_error(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); } void DeviceMem::SetZero() const { if(mpDeviceBuf) { - hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); + HIP_CHECK_ERROR(hipMemset(mpDeviceBuf, 0, mMemSize)); } } @@ -71,6 +73,13 @@ DeviceMem::~DeviceMem() { if(mpDeviceBuf) { - hip_check_error(hipFree(mpDeviceBuf)); + try + { + HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); + } + catch(std::runtime_error& re) + { + std::cerr << re.what() << std::endl; + } } } From 5999fd9b124822cf86f7b31c6bc8b597613067b6 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 10 Dec 2023 01:41:20 -0600 Subject: [PATCH 10/45] merge main --- .github/dependabot.yml | 6 + .gitignore | 1 - .readthedocs.yaml | 10 +- CHANGELOG.md | 1 + CITATION.cff | 4 +- CMakeLists.txt | 3 +- Dockerfile | 4 +- Jenkinsfile | 64 ++- README.md | 2 +- ...rouped_conv_fwd_scaleadd_scaleadd_relu.inc | 16 +- .../25_tensor_transforms/CMakeLists.txt | 4 + .../tensor_transform.cpp | 0 .../tensor_transform_using_wrapper.cpp | 31 +- dev-requirements.txt | 4 +- docs/conf.py | 27 +- docs/doxygen/Doxyfile | 6 +- docs/index.rst | 2 + docs/sphinx/_toc.yml.in | 6 +- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- docs/wrapper.rst | 54 ++ example/62_conv_fwd_activ/CMakeLists.txt | 2 + ...aleadd_scaleadd_relu_bcasted_bias_fp16.cpp | 294 +++++++++++ .../run_convnd_fwd_activ_example.inc | 2 +- example/64_tensor_transforms/CMakeLists.txt | 2 - .../batched_gemm_softmax_gemm.cpp | 11 + example/91_tile_program/fmha_fwd.cpp | 11 + example/91_tile_program/gemm.cpp | 7 + example/91_tile_program/gemm_gemm.cpp | 10 + .../reference_batched_softmax.hpp | 15 +- include/ck/ck.hpp | 3 + include/ck/config.h.default | 109 ++++ include/ck/host_utility/device_prop.hpp | 2 +- ...vice_gemm_xdl_cshuffle_lds_direct_load.hpp | 4 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 32 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 3 +- .../gpu/device/tensor_layout.hpp | 6 - ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 41 +- .../gridwise_gemm_pipeline_v4_direct_load.hpp | 147 +++++- .../threadwise_tensor_slice_transfer_v3r1.hpp | 41 +- .../transform_conv_fwd_to_gemm.hpp | 15 +- .../block_gemm_areg_bsmem_creg_problem.hpp | 1 - ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 43 +- ...emm_asmem_bsmem_creg_v1_default_policy.hpp | 41 +- ...k_fmha_pipeline_qr_ks_vs_custom_policy.hpp | 22 +- ...ffle_distributed_tensor_impl_in_thread.hpp | 9 - .../ck/tile_program/warp_tile/warp_gemm.hpp | 40 ++ .../warp_gemm_attribute_mfma_impl.hpp | 78 ++- .../warp_tile/warp_gemm_dispatcher.hpp | 11 + include/ck/utility/amd_buffer_addressing.hpp | 10 + include/ck/utility/transpose_vectors.hpp | 131 ++++- include/ck/utility/tuple_helper.hpp | 12 + .../ck/wrapper/layout.hpp | 181 ++----- include/ck/wrapper/layout_utils.hpp | 321 ++++++++++++ .../device_operation_instance_factory.hpp | 6 +- ...rouped_convolution_forward_scaleadd_ab.hpp | 43 +- ...olution_forward_scaleadd_scaleadd_relu.hpp | 12 +- ...c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp | 1 + ...c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp | 1 + ...ect_load_f16_f16_f16_mk_nk_mn_instance.cpp | 16 +- ...ect_load_f32_f32_f32_km_kn_mn_instance.cpp | 3 +- ...ect_load_f32_f32_f32_km_nk_mn_instance.cpp | 3 +- ...ect_load_f32_f32_f32_mk_kn_mn_instance.cpp | 3 +- ...ect_load_f32_f32_f32_mk_nk_mn_instance.cpp | 5 +- ..._ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 73 +-- ...elu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 8 +- ...relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 8 +- ...relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 8 +- ...elu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp | 8 +- .../conv2d_fwd/conv2d_quantization_common.hpp | 6 +- profiler/src/profile_transpose.cpp | 85 ---- test/CMakeLists.txt | 1 + test/wrapper/CMakeLists.txt | 2 + test/wrapper/test_layout.cpp | 481 ++++++++++++++++++ 74 files changed, 2175 insertions(+), 504 deletions(-) create mode 100644 client_example/25_tensor_transforms/CMakeLists.txt rename {example/64_tensor_transforms => client_example/25_tensor_transforms}/tensor_transform.cpp (100%) rename {example/64_tensor_transforms => client_example/25_tensor_transforms}/tensor_transform_using_wrapper.cpp (74%) create mode 100644 docs/wrapper.rst create mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp delete mode 100644 example/64_tensor_transforms/CMakeLists.txt create mode 100644 include/ck/config.h.default rename example/64_tensor_transforms/tensor_transform_wrapper.hpp => include/ck/wrapper/layout.hpp (68%) create mode 100644 include/ck/wrapper/layout_utils.hpp delete mode 100644 profiler/src/profile_transpose.cpp create mode 100644 test/wrapper/CMakeLists.txt create mode 100644 test/wrapper/test_layout.cpp diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 276690bd4..0e0a252eb 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -10,3 +10,9 @@ updates: open-pull-requests-limit: 10 schedule: interval: "daily" + labels: + - "documentation" + - "dependencies" + - "ci:docs-only" + reviewers: + - "samjwu" diff --git a/.gitignore b/.gitignore index 7af066c82..340f11cbd 100644 --- a/.gitignore +++ b/.gitignore @@ -54,5 +54,4 @@ _images/ _static/ _templates/ _toc.yml -docBin/ _doxygen/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 5f50df252..9e6678abe 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,11 +3,6 @@ version: 2 -build: - os: ubuntu-22.04 - tools: - python: "3.8" - sphinx: configuration: docs/conf.py @@ -16,3 +11,8 @@ formats: [htmlzip, pdf, epub] python: install: - requirements: docs/sphinx/requirements.txt + +build: + os: ubuntu-22.04 + tools: + python: "3.8" diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e46a4ab4..3da22fc79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ None - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) - Support for Batched Gemm DL (#732) +- Introduce wrapper sublibrary (limited functionality) (#1071) ### Changes - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) diff --git a/CITATION.cff b/CITATION.cff index d35fe9e58..3813d6381 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -59,9 +59,9 @@ authors: family-names: Zhou - given-names: Jianfeng family-names: Yan -repository-code: 'https://github.com/ROCmSoftwarePlatform/composable_kernel' +repository-code: 'https://github.com/ROCm/composable_kernel' abstract: Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for Machine Learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel progarmming languages, like HIP C++. keywords: - 'CK, Composable Kernel, Tensor Coordinate Transformation' license: MIT -license-url: https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE +license-url: https://github.com/ROCm/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE diff --git a/CMakeLists.txt b/CMakeLists.txt index 68bbf88c6..e780c1565 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -373,7 +373,8 @@ include_directories(BEFORE SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") if(BUILD_DEV) - add_compile_options(-Werror -Weverything) + add_compile_options(-Werror) + add_compile_options(-Weverything) endif() message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") diff --git a/Dockerfile b/Dockerfile index 7134e206c..87b4eb8e2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -111,7 +111,7 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ +RUN if [ "$compiler_version" != "" ] && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ @@ -119,7 +119,7 @@ RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; else echo "using the release compiler"; \ fi -RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ +RUN if [ "$compiler_version" != "" ] && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ diff --git a/Jenkinsfile b/Jenkinsfile index 91499e7eb..8f661e478 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -84,7 +84,7 @@ def build_compiler(){ compiler = '/opt/rocm/bin/hipcc' } else{ - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ compiler = "/llvm-project/build/bin/clang++" } else{ @@ -293,7 +293,7 @@ def buildHipClangJob(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -302,7 +302,7 @@ def buildHipClangJob(Map conf=[:]){ def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 5, unit: 'HOURS') { @@ -348,14 +348,14 @@ def runCKProfiler(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -479,7 +479,7 @@ def Build_CK(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -487,7 +487,7 @@ def Build_CK(Map conf=[:]){ def retimage def navi_node = 0 - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -553,7 +553,7 @@ def Build_CK(Map conf=[:]){ sh """#!/bin/bash rm -rf "${params.hipTensor_branch}".zip rm -rf hipTensor-"${params.hipTensor_branch}" - wget https://github.com/ROCmSoftwarePlatform/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip + wget https://github.com/ROCm/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip unzip -o "${params.hipTensor_branch}".zip """ dir("hipTensor-${params.hipTensor_branch}"){ @@ -605,7 +605,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) } @@ -657,7 +657,8 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION= 0 21 * * * % ROCMVERSION=5.7;COMPILER_VERSION=;COMPILER_COMMIT= - 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" + 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=;USE_SCCACHE=false + 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" pipeline { agent none @@ -679,15 +680,15 @@ pipeline { string( name: 'COMPILER_VERSION', defaultValue: '', - description: 'Specify which version of compiler to use: release, amd-stg-open, or leave blank (default).') + description: 'Specify which version of compiler to use: release, amd-stg-open, amd-mainline-open, or leave blank (default).') string( name: 'COMPILER_COMMIT', defaultValue: '', - description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.') + description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit (default), or use some specific commit of llvm-project branch.') string( name: 'BUILD_COMPILER', - defaultValue: 'hipcc', - description: 'Specify whether to build CK with hipcc (default) or with clang.') + defaultValue: 'clang', + description: 'Specify whether to build CK with hipcc or with clang (default).') booleanParam( name: "RUN_FULL_QA", defaultValue: false, @@ -767,8 +768,15 @@ pipeline { } agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ + -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ + -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \ + -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -783,8 +791,12 @@ pipeline { } agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx908;gfx90a" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -799,8 +811,12 @@ pipeline { } agent{ label rocmnode("navi21") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx1030" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -815,8 +831,12 @@ pipeline { } agent{ label rocmnode("navi32") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx1101" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') diff --git a/README.md b/README.md index e5a20f143..7679607e6 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa 3. Clone CK source code from the GitHub repository and start the build: ```bash - git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git && \ + git clone https://github.com/ROCm/composable_kernel.git && \ cd composable_kernel && \ mkdir build && \ cd build diff --git a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc b/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc index c72c72971..e8f552952 100644 --- a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc +++ b/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc @@ -16,6 +16,7 @@ using InLayout = ck::tensor_layout::convolution::NDHWGC; using WeiLayout = ck::tensor_layout::convolution::GKZYXC; using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using BiasLayout = ck::tensor_layout::convolution::G_K; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; @@ -64,6 +65,9 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() std::array out_lengths{G, N, K, Do, Ho, Wo}; std::array out_strides{ K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; std::array filter_strides{1, 1, 1}; std::array filter_dilations{1, 1, 1}; @@ -74,13 +78,13 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); SimpleDeviceMem d0(sizeof(std::tuple_element_t<0, DDataTypes>) * N * Do * Ho * Wo * G * K); - SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * N * Do * Ho * Wo * G * K); + SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< NumDimSpatial, InLayout, WeiLayout, - ck::Tuple, + ck::Tuple, OutLayout, InDataType, WeiDataType, @@ -117,8 +121,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() in_strides, wei_lengths, wei_strides, - {out_lengths, out_lengths}, - {out_strides, out_strides}, + {out_lengths, bias_lengths}, + {out_strides, bias_strides}, out_lengths, out_strides, filter_strides, @@ -187,8 +191,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() in_strides, wei_lengths, wei_strides, - {out_lengths, out_lengths}, - {out_strides, out_strides}, + {out_lengths, bias_lengths}, + {out_strides, bias_strides}, out_lengths, out_strides, filter_strides, diff --git a/client_example/25_tensor_transforms/CMakeLists.txt b/client_example/25_tensor_transforms/CMakeLists.txt new file mode 100644 index 000000000..d1543fb0e --- /dev/null +++ b/client_example/25_tensor_transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(client_tensor_transform tensor_transform.cpp) +target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations) +add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) +target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) diff --git a/example/64_tensor_transforms/tensor_transform.cpp b/client_example/25_tensor_transforms/tensor_transform.cpp similarity index 100% rename from example/64_tensor_transforms/tensor_transform.cpp rename to client_example/25_tensor_transforms/tensor_transform.cpp diff --git a/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp b/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp similarity index 74% rename from example/64_tensor_transforms/tensor_transform_using_wrapper.cpp rename to client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp index df2449e99..de9fcde0b 100644 --- a/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp +++ b/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp @@ -9,7 +9,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/sequence.hpp" -#include "tensor_transform_wrapper.hpp" +#include "ck/wrapper/layout.hpp" using DataType = int; @@ -17,7 +17,7 @@ template void Print1d(const Layout& layout) { std::cout << "Print1d" << std::endl; - for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size(layout); w++) { std::cout << layout(ck::make_tuple(w)) << " "; } @@ -28,9 +28,9 @@ template void Print2d(const Layout& layout) { std::cout << "Print2d" << std::endl; - for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++) + for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++) { - for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) { std::cout << layout(ck::make_tuple(h, w)) << " "; } @@ -43,15 +43,11 @@ template void Print3dCustom(const Layout& layout) { std::cout << "Print3dCustom" << std::endl; - for(ck::index_t d = 0; - d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout)); - d++) + for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++) { - for(ck::index_t h = 0; - h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout)); - h++) + for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++) { - for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) { std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " "; } @@ -68,7 +64,7 @@ int main() // Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor) // (dims:4,8 strides:1,4) const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}); - const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8); + const auto layout_4x8_s1x4 = ck::wrapper::make_layout(shape_4x8); std::cout << "dims:4,8 strides:1,4" << std::endl; Print2d(layout_4x8_s1x4); using Cord1x1Type = ck::Tuple, ck::Number<1>>; @@ -77,10 +73,9 @@ int main() // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor) // dims:4,(2,4) strides:2,(1,8) - const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); - const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); - const auto layout_4x2x4_s2x1x8 = - ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout_4x2x4_s2x1x8 = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; Print2d(layout_4x2x4_s2x1x8); @@ -92,7 +87,7 @@ int main() const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::make_tuple(ck::Number<2>{}, ck::Number<8>{})); static const auto layout_2x2x2x4_s1x4x2x8 = - ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); + ck::wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; Print2d(layout_2x2x2x4_s1x4x2x8); @@ -108,7 +103,7 @@ int main() ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}), ck::Number<8>{}); static const auto layout_2x2x2x4_s1x4x2x8_nested = - ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); + ck::wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; Print1d(layout_2x2x2x4_s1x4x2x8_nested); diff --git a/dev-requirements.txt b/dev-requirements.txt index 9e7b9f01e..d5d91f8c2 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ -ROCmSoftwarePlatform/rocm-recipes +ROCm/rocm-recipes RadeonOpenCompute/rocm-cmake@04f694df2a8dc9d7e35fa4dee4ba5fa407ec04f8 --build -danmar/cppcheck@2.9 \ No newline at end of file +danmar/cppcheck@2.9 diff --git a/docs/conf.py b/docs/conf.py index 0de590da1..e441ff1ce 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,23 +4,34 @@ # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -import subprocess +import re from rocm_docs import ROCmDocs +html_theme_options = {"flavor": "list"} -name = "Composable Kernel" -get_version = r'sed -n -e "s/^rocm_setup_version(.* \([0-9\.]\{1,\}\).*/\1/p" ../CMakeLists.txt' -version = subprocess.getoutput(get_version) -if len(version) > 0: - name = f"{name} {version}" +with open('../CMakeLists.txt', encoding='utf-8') as f: + match = re.search(r'.*set\(version ([0-9.]+)[^0-9.]+', f.read()) + if not match: + raise ValueError("VERSION not found!") + version_number = match[1] +left_nav_title = f"Composable Kernel {version_number} Documentation" + +# for PDF output on Read the Docs +project = "Composable Kernel Documentation" +author = "Advanced Micro Devices, Inc." +copyright = "Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved." +version = version_number +release = version_number external_toc_path = "./sphinx/_toc.yml" -docs_core = ROCmDocs(f"{name} Documentation") -docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/docBin/xml") +docs_core = ROCmDocs(left_nav_title) +docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") docs_core.setup() +external_projects_current_project = "composable_kernel" + mathjax3_config = { 'tex': { 'macros': { diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index 1084f94c8..fac9e138e 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -58,7 +58,7 @@ PROJECT_LOGO = # entered, it will be relative to the location where doxygen was started. If # left blank the current directory will be used. -OUTPUT_DIRECTORY = docBin +OUTPUT_DIRECTORY = . # If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- # directories (in 2 levels) under the output directory of each output format and @@ -778,7 +778,9 @@ WARN_LOGFILE = INPUT = ../../include/ck/tensor_operation/gpu/grid \ ../../include/ck/tensor_operation/gpu/block \ ../../include/ck/tensor_operation/gpu/thread \ - ../../library/include/ck/library/utility + ../../library/include/ck/library/utility \ + ../../include/ck/wrapper + # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/docs/index.rst b/docs/index.rst index 51c0c862a..8c4aaa2b3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -34,6 +34,7 @@ Current CK library are structured into 4 layers: * "Templated Tile Operators" layer * "Templated Kernel and Invoker" layer * "Instantiated Kernel and Invoker" layer +* "Wrapper for tensor transform operations" * "Client API" layer .. image:: data/ck_layer.png @@ -50,6 +51,7 @@ The following is a list of CK documents in the suggested reading order: tutorial_hello_world dockerhub + wrapper Supported_Primitives_Guide API_Reference_Guide Contributors_Guide diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 83dd1e7b1..c37ba29ce 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -5,6 +5,6 @@ defaults: maxdepth: 6 root: index subtrees: - - caption: About - entries: - - file: license +- caption: About + entries: + - file: license diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index c4ce8be79..0a65ffc81 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core>=0.20.0 +rocm-docs-core==0.30.1 sphinxcontrib-bibtex==2.6.1 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 585231595..01cb32e71 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.28.2 # via # pygithub # sphinx -rocm-docs-core==0.27.0 +rocm-docs-core==0.30.1 # via -r requirements.in six==1.16.0 # via diff --git a/docs/wrapper.rst b/docs/wrapper.rst new file mode 100644 index 000000000..64fb6a403 --- /dev/null +++ b/docs/wrapper.rst @@ -0,0 +1,54 @@ +=============== +Wrapper +=============== + +------------------------------------- +Description +------------------------------------- + +.. note:: + + The wrapper is under development and its functionality is limited. + + +CK provides a lightweight wrapper for more complex operations implemented in +the library. It allows indexing of nested layouts using a simple interface +(avoiding complex descriptor transformations). + +Example: + +.. code-block:: c + + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; + for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(h, w)) << " "; + } + std::cout << std::endl; + } + +Output:: + + dims:4,(2,4) strides:2,(1,8) + 0 1 8 9 16 17 24 25 + 2 3 10 11 18 19 26 27 + 4 5 12 13 20 21 28 29 + 6 7 14 15 22 23 30 31 + +------------------------------------- +Layout +------------------------------------- + +.. doxygenstruct:: ck::wrapper::Layout + +------------------------------------- +Layout helpers +------------------------------------- + +.. doxygenfile:: layout_utils.hpp diff --git a/example/62_conv_fwd_activ/CMakeLists.txt b/example/62_conv_fwd_activ/CMakeLists.txt index bb9560241..d1f26bbfe 100644 --- a/example/62_conv_fwd_activ/CMakeLists.txt +++ b/example/62_conv_fwd_activ/CMakeLists.txt @@ -42,6 +42,8 @@ foreach(gpu IN LISTS GPU_TARGETS) # ScaleAdd ScaleAdd Relu add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) + add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) set(target 1) endif() endforeach() diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp new file mode 100644 index 000000000..196636f8b --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using BiasLayout = ck::tensor_layout::convolution::G_K; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +using OutElementOp = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 2; + const ck::index_t G = out_g_n_k_wos_desc.GetLengths()[0]; + const ck::index_t K = out_g_n_k_wos_desc.GetLengths()[2]; + + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_g_k_lengths; + std::array bias_g_k_strides; + // Fill other lenghts than G,K with 1 and strides with 0 + bias_g_k_lengths.fill(1); + bias_g_k_strides.fill(0); + bias_g_k_lengths[0] = G; + bias_g_k_lengths[2] = K; + bias_g_k_strides[0] = K; // stride to G + bias_g_k_strides[2] = 1; // stride to K + const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides); + + // y = relu ( alpha1 * conv(x) + alpha2 * z + bias ) + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + std::array, NumDs> d_tensors = {Tensor(out_g_n_k_wos_desc), + Tensor(broadcasted_bias_desc)}; + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + std::cout << "z_tensor: " << d_tensors[0].mDesc << std::endl; + std::cout << "bias_tensor: " << d_tensors[1].mDesc << std::endl; + + // Make sure that we allocated only G * K values for bias + assert(static_cast(d_tensors[1].mData.size()) == G * K); + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem z_buf(sizeof(OutDataType) * d_tensors[0].mDesc.GetElementSpaceSize()); + DeviceMem bias_buf(sizeof(OutDataType) * d_tensors[1].mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + z_buf.ToDevice(d_tensors[0].mData.data()); + bias_buf.ToDevice(d_tensors[1].mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + const std::array ds = {z_buf.GetDeviceBuffer(), bias_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{ + e_g_n_k_wos_lengths, bias_g_k_lengths}, + std::array, NumDs>{ + e_g_n_k_wos_strides, bias_g_k_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops() + G * K + + conv_param.GetOutputByte() / sizeof(OutDataType); + std::size_t num_btype = conv_param.GetByte() + + G * K * sizeof(OutDataType) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc b/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc index 7c20c0106..aa547c870 100644 --- a/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc +++ b/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc @@ -24,7 +24,7 @@ bool run_convnd_fwd_example(int argc, char* argv[]) // Following shapes are selected to avoid overflow. Expect inf in case of // size increase for some elementwise ops. ck::utils::conv::ConvParam conv_param{ - 3, 1, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + 3, 2, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; if(argc == 1) { diff --git a/example/64_tensor_transforms/CMakeLists.txt b/example/64_tensor_transforms/CMakeLists.txt deleted file mode 100644 index 9d14a410e..000000000 --- a/example/64_tensor_transforms/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_example_executable(example_tensor_transform tensor_transform.cpp) -add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm.cpp b/example/91_tile_program/batched_gemm_softmax_gemm.cpp index 96221679c..f785ffcf9 100644 --- a/example/91_tile_program/batched_gemm_softmax_gemm.cpp +++ b/example/91_tile_program/batched_gemm_softmax_gemm.cpp @@ -19,6 +19,7 @@ int main(int argc, char* argv[]) { +#if 1 using QDataType = ck::half_t; using KDataType = ck::half_t; using VDataType = ck::half_t; @@ -27,6 +28,16 @@ int main(int argc, char* argv[]) using PDataType = ck::half_t; using OaccDataType = float; using ODataType = ck::half_t; +#else + using QDataType = ck::bhalf_t; + using KDataType = ck::bhalf_t; + using VDataType = ck::bhalf_t; + using SaccDataType = float; + using SMPLComputeDataType = float; + using PDataType = ck::bhalf_t; + using OaccDataType = float; + using ODataType = ck::bhalf_t; +#endif ck::index_t Batch = 16; ck::index_t M0 = 3328; diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index 1678a9905..4c1af6e79 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -41,6 +41,7 @@ #include "fmha_utils.hpp" #include "arg_parser.hpp" +#if 1 using QDataType = ck::half_t; using KDataType = ck::half_t; using VDataType = ck::half_t; @@ -50,6 +51,16 @@ using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck::half_t; // data type for A matrix of second gemm using OaccDataType = float; // data type for second gemm accumulation using ODataType = ck::half_t; +#else +using QDataType = ck::bhalf_t; +using KDataType = ck::bhalf_t; +using VDataType = ck::bhalf_t; +using SaccDataType = float; // data type for first gemm accumulation +using SMPLComputeDataType = float; // data type for reduction, softmax +using PDataType = ck::bhalf_t; // data type for A matrix of second gemm +using OaccDataType = float; // data type for second gemm accumulation +using ODataType = ck::bhalf_t; +#endif // M0 N0 K0 N1 K1 K0L // using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>; diff --git a/example/91_tile_program/gemm.cpp b/example/91_tile_program/gemm.cpp index b6809f214..67e8479ea 100644 --- a/example/91_tile_program/gemm.cpp +++ b/example/91_tile_program/gemm.cpp @@ -46,10 +46,17 @@ struct CElementFunction int main(int argc, char* argv[]) { +#if 1 using ADataType = ck::half_t; using BDataType = ck::half_t; using AccDataType = float; using CDataType = ck::half_t; +#else + using ADataType = ck::bhalf_t; + using BDataType = ck::bhalf_t; + using AccDataType = float; + using CDataType = ck::bhalf_t; +#endif using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; diff --git a/example/91_tile_program/gemm_gemm.cpp b/example/91_tile_program/gemm_gemm.cpp index 0e14576ac..ccbea2369 100644 --- a/example/91_tile_program/gemm_gemm.cpp +++ b/example/91_tile_program/gemm_gemm.cpp @@ -18,6 +18,7 @@ int main(int argc, char* argv[]) { +#if 1 using A0DataType = ck::half_t; using B0DataType = ck::half_t; using B1DataType = ck::half_t; @@ -25,6 +26,15 @@ int main(int argc, char* argv[]) using C0DataType = ck::half_t; using Acc1DataType = float; using C1DataType = ck::half_t; +#else + using A0DataType = ck::bhalf_t; + using B0DataType = ck::bhalf_t; + using B1DataType = ck::bhalf_t; + using Acc0DataType = float; + using C0DataType = ck::bhalf_t; + using Acc1DataType = float; + using C1DataType = ck::bhalf_t; +#endif ck::index_t M0 = 13312; ck::index_t N0 = 4096; diff --git a/example/91_tile_program/reference_batched_softmax.hpp b/example/91_tile_program/reference_batched_softmax.hpp index 344acfee7..a9fa3f103 100644 --- a/example/91_tile_program/reference_batched_softmax.hpp +++ b/example/91_tile_program/reference_batched_softmax.hpp @@ -6,28 +6,28 @@ #include "ck/utility/common_header.hpp" #include "ck/library/utility/host_tensor.hpp" -template +template void reference_batched_softmax(const Tensor& a_b_m_n, Tensor& b_b_m_n) { const int N = a_b_m_n.mDesc.GetLengths()[2]; auto f = [&](auto batch, auto m) { - AccDataType v_max = ck::NumericLimits::Lowest(); + CompDataType v_max = ck::NumericLimits::Lowest(); // max for(int n = 0; n < N; ++n) { - const ADataType v_a = a_b_m_n(batch, m, n); + const CompDataType v_a = ck::type_convert(a_b_m_n(batch, m, n)); v_max = v_max < v_a ? v_a : v_max; } - AccDataType v_exp_sum = 0; + CompDataType v_exp_sum = 0; // sum for(int n = 0; n < N; ++n) { - const ADataType v_a = a_b_m_n(batch, m, n); + const CompDataType v_a = ck::type_convert(a_b_m_n(batch, m, n)); v_exp_sum += ck::math::exp(v_a - v_max); } @@ -35,9 +35,10 @@ void reference_batched_softmax(const Tensor& a_b_m_n, Tensor(a_b_m_n(batch, m, n)); - b_b_m_n(batch, m, n) = ck::math::exp(v_a - v_max) / v_exp_sum; + b_b_m_n(batch, m, n) = + ck::type_convert(ck::math::exp(v_a - v_max) / v_exp_sum); } }; diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index f2811dfd5..c7e279923 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -134,6 +134,9 @@ // inner product using V_DOT with DPP8 modifiers #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 +// LDS direct loads using inline assembly +#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1 + // set stochastic rounding as default for f8 conversions #define CK_USE_SR_F8_CONVERSION 1 diff --git a/include/ck/config.h.default b/include/ck/config.h.default new file mode 100644 index 000000000..dbf4a9597 --- /dev/null +++ b/include/ck/config.h.default @@ -0,0 +1,109 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_CONFIG_H_IN +#define CK_CONFIG_H_IN + +// clang-format off +// +// DataType supports in the current CK build +// +#ifndef DTYPES +/* #undef DTYPES */ +#endif +// if DTYPES is not defined, enable all datatypes in headerfiles +#ifndef CK_ENABLE_ALL_DTYPES +#define CK_ENABLE_ALL_DTYPES ON +#if defined(CK_ENABLE_ALL_DTYPES) +#ifndef CK_ENABLE_INT8 +#define CK_ENABLE_INT8 "ON" +#endif +#ifndef CK_ENABLE_FP8 +#define CK_ENABLE_FP8 "ON" +#endif +#ifndef CK_ENABLE_BF8 +#define CK_ENABLE_BF8 "ON" +#endif +#ifndef CK_ENABLE_FP16 +#define CK_ENABLE_FP16 "ON" +#endif +#ifndef CK_ENABLE_BF16 +#define CK_ENABLE_BF16 "ON" +#endif +#ifndef CK_ENABLE_FP32 +#define CK_ENABLE_FP32 "ON" +#endif +#ifndef CK_ENABLE_FP64 +#define CK_ENABLE_FP64 "ON" +#endif +#endif +#endif +// if DTYPES are selectively enabled +#ifndef CK_ENABLE_INT8 +/* #undef CK_ENABLE_INT8 */ +#endif + +#ifndef CK_ENABLE_FP8 +/* #undef CK_ENABLE_FP8 */ +#endif + +#ifndef CK_ENABLE_BF8 +/* #undef CK_ENABLE_BF8 */ +#endif + +#ifndef CK_ENABLE_FP16 +/* #undef CK_ENABLE_FP16 */ +#endif + +#ifndef CK_ENABLE_BF16 +/* #undef CK_ENABLE_BF16 */ +#endif + +#ifndef CK_ENABLE_FP32 +/* #undef CK_ENABLE_FP32 */ +#endif + +#ifndef CK_ENABLE_FP64 +/* #undef CK_ENABLE_FP64 */ +#endif + +// +// Legacy DL kernel supports in the current CK build +// by default DL kernels are turned OFF +// +#ifndef CK_ENABLE_DL_KERNELS +/* #undef CK_ENABLE_DL_KERNELS */ +#endif + +// +// Instances supports in the current CK build +// +#ifndef CK_ENABLE_INSTANCES_ONLY +/* #undef CK_ENABLE_INSTANCES_ONLY */ +#endif + +// clang-format on + +#endif // CK_CONFIG_H_IN diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index be2c2395f..e8dabc997 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -26,7 +26,7 @@ inline std::string get_device_name() } const std::string raw_name(props.gcnArchName); - // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 static std::map device_name_map = { {"Ellesmere", "gfx803"}, {"Baffin", "gfx803"}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp index f8264cefd..ac2e82672 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp @@ -380,7 +380,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_lengths, const std::array, NumDTensor>& ds_g_n_k_wos_strides) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], + return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); }, Number{}); @@ -569,7 +571,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // D desc ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); }); compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; @@ -916,8 +918,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) + is_same_v || is_same_v) { const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; @@ -925,6 +926,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { valid = false; } + + if constexpr(is_same_v) + { + // G and K must be the same + if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] || + arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2]) + { + valid = false; + } + } + else + { + // E and D must have the same shape + for(index_t d = 0; d < NDimSpatial + 3; d++) + { + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + valid = false; + } + } + } } else { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 80a5d0e97..0050a5b28 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -631,8 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) + is_same_v || is_same_v) { const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index b2d141fd6..ecc71ba2f 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -308,12 +308,6 @@ struct GNDHWK : public BaseTensorLayout static constexpr const char* name = "GNDHWK"; }; -// for output bias -struct GK : public BaseTensorLayout -{ - static constexpr const char* name = "GK"; -}; - // output tensor // packed NWGK/NHWGK/NDHWGK struct NWGK : public BaseTensorLayout diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 472496e22..b6a17e53a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -236,9 +236,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + - b_block_space_size_aligned * sizeof(BComputeDataType), - c_block_size * sizeof(CShuffleDataType)); + return math::max( + NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) + + NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType), + c_block_size * sizeof(CShuffleDataType)); } __host__ __device__ static auto @@ -491,6 +492,22 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + template + __device__ static auto AllocateBlockBuffers(void* p_shared, + int32_t num_elems, + int32_t offset_elems, + int32_t max_lds_align) + { + const int32_t single_buffer_offset = math::integer_least_multiple(num_elems, max_lds_align); + return generate_tuple( + [&](auto i) { + const int32_t local_offset = i * single_buffer_offset; + return make_dynamic_buffer( + static_cast(p_shared) + local_offset + offset_elems, num_elems); + }, + Number{}); + } + template ( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - - auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto a_block_buffers = AllocateBlockBuffers( + p_shared, a_block_desc_ak0_m_ak1.GetElementSpaceSize(), 0, max_lds_align); + const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage; + auto b_block_buffers = + AllocateBlockBuffers(p_shared, + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); @@ -645,13 +664,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, - a_block_buf, + a_block_buffers, a_block_slice_copy_step, b_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_blockwise_copy, b_grid_buf, - b_block_buf, + b_block_buffers, b_block_slice_copy_step, blockwise_gemm, c_thread_buf, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp index 1c59f37a9..08d986d0d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp @@ -7,6 +7,20 @@ #include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +namespace lds_direct_load { + +__device__ void sched_barrier() +{ +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + // When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of + // `sched_barrier` is necessary to make sure no instructions that use the loaded memory + // are scheduled by the compiler before the `waitcnt` instruction. + __builtin_amdgcn_sched_barrier(0); +#endif +} + +} // namespace lds_direct_load + namespace ck { template @@ -17,7 +31,6 @@ template <> struct GridwiseGemmPipeline_v4<1> { static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } @@ -31,13 +44,13 @@ struct GridwiseGemmPipeline_v4<1> typename ABlockDesc, typename ABlockTransfer, typename AGridBuffer, - typename ABlockBuffer, + typename ABlockBuffers, typename ABlockTransferStep, typename BGridDesc, typename BBlockDesc, typename BBlockTransfer, typename BGridBuffer, - typename BBlockBuffer, + typename BBlockBuffers, typename BBlockTransferStep, typename BlockwiseGemm, typename CThreadBuffer> @@ -45,18 +58,22 @@ struct GridwiseGemmPipeline_v4<1> const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, + ABlockBuffers& a_block_bufs, const ABlockTransferStep& a_block_copy_step, const BGridDesc& b_grid_desc, const BBlockDesc& b_block_desc, BBlockTransfer& b_blockwise_copy, const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, + BBlockBuffers& b_block_bufs, const BBlockTransferStep& b_block_copy_step, const BlockwiseGemm& blockwise_gemm, CThreadBuffer& c_thread_buf, index_t num_loop) { + static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1); + auto& a_block_buf = a_block_bufs.At(I0); + auto& b_block_buf = b_block_bufs.At(I0); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); @@ -74,10 +91,12 @@ struct GridwiseGemmPipeline_v4<1> do { block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); @@ -92,10 +111,128 @@ struct GridwiseGemmPipeline_v4<1> // tail { block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } } }; +// 2-stages prefetch +template <> +struct GridwiseGemmPipeline_v4<2> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffers& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffers& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2); + auto& a_block_buf1 = a_block_bufs.At(I0); + auto& a_block_buf2 = a_block_bufs.At(I1); + auto& b_block_buf1 = b_block_bufs.At(I0); + auto& b_block_buf2 = b_block_bufs.At(I1); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + } + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index e670bb2d2..c8ce8d2bc 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -233,8 +233,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using src_vector_t = vector_type_maker_t; using dst_vector_t = vector_type_maker_t; - // get DstScalarPerVector # of read-only references to src vectors from - // src_thread_scratch_ +#if 0 + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ const auto src_vector_refs = generate_tie( [&](auto i) -> const src_vector_t& { // i increment corresponds to movement in DstVectorDim @@ -255,15 +256,37 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // do data transpose transpose_vectors{}( src_vector_refs, dst_vector_refs); +#else + StaticallyIndexedArray src_vectors; + StaticallyIndexedArray dst_vectors; + + // read DstScalarPerVector # of src vectors from src_thread_scratch_ + static_for<0, num_src_vector, 1>{}([&](auto i) { + // i increment corresponds to movement in DstVectorDim + src_vectors(i) = + src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }); + + // do data transpose + transpose_vectors{}(src_vectors, + dst_vectors); + + // write SrcScalarPerVector # dst vectors into dst_thread_scratch_ + static_for<0, num_dst_vector, 1>{}([&](auto i) { + // i increment corresponds to movement in SrcVectorDim + dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector) = dst_vectors[i]; + }); +#endif + }); + } + else + { + static_ford{}([&](auto idx) { + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); } - - static_ford{}([&](auto idx) { - // apply the src elementwise op and convert to DstData under the hood if needed - DstData dst_v; - src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]); - dst_thread_scratch_(idx) = dst_v; - }); } template diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 6f546f1d6..e2f75142d 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -522,22 +522,21 @@ struct TransformConvFwdToGemm // for output bias template || - is_same_v, + typename std::enable_if, bool>::type = false> - static auto - MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */) + static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides) { - const index_t N = c_g_n_k_wos_lengths[1]; - const index_t K = c_g_n_k_wos_lengths[2]; + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + const index_t KStride = c_g_n_k_wos_strides[2]; const index_t NHoWo = N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1)); + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); return out_gemmm_gemmn_desc; } diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp index 217928dfa..2b01568ed 100644 --- a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp @@ -26,7 +26,6 @@ struct BlockGemmARegBSmemCRegProblem static constexpr index_t kBlockSize = kBlockSize_; }; - } // namespace block } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 2a67ec2cb..8933eb2bf 100644 --- a/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -26,30 +26,41 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy { using namespace ck::tile_program::warp; + if constexpr(is_same_v && + is_same_v && + is_same_v) + { #if 0 - constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); - constexpr index_t NumWarp = kBlockSize / get_warp_size(); + constexpr index_t NumWarp = kBlockSize / get_warp_size(); - // FIXME - if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && - kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) - { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + // FIXME + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#endif } - else + else if constexpr(is_same_v && + is_same_v && + is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); } -#else - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); -#endif } }; diff --git a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index 2666c42bc..a028ef68c 100644 --- a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -26,29 +26,40 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy { using namespace ck::tile_program::warp; + if constexpr(is_same_v && + is_same_v && + is_same_v) + { #if 0 - constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); - constexpr index_t NumWarp = kBlockSize / get_warp_size(); + constexpr index_t NumWarp = kBlockSize / get_warp_size(); - if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && - kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) - { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); +#endif } - else + else if constexpr(is_same_v && + is_same_v && + is_same_v) { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); } -#else - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); -#endif } }; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp index b00028f8e..0fb576590 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp @@ -686,18 +686,28 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy TileGemmShape>; - - using WarpGemm = warp::WarpGemmImpl< - warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< - warp::WarpGemmAttributeMfmaImplF16F16F32M32N32K8, - 2>>; + + constexpr auto warp_gemm = []() { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + }(); using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy; + decltype(warp_gemm)>; return BlockGemmARegBSmemCRegV2{}; } diff --git a/include/ck/tile_program/tile/shuffle_distributed_tensor_impl_in_thread.hpp b/include/ck/tile_program/tile/shuffle_distributed_tensor_impl_in_thread.hpp index a8758460d..b37a0ec54 100644 --- a/include/ck/tile_program/tile/shuffle_distributed_tensor_impl_in_thread.hpp +++ b/include/ck/tile_program/tile/shuffle_distributed_tensor_impl_in_thread.hpp @@ -99,15 +99,6 @@ __device__ void shuffle_distributed_tensor_impl_in_thread(OutTensor& out_tensor, StaticallyIndexedArray in_vectors; StaticallyIndexedArray out_vectors; -#if 0 - print(y_dim_out_to_in); - printf("\n"); - printf("y_dim_vec_in %d\n", y_dim_vec_in); - printf("y_dim_vec_out %d\n", y_dim_vec_out); - printf("num_vec_in %d\n", num_vec_in); - printf("num_vec_out %d\n", num_vec_out); -#endif - // loop over SFC and do transpose static_for<0, num_access, 1>{}([&](auto iAccess) { // data index [y0, y1, ...] in the order of input tensor diff --git a/include/ck/tile_program/warp_tile/warp_gemm.hpp b/include/ck/tile_program/warp_tile/warp_gemm.hpp index 50a80a07c..f08e24631 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm.hpp @@ -13,6 +13,7 @@ namespace ck { namespace tile_program { namespace warp { +// fp16 using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl>; @@ -41,6 +42,45 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; +using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = + WarpGemmImpl>; + +// bf16 +using WarpGemmMfmaBf16Bf16F32M32N32K8 = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M16N16K16 = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M32N32K16 = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M16N16K32 = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = + WarpGemmImpl>; + +using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = + WarpGemmImpl>; + } // namespace warp } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp index 68c721bfc..72431c802 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp @@ -9,6 +9,7 @@ namespace ck { namespace tile_program { namespace warp { +// FP16 struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 { using ADataType = half_t; @@ -42,7 +43,6 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 // c_vec = a_vec * b_vec __device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { - // FIXME: Is this correct? return __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0); } }; @@ -80,11 +80,85 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 // c_vec = a_vec * b_vec __device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { - // FIXME: Is this correct? return __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, CVecType{0.f}, 0, 0, 0); } }; +// Bf16 +struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 +{ + using ADataType = bhalf_t; + using BDataType = bhalf_t; + using CDataType = float; + + using AVecType = typename vector_type::type; + using BVecType = typename vector_type::type; + using CVecType = typename vector_type::type; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 8; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + __device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + } + + // c_vec = a_vec * b_vec + __device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, CVecType{0.f}, 0, 0, 0); + } +}; + +struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 +{ + using ADataType = bhalf_t; + using BDataType = bhalf_t; + using CDataType = float; + + using AVecType = typename vector_type::type; + using BVecType = typename vector_type::type; + using CVecType = typename vector_type::type; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 16; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 4; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + __device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + } + + // c_vec = a_vec * b_vec + __device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, CVecType{0.f}, 0, 0, 0); + } +}; + } // namespace warp } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp b/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp index 42751ce13..68f2255b5 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp @@ -21,6 +21,7 @@ template struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; @@ -29,6 +30,16 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; + +// bf16 +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; // clang-format on } // namespace impl diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 3e2c01455..c184e9729 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1456,6 +1456,15 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + T* lds_ptr = lds_base_ptr + lds_offset; + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), + "s"(src_resource)); +#else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( @@ -1463,6 +1472,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); +#endif } } // namespace ck diff --git a/include/ck/utility/transpose_vectors.hpp b/include/ck/utility/transpose_vectors.hpp index 503832ab4..980c99408 100644 --- a/include/ck/utility/transpose_vectors.hpp +++ b/include/ck/utility/transpose_vectors.hpp @@ -9,6 +9,7 @@ namespace ck { +#if 0 // debug // S: scalar type // NX: # of vector before transpose // NY: # of vector after transpose @@ -21,22 +22,6 @@ struct transpose_vectors; // transpose fp16 2x2 __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t& y0, half2_t& y1) { -#if 0 - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - const vector_type vx0{x0}, vx1{x1}; - vector_type vy0, vy1; - - vy0.template AsType()(I0) = vx0.template AsType()[I0]; - vy0.template AsType()(I1) = vx1.template AsType()[I0]; - - vy1.template AsType()(I0) = vx0.template AsType()[I1]; - vy1.template AsType()(I1) = vx1.template AsType()[I1]; - - y0 = vy0.template AsType()[I0]; - y1 = vy1.template AsType()[I0]; -#else constexpr int32_t m0 = 0x05040100; constexpr int32_t m1 = 0x07060302; @@ -46,7 +31,6 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t // index is reversed because of little endianness (least significant bits first) y0 = bit_cast(__builtin_amdgcn_perm(bit_cast(x1), bit_cast(x0), m0)); y1 = bit_cast(__builtin_amdgcn_perm(bit_cast(x1), bit_cast(x0), m1)); -#endif } template @@ -148,6 +132,7 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0, y3 = bit_cast(z3); } + template struct transpose_vectors { @@ -190,5 +175,117 @@ struct transpose_vectors }); } }; +#else + +// S: scalar type +// NX: # of vector before transpose +// NY: # of vector after transpose +// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data +template ::value, bool>::type = false> +struct transpose_vectors +{ + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = remove_cvref_t; + + using VX = vector_type; + using VY = vector_type; + + __device__ void operator()(const StaticallyIndexedArray& vx_tuple, + StaticallyIndexedArray& vy_tuple) + { + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + if constexpr(is_same_v || is_same_v) + { + static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!"); + + using S2 = typename vector_type::type; + + // loop over 2x2 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 2>{}([&](auto iy) { + static_for<0, NX, 2>{}([&](auto ix) { + // 2 16bitx2 data from vx_tuple to be transposed + const int32_t x_s2_0 = + bit_cast(vx_tuple[ix].template AsType()[iy / I2]); + const int32_t x_s2_1 = + bit_cast(vx_tuple[ix + I1].template AsType()[iy / I2]); + + constexpr int32_t m0 = 0x05040100; + constexpr int32_t m1 = 0x07060302; + + // transpose 2x2 16bit + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0); + const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1); + + // 2 16bitx2 data after transposed + vy_tuple(iy).template AsType()(ix / I2) = bit_cast(y_s2_0); + vy_tuple(iy + I1).template AsType()(ix / I2) = bit_cast(y_s2_1); + }); + }); + } + else if constexpr(is_same_v || is_same_v || is_same_v) + { + static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); + + using S4 = typename vector_type::type; + + // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // 4 int8x4 data from vx_tuple + const int32_t x_s4_0 = + bit_cast(vx_tuple[ix].template AsType()[iy / I4]); + const int32_t x_s4_1 = + bit_cast(vx_tuple[ix + I1].template AsType()[iy / I4]); + const int32_t x_s4_2 = + bit_cast(vx_tuple[ix + I2].template AsType()[iy / I4]); + const int32_t x_s4_3 = + bit_cast(vx_tuple[ix + I3].template AsType()[iy / I4]); + + // transpose + int32_t t_s4_0, t_s4_1; + int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3; + + constexpr int32_t m0 = 0x05010400; + constexpr int32_t m1 = 0x05040100; + constexpr int32_t m2 = 0x07060302; + constexpr int32_t m3 = 0x07030602; + + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0); + t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0); + y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); + y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); + t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3); + t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3); + y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); + y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); + + // 4 int8x4 data from vy_tuple + vy_tuple(iy).template AsType()(ix / I4) = bit_cast(y_s4_0); + vy_tuple(iy + I1).template AsType()(ix / I4) = bit_cast(y_s4_1); + vy_tuple(iy + I2).template AsType()(ix / I4) = bit_cast(y_s4_2); + vy_tuple(iy + I3).template AsType()(ix / I4) = bit_cast(y_s4_3); + }); + }); + } + } +}; +#endif } // namespace ck diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 036771f31..3f7f7599c 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -166,6 +166,18 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple&) return (is_detected::value || ...); } +template +__host__ __device__ constexpr auto TupleDepth(const T&) +{ + return depth; +} + +template +__host__ __device__ constexpr auto TupleDepth(const Tuple&) +{ + return math::max(TupleDepth(Ts{})...); +} + } // namespace ck // Macro function diff --git a/example/64_tensor_transforms/tensor_transform_wrapper.hpp b/include/ck/wrapper/layout.hpp similarity index 68% rename from example/64_tensor_transforms/tensor_transform_wrapper.hpp rename to include/ck/wrapper/layout.hpp index 71cd6091f..b337d88a1 100644 --- a/example/64_tensor_transforms/tensor_transform_wrapper.hpp +++ b/include/ck/wrapper/layout.hpp @@ -3,27 +3,13 @@ #pragma once -#include "ck/ck.hpp" - -#include "ck/utility/number.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/utility/tuple_helper.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/utility/sequence_helper.hpp" -#include "ck/utility/is_detected.hpp" - -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/wrapper/layout_utils.hpp" namespace ck { -namespace tensor_transform_wrapper { +namespace wrapper { /** - * \brief Layout wrapper - * - * \details - * Layout wrapper that performs the tensor descriptor logic. + * \brief Layout wrapper that performs the tensor descriptor logic. * * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t * (dynamic layout). It is possible to pass nested shapes @@ -32,21 +18,19 @@ namespace tensor_transform_wrapper { * (dynamic layout). Stride tuple should be nested if shape tuple is * nested. */ -template > +template struct Layout { private: static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - template - using is_tuple = decltype(std::declval().IsTuple()); - // Generate packed (column-major) strides if not passed template __host__ __device__ constexpr static auto - GenerateColumnMajorPackedStrides(const Tuple& tuple) + GenerateColumnMajorPackedStrides(const Tuple& shape) { + const auto unrolled_shape = UnrollNestedTuple(shape); return generate_tuple( [&](auto i) { if constexpr(i.value == 0) @@ -56,10 +40,10 @@ struct Layout else { return TupleReduce([](auto x, auto y) { return x * y; }, - tuple); + unrolled_shape); } }, - Number::Size()>{}); + Number{}); } // Generate LowerDims in Compile-time for MergeTrasform using passed Type @@ -112,8 +96,8 @@ struct Layout // Example shape: (2, (2, 2)), 2, (2, 2) // Unrolled shape: 2, (2, 2), 2, (2, 2) template - __host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple& shape, - const Tuple& idx) + __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple& shape, + const Tuple& idx) { if constexpr(!IsNestedTuple(Tuple{})) { @@ -125,7 +109,7 @@ struct Layout // Iterate over shape tuple elements: // 1. If corresponding idx element is tuple then return (will be unrolled) // 2. If no, pack in tuple. It will be restored during unroll. - auto unrolled_shape_via_idx = generate_tuple( + auto aligned_shape = generate_tuple( [&](auto i) { if constexpr(is_detected>>::value) @@ -140,8 +124,8 @@ struct Layout Number::Size()>{}); // Unroll and process next step - return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx), - UnrollNestedTuple<0, 1>(idx)); + return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape), + UnrollNestedTuple<0, 1>(idx)); } } @@ -150,27 +134,24 @@ struct Layout DescriptorToMerge& desc) { // Reverse each element in tuple - using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape))); - const auto merge_elems = ReversedUnrolledShape{}; - + const auto merge_elems = TupleReverse(UnrollNestedTuple(shape)); // Generate reverted indexes (column major traverse) - using MergeElemsSequence = - typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type; - const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); - const auto upper_dims = make_tuple(Sequence<0>{}); + using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type; + const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); + const auto upper_dims = make_tuple(Sequence<0>{}); // Merge to 1d return transform_tensor_descriptor( desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); } - // Merge nested shape dims + // Merge nested shape dims. Merge nested shape dims when idx is also nested. // Input desc shape: 2, 2, 2, 2, 2, 2 // Example idx: 1, 1, 1, 1 // Example shape: 2, (2, 2), 2, (2, 2) // Merged shape: 2, 4, 2, 4 template - __host__ __device__ constexpr static auto - MakeMerges(const Tuple& shape, const Tuple&, DescriptorToMerge& desc) + __host__ __device__ constexpr static auto CreateMergedDescriptor( + const Tuple& shape, const Tuple&, DescriptorToMerge& desc) { const auto transforms = generate_tuple( [&](auto i) { @@ -224,9 +205,9 @@ struct Layout static_assert(Tuple::Size() == Tuple::Size(), "Idx rank and Shape rank must be the same (except 1d)."); // Unroll while IdxDims is nested - const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx); + const auto aligned_shape = AlignShapeToIdx(shape, idx); // Transform correct form of shape - return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_); + return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_); } } @@ -234,26 +215,21 @@ struct Layout __host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape, const LayoutStrides& strides) { - const auto unrolled_shape = UnrollNestedTuple(shape); - - if constexpr(ck::is_same_v>) - { - // If shape is packed - const auto column_major_packed_strides = - GenerateColumnMajorPackedStrides(unrolled_shape); - return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides); - } - else - { - const auto unrolled_strides = UnrollNestedTuple(strides); - static_assert(unrolled_shape.Size() == unrolled_strides.Size(), - "Size of strides and shape are not consistent."); - return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); - } + const auto unrolled_shape = UnrollNestedTuple(shape); + const auto unrolled_strides = UnrollNestedTuple(strides); + static_assert(unrolled_shape.Size() == unrolled_strides.Size(), + "Size of strides and shape are not consistent."); + return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); } public: - using NaiveDescriptorType = remove_cvref_t; + // If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`. + using DeducedStrides = + std::conditional_t>, + remove_cvref_t, + Strides>; + using NaiveDescriptorType = + remove_cvref_t; /** * \brief Layout constructor. @@ -268,9 +244,9 @@ struct Layout // Construct if runtime mode if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) { - // Keep only shape, strides are not need for transforms shape_ = shape; - descriptor_ = MakeNaiveDescriptor(shape, strides); + strides_ = strides; + descriptor_ = MakeNaiveDescriptor(shape_, strides_); } } @@ -279,7 +255,8 @@ struct Layout if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) { shape_ = shape; - descriptor_ = MakeNaiveDescriptor(shape, Strides{}); + strides_ = GenerateColumnMajorPackedStrides(shape_); + descriptor_ = MakeNaiveDescriptor(shape_, strides_); } } @@ -338,7 +315,7 @@ struct Layout * * \return Calculated size. */ - __host__ __device__ constexpr index_t GetLength() const + __host__ __device__ constexpr index_t GetLengths() const { const auto unrolled_shape = UnrollNestedTuple(shape_); return TupleReduce([](auto x, auto y) { return x * y; }, @@ -346,80 +323,24 @@ struct Layout } /** - * \brief Dimension getter. + * \brief Shape getter. * - * \tparam IDim Dimension idx. - * \return Calculated size. + * \return Shape. */ - template - __host__ __device__ constexpr auto Get() const - { - const auto elem = shape_.At(Number{}); - return elem; - } + __host__ __device__ constexpr Shape GetShape() const { return shape_; } + + /** + * \brief Strides getter. + * + * \return Strides. + */ + __host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; } private: NaiveDescriptorType descriptor_; Shape shape_; + DeducedStrides strides_; }; -// Layout helpers -// Length getter (product if tuple) -template -__host__ __device__ constexpr index_t size(const Layout& layout) -{ - return layout.template GetLength(); -} - -// Get shape size (product of dims if tuple) -template -__host__ __device__ constexpr index_t size(const Tuple& shape) -{ - using UnrolledShape = decltype(UnrollNestedTuple(shape)); - return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; }, - UnrolledShape{}); -} - -// Get dim size (could be returned from get function) -template -__host__ __device__ T constexpr size(const T& dim) -{ - return dim; -} - -// Get layout size (product of shapes) -template -__host__ __device__ constexpr index_t size(const Layout& layout) -{ - return layout.GetLength(); -} - -// Get shape element size -template -__host__ __device__ constexpr index_t size(const Tuple& shape) -{ - return size(shape.At(Number{})); -} - -// Dim getter (tuple if tuple) -template -__host__ __device__ constexpr auto get(const Layout& layout) -{ - return layout.template Get(); -} - -template -__host__ __device__ constexpr Layout make_layout(const Shape& shape, - const Strides& strides) -{ - return Layout(shape, strides); -} - -template -__host__ __device__ constexpr Layout make_layout(const Shape& shape) -{ - return Layout(shape); -} - -} // namespace tensor_transform_wrapper +} // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/layout_utils.hpp b/include/ck/wrapper/layout_utils.hpp new file mode 100644 index 000000000..fac8f3385 --- /dev/null +++ b/include/ck/wrapper/layout_utils.hpp @@ -0,0 +1,321 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/is_detected.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +namespace ck { +namespace wrapper { + +// Disable from doxygen docs generation +/// @cond +// forward declaration +template > +struct Layout; + +template +using is_tuple = decltype(std::declval().IsTuple()); +/// @endcond + +// make_* +/** + * \brief Make layout function. + * + * \tparam Shape Shape for layout. + * \tparam Strides Strides for layout. + * \return Constructed layout. + */ +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape, + const Strides& strides) +{ + return Layout(shape, strides); +} + +/** + * \brief Make layout function with packed strides + * (column-major). + * + * \tparam Shape Shape for layout. + * \return Constructed layout. + */ +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape) +{ + return Layout(shape); +} + +// Layout helpers +// get +/** + * \brief Get element from tuple (Shape/Strides/Idxs). + * + * \tparam idx Index to lookup. + * \param tuple Tuple to lookup. + * \return Requsted element. + */ +template +__host__ __device__ constexpr auto get(const Tuple& tuple) +{ + return tuple.At(Number{}); +} + +/** + * \brief Get sub layout. + * + * \tparam idx Index to lookup. + * \param layout Layout to create sub layout. + * \return Requsted sub layout. + */ +template +__host__ __device__ constexpr auto get(const Layout& layout) +{ + const auto new_shape = get(layout.GetShape()); + static_assert(is_detected::value, + "Shape of sub layout must be tuple"); + if constexpr(is_same_v>) + { + // If stride not passed, create without strides + return make_layout(new_shape); + } + else + { + const auto new_strides = get(layout.GetStrides()); + static_assert(is_detected::value, + "Strides of sub layout must be tuple"); + return make_layout(new_shape, new_strides); + } +} + +/** + * \brief Hierarchical get. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted element. + */ +template +__host__ __device__ constexpr auto get(const T& elem) +{ + return get(get(elem)); +} + +// size +/** + * \brief Length get (product if tuple). + * + * \tparam idx Index to lookup. + * \param layout Layout to get Shape. + * \return Requsted length. + */ +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.template GetLength(); +} + +/** + * \brief Shape size (product of dims). + * + * \param shape Shape to lookup. + * \return Requsted size. + */ +template +__host__ __device__ constexpr index_t size(const Tuple& shape) +{ + const auto unrolled_shape = UnrollNestedTuple(shape); + return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, + unrolled_shape); +} + +// Get dim size (could be returned from get function) +/** + * \private + */ +template +__host__ __device__ T constexpr size(const T& dim) +{ + return dim; +} + +/** + * \brief Layout size (product of dims). + * + * \param layout Layout to calculate shape size. + * \return Requsted size. + */ +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.GetLengths(); +} + +/** + * \brief Length get from tuple (product if tuple). + * + * \tparam idx Index to lookup. + * \param tuple Tuple to lookup. + * \return Requsted length. + */ +template +__host__ __device__ constexpr index_t size(const Tuple& tuple) +{ + return size(tuple.At(Number{})); +} + +/** + * \brief Hierarchical size. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted element. + */ +template +__host__ __device__ constexpr auto size(const T& elem) +{ + return size(get(elem)); +} + +// rank +/** + * \brief Get layout rank (num elements in shape). + * + * \param layout Layout to calculate rank. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout& layout) +{ + return Shape::Size(); +} + +/** + * \brief Get tuple rank (num elements in tuple). + * Return 1 if scalar passed. + * + * \param tuple Tuple to calculate rank. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple& tuple) +{ + return Tuple::Size(); +} + +/** + * \private + */ +template +__host__ __device__ constexpr index_t rank(const Number&) +{ + return 1; +} + +/** + * \private + */ +__host__ __device__ constexpr index_t rank(const index_t&) { return 1; } + +/** + * \brief Hierarchical rank. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr auto rank(const T& elem) +{ + return rank(get(elem)); +} + +// depth +/** + * \brief Get depth of the layout shape (return 0 if scalar). + * + * \param layout Layout to calculate depth. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr auto depth(const Layout& layout) +{ + return TupleDepth(layout.GetShape()); +} + +/** + * \brief Get depth of the tuple. (return 0 if scalar) + * + * \param tuple Tuple to calculate depth. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr auto depth(const Tuple& tuple) +{ + return TupleDepth(tuple); +} + +/** + * \private + */ +template +__host__ __device__ constexpr index_t depth(const Number&) +{ + return 0; +} + +/** + * \private + */ +__host__ __device__ constexpr index_t depth(const index_t&) { return 0; } + +/** + * \brief Hierarchical depth. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr auto depth(const T& elem) +{ + return depth(get(elem)); +} + +/** + * \brief Get Layout strides. + * + * \param layout Layout to get strides. + * \return Requsted strides. + */ +template +__host__ __device__ constexpr auto stride(const Layout& layout) +{ + return layout.GetStrides(); +} + +/** + * \brief Get Layout shape. + * + * \param layout Layout to get shape. + * \return Requsted shape. + */ +template +__host__ __device__ constexpr auto shape(const Layout& layout) +{ + return layout.GetShape(); +} + +} // namespace wrapper +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 89b8b9667..dc47c7ec1 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -86,9 +86,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; // -using GK = ck::tensor_layout::convolution::G_K; -using GK_Tuple = ck::Tuple; -using GK_GK_Tuple = ck::Tuple; +using G_K = ck::tensor_layout::convolution::G_K; +using GK_Tuple = ck::Tuple; +using GK_GK_Tuple = ck::Tuple; // pointwise functor using PassThrough = ck::tensor_operation::element_wise::PassThrough; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp index 1bea403af..348bcaef8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -23,19 +23,20 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; #ifdef CK_ENABLE_BF16 // grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK -void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - std::vector, - NDHWGK, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - ScaleAdd, - ScaleAdd, - PassThrough>>>& instances); +// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 +// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +// std::vector, +// NDHWGK, +// ck::Tuple, +// ck::Tuple, +// ck::Tuple<>, +// BF16, +// ScaleAdd, +// ScaleAdd, +// PassThrough>>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -151,13 +152,15 @@ struct DeviceOperationInstanceFactory> && - is_same_v> && - is_same_v && is_same_v) - { - add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - op_ptrs); - } + // TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 + // if constexpr(is_same_v> && + // is_same_v> && + // is_same_v && is_same_v) + // { + // add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + // op_ptrs); + // } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v> && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp index dc9f44dc8..efb626642 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp @@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, BF16, BF16, @@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, F16, F16, @@ -59,7 +59,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, F32, F32, @@ -75,7 +75,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, int8_t, int8_t, @@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory> op_ptrs; if constexpr(NumDimSpatial == 3 && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + DLayouts::Size() == 2 && is_same_v, NDHWGK> && + is_same_v, G_K>) { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp index 3c9e03b67..b3d1e925d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp @@ -16,6 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { +using F8 = ck::f8_t; using F16 = ck::half_t; using F32 = float; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp index aab0af990..9c8099594 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp @@ -16,6 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { +using F8 = ck::f8_t; using F16 = ck::half_t; using F32 = float; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp index 9c96e12c3..bb40237bf 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp @@ -35,7 +35,21 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 32, 128, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp index fcfd766b0..94f75d0e0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp @@ -32,7 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp index 68c048880..0f4ebc350 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp @@ -32,7 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp index ef09478d1..d2bc9351b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp @@ -31,7 +31,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp index aec542162..2c208c01f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp @@ -24,8 +24,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off @@ -34,7 +33,7 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances = // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index c7801f02c..d5b9da86c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -9,42 +9,43 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - std::vector, - NDHWGK, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - ScaleAdd, - ScaleAdd, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1S1P0>{}); -} +// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 +// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +// std::vector, +// NDHWGK, +// ck::Tuple, +// ck::Tuple, +// ck::Tuple<>, +// BF16, +// ScaleAdd, +// ScaleAdd, +// PassThrough>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, +// NDHWGC, +// GKZYXC, +// NDHWGK, +// ConvFwdDefault>{}); +// add_device_operation_instances( +// instances, +// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, +// NDHWGC, +// GKZYXC, +// NDHWGK, +// ConvFwd1x1P0>{}); +// add_device_operation_instances( +// instances, +// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, +// NDHWGC, +// GKZYXC, +// NDHWGK, +// ConvFwd1x1S1P0>{}); +// } } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index c6627a482..7d2df94ad 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, BF16, BF16, @@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwdDefault>{}); add_device_operation_instances( @@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1P0>{}); add_device_operation_instances( @@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1S1P0>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index 627af24d7..8a09d0396 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, F16, F16, @@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwdDefault>{}); add_device_operation_instances( @@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1P0>{}); add_device_operation_instances( @@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1S1P0>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp index 1fd567e36..696695963 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, F32, F32, @@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwdDefault>{}); add_device_operation_instances( @@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1P0>{}); add_device_operation_instances( @@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1S1P0>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp index dae292891..2606f6942 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp @@ -12,7 +12,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, int8_t, int8_t, @@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwdDefault>{}); add_device_operation_instances( @@ -35,7 +35,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1P0>{}); add_device_operation_instances( @@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1S1P0>{}); } diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp index 711314985..d46fe090b 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp @@ -22,13 +22,13 @@ using S = ck::Sequence; using NHWGC = ck::tensor_layout::convolution::NHWGC; using GKYXC = ck::tensor_layout::convolution::GKYXC; using NHWGK = ck::tensor_layout::convolution::NHWGK; -using GK = ck::tensor_layout::convolution::G_K; +using G_K = ck::tensor_layout::convolution::G_K; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Relu = ck::tensor_operation::element_wise::Relu; using TanH = ck::tensor_operation::element_wise::TanH; -using GK_Tuple = ck::Tuple; -using GK_GK_Tuple = ck::Tuple; +using GK_Tuple = ck::Tuple; +using GK_GK_Tuple = ck::Tuple; using I32_Tuple = ck::Tuple; using F32_Tuple = ck::Tuple; using I32_F32_Tuple = ck::Tuple; diff --git a/profiler/src/profile_transpose.cpp b/profiler/src/profile_transpose.cpp deleted file mode 100644 index c239a520d..000000000 --- a/profiler/src/profile_transpose.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "profiler/profile_transpose_impl.hpp" -#include "profiler_operation_registry.hpp" - -enum struct MatrixLayout -{ - NCDHW, // 0 - NCHWD, // 1 -}; - -enum struct DataType -{ - F32_F32_F32_F32_F32, // 0 - F16_F16_F16_F16_F16, // 1 -}; - -#define OP_NAME "transpose" -#define OP_DESC "Transpose" - -int profile_transpose(int argc, char* argv[]) -{ - if(argc != 15) - { - printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); - printf("arg2: data type (0: fp32; 1: fp16)\n"); - // printf("arg3: matrix layout (NCDHW -> NDCHW);\n"); - printf("arg4: verification (0: no; 1: yes)\n"); - printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg6: print tensor value (0: no; 1: yes)\n"); - printf("arg7: time kernel (0=no, 1=yes)\n"); - printf("arg8 to 13: N, C, D, H, W\n"); - exit(1); - } - - const auto data_type = static_cast(std::stoi(argv[2])); - // const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const bool time_kernel = std::stoi(argv[6]); - std::vector lengths = std::stoi(argv[7]); - - /**const int N = std::stoi(argv[7]); - const int C = std::stoi(argv[8]); - const int D = std::stoi(argv[9]); - const int H = std::stoi(argv[10]); - const int W = std::stoi(argv[11]);**/ - - using F32 = float; - using F16 = ck::half_t; - - auto profile = [&](auto a_type, auto b_type) { - using ADataType = decltype(a_type); - using BDataType = decltype(b_type); - - bool pass = ck::profiler::profile_transpose_impl( - do_verification, init_method, do_log, time_kernel, lengths); - - return pass ? 0 : 1; - }; - - if(data_type == GemmDataType::F32_F32_F32_F32_F32) - { - return profile(F32{}, F32{}); - } - else if(data_type == GemmDataType::F16_F16_F16_F16_F16) - { - return profile(F16{}, F16{}); - } - else - { - std::cout << "this data_type & layout is not implemented" << std::endl; - - return 1; - } -} - -REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_transpose); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4aaa5fcfa..b325a3a7f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -149,6 +149,7 @@ add_subdirectory(batched_gemm_multi_d) add_subdirectory(grouped_convnd_bwd_data) add_subdirectory(conv_tensor_rearrange) add_subdirectory(transpose) +add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt new file mode 100644 index 000000000..e25ef176d --- /dev/null +++ b/test/wrapper/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_layout test_layout.cpp) +target_link_libraries(test_layout PRIVATE utility) diff --git a/test/wrapper/test_layout.cpp b/test/wrapper/test_layout.cpp new file mode 100644 index 000000000..7d09696fb --- /dev/null +++ b/test/wrapper/test_layout.cpp @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/wrapper/layout.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +class TestWrapperLayout : public ::testing::Test +{ + protected: + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + + template + void Run(Desc& desc, + Desc1d& desc_1d, + LayoutRuntime& layout_runtime, + LayoutCompiletime& layout_compiletime, + const std::vector& idxs) + { + // 1d check + EXPECT_EQ(desc_1d.GetLength(I0), ck::wrapper::size(layout_runtime)); + // Check layout compiletime and runtime result consistency + EXPECT_EQ(ck::wrapper::size(layout_runtime), ck::wrapper::size(layout_compiletime)); + + for(ck::index_t i = 0; i < desc_1d.GetLength(I0); i++) + { + const ck::index_t layout_runtime_offset_1d = layout_runtime(ck::make_tuple(i)); + const ck::index_t layout_compiletime_offset_1d = layout_compiletime(ck::make_tuple(i)); + const ck::index_t desc_offset_1d = desc_1d.CalculateOffset(ck::make_tuple(i)); + EXPECT_EQ(layout_runtime_offset_1d, desc_offset_1d); + EXPECT_EQ(layout_compiletime_offset_1d, layout_runtime_offset_1d); + } + // size(layout)-d check, don't check if access is hierarchical + if constexpr(!IsNestedTuple(Idxs{})) + { + ck::static_for<0, Idxs::Size(), 1>{}([&](auto d) { + EXPECT_EQ(desc.GetLength(ck::Number{}), ck::wrapper::size(layout_runtime)); + EXPECT_EQ(ck::wrapper::size(layout_runtime), + ck::wrapper::size(layout_compiletime)); + }); + } + for(const auto idx : idxs) + { + const ck::index_t layout_runtime_offset = layout_runtime(idx); + const ck::index_t layout_compiletime_offset = layout_compiletime(idx); + const ck::index_t desc_offset = + desc.CalculateOffset(UnrollNestedTuple(idx)); // Unroll if nested + EXPECT_EQ(layout_runtime_offset, desc_offset); + EXPECT_EQ(layout_runtime_offset, layout_compiletime_offset); + } + } +}; + +TEST_F(TestWrapperLayout, 2d) +{ + // dims:(4, 3) strides:(1, 4) + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s1 = 1; + constexpr ck::index_t s0 = 4; + const auto desc = + ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1))), + ck::make_tuple(ck::Sequence<1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto layout_runtime = ck::wrapper::make_layout(ck::make_tuple(d1, d0)); + const auto layout_compiletime = + ck::wrapper::make_layout(ck::make_tuple(ck::Number{}, ck::Number{})); + std::vector> idxs; + + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs.emplace_back(h, w); + } + } + + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs); +} + +TEST_F(TestWrapperLayout, 3d_nested) +{ + // dims:((2, 3), 4, 3) strides:((2, 4), 12, 48) + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s3 = 2; + constexpr ck::index_t s2 = 4; + constexpr ck::index_t s1 = 12; + constexpr ck::index_t s0 = 48; + const auto desc = ck::make_naive_tensor_descriptor( + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))), + ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_3d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)), + ck::make_pass_through_transform(d1), + ck::make_pass_through_transform(d2)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + const auto layout_runtime = + ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), d1, d0), + ck::make_tuple(ck::make_tuple(s3, s2), s1, s0)); + const auto layout_compiletime = ck::wrapper::make_layout( + ck::make_tuple( + ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}, ck::Number{}), + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::Number{}, + ck::Number{})); + std::vector> idxs_3d; + + for(ck::index_t d = 0; d < d2 * d3; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_3d.emplace_back(d, h, w); + } + } + } + this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d); + + // Check also 4d iteration + std::vector, ck::index_t, ck::index_t>> idxs_4d; + + for(ck::index_t e = 0; e < d3; e++) + { + for(ck::index_t d = 0; d < d2; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_4d.emplace_back(ck::make_tuple(e, d), h, w); + } + } + } + } + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d); +} + +TEST_F(TestWrapperLayout, 2d_nested) +{ + // dims:((2, 3), (4, 3)) strides:((2, 4), (48, 12)) + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s3 = 2; + constexpr ck::index_t s2 = 4; + constexpr ck::index_t s1 = 48; + constexpr ck::index_t s0 = 12; + const auto desc = ck::make_naive_tensor_descriptor( + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))), + ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_2d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)), + ck::make_merge_transform(ck::make_tuple(d0, d1))), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + const auto layout_runtime = + ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), ck::make_tuple(d1, d0)), + ck::make_tuple(ck::make_tuple(s3, s2), ck::make_tuple(s1, s0))); + const auto layout_compiletime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})), + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + std::vector> idxs_2d; + + for(ck::index_t h = 0; h < d2 * d3; h++) + { + for(ck::index_t w = 0; w < d0 * d1; w++) + { + idxs_2d.emplace_back(h, w); + } + } + this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d); + // Check also 4d iteration + std::vector, ck::Tuple>> + idxs_4d; + + for(ck::index_t e = 0; e < d3; e++) + { + for(ck::index_t d = 0; d < d2; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_4d.emplace_back(ck::make_tuple(e, d), ck::make_tuple(h, w)); + } + } + } + } + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d); +} + +TEST_F(TestWrapperLayout, 3d_double_nested) +{ + // dims:(((2, 2), 3), (4, 3)) strides:(((2, 4), 8), (96, 24)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s4 = 2; + constexpr ck::index_t s3 = 4; + constexpr ck::index_t s2 = 8; + constexpr ck::index_t s1 = 96; + constexpr ck::index_t s0 = 24; + const auto desc = ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{}), + ck::make_tuple(ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3, d4))), + ck::make_tuple(ck::Sequence<4, 3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_3d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d3, d4)), + ck::make_pass_through_transform(d2), + ck::make_merge_transform(ck::make_tuple(d0, d1))), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<4, 3>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + const auto desc_2d = transform_tensor_descriptor( + desc_3d, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3 * d4)), + ck::make_pass_through_transform(d1 * d0)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + const auto layout_runtime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)), + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, s3), s2), ck::make_tuple(s1, s0))); + const auto layout_compiletime = ck::wrapper::make_layout( + ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})), + ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + std::vector> idxs_2d; + + for(ck::index_t h = 0; h < d2 * d3 * d4; h++) + { + for(ck::index_t w = 0; w < d0 * d1; w++) + { + idxs_2d.emplace_back(h, w); + } + } + this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d); + // Check also 3d iteration + std::vector, ck::index_t>> idxs_3d; + + for(ck::index_t d = 0; d < d3 * d4; d++) + { + for(ck::index_t h = 0; h < d2; h++) + { + for(ck::index_t w = 0; w < d1 * d0; w++) + { + idxs_3d.emplace_back(ck::make_tuple(d, h), w); + } + } + } + this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d); + // Check also 5d iteration + std::vector, ck::index_t>, + ck::Tuple>> + idxs_5d; + + for(ck::index_t f = 0; f < d4; f++) + { + for(ck::index_t e = 0; e < d3; e++) + { + for(ck::index_t d = 0; d < d2; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_5d.emplace_back(ck::make_tuple(ck::make_tuple(f, e), d), + ck::make_tuple(h, w)); + } + } + } + } + } + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_5d); +} + +TEST(TestLayoutHelpers, SizeAndGet) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + const auto layout_runtime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0))); + const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + + // Size of layout + EXPECT_EQ(ck::wrapper::size(layout_runtime), d4 * d3 * d2 * d1 * d0); + EXPECT_EQ(ck::wrapper::size(layout_compiletime), d4 * d3 * d2 * d1 * d0); + + // Size of dims + EXPECT_EQ(ck::wrapper::size<0>(layout_runtime), d4 * d3 * d2); + EXPECT_EQ(ck::wrapper::size<0>(layout_compiletime), d4 * d3 * d2); + EXPECT_EQ(ck::wrapper::size<1>(layout_runtime), d1 * d0); + EXPECT_EQ(ck::wrapper::size<1>(layout_compiletime), d1 * d0); + + // Access through new layout (using get with layout object) + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_runtime)), d4 * d3); + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_compiletime)), d4 * d3); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2); + + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d4); + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))), + d4); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d3); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))), + d3); + + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2); + + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_runtime)), d1); + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_compiletime)), d1); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_runtime)), d0); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_compiletime)), d0); +} + +TEST(TestLayoutHelpers, DepthAndRank) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + const auto layout_runtime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0))); + const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + + EXPECT_EQ(ck::wrapper::depth(layout_runtime), 3); + EXPECT_EQ(ck::wrapper::depth(layout_compiletime), 3); + EXPECT_EQ(ck::wrapper::depth(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2); + // Check for integer + EXPECT_EQ(ck::wrapper::depth(d0), 0); + + EXPECT_EQ(ck::wrapper::rank(layout_runtime), 2); + EXPECT_EQ(ck::wrapper::rank(layout_compiletime), 2); + EXPECT_EQ(ck::wrapper::rank(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2); + // Check for integer + EXPECT_EQ(ck::wrapper::rank(d0), 1); +} + +TEST(TestLayoutHelpers, ShapeAndStrides) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s4 = 2; + constexpr ck::index_t s3 = 4; + constexpr ck::index_t s2 = 8; + constexpr ck::index_t s1 = 96; + constexpr ck::index_t s0 = 24; + const auto shape_compiletime = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); + const auto strides_compiletime = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); + const auto shape_runtime = + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)); + const auto strides_runtime = + ck::make_tuple(ck::make_tuple(ck::make_tuple(s4, s3), s2), ck::make_tuple(s1, s0)); + const auto layout_runtime = ck::wrapper::make_layout(shape_runtime, strides_runtime); + const auto layout_compiletime = + ck::wrapper::make_layout(shape_compiletime, strides_compiletime); + + constexpr bool check_compiletime_shape = + std::is_same_v::type, + decltype(shape(layout_compiletime))>; + constexpr bool check_compiletime_strides = + std::is_same_v::type, + decltype(stride(layout_compiletime))>; + constexpr bool check_runtime_shape = + std::is_same_v::type, + decltype(shape(layout_runtime))>; + constexpr bool check_runtime_strides = + std::is_same_v::type, + decltype(stride(layout_runtime))>; + EXPECT_TRUE(check_compiletime_shape); + EXPECT_TRUE(check_compiletime_strides); + EXPECT_TRUE(check_runtime_shape); + EXPECT_TRUE(check_runtime_strides); +} + +TEST(TestLayoutHelpers, Hierarchical) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + const auto runtime_shape = + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)); + const auto layout_runtime = ck::wrapper::make_layout(runtime_shape); + const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + + EXPECT_EQ((ck::wrapper::rank<0, 0>(runtime_shape)), 2); + EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_runtime)), 2); + EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_compiletime)), 2); + + EXPECT_EQ((ck::wrapper::depth<0, 0>(runtime_shape)), 1); + EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_runtime)), 1); + EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_compiletime)), 1); + + EXPECT_EQ((ck::wrapper::size<0, 0>(runtime_shape)), d4 * d3); + EXPECT_EQ((ck::wrapper::size<0, 0>(layout_runtime)), d4 * d3); + EXPECT_EQ((ck::wrapper::size<0, 0>(layout_compiletime)), d4 * d3); + + EXPECT_EQ((ck::wrapper::get<0, 0, 0>(runtime_shape)), d4); +} From ced4670a10bac1b07d300aac5b9610c1ca4ed9b6 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 10 Dec 2023 23:02:03 -0600 Subject: [PATCH 11/45] add missing bf16 type --- example/91_tile_program/fmha_fwd.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index 4c1af6e79..2e993926c 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -55,6 +55,7 @@ using ODataType = ck::half_t; using QDataType = ck::bhalf_t; using KDataType = ck::bhalf_t; using VDataType = ck::bhalf_t; +using BiasDataType = ck::bhalf_t; using SaccDataType = float; // data type for first gemm accumulation using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck::bhalf_t; // data type for A matrix of second gemm From cfcc7e79ee2f21109b0c1ec5ec746e6745084be3 Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Mon, 11 Dec 2023 01:44:32 -0500 Subject: [PATCH 12/45] Fix loop counter update logics --- .../block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index e468f79b2..8ce1e36b0 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -315,7 +315,6 @@ struct BlockFmhaPipelineQRKSVS } // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); - i_total_loops++; // tail { block_sync_lds(); @@ -324,7 +323,7 @@ struct BlockFmhaPipelineQRKSVS v_lds_window); block_sync_lds(); } - } while(i_total_loops < num_total_loop); + } while(++i_total_loops < num_total_loop); // finally, O constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); From 5a24af3815e04db4104f5881a1db9187726e48fb Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 1 Dec 2023 16:35:23 +0000 Subject: [PATCH 13/45] Disable exp() and log() overloading for half_t to support xformers C++ extension building --- include/ck/utility/math_v2.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 082fa7baa..594123097 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -312,11 +312,13 @@ inline __device__ float exp(float x) return __expf(x); } +/* template <> inline __device__ half_t exp(half_t x) { return hexp(x); }; +*/ template <> inline __device__ double exp(double x) @@ -346,11 +348,13 @@ inline __device__ T log(T x) return ck::type_convert(__logf(ck::type_convert(x))); }; +/* template <> inline __device__ half_t log(half_t x) { return hlog(x); }; +*/ template <> inline __device__ float log(float x) From d18039131aa762fc64b287fb6156fe4582f88bd3 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 1 Dec 2023 16:34:42 +0000 Subject: [PATCH 14/45] Add include/ck/config.h to support xformers c++ extension building --- include/ck/config.h | 109 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 include/ck/config.h diff --git a/include/ck/config.h b/include/ck/config.h new file mode 100644 index 000000000..dbf4a9597 --- /dev/null +++ b/include/ck/config.h @@ -0,0 +1,109 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_CONFIG_H_IN +#define CK_CONFIG_H_IN + +// clang-format off +// +// DataType supports in the current CK build +// +#ifndef DTYPES +/* #undef DTYPES */ +#endif +// if DTYPES is not defined, enable all datatypes in headerfiles +#ifndef CK_ENABLE_ALL_DTYPES +#define CK_ENABLE_ALL_DTYPES ON +#if defined(CK_ENABLE_ALL_DTYPES) +#ifndef CK_ENABLE_INT8 +#define CK_ENABLE_INT8 "ON" +#endif +#ifndef CK_ENABLE_FP8 +#define CK_ENABLE_FP8 "ON" +#endif +#ifndef CK_ENABLE_BF8 +#define CK_ENABLE_BF8 "ON" +#endif +#ifndef CK_ENABLE_FP16 +#define CK_ENABLE_FP16 "ON" +#endif +#ifndef CK_ENABLE_BF16 +#define CK_ENABLE_BF16 "ON" +#endif +#ifndef CK_ENABLE_FP32 +#define CK_ENABLE_FP32 "ON" +#endif +#ifndef CK_ENABLE_FP64 +#define CK_ENABLE_FP64 "ON" +#endif +#endif +#endif +// if DTYPES are selectively enabled +#ifndef CK_ENABLE_INT8 +/* #undef CK_ENABLE_INT8 */ +#endif + +#ifndef CK_ENABLE_FP8 +/* #undef CK_ENABLE_FP8 */ +#endif + +#ifndef CK_ENABLE_BF8 +/* #undef CK_ENABLE_BF8 */ +#endif + +#ifndef CK_ENABLE_FP16 +/* #undef CK_ENABLE_FP16 */ +#endif + +#ifndef CK_ENABLE_BF16 +/* #undef CK_ENABLE_BF16 */ +#endif + +#ifndef CK_ENABLE_FP32 +/* #undef CK_ENABLE_FP32 */ +#endif + +#ifndef CK_ENABLE_FP64 +/* #undef CK_ENABLE_FP64 */ +#endif + +// +// Legacy DL kernel supports in the current CK build +// by default DL kernels are turned OFF +// +#ifndef CK_ENABLE_DL_KERNELS +/* #undef CK_ENABLE_DL_KERNELS */ +#endif + +// +// Instances supports in the current CK build +// +#ifndef CK_ENABLE_INSTANCES_ONLY +/* #undef CK_ENABLE_INSTANCES_ONLY */ +#endif + +// clang-format on + +#endif // CK_CONFIG_H_IN From c1814f90e2dd5b0659c6e1ed577fb1bba596c126 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 11 Dec 2023 01:52:49 -0600 Subject: [PATCH 15/45] refactor mask in async copy pipeline --- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 60 +++++++++++++------ 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 88a7864bf..c4ff4efaf 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -207,25 +207,58 @@ struct BlockFmhaPipelineQRKSVSAsync bias_dram_block_window_tmp.GetWindowOrigin(), Policy::template MakeBiasDramTileDistribution()); + const auto q_origin = q_dram_window.GetWindowOrigin(); + auto k_origin = k_dram_block_window.GetWindowOrigin(); + bool skip_tile = causal_mask.IsTileSkippable( + q_origin.At(Number<0>{}), k_origin.At(Number<0>{}), kM0, kN0); + // prefetch K tile - async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); + if(!skip_tile) + { + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + } __builtin_amdgcn_sched_barrier(0); - const auto q_origin = q_dram_window.GetWindowOrigin(); - buffer_load_fence(k_dram_window.GetNumAccess()); + if constexpr(std::is_same::value) + buffer_load_fence(k_dram_window.GetNumAccess()); + else + buffer_load_fence(0); // unconditionally wait for q if this is a mask kernel auto q_tile = tile_elementwise_in(q_element_func, q); __builtin_amdgcn_sched_barrier(0); index_t i_total_loops = 0; constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k1_loops = kN0 / kK1; + auto prefetch_k = [&]() { + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window = make_tile_window(k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); + + k_origin = k_dram_block_window.GetWindowOrigin(); + skip_tile = causal_mask.IsTileSkippable( + q_origin.At(Number<0>{}), k_origin.At(Number<0>{}), kM0, kN0); + if(!skip_tile) + { + if constexpr(k1_loops >= 2 && + LdsSeq.At(Number<0>{}) == LdsSeq.At(Number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + } + }; + + // main loop do { - const auto k_origin = k_dram_block_window.GetWindowOrigin(); - if(causal_mask.IsTileSkippable( - q_origin.At(Number<0>{}), k_origin.At(Number<0>{}), kM0, kN0)) + if(skip_tile) { + i_total_loops++; + if(i_total_loops < num_total_loop) + prefetch_k(); continue; } @@ -467,18 +500,7 @@ struct BlockFmhaPipelineQRKSVSAsync if(i_total_loops < num_total_loop) { // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window = - make_tile_window(k_dram_block_window.GetBottomTensorView(), - k_dram_block_window.GetWindowLengths(), - k_dram_block_window.GetWindowOrigin(), - Policy::template MakeKDramTileDistribution()); - - if constexpr(k1_loops >= 2 && - LdsSeq.At(Number<0>{}) == LdsSeq.At(Number{})) - __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); + prefetch_k(); } // tail { From 7fab8b068a004c0b43acaf8e569f494cc99edc53 Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Mon, 11 Dec 2023 09:59:03 -0500 Subject: [PATCH 16/45] Make sure RNG data for MaskUpperTriangleFromBottomRightPredicate is valid --- example/91_tile_program/fmha_fwd.cpp | 5 +- example/91_tile_program/fmha_utils.hpp | 130 +++++++++++++++++++------ 2 files changed, 102 insertions(+), 33 deletions(-) diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index 2e993926c..ed0d4ce67 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -360,8 +360,9 @@ int main(int argc, char* argv[]) StreamConfig stream_config{nullptr, true, 0, stream_warmup, stream_repeat}; - const std::vector seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); - const std::vector seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + const auto [seqlens_q, seqstart_q_host] = generate_seqlens_seqstarts_q(mode, batch, seqlen_q); + const std::vector seqstart_k_host = + generate_seqstarts_k(mode, batch, seqlen_k, seqlens_q, seqlen_q); // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; diff --git a/example/91_tile_program/fmha_utils.hpp b/example/91_tile_program/fmha_utils.hpp index e6bdc4a49..96f74a045 100644 --- a/example/91_tile_program/fmha_utils.hpp +++ b/example/91_tile_program/fmha_utils.hpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include #include #include #include @@ -8,6 +9,8 @@ #include #include +#include "ck/utility/span.hpp" + #pragma once enum class Mode : unsigned @@ -21,50 +24,115 @@ inline std::ostream& operator<<(std::ostream& stream, Mode mode) return stream << (mode == Mode::Batch ? "batch" : "group"); } -/// TODO: make sure result is valid for MaskUpperTriangleFromBottomRightPredicate -std::vector generate_seqstarts(Mode mode, - unsigned count, - int32_t seqlens_sum, - std::optional seed = std::nullopt) +inline std::vector to_seqstarts(ck::span seqlens) +{ + std::vector seqstarts = {0}; + for(int32_t seqlen : seqlens) + { + seqstarts.push_back(seqstarts.back() + seqlen); + } + assert(seqstarts.size() == seqlens.size() + 1); + return seqstarts; +} + +inline std::vector generate_seqlens_q(Mode mode, + unsigned count, + int32_t seqlens_q_sum, + std::optional seed = std::nullopt) { assert(0 < count); - const std::vector seqlens = [&]() { - std::vector original_seqlens(count, seqlens_sum); + std::vector seqlens_q(count, seqlens_q_sum); - if(mode == Mode::Group && 1 < count) - { - using size_type = std::vector::size_type; + if(mode == Mode::Group && 1 < count) + { + using size_type = std::vector::size_type; - std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); - std::uniform_int_distribution idx_dist(0, count - 1); - auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); - std::uniform_int_distribution step_dist(1, count - 1); - auto next_step = std::bind(step_dist, std::ref(random_engine)); + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); - for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) + for(unsigned repeat = seqlens_q_sum * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + // make sure each elements of seqlens_q is always greater than 0 + if(seqlens_q[to_decrease] == 1) { - const size_type to_decrease = next_idx(); - if(original_seqlens[to_decrease] == 1) - { - continue; - } + continue; + } - const size_type to_increase = (to_decrease + next_step()) % count; + const size_type to_increase = (to_decrease + next_step()) % count; - --original_seqlens[to_decrease]; - ++original_seqlens[to_increase]; - } + --seqlens_q[to_decrease]; + ++seqlens_q[to_increase]; } + } - return original_seqlens; - }(); + return seqlens_q; +} - std::vector seqstarts = {0}; - for(int32_t seqlen : seqlens) +inline std::tuple, std::vector> generate_seqlens_seqstarts_q( + Mode mode, unsigned count, int32_t seqlens_q_sum, std::optional seed = std::nullopt) +{ + const std::vector seqlens_q = generate_seqlens_q(mode, count, seqlens_q_sum, seed); + return std::tuple(seqlens_q, to_seqstarts(seqlens_q)); +} + +inline std::vector generate_seqlens_k(Mode mode, + unsigned count, + int32_t seqlens_k_sum, + ck::span seqlens_q, + int32_t seqlens_q_sum, + std::optional seed = std::nullopt) +{ + assert(0 < count); + assert(seqlens_q.size() == count); + + std::vector seqlens_k(count, seqlens_k_sum); + + if(mode == Mode::Group && 1 < count) { - seqstarts.push_back(seqstarts.back() + seqlen); + using size_type = std::vector::size_type; + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); + + for(unsigned repeat = seqlens_k_sum * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + // make sure each elements of seqlens_k is always greater than 0 & greater than + // corresponding elements in seqlens_q + if(seqlens_k[to_decrease] == 1 || + (seqlens_q_sum < seqlens_k_sum && + seqlens_k[to_decrease] <= seqlens_q[to_decrease] + 1)) + { + continue; + } + + const size_type to_increase = (to_decrease + next_step()) % count; + + --seqlens_k[to_decrease]; + ++seqlens_k[to_increase]; + } } - return seqstarts; + + return seqlens_k; +} + +inline std::vector generate_seqstarts_k(Mode mode, + unsigned count, + int32_t seqlens_k_sum, + ck::span seqlens_q, + int32_t seqlens_q_sum, + std::optional seed = std::nullopt) +{ + return to_seqstarts( + generate_seqlens_k(mode, count, seqlens_k_sum, seqlens_q, seqlens_q_sum, seed)); } From dc9ba2eccb1abea7947e559f452806ea8ed9ce7e Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Mon, 11 Dec 2023 10:47:42 -0500 Subject: [PATCH 17/45] Use std::make_tuple() to construct temp std::tuple<> --- example/91_tile_program/fmha_utils.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/example/91_tile_program/fmha_utils.hpp b/example/91_tile_program/fmha_utils.hpp index 96f74a045..df885b7e3 100644 --- a/example/91_tile_program/fmha_utils.hpp +++ b/example/91_tile_program/fmha_utils.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -78,7 +79,7 @@ inline std::tuple, std::vector> generate_seqlens_s Mode mode, unsigned count, int32_t seqlens_q_sum, std::optional seed = std::nullopt) { const std::vector seqlens_q = generate_seqlens_q(mode, count, seqlens_q_sum, seed); - return std::tuple(seqlens_q, to_seqstarts(seqlens_q)); + return std::make_tuple(seqlens_q, to_seqstarts(seqlens_q)); } inline std::vector generate_seqlens_k(Mode mode, From b7e3f3b69e11d2ac2d83c2c2cfcf1f6ad5ecce90 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 13 Dec 2023 14:01:25 +0000 Subject: [PATCH 18/45] Add bhalf2_t, bhalf4_t inner_product --- include/ck/utility/inner_product.hpp | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index 65efaf388..c1943520c 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -152,6 +152,48 @@ __device__ void inner_product(const half8_t& a, const h c); } +template <> +__device__ void +inner_product(const bhalf2_t& a, const bhalf2_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const bhalf4_t& a, const bhalf4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + template <> __device__ void inner_product(const int8_t& a, const int8_t& b, int32_t& c) { From d205bc5de70400f78078fc3c4c9de47f6139b998 Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Wed, 20 Dec 2023 04:53:44 -0500 Subject: [PATCH 19/45] Choose constant according precision --- .../block_fmha_pipeline_qr_ks_vs.hpp | 8 ++------ .../block_fmha_pipeline_qr_ks_vs_async.hpp | 8 ++------ include/ck/utility/math.hpp | 20 +++++++++++++++++++ 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index ae74d269e..eb58c0808 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -20,10 +20,6 @@ #include "ck/tile_program/block_tile/block_reduce.hpp" #include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" -#ifndef C_LOG2E -#define C_LOG2E 1.44269504088896340736 // log2(e) -#endif - namespace ck { namespace tile_program { namespace block { @@ -253,10 +249,10 @@ struct BlockFmhaPipelineQRKSVS tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_FMHA_FWD_FAST_EXP2 - x = scale * x + type_convert(bias_element_func(y)); + x = scale * x + type_convert(bias_element_func(y)); #else x = scale * x + - C_LOG2E * type_convert(bias_element_func(y)); + math::log2e_v * type_convert(bias_element_func(y)); #endif }, s_acc, diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index c4ff4efaf..c53146fe5 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -20,10 +20,6 @@ #include "ck/tile_program/block_tile/block_reduce.hpp" #include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" -#ifndef C_LOG2E -#define C_LOG2E 1.44269504088896340736 // log2(e) -#endif - namespace ck { namespace tile_program { namespace block { @@ -330,10 +326,10 @@ struct BlockFmhaPipelineQRKSVSAsync tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_FMHA_FWD_FAST_EXP2 - x = scale * x + type_convert(bias_element_func(y)); + x = scale * x + type_convert(bias_element_func(y)); #else x = scale * x + - C_LOG2E * type_convert(bias_element_func(y)); + math::log2e_v * type_convert(bias_element_func(y)); #endif }, s_acc, diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index e654f7dfd..3095faf6c 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -251,5 +251,25 @@ __host__ __device__ constexpr bool is_power_of_two_integer(int32_t x) return x == (1 << integer_log2_floor(x)); } +#ifndef C_LOG2E +#define C_LOG2E 1.44269504088896340736 // log2(e) +#endif + +template +struct log2e; + +template <> +struct log2e { + static constexpr double value = C_LOG2E; +}; + +template <> +struct log2e { + static constexpr float value = C_LOG2E; +}; + +template +inline constexpr T log2e_v = log2e::value; + } // namespace math } // namespace ck From 7ddff7b540075bd98faec462f070d11db81b56d0 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Wed, 20 Dec 2023 10:11:47 +0000 Subject: [PATCH 20/45] Avoid inefficient instruction --- .../block_fmha_pipeline_qr_ks_vs.hpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index eb58c0808..117796636 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -197,6 +197,11 @@ struct BlockFmhaPipelineQRKSVS store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); } + + __builtin_amdgcn_sched_barrier(0); + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0_loops > 2) { static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { @@ -216,9 +221,8 @@ struct BlockFmhaPipelineQRKSVS }); } - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile - { // tail + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail block_sync_lds(); gemm_0(s_acc, get_slice_tile(q_tile, @@ -246,13 +250,15 @@ struct BlockFmhaPipelineQRKSVS } else { + __builtin_amdgcn_sched_barrier(0); + tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_FMHA_FWD_FAST_EXP2 x = scale * x + type_convert(bias_element_func(y)); #else - x = scale * x + - math::log2e_v * type_convert(bias_element_func(y)); + x = scale * x + math::log2e_v * + type_convert(bias_element_func(y)); #endif }, s_acc, From 6b888b6e1706c5acd90b932448e20991510bdd5d Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Wed, 20 Dec 2023 13:33:38 +0000 Subject: [PATCH 21/45] Remove sched_barrier() for non-bias mode --- .../block_fmha_pipeline_qr_ks_vs.hpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 117796636..f0b127b57 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -198,9 +198,17 @@ struct BlockFmhaPipelineQRKSVS k_block_tile = load_tile(k_dram_window); } - __builtin_amdgcn_sched_barrier(0); + if constexpr(!is_null_tile_window(bias_dram_window)) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - __builtin_amdgcn_sched_barrier(0); + if constexpr(!is_null_tile_window(bias_dram_window)) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } if constexpr(k0_loops > 2) { @@ -250,8 +258,6 @@ struct BlockFmhaPipelineQRKSVS } else { - __builtin_amdgcn_sched_barrier(0); - tile_elementwise_inout( [&](auto& x, const auto& y) { #if !CK_FMHA_FWD_FAST_EXP2 From 3913a4001fcc6e5b9cc4ebe244aada65847ca676 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 29 Dec 2023 20:31:52 +0800 Subject: [PATCH 22/45] WIP add generic masking (#59) * WIP add generic masking * now local is not correct * fix bug in local atn * support when a whole row is masked * fix a bug in local attn --- example/91_tile_program/fmha_fwd.cpp | 307 +++++++++++++----- example/91_tile_program/fmha_fwd_kernel.hpp | 270 ++++++--------- .../reference_batched_masking.hpp | 26 +- .../reference_batched_softmax.hpp | 5 +- .../tile_program/block_tile/block_masking.hpp | 208 ++++++++++++ .../block_masking_specialization.hpp | 104 ------ .../block_fmha_pipeline_problem.hpp | 4 +- .../block_fmha_pipeline_qr_ks_vs.hpp | 101 +++--- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 155 +++++---- .../ck/tile_program/tile/null_tile_window.hpp | 6 +- include/ck/utility/math.hpp | 10 +- include/ck/utility/static_switch.hpp | 6 +- 12 files changed, 687 insertions(+), 515 deletions(-) create mode 100644 include/ck/tile_program/block_tile/block_masking.hpp delete mode 100644 include/ck/tile_program/block_tile/block_masking_specialization.hpp diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp index ed0d4ce67..a8132027d 100644 --- a/example/91_tile_program/fmha_fwd.cpp +++ b/example/91_tile_program/fmha_fwd.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -27,7 +28,7 @@ #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" -#include "ck/tile_program/block_tile/block_masking_specialization.hpp" +#include "ck/tile_program/block_tile/block_masking.hpp" #include "ck/tile_program/tile/tile_fmha_shape.hpp" #include "ck/tile_program/tile/tile_fmha_traits.hpp" @@ -108,7 +109,7 @@ struct FmhaShape // using FmhaMask = ck::tile_program::block::MaskUpperTriangleFromTopLeftPredicate; // using FmhaMask = ck::tile_program::block::MaskUpperTriangleFromBottomRightPredicate; -using FmhaMask = ck::tile_program::block::MaskDisabledPredicate; +// using FmhaMask = ck::tile_program::block::MaskDisabledPredicate; inline constexpr bool kM0NeedPadding = false; inline constexpr bool kN0K1NeedPadding = false; @@ -121,7 +122,7 @@ using FmhaTraits = ck::tile_program::TileFmhaTraits using FmhaTilePartitioner = FmhaFwdTilePartitioner>; -template +template using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem>; -template +template using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync< - FmhaPipelineProblem>; + FmhaPipelineProblem>; using FmhaEpilogue = FmhaFwdEpilogue>; -template -using FmhaKernel = FmhaFwdKernel, - FmhaPipeline, - FmhaEpilogue>; +enum class mask_enum +{ + no_mask = 0, + causal_top_left, + causal_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck::index_t y, x; + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::causal_top_left) + os << "tl"; + else if(type == mask_enum::causal_bottom_right) + os << "br"; + else + { + os << "g(" << y << "/" << x << ")"; + } + } + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); +}; + +std::ostream& operator<<(std::ostream& os, const mask_info& mi) +{ + mi.serialize(os); + return os; +} + +mask_info decode_mask_info(std::string str, ck::index_t seqlen_q, ck::index_t seqlen_k) +{ + ck::index_t x_total = seqlen_k; + ck::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.At(ck::Number<0>{}); + tmp.x = r.At(ck::Number<1>{}); + } + else if(t == "b") + { + auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.At(ck::Number<0>{}); + tmp.x = r.At(ck::Number<1>{}); + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + else + { + // should be 0, 1, 2 + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::causal_top_left) + { + tmp.y = seqlen_q; + tmp.x = 1; + } + else if(tmp.type == mask_enum::causal_bottom_right) + { + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + } + } + return tmp; +} template -float invoker_fmha_kernel(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck::index_t batch, - ck::index_t nhead, - ck::index_t nhead_k, - ck::index_t seqlen_q, - ck::index_t seqlen_k, - ck::index_t hdim_q, - ck::index_t hdim_v, - ck::index_t max_seqlen_q, - float scale, - bool i_perm, - bool o_perm, - StreamConfig stream_config) +float invoke_fmha_kernel(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck::index_t batch, + ck::index_t nhead, + ck::index_t nhead_k, + ck::index_t seqlen_q, + ck::index_t seqlen_k, + ck::index_t hdim_q, + ck::index_t hdim_v, + ck::index_t max_seqlen_q, + float scale, + bool i_perm, + bool o_perm, + ck::index_t mask_y, + ck::index_t mask_x, + StreamConfig stream_config) { constexpr bool is_v_rowmajor = ck::is_same_v; @@ -232,7 +327,9 @@ float invoker_fmha_kernel(const void* q_ptr, nhead_stride_k, nhead_stride_v, nhead_stride_bias, - nhead_stride_o); + nhead_stride_o, + mask_y, + mask_x); } else { // create batch mode kernel arguments @@ -261,7 +358,9 @@ float invoker_fmha_kernel(const void* q_ptr, batch_stride_k, batch_stride_v, batch_stride_bias, - batch_stride_o); + batch_stride_o, + mask_y, + mask_x); } }(); @@ -278,6 +377,50 @@ float invoker_fmha_kernel(const void* q_ptr, kargs); // BatchStrideO } +template +struct fmha_fwd_kernel_invoker +{ + static constexpr ck::index_t HDim = HDim_; + // these args are used to select kernel. + // args that may passed as karg shoule use operator() + Mode mode; + bool use_bias; + mask_info mask; + + fmha_fwd_kernel_invoker(Mode mode_, bool use_bias_, mask_info mask_) + : mode(mode_), use_bias(use_bias_), mask(mask_) + { + } + + template + float operator()(Args&&... args) + { + float ave_time; + BOOL_SWITCH_2(mode == Mode::Group, kIsGroupMode, use_bias, kHasBias, [&] { + if(mask.type == mask_enum::no_mask) + { + using FmhaMask = ck::tile_program::block::GenericAttentionMask; + using Kernel = FmhaFwdKernel, + FmhaPipeline, + FmhaEpilogue>; + ave_time = invoke_fmha_kernel(std::forward(args)...); + } + else + { + BOOL_SWITCH(mask.type == mask_enum::window_generic, kIsLocal, [&]() { + using FmhaMask = ck::tile_program::block::GenericAttentionMask; + using Kernel = + FmhaFwdKernel, + FmhaPipeline, + FmhaEpilogue>; + ave_time = invoke_fmha_kernel(std::forward(args)...); + }); + } + }); + return ave_time; + } +}; + static inline int env_get_int(const char* var_name, int default_int) { char* v = getenv(var_name); @@ -309,6 +452,12 @@ auto create_args(int argc, char* argv[]) "if true, will be b*h*s*d, else b*s*h*d") .insert("operm", "1", "permute output") .insert("bias", "0", "add bias or not") + .insert("mask", + "0", + "0: no mask, 1: top-left, 2:bottom-right\n" + "'t:l,r', top-left local-attn with left right size\n" + "'b:l,r', bottom-r local-attn with left right size\n" + "'g:y,x', generic attention mask coordinate with y/x size\n") .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float"); bool result = arg_parser.parse(argc, argv); @@ -353,6 +502,8 @@ int main(int argc, char* argv[]) bool use_bias = arg_parser.get_uint32("bias"); + mask_info mask = decode_mask_info(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + int init_method = arg_parser.get_int("init"); int stream_warmup = env_get_int("CK_WARMUP", 5); @@ -465,66 +616,44 @@ int main(int argc, char* argv[]) }; // clang-format on - std::cout << "[" << mode << "] b:" << batch << ", h:" << nhead << ", h_k:" << nhead_k - << ", s:" << seqlen_q << ", s_k:" << seqlen_k << ", d:" << hdim_q - << ", d_v:" << hdim_v << ", scale:" << scale << ", i:" << layout_str(i_perm) - << ", o:" << layout_str(o_perm) << ", bias:" << use_bias + std::cout << "[" << mode << "|" << layout_str(i_perm) << "|" << layout_str(o_perm) + << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" + << seqlen_k << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale + << ", bias:" << use_bias << ", mask:" << mask << ", v:" << std::string(VLayout::name)[0] << std::flush; +#define INVOKE_FMHA_KERNEL(hdim_) \ + fmha_fwd_kernel_invoker{mode, use_bias, mask}(q_buf.GetDeviceBuffer(), \ + k_buf.GetDeviceBuffer(), \ + v_buf.GetDeviceBuffer(), \ + bias_buf.GetDeviceBuffer(), \ + o_buf.GetDeviceBuffer(), \ + seqstart_q.GetDeviceBuffer(), \ + seqstart_k.GetDeviceBuffer(), \ + nullptr, \ + batch, \ + nhead, \ + nhead_k, \ + shape_seqlen_q, \ + shape_seqlen_k, \ + hdim_q, \ + hdim_v, \ + max_seqlen_q, \ + scale, \ + i_perm, \ + o_perm, \ + mask.y, \ + mask.x, \ + stream_config) + float ave_time = 0; if(hdim_q == hdim_v && hdim_q == 64) { - BOOL_SWITCH_2(mode == Mode::Group, kIsGroupMode, use_bias, kHasBias, [&] { - using Kernel = FmhaKernel; - - ave_time = invoker_fmha_kernel(q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - nullptr, - batch, - nhead, - nhead_k, - shape_seqlen_q, - shape_seqlen_k, - hdim_q, - hdim_v, - max_seqlen_q, - scale, - i_perm, - o_perm, - stream_config); - }); + ave_time = INVOKE_FMHA_KERNEL(64); } else if(hdim_q == hdim_v && hdim_q == 128) { - BOOL_SWITCH_2(mode == Mode::Group, kIsGroupMode, use_bias, kHasBias, [&] { - using Kernel = FmhaKernel; - - ave_time = invoker_fmha_kernel(q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - nullptr, - batch, - nhead, - nhead_k, - shape_seqlen_q, - shape_seqlen_k, - hdim_q, - hdim_v, - max_seqlen_q, - scale, - i_perm, - o_perm, - stream_config); - }); + ave_time = INVOKE_FMHA_KERNEL(128); } else { @@ -609,10 +738,18 @@ int main(int argc, char* argv[]) s_host_ref, bias_host_ref, s_host_ref); } - reference_batched_masking(s_host_ref); + if(mask.type == mask_enum::no_mask) { + reference_batched_masking(s_host_ref, ck::tile_program::block::GenericAttentionMask{}); + } else if(mask.type == mask_enum::window_generic) { + reference_batched_masking(s_host_ref, + ck::tile_program::block::GenericAttentionMask{mask.y, mask.x, seqlen_q, seqlen_k}); + } else { + reference_batched_masking(s_host_ref, + ck::tile_program::block::GenericAttentionMask{mask.y, mask.x, seqlen_q, seqlen_k}); + } reference_batched_softmax(s_host_ref, p_host_ref); reference_batched_gemm(p_host_ref, v_host_ref, o_host_ref); - + Tensor o_host_result({nhead, real_seqlen_q, hdim_v}); // permute if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); diff --git a/example/91_tile_program/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha_fwd_kernel.hpp index ca8d0930b..390febd07 100644 --- a/example/91_tile_program/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha_fwd_kernel.hpp @@ -15,10 +15,6 @@ // P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] -#ifndef C_LOG2E -#define C_LOG2E 1.44269504088896340736 // log2(e) -#endif - template struct FmhaFwdKernel { @@ -40,60 +36,23 @@ struct FmhaFwdKernel static constexpr bool kM0NeedPadding = FmhaPipeline::kM0NeedPadding; static constexpr bool kN0K1NeedPadding = FmhaPipeline::kN0K1NeedPadding; static constexpr bool kHasBias = FmhaPipeline::kHasBias; + using FmhaMask = ck::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; - using C0MatrixMask = ck::tile_program::block::C0MatrixMask_impl< - ck::remove_cvref_t>; + // using C0MatrixMask = ck::tile_program::block::C0MatrixMask_impl< + // ck::remove_cvref_t>; private: + template // to avoid duplicated base class prblem, introduce an template arg struct EmptyKargs { }; + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. struct CommonKargs { - __host__ constexpr CommonKargs(const void* q_ptr_, - const void* k_ptr_, - const void* v_ptr_, - void* o_ptr_, - ck::index_t seqlen_q_, - ck::index_t seqlen_k_, - ck::index_t hdim_q_, - ck::index_t hdim_v_, - ck::index_t nhead_ratio_qk_, - float scale_, - ck::index_t stride_q_, - ck::index_t stride_k_, - ck::index_t stride_v_, - ck::index_t stride_o_, - ck::index_t nhead_stride_q_, - ck::index_t nhead_stride_k_, - ck::index_t nhead_stride_v_, - ck::index_t nhead_stride_o_) - : q_ptr{reinterpret_cast(q_ptr_)}, - k_ptr{reinterpret_cast(k_ptr_)}, - v_ptr{reinterpret_cast(v_ptr_)}, - o_ptr{reinterpret_cast(o_ptr_)}, - seqlen_q{seqlen_q_}, - seqlen_k{seqlen_k_}, - hdim_q{hdim_q_}, - hdim_v{hdim_v_}, - nhead_ratio_qk{nhead_ratio_qk_}, -#if CK_FMHA_FWD_FAST_EXP2 - scale{static_cast(scale_ * C_LOG2E)}, -#else - scale{scale_}, -#endif - stride_q{stride_q_}, - stride_k{stride_k_}, - stride_v{stride_v_}, - stride_o{stride_o_}, - nhead_stride_q{nhead_stride_q_}, - nhead_stride_k{nhead_stride_k_}, - nhead_stride_v{nhead_stride_v_}, - nhead_stride_o{nhead_stride_o_} - { - } - const QDataType* q_ptr; const KDataType* k_ptr; const VDataType* v_ptr; @@ -132,107 +91,25 @@ struct FmhaFwdKernel ck::index_t batch_stride_bias = 0; }; - struct BatchModeKargs : CommonKargs, - std::conditional_t + struct MaskKargs { - __host__ constexpr BatchModeKargs(const void* q_ptr_, - const void* k_ptr_, - const void* v_ptr_, - void* o_ptr_, - ck::index_t seqlen_q_, - ck::index_t seqlen_k_, - ck::index_t hdim_q_, - ck::index_t hdim_v_, - ck::index_t nhead_ratio_qk_, - float scale_, - ck::index_t stride_q_, - ck::index_t stride_k_, - ck::index_t stride_v_, - ck::index_t stride_o_, - ck::index_t nhead_stride_q_, - ck::index_t nhead_stride_k_, - ck::index_t nhead_stride_v_, - ck::index_t nhead_stride_o_, - ck::index_t batch_stride_q_, - ck::index_t batch_stride_k_, - ck::index_t batch_stride_v_, - ck::index_t batch_stride_o_) - : CommonKargs{q_ptr_, - k_ptr_, - v_ptr_, - o_ptr_, - seqlen_q_, - seqlen_k_, - hdim_q_, - hdim_v_, - nhead_ratio_qk_, - scale_, - stride_q_, - stride_k_, - stride_v_, - stride_o_, - nhead_stride_q_, - nhead_stride_k_, - nhead_stride_v_, - nhead_stride_o_}, - batch_stride_q{batch_stride_q_}, - batch_stride_k{batch_stride_k_}, - batch_stride_v{batch_stride_v_}, - batch_stride_o{batch_stride_o_} - { - } + ck::index_t mask_y, mask_x; + }; + struct BatchModeKargs : CommonKargs, + std::conditional_t>, + std::conditional_t> + { ck::index_t batch_stride_q; ck::index_t batch_stride_k; ck::index_t batch_stride_v; ck::index_t batch_stride_o; }; - struct GroupModeKargs : CommonKargs, std::conditional_t + struct GroupModeKargs : CommonKargs, + std::conditional_t>, + std::conditional_t> { - __host__ constexpr GroupModeKargs(const void* q_ptr_, - const void* k_ptr_, - const void* v_ptr_, - void* o_ptr_, - const void* seqstart_q_ptr_, - const void* seqstart_k_ptr_, - const void* seqlen_k_ptr_, - ck::index_t hdim_q_, - ck::index_t hdim_v_, - ck::index_t nhead_ratio_qk_, - float scale_, - ck::index_t stride_q_, - ck::index_t stride_k_, - ck::index_t stride_v_, - ck::index_t stride_o_, - ck::index_t nhead_stride_q_, - ck::index_t nhead_stride_k_, - ck::index_t nhead_stride_v_, - ck::index_t nhead_stride_o_) - : CommonKargs{q_ptr_, - k_ptr_, - v_ptr_, - o_ptr_, - -1 /* will be updated inside the kernel */, - -1 /* will be updated inside the kernel */, - hdim_q_, - hdim_v_, - nhead_ratio_qk_, - scale_, - stride_q_, - stride_k_, - stride_v_, - stride_o_, - nhead_stride_q_, - nhead_stride_k_, - nhead_stride_v_, - nhead_stride_o_}, - seqstart_q_ptr{reinterpret_cast(seqstart_q_ptr_)}, - seqstart_k_ptr{reinterpret_cast(seqstart_k_ptr_)}, - seqlen_k_ptr{reinterpret_cast(seqlen_k_ptr_)} - { - } - const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; @@ -267,13 +144,38 @@ struct FmhaFwdKernel ck::index_t batch_stride_k, ck::index_t batch_stride_v, ck::index_t batch_stride_bias, - ck::index_t batch_stride_o) + ck::index_t batch_stride_o, + ck::index_t mask_y, + ck::index_t mask_x) { - Kargs kargs{q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q, - seqlen_k, hdim_q, hdim_v, nhead_ratio_qk, scale, - stride_q, stride_k, stride_v, stride_o, nhead_stride_q, - nhead_stride_k, nhead_stride_v, nhead_stride_o, batch_stride_q, batch_stride_k, - batch_stride_v, batch_stride_o}; + Kargs kargs{{reinterpret_cast(q_ptr), + reinterpret_cast(k_ptr), + reinterpret_cast(v_ptr), + reinterpret_cast(o_ptr), + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + nhead_ratio_qk, +#if CK_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck::math::log2e_v<>), +#else + scale, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; if constexpr(kHasBias) { @@ -283,6 +185,12 @@ struct FmhaFwdKernel kargs.batch_stride_bias = batch_stride_bias; } + if constexpr(kHasMask) + { + kargs.mask_y = mask_y; + kargs.mask_x = mask_x; + } + return kargs; } @@ -308,27 +216,37 @@ struct FmhaFwdKernel ck::index_t nhead_stride_k, ck::index_t nhead_stride_v, ck::index_t nhead_stride_bias, - ck::index_t nhead_stride_o) + ck::index_t nhead_stride_o, + ck::index_t mask_y, + ck::index_t mask_x) { - Kargs kargs{q_ptr, - k_ptr, - v_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - nhead_ratio_qk, - scale, - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_o}; + Kargs kargs{{reinterpret_cast(q_ptr), + reinterpret_cast(k_ptr), + reinterpret_cast(v_ptr), + reinterpret_cast(o_ptr), + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + nhead_ratio_qk, +#if CK_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck::math::log2e_v<>), +#else + scale, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; if constexpr(kHasBias) { @@ -336,6 +254,11 @@ struct FmhaFwdKernel kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; } + if constexpr(kHasMask) + { + kargs.mask_y = mask_y; + kargs.mask_x = mask_x; + } return kargs; } @@ -582,17 +505,22 @@ struct FmhaFwdKernel } }(); - C0MatrixMask casual_mask{kargs.seqlen_q, kargs.seqlen_k}; + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k}; + else + return FmhaMask{0, 0, kargs.seqlen_q, kargs.seqlen_k}; + }(); auto o_acc_tile = FmhaPipeline{}(q_dram_window, k_dram_window, v_dram_window, bias_dram_window, - casual_mask, + mask, kargs.scale, - ck::math::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0), - ck::math::integer_divide_ceil(kargs.hdim_q, FmhaPipeline::kK0), + // ck::math::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0), + // ck::math::integer_divide_ceil(kargs.hdim_q, FmhaPipeline::kK0), smem_ptr); // O DRAM and O DRAM window diff --git a/example/91_tile_program/reference_batched_masking.hpp b/example/91_tile_program/reference_batched_masking.hpp index 3351dcd4f..5fc54457d 100644 --- a/example/91_tile_program/reference_batched_masking.hpp +++ b/example/91_tile_program/reference_batched_masking.hpp @@ -5,39 +5,21 @@ #include "ck/utility/common_header.hpp" #include "ck/library/utility/host_tensor.hpp" -#include "ck/tile_program/block_tile/block_masking_specialization.hpp" +#include "ck/tile_program/block_tile/block_masking.hpp" template -void reference_batched_masking(Tensor& c_b_m_n) +void reference_batched_masking(Tensor& c_b_m_n, const MaskingType& mask) { const int M = c_b_m_n.mDesc.GetLengths()[1]; const int N = c_b_m_n.mDesc.GetLengths()[2]; - const int MNDiff = M - N; - auto f = [&](auto batch) { for(int n = 0; n < N; ++n) { for(int m = 0; m < M; ++m) { - if constexpr(std::is_same_v< - MaskingType, - ck::tile_program::block::MaskUpperTriangleFromTopLeftPredicate>) - { - if(n > m) - { - c_b_m_n(batch, m, n) = -ck::NumericLimits::Infinity(); - } - } - else if constexpr(std::is_same_v) - { - if(n > m - MNDiff) - { - c_b_m_n(batch, m, n) = -ck::NumericLimits::Infinity(); - } - } + if(mask.IsOutOfBound(m, n)) + c_b_m_n(batch, m, n) = -ck::NumericLimits::Infinity(); } } }; diff --git a/example/91_tile_program/reference_batched_softmax.hpp b/example/91_tile_program/reference_batched_softmax.hpp index a9fa3f103..e707db576 100644 --- a/example/91_tile_program/reference_batched_softmax.hpp +++ b/example/91_tile_program/reference_batched_softmax.hpp @@ -32,13 +32,16 @@ void reference_batched_softmax(const Tensor& a_b_m_n, Tensor(a_b_m_n(batch, m, n)); b_b_m_n(batch, m, n) = - ck::type_convert(ck::math::exp(v_a - v_max) / v_exp_sum); + ck::type_convert(ck::math::exp(v_a - v_max) * inv_sum); } }; diff --git a/include/ck/tile_program/block_tile/block_masking.hpp b/include/ck/tile_program/block_tile/block_masking.hpp new file mode 100644 index 000000000..da074b80c --- /dev/null +++ b/include/ck/tile_program/block_tile/block_masking.hpp @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tile_program { +namespace block { + +// clang-format off +/* Generic Attention Mask Coordinate + use x(horizontal axis), y(vertical axis) to describe mask. + top-left corner is origin + + x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask) + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + 1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + 1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + l=7,-1/r=0(tl) l=7,-1/r=0(br) + + x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2 + 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 + 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 + * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 + * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1 + l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl) + l=4/r=0(br) l=4/r=2(br) l=4/r=4(br) + + x=4/y=-1 x=6/y=-1 x=8/y=-1 + * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 + * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 + * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 + * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1 + * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1 + + x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r) + * * * * * * * * 1 * * * * * * * + * * * * * * * * 1 1 * * 1 * * * + * * * * * * * * 1 1 1 * 1 1 * * + 1 * * * * * * * 1 1 1 1 1 1 1 * + 1 1 * * * * * * 1 1 1 1 1 1 1 1 + + Validations: + x + y > 1 (x + y >= 2) + + Note: + y = seq_q, x = 1 -> top-left + y = seq_q, x = seq_k - seq_q + 1 -> bottom-right + y < seq_q, x < seq_k -> local-attn + y = seq_q, x = seq_k -> no mask + +*/ +// clang-format on +template +struct GenericAttentionMask +{ + static constexpr bool IsMasking = IsMasking_; // false will disable masking + static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask, + // else only upper-right could have mask + + __host__ __device__ GenericAttentionMask() : y(0), x(0), y_total(0), x_total(0) {} + + __host__ __device__ + GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + { + } + template + __host__ __device__ GenericAttentionMask(const MaskCoordinates& mask_coord) + : y(mask_coord.At(Number<0>{})), + x(mask_coord.At(Number<1>{})), + y_total(mask_coord.At(Number<2>{})), + x_total(mask_coord.At(Number<3>{})) + { + } + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (like k-seqlen loopover) + // TODO: x_end still could be negative, so end-start could be negative(need check) + template + __host__ __device__ constexpr auto + GetTileRangeAlongX(index_t i_y, Number, Number) const + { + if constexpr(!IsMasking) + { + return ck::make_tuple(0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + if constexpr(IsLocal) + { + index_t tmp = math::max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + } + else + { + return 0; + } + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = math::min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + return ck::make_tuple(x_start, x_end); + } + } + + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) + __host__ __device__ constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + { + return false; + } + else + { + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + i_y + 1; + index_t x_end = i_y + x; + + if constexpr(IsLocal) + { + return i_x < x_start || i_x >= x_end; + } + else + { + return i_x >= x_end; + } + } + } + + // if current tile is at the edge, means need per-pixel mask check. + // otherwise no need to check per-pixel + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // can be used as a fast-path to decide if do per-pixel check or not + template + __host__ __device__ constexpr auto + IsEdgeTile(index_t i_y, index_t i_x, Number, Number) const + { + if constexpr(IsLocal) + { + // check top-right corner > x or left-borrom corner < x + bool top_right_edge = (i_x + XTile) > (x + i_y); + bool bottom_left_edge = (i_y + YTile) > (y + i_x); + return top_right_edge || bottom_left_edge; + } + else + { + // only need to check top-right corner > x + bool top_right_edge = (i_x + XTile) > (x + i_y); + return top_right_edge; + } + } + + private: + index_t y, x; + index_t y_total, x_total; +}; + +} // namespace block +} // namespace tile_program + +// TODO: prefer use this function in host code +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +__host__ constexpr auto +make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, + index_t right_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + index_t x = 0, y = 0; + + if(is_top_left) + { + if(left_size < 0) + left_size = y_total - 1; + if(right_size < 0) + right_size = x_total - 1; + + x = 1 + right_size; + y = left_size + 1; + } + else + { + if(left_size < 0) + left_size = x_total - 1; + if(right_size < 0) + right_size = y_total - 1; + + x = x_total - y_total + 1 + right_size; + y = y_total - x_total + 1 + left_size; + } + + return ck::make_tuple(y, x, y_total, x_total); +} +} // namespace ck diff --git a/include/ck/tile_program/block_tile/block_masking_specialization.hpp b/include/ck/tile_program/block_tile/block_masking_specialization.hpp deleted file mode 100644 index e9f067733..000000000 --- a/include/ck/tile_program/block_tile/block_masking_specialization.hpp +++ /dev/null @@ -1,104 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -namespace ck { -namespace tile_program { -namespace block { - -struct MaskDisabledPredicate -{ - __host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const - { - return false; - }; - - __host__ __device__ constexpr bool - IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const - { - return false; - } -}; - -struct MaskUpperTriangleFromTopLeftPredicate -{ - __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; } - - __host__ __device__ constexpr bool - IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const - { - return operator()(m + m_tile - 1, n); - } -}; - -// eg: m = 3, n = 5 => offset = 2 -// so matrix(n > m + offset) = 0 -// 1 2 3 4 5 -// 1 * * * 0 0 -// 2 * * * * 0 -// 3 * * * * * -struct MaskUpperTriangleFromBottomRightPredicate -{ - __host__ __device__ void SetDiagonalOffset(const index_t diagonal_offset) - { - diagonal_offset_ = diagonal_offset; - } - __host__ __device__ constexpr bool operator()(index_t m, index_t n) const - { - return n > (m - diagonal_offset_); - } - - __host__ __device__ constexpr bool IsTileSkippable(index_t m_tile_orig, - index_t n_tile_orig, - index_t m_tile_size, - index_t /*n_tile_size*/) const - { - return operator()(m_tile_orig + m_tile_size - 1, n_tile_orig); - } - - private: - index_t diagonal_offset_; -}; - -// to track the points which need to be set to -inf on C0 -// Note: no need to reset M padding value, because they will not be stored out. -template -struct C0MatrixMask_impl -{ - using MaskOutPredicate = MaskOutPredicate_; - - __host__ __device__ C0MatrixMask_impl(index_t MRaw, index_t NRaw) - : NRaw_(NRaw), predicate_(MaskOutPredicate{}) - { - if constexpr(std::is_same_v) - { - predicate_.SetDiagonalOffset(MRaw - NRaw); - } - } - - __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const - { - return n >= NRaw_; - } - - __host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const - { - return predicate_(m, n) || IsNOutOfBound(n); - } - - __host__ __device__ constexpr bool - IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const - { - return predicate_.IsTileSkippable(m, n, m_tile, n_tile); - } - - private: - // index_t MRaw_; - index_t NRaw_; - MaskOutPredicate predicate_; -}; - -} // namespace block -} // namespace tile_program -} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp index e0c19af03..ed42ae937 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp @@ -22,7 +22,7 @@ template struct BlockFmhaPipelineProblem { @@ -36,7 +36,7 @@ struct BlockFmhaPipelineProblem using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; - using BlockFmhaMask = remove_cvref_t; + using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; static constexpr index_t kBlockSize = kBlockSize_; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index f0b127b57..854b1631e 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -16,7 +16,6 @@ #include "ck/tile_program/tile/slice_tile.hpp" #include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" -#include "ck/tile_program/block_tile/block_masking_specialization.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp" #include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" @@ -37,7 +36,7 @@ struct BlockFmhaPipelineQRKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; - using BlockFmhaMask = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -70,8 +69,7 @@ struct BlockFmhaPipelineQRKSVS typename QElementFunction, typename KElementFunction, typename VElementFunction, - typename BiasElementFunction, - typename CausalMask> + typename BiasElementFunction> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -81,10 +79,8 @@ struct BlockFmhaPipelineQRKSVS const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, - CausalMask causal_mask, + FmhaMask mask, float scale, - index_t num_total_loop, - index_t /*num_sub_loop_qk*/, // in this pipeline, the 1st gemm loop must be static void* smem_ptr) const { static_assert( @@ -153,21 +149,41 @@ struct BlockFmhaPipelineQRKSVS set_tile(m, NumericLimits::Lowest()); clear_tile(l); - auto k_dram_block_window = k_dram_block_window_tmp; - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), - v_dram_block_window_tmp.GetWindowLengths(), - v_dram_block_window_tmp.GetWindowOrigin(), - Policy::template MakeVDramTileDistribution()); + const auto q_origin = q_dram_window.GetWindowOrigin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.At(Number<0>{}), Number{}, Number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.GetBottomTensorView(), + k_dram_block_window_tmp.GetWindowLengths(), + {seqlen_k_start, 0}); - auto bias_dram_window = make_tile_window( + const auto bias_origin = bias_dram_block_window_tmp.GetWindowOrigin(); + auto bias_dram_window = make_tile_window( bias_dram_block_window_tmp.GetBottomTensorView(), bias_dram_block_window_tmp.GetWindowLengths(), - bias_dram_block_window_tmp.GetWindowOrigin(), + {bias_origin.At(Number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - const auto q_origin = q_dram_window.GetWindowOrigin(); - auto q_tile = tile_elementwise_in(q_element_func, q); + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), + v_dram_block_window_tmp.GetWindowLengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + auto q_tile = tile_elementwise_in(q_element_func, q); // prefetch K tile index_t i_total_loops = 0; @@ -175,13 +191,6 @@ struct BlockFmhaPipelineQRKSVS constexpr index_t k1_loops = kN0 / kK1; do { - const auto k_origin = k_dram_block_window.GetWindowOrigin(); - if(causal_mask.IsTileSkippable( - q_origin.At(Number<0>{}), k_origin.At(Number<0>{}), kM0, kN0)) - { - continue; - } - // STAGE 1, QK gemm auto k_dram_window = make_tile_window( k_dram_block_window.GetBottomTensorView(), @@ -271,16 +280,22 @@ struct BlockFmhaPipelineQRKSVS bias_tile); } move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kN0K1NeedPadding || - !is_same_v) + if constexpr(kN0K1NeedPadding || FmhaMask::IsMasking) { - set_tile_if( - s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { - const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); - const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); - - return causal_mask.IsMaskedElement(row, col); - }); + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), + k_origin.At(Number<0>{}), + Number{}, + Number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } } const auto s = cast_tile(s_acc); // S{j} @@ -418,7 +433,14 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - const auto tmp = 1 / l[i_idx]; + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); o_acc(i_j_idx) *= tmp; @@ -431,17 +453,14 @@ struct BlockFmhaPipelineQRKSVS template + typename BiasDramBlockWindowTmp> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - CausalMask causal_mask, + FmhaMask mask, float scale, - index_t num_total_loop, - index_t num_sub_loop_qk, void* smem_ptr) const { return operator()(q_dram_block_window_tmp, @@ -452,10 +471,8 @@ struct BlockFmhaPipelineQRKSVS identity{}, bias_dram_block_window_tmp, identity{}, - causal_mask, + mask, scale, - num_total_loop, - num_sub_loop_qk, smem_ptr); } }; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index c53146fe5..695e2bdcb 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -16,7 +16,6 @@ #include "ck/tile_program/tile/slice_tile.hpp" #include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" -#include "ck/tile_program/block_tile/block_masking_specialization.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp" #include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" @@ -37,7 +36,7 @@ struct BlockFmhaPipelineQRKSVSAsync using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; - using BlockFmhaMask = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -70,8 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsync typename QElementFunction, typename KElementFunction, typename VElementFunction, - typename BiasElementFunction, - typename CausalMask> + typename BiasElementFunction> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -81,10 +79,8 @@ struct BlockFmhaPipelineQRKSVSAsync const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, - CausalMask causal_mask, + FmhaMask mask, float scale, - index_t num_total_loop, - index_t /*num_sub_loop_qk*/, // in this pipeline, the 1st gemm loop must be static void* smem_ptr) const { static_assert( @@ -182,82 +178,64 @@ struct BlockFmhaPipelineQRKSVSAsync set_tile(m, NumericLimits::Lowest()); clear_tile(l); - auto k_dram_block_window = k_dram_block_window_tmp; - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), - v_dram_block_window_tmp.GetWindowLengths(), - v_dram_block_window_tmp.GetWindowOrigin(), - Policy::template MakeVDramTileDistribution()); - __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.GetWindowOrigin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.At(Number<0>{}), Number{}, Number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.GetBottomTensorView(), + k_dram_block_window_tmp.GetWindowLengths(), + {seqlen_k_start, 0}); + auto k_dram_window = make_tile_window( k_dram_block_window.GetBottomTensorView(), k_dram_block_window.GetWindowLengths(), k_dram_block_window.GetWindowOrigin(), Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load - - auto bias_dram_window = make_tile_window( + const auto bias_origin = bias_dram_block_window_tmp.GetWindowOrigin(); + auto bias_dram_window = make_tile_window( bias_dram_block_window_tmp.GetBottomTensorView(), bias_dram_block_window_tmp.GetWindowLengths(), - bias_dram_block_window_tmp.GetWindowOrigin(), + {bias_origin.At(Number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - const auto q_origin = q_dram_window.GetWindowOrigin(); - auto k_origin = k_dram_block_window.GetWindowOrigin(); - bool skip_tile = causal_mask.IsTileSkippable( - q_origin.At(Number<0>{}), k_origin.At(Number<0>{}), kM0, kN0); + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), + v_dram_block_window_tmp.GetWindowLengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); // prefetch K tile - if(!skip_tile) - { - async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - } + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); - if constexpr(std::is_same::value) - buffer_load_fence(k_dram_window.GetNumAccess()); - else - buffer_load_fence(0); // unconditionally wait for q if this is a mask kernel + buffer_load_fence(k_dram_window.GetNumAccess()); + auto q_tile = tile_elementwise_in(q_element_func, q); __builtin_amdgcn_sched_barrier(0); index_t i_total_loops = 0; constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k1_loops = kN0 / kK1; - auto prefetch_k = [&]() { - move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window = make_tile_window(k_dram_block_window.GetBottomTensorView(), - k_dram_block_window.GetWindowLengths(), - k_dram_block_window.GetWindowOrigin(), - Policy::template MakeKDramTileDistribution()); - - k_origin = k_dram_block_window.GetWindowOrigin(); - skip_tile = causal_mask.IsTileSkippable( - q_origin.At(Number<0>{}), k_origin.At(Number<0>{}), kM0, kN0); - if(!skip_tile) - { - if constexpr(k1_loops >= 2 && - LdsSeq.At(Number<0>{}) == LdsSeq.At(Number{})) - __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - } - }; // main loop do { - if(skip_tile) - { - i_total_loops++; - if(i_total_loops < num_total_loop) - prefetch_k(); - continue; - } - // STAGE 1, QK gemm clear_tile(s_acc); // Initialize C if constexpr(k0_loops > 1) @@ -328,24 +306,30 @@ struct BlockFmhaPipelineQRKSVSAsync #if !CK_FMHA_FWD_FAST_EXP2 x = scale * x + type_convert(bias_element_func(y)); #else - x = scale * x + - math::log2e_v * type_convert(bias_element_func(y)); + x = scale * x + math::log2e_v * + type_convert(bias_element_func(y)); #endif }, s_acc, bias_tile); } move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kN0K1NeedPadding || - !is_same_v) + if constexpr(kN0K1NeedPadding || FmhaMask::IsMasking) { - set_tile_if( - s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { - const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); - const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); - - return causal_mask.IsMaskedElement(row, col); - }); + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), + k_origin.At(Number<0>{}), + Number{}, + Number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } } const auto s = cast_tile(s_acc); // S{j} @@ -496,7 +480,18 @@ struct BlockFmhaPipelineQRKSVSAsync if(i_total_loops < num_total_loop) { // move K tile windows - prefetch_k(); + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window = + make_tile_window(k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); + + if constexpr(k1_loops >= 2 && + LdsSeq.At(Number<0>{}) == LdsSeq.At(Number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.At(Number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); } // tail { @@ -516,7 +511,14 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - const auto tmp = 1 / l[i_idx]; + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); o_acc(i_j_idx) *= tmp; @@ -529,17 +531,14 @@ struct BlockFmhaPipelineQRKSVSAsync template + typename BiasDramBlockWindowTmp> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - CausalMask causal_mask, + FmhaMask mask, float scale, - index_t num_total_loop, - index_t num_sub_loop_qk, void* smem_ptr) const { return operator()(q_dram_block_window_tmp, @@ -550,10 +549,8 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, bias_dram_block_window_tmp, identity{}, - causal_mask, + mask, scale, - num_total_loop, - num_sub_loop_qk, smem_ptr); } }; diff --git a/include/ck/tile_program/tile/null_tile_window.hpp b/include/ck/tile_program/tile/null_tile_window.hpp index d959d7154..2d873bcfc 100644 --- a/include/ck/tile_program/tile/null_tile_window.hpp +++ b/include/ck/tile_program/tile/null_tile_window.hpp @@ -64,8 +64,10 @@ __device__ constexpr auto make_null_tile_window(const WindowLengths& window_leng } template -__device__ constexpr auto -make_tile_window(NullTensorView, const WindowLengths& window_lengths, Ts&&...) +__device__ constexpr auto make_tile_window(NullTensorView, + const WindowLengths& window_lengths, + const MultiIndex& /*origin*/, + Ts&&...) { static_assert(is_known_at_compile_time::value, "wrong! lengths should be static"); diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index 3095faf6c..c4039bbcb 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -259,13 +259,15 @@ template struct log2e; template <> -struct log2e { - static constexpr double value = C_LOG2E; +struct log2e +{ + static constexpr double value = C_LOG2E; }; template <> -struct log2e { - static constexpr float value = C_LOG2E; +struct log2e +{ + static constexpr float value = C_LOG2E; }; template diff --git a/include/ck/utility/static_switch.hpp b/include/ck/utility/static_switch.hpp index 293487f72..9ddfed6a0 100644 --- a/include/ck/utility/static_switch.hpp +++ b/include/ck/utility/static_switch.hpp @@ -3,7 +3,7 @@ #pragma once -#define BOOL_SWITCH_1(COND1, CONST_NAME1, ...) \ +#define BOOL_SWITCH(COND1, CONST_NAME1, ...) \ [&] { \ if(COND1) \ { \ @@ -22,12 +22,12 @@ if(COND1) \ { \ constexpr bool CONST_NAME1 = true; \ - BOOL_SWITCH_1(COND2, CONST_NAME2, ##__VA_ARGS__); \ + BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \ } \ else \ { \ constexpr bool CONST_NAME1 = false; \ - BOOL_SWITCH_1(COND2, CONST_NAME2, ##__VA_ARGS__); \ + BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \ } \ }() From afea7392d59cbd71247336483f5cf190c0929866 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Wed, 3 Jan 2024 08:58:09 +0000 Subject: [PATCH 23/45] add __device__ to make_generic_attention_mask_coordinates_from_lr_window --- include/ck/tile_program/block_tile/block_masking.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck/tile_program/block_tile/block_masking.hpp b/include/ck/tile_program/block_tile/block_masking.hpp index da074b80c..0bb44c844 100644 --- a/include/ck/tile_program/block_tile/block_masking.hpp +++ b/include/ck/tile_program/block_tile/block_masking.hpp @@ -173,7 +173,7 @@ struct GenericAttentionMask // can convert from the FA style left/right to our generic coordinate // if left_size < 0 && right_size = 0, it is normal causal mask // local is left_size >=0 or right_size >=0 -__host__ constexpr auto +__host__ __device__ constexpr auto make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t right_size, index_t y_total, From 33a2ee1735806ee54248875c45bfd21fb64f6875 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Fri, 5 Jan 2024 18:28:51 +0000 Subject: [PATCH 24/45] Fix in block_masking.hpp --- include/ck/tile_program/block_tile/block_masking.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck/tile_program/block_tile/block_masking.hpp b/include/ck/tile_program/block_tile/block_masking.hpp index 0bb44c844..238f63f8b 100644 --- a/include/ck/tile_program/block_tile/block_masking.hpp +++ b/include/ck/tile_program/block_tile/block_masking.hpp @@ -119,13 +119,13 @@ struct GenericAttentionMask { if constexpr(!IsMasking) { - return false; + return i_x >= x_total; } else { // no need to do min/max here, since i_x will never be < 0 or >= x_total index_t x_start = -y + i_y + 1; - index_t x_end = i_y + x; + index_t x_end = math::min(i_y + x, x_total); if constexpr(IsLocal) { @@ -156,7 +156,7 @@ struct GenericAttentionMask else { // only need to check top-right corner > x - bool top_right_edge = (i_x + XTile) > (x + i_y); + bool top_right_edge = (i_x + XTile) > math::min(x + i_y, x_total); return top_right_edge; } } From b556a44dd9046b31db8fe528afe87151c6ac9ecf Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 5 Jan 2024 22:57:58 +0800 Subject: [PATCH 25/45] Re-organize example directories (#60) * Re-organize example directories * Move reference operation into sub-folder * Move mask types into dedicated files * Separate utility interface & implementation * Resume pipeline changes in fmha_fwd.cpp * Rename folder 'fmha_fwd' to 'fmha' * Move more function to utils.* * Remove 'fmha_fwd_kernel_invoker.hpp' * Re-format files * Move Kargs types into dedicated file * Fix formating * Fix compilation errors * Avoid instantiating unused types * Extract configurable codes * Add missing include directive * Instantiate template functions outside fmha_fwd.cpp * Separate implementation files * Merge config files * Merge duplicated code * Remove no-longer used file * Unify enum name * Extract no_mask kernel * Further separate template specializations * Use file(GLOB) to get file list * Include needed config file only once * Remove debug message * Add comment to explain template specializations * Move impl files under 'kernels' sub-folder * Only include *.inc in *.inc files * Add extra type arg to control selected kernel * Add kernel specializations for bf16 * Switch kernel according to cmdline options * Re-order type parameters * Reduce loop indent level * Instantiate launch_kernel() * Rename source files * Remove duplicated codes * Remove more duplicated codes * Clean up codes * Rename 'FmhaMaskType' to 'FmhaMasks' * Remove no-longer used include directive * Move template declarations into dedicated header * use python codegen * modify validation logic * format print and add smoke_test script * modify bf16 elimit, add benchmark script --------- Co-authored-by: carlushuang --- CMakeLists.txt | 2 + example/91_tile_program/CMakeLists.txt | 31 +- .../batched_gemm_softmax_gemm/CMakeLists.txt | 1 + .../batched_gemm_softmax_gemm.cpp | 4 +- .../batched_gemm_softmax_gemm.hpp | 2 +- .../{ => common}/arg_parser.hpp | 0 example/91_tile_program/fmha/CMakeLists.txt | 33 + example/91_tile_program/fmha/fmha_fwd.cpp | 494 +++++++++++ example/91_tile_program/fmha/fmha_fwd.hpp | 281 +++++++ .../{ => fmha}/fmha_fwd_epilogue.hpp | 0 .../{ => fmha}/fmha_fwd_kernel.hpp | 73 +- .../{ => fmha}/fmha_fwd_tile_partitioner.hpp | 0 .../91_tile_program/fmha/generate_kernels.py | 113 +++ example/91_tile_program/fmha/mask.hpp | 109 +++ .../91_tile_program/fmha/script/benchmark.sh | 25 + .../91_tile_program/fmha/script/smoke_test.sh | 23 + .../{fmha_utils.hpp => fmha/utils.hpp} | 72 +- example/91_tile_program/fmha_fwd.cpp | 775 ------------------ example/91_tile_program/gemm/CMakeLists.txt | 1 + example/91_tile_program/{ => gemm}/gemm.cpp | 2 +- example/91_tile_program/{ => gemm}/gemm.hpp | 0 .../91_tile_program/gemm_gemm/CMakeLists.txt | 1 + .../{ => gemm_gemm}/gemm_gemm.cpp | 2 +- .../{ => gemm_gemm}/gemm_gemm.hpp | 0 .../gemm_softmax_gemm/CMakeLists.txt | 1 + .../gemm_softmax_gemm.cpp | 4 +- .../gemm_softmax_gemm.hpp | 0 .../gemm_softmax_gemm_impl.hpp | 0 example/91_tile_program/im2col/CMakeLists.txt | 1 + .../91_tile_program/{ => im2col}/im2col.cpp | 50 +- example/91_tile_program/reduce/CMakeLists.txt | 1 + .../91_tile_program/{ => reduce}/reduce.cpp | 22 +- .../91_tile_program/{ => reduce}/reduce.hpp | 0 .../reference_batched_elementwise.hpp | 0 .../reference_batched_gemm.hpp | 0 .../reference_batched_masking.hpp | 0 .../reference_batched_softmax.hpp | 0 .../{ => reference}/reference_gemm.hpp | 0 .../reference/reference_im2col.hpp | 57 ++ .../reference/reference_reduce.hpp | 28 + .../{ => reference}/reference_softmax.hpp | 0 .../91_tile_program/softmax/CMakeLists.txt | 1 + .../91_tile_program/{ => softmax}/softmax.cpp | 2 +- .../91_tile_program/{ => softmax}/softmax.hpp | 0 .../tile_program/block_tile/block_masking.hpp | 4 + .../tile/static_distributed_tensor.hpp | 1 + include/ck/utility/buffer_view_declare.hpp | 4 +- 47 files changed, 1278 insertions(+), 942 deletions(-) create mode 100644 example/91_tile_program/batched_gemm_softmax_gemm/CMakeLists.txt rename example/91_tile_program/{ => batched_gemm_softmax_gemm}/batched_gemm_softmax_gemm.cpp (98%) rename example/91_tile_program/{ => batched_gemm_softmax_gemm}/batched_gemm_softmax_gemm.hpp (98%) rename example/91_tile_program/{ => common}/arg_parser.hpp (100%) create mode 100644 example/91_tile_program/fmha/CMakeLists.txt create mode 100644 example/91_tile_program/fmha/fmha_fwd.cpp create mode 100644 example/91_tile_program/fmha/fmha_fwd.hpp rename example/91_tile_program/{ => fmha}/fmha_fwd_epilogue.hpp (100%) rename example/91_tile_program/{ => fmha}/fmha_fwd_kernel.hpp (92%) rename example/91_tile_program/{ => fmha}/fmha_fwd_tile_partitioner.hpp (100%) create mode 100644 example/91_tile_program/fmha/generate_kernels.py create mode 100644 example/91_tile_program/fmha/mask.hpp create mode 100644 example/91_tile_program/fmha/script/benchmark.sh create mode 100644 example/91_tile_program/fmha/script/smoke_test.sh rename example/91_tile_program/{fmha_utils.hpp => fmha/utils.hpp} (62%) delete mode 100644 example/91_tile_program/fmha_fwd.cpp create mode 100644 example/91_tile_program/gemm/CMakeLists.txt rename example/91_tile_program/{ => gemm}/gemm.cpp (99%) rename example/91_tile_program/{ => gemm}/gemm.hpp (100%) create mode 100644 example/91_tile_program/gemm_gemm/CMakeLists.txt rename example/91_tile_program/{ => gemm_gemm}/gemm_gemm.cpp (99%) rename example/91_tile_program/{ => gemm_gemm}/gemm_gemm.hpp (100%) create mode 100644 example/91_tile_program/gemm_softmax_gemm/CMakeLists.txt rename example/91_tile_program/{ => gemm_softmax_gemm}/gemm_softmax_gemm.cpp (98%) rename example/91_tile_program/{ => gemm_softmax_gemm}/gemm_softmax_gemm.hpp (100%) rename example/91_tile_program/{ => gemm_softmax_gemm}/gemm_softmax_gemm_impl.hpp (100%) create mode 100644 example/91_tile_program/im2col/CMakeLists.txt rename example/91_tile_program/{ => im2col}/im2col.cpp (88%) create mode 100644 example/91_tile_program/reduce/CMakeLists.txt rename example/91_tile_program/{ => reduce}/reduce.cpp (83%) rename example/91_tile_program/{ => reduce}/reduce.hpp (100%) rename example/91_tile_program/{ => reference}/reference_batched_elementwise.hpp (100%) rename example/91_tile_program/{ => reference}/reference_batched_gemm.hpp (100%) rename example/91_tile_program/{ => reference}/reference_batched_masking.hpp (100%) rename example/91_tile_program/{ => reference}/reference_batched_softmax.hpp (100%) rename example/91_tile_program/{ => reference}/reference_gemm.hpp (100%) create mode 100644 example/91_tile_program/reference/reference_im2col.hpp create mode 100644 example/91_tile_program/reference/reference_reduce.hpp rename example/91_tile_program/{ => reference}/reference_softmax.hpp (100%) create mode 100644 example/91_tile_program/softmax/CMakeLists.txt rename example/91_tile_program/{ => softmax}/softmax.cpp (98%) rename example/91_tile_program/{ => softmax}/softmax.hpp (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index e780c1565..d132a3e4e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,8 @@ set(version 1.1.0) # Check support for CUDA/HIP in Cmake project(composable_kernel VERSION ${version}) +find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") if (DTYPES) diff --git a/example/91_tile_program/CMakeLists.txt b/example/91_tile_program/CMakeLists.txt index 9bc86eccb..123eabe3e 100644 --- a/example/91_tile_program/CMakeLists.txt +++ b/example/91_tile_program/CMakeLists.txt @@ -1,21 +1,12 @@ -add_example_executable(example_im2col im2col.cpp) -add_example_executable(example_gemm gemm.cpp) -add_example_executable(example_gemm_gemm gemm_gemm.cpp) -add_example_executable(example_reduce reduce.cpp) -add_example_executable(example_softmax softmax.cpp) -add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp) -add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp) -add_example_executable(example_fmha_fwd fmha_fwd.cpp) - -# NOTE: this is dangerous since will change the whole kernel to flush denormals -# WIP with compiler team for an exp2 intrinsic..., then remove this -if(NOT DEFINED FMHA_FWD_FAST_EXP2) - set(FMHA_FWD_FAST_EXP2 true) -endif() - -if(FMHA_FWD_FAST_EXP2) -set_source_files_properties(fmha_fwd.cpp PROPERTIES COMPILE_OPTIONS "-DCK_FMHA_FWD_FAST_EXP2=1;-fgpu-flush-denormals-to-zero") -else() -set_source_files_properties(fmha_fwd.cpp PROPERTIES COMPILE_OPTIONS "-DCK_FMHA_FWD_FAST_EXP2=0") -endif() +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) +add_subdirectory(batched_gemm_softmax_gemm) +add_subdirectory(fmha) +add_subdirectory(gemm) +add_subdirectory(gemm_gemm) +add_subdirectory(gemm_softmax_gemm) +add_subdirectory(im2col) +add_subdirectory(reduce) +add_subdirectory(softmax) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm/CMakeLists.txt b/example/91_tile_program/batched_gemm_softmax_gemm/CMakeLists.txt new file mode 100644 index 000000000..69fae2e10 --- /dev/null +++ b/example/91_tile_program/batched_gemm_softmax_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_batched_gemm_softmax_gemm batched_gemm_softmax_gemm.cpp) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm.cpp b/example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.cpp similarity index 98% rename from example/91_tile_program/batched_gemm_softmax_gemm.cpp rename to example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.cpp index f785ffcf9..8c6cbbf45 100644 --- a/example/91_tile_program/batched_gemm_softmax_gemm.cpp +++ b/example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.cpp @@ -13,8 +13,8 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_batched_gemm.hpp" -#include "reference_batched_softmax.hpp" +#include "reference/reference_batched_gemm.hpp" +#include "reference/reference_batched_softmax.hpp" #include "batched_gemm_softmax_gemm.hpp" int main(int argc, char* argv[]) diff --git a/example/91_tile_program/batched_gemm_softmax_gemm.hpp b/example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.hpp similarity index 98% rename from example/91_tile_program/batched_gemm_softmax_gemm.hpp rename to example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.hpp index f396222fe..179440a89 100644 --- a/example/91_tile_program/batched_gemm_softmax_gemm.hpp +++ b/example/91_tile_program/batched_gemm_softmax_gemm/batched_gemm_softmax_gemm.hpp @@ -17,7 +17,7 @@ #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck/tile_program/block_tile/block_reduce.hpp" -#include "gemm_softmax_gemm_impl.hpp" +#include "gemm_softmax_gemm/gemm_softmax_gemm_impl.hpp" // S[M0, N0] = Q[M0, K0] * K[N0, K0] // P[M0, N0] = Softmax(S[M0, N0]) diff --git a/example/91_tile_program/arg_parser.hpp b/example/91_tile_program/common/arg_parser.hpp similarity index 100% rename from example/91_tile_program/arg_parser.hpp rename to example/91_tile_program/common/arg_parser.hpp diff --git a/example/91_tile_program/fmha/CMakeLists.txt b/example/91_tile_program/fmha/CMakeLists.txt new file mode 100644 index 000000000..c10947e64 --- /dev/null +++ b/example/91_tile_program/fmha/CMakeLists.txt @@ -0,0 +1,33 @@ +# generate a list of kernels, but not actually emit files at config stage +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate_kernels.py + --list_kernels ${CMAKE_CURRENT_BINARY_DIR}/kernel_list.txt +) + +# NOTE: for cmake, the FMHA_FWD_GEN_KERNELS files must be in the same directory +# as current cmake list, otherwise will not figure out the dependency properly +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/kernel_list.txt FMHA_FWD_GEN_KERNELS) + +add_custom_command( + OUTPUT ${FMHA_FWD_GEN_KERNELS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate_kernels.py + --output_dir ${CMAKE_CURRENT_BINARY_DIR} +) + +add_example_executable(example_fmha_fwd fmha_fwd.cpp) +target_include_directories(example_fmha_fwd PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(example_fmha_fwd PRIVATE ${FMHA_FWD_GEN_KERNELS}) + +# NOTE: this is dangerous since will change the whole kernel to flush denormals +# WIP with compiler team for an exp2 intrinsic..., then remove this +if(NOT DEFINED FMHA_FWD_FAST_EXP2) + set(FMHA_FWD_FAST_EXP2 true) +endif() + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +# ... because they are auto-generated +if(FMHA_FWD_FAST_EXP2) +target_compile_options(example_fmha_fwd PRIVATE "-Wno-undefined-func-template;-DCK_FMHA_FWD_FAST_EXP2=1;-fgpu-flush-denormals-to-zero") +else() +target_compile_options(example_fmha_fwd PRIVATE "-Wno-undefined-func-template;-DCK_FMHA_FWD_FAST_EXP2=0") +endif() diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp new file mode 100644 index 000000000..a306d319c --- /dev/null +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -0,0 +1,494 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor/tensor_view.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/common_header.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +#include "common/arg_parser.hpp" +#include "fmha_fwd.hpp" +#include "mask.hpp" +#include "reference/reference_batched_elementwise.hpp" +#include "reference/reference_batched_gemm.hpp" +#include "reference/reference_batched_masking.hpp" +#include "reference/reference_batched_softmax.hpp" +#include "utils.hpp" + +auto create_args(int argc, char* argv[]) +{ + ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do cpu validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "0", + "num of head, for k/v, 0 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", "3328", "seqlen_q") + .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(seqlen)") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", "0", "add bias or not") + .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("mask", + "0", + "0: no mask, 1: top-left, 2:bottom-right\n" + "'t:l,r', top-left local-attn with left right size\n" + "'b:l,r', bottom-r local-attn with left right size\n" + "'g:y,x', generic attention mask coordinate with y/x size\n") + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +struct fmha_fwd_kernel_invoker +{ + static constexpr ck::index_t HDim = HDim_; + using DataType = DataType_; + // these args are used to select kernel. + // args that may passed as karg shoule use operator() + mode_enum mode; + bool use_bias; + mask_info mask; + + fmha_fwd_kernel_invoker(mode_enum mode_, bool use_bias_, mask_info mask_) + : mode(mode_), use_bias(use_bias_), mask(mask_) + { + } + + template + float operator()(const StreamConfig& stream, Args&&... args) + { + float ave_time; + BOOL_SWITCH_2(mode == mode_enum::group, kIsGroupMode, use_bias, kHasBias, [&] { + if(mask.type == mask_enum::no_mask) + { + using FmhaMask = FmhaMasks::NoMask; + using Kernel = + FmhaFwdKernelSelector; + + auto [kargs, grids] = + fmha_fwd_create_kargs_and_grids(std::forward(args)...); + ave_time = fmha_fwd_run(stream, kargs, grids); + } + else + { + BOOL_SWITCH(mask.type == mask_enum::window_generic, kIsLocal, [&]() { + using FmhaMask = ck::tile_program::block::GenericAttentionMask; + using Kernel = + FmhaFwdKernelSelector; + + auto [kargs, grids] = + fmha_fwd_create_kargs_and_grids(std::forward(args)...); + ave_time = fmha_fwd_run(stream, kargs, grids); + }); + } + }); + return ave_time; + } +}; + +// different threshold for different dtype +template +auto get_elimit(int /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(int init_method) +{ + if(init_method == 0) + { + double rtol = 1e-2; + double atol = 1e-2; + return ck::make_tuple(rtol, atol); + } + else + { + double rtol = 3e-3; + double atol = 3e-3; + return ck::make_tuple(rtol, atol); + } +} + +template +bool run(const ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck::index_t batch = arg_parser.get_int("b"); + ck::index_t nhead = arg_parser.get_int("h"); + ck::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k == 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + ck::index_t seqlen_q = arg_parser.get_int("s"); + ck::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k == 0) + seqlen_k = seqlen_q; + ck::index_t hdim_q = arg_parser.get_int("d"); + ck::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v == 0) + hdim_v = hdim_q; + + int i_perm = arg_parser.get_int("iperm"); // if true, will be batch * nhead * seqlen * hdim + int o_perm = arg_parser.get_int("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale = arg_parser.get_float("scale"); + if(scale == .0f) + scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + + bool use_bias = arg_parser.get_uint32("bias"); + + mask_info mask = decode_mask_info(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + + int init_method = arg_parser.get_int("init"); + + int stream_warmup = env_get_int("CK_WARMUP", 5); + int stream_repeat = env_get_int("CK_REPEAT", 20); + + StreamConfig stream_config{nullptr, true, 0, stream_warmup, stream_repeat}; + + const auto [seqlens_q, seqstart_q_host] = generate_seqlens_seqstarts_q(mode, batch, seqlen_q); + const std::vector seqstart_k_host = + generate_seqstarts_k(mode, batch, seqlen_k, seqlens_q, seqlen_q); + + using TypeConfig = FmhaFwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + using namespace ck::literals; + + flop += nhead * (2_uz * real_seqlen_q * real_seqlen_k * hdim_q + + 2_uz * real_seqlen_q * hdim_v * real_seqlen_k); + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k + + sizeof(ODataType) * real_seqlen_q * hdim_v); + } + } + + auto get_lengths = [&](int permute, + ck::index_t b /*batch*/, + ck::index_t h /*nhead*/, + ck::index_t s /*seqlen*/, + ck::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + constexpr bool is_v_rowmajor = ck::is_same_v; + + // host memory for storing all the tensor elements + const ck::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + const ck::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + + Tensor q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + Tensor k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + Tensor v_host( + is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); + // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host + // will not be used for verification at all (but will be copied to device anyway). + Tensor bias_host( + use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + Tensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + + if(init_method == 0) + { + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(bias_host); + } + else if(init_method == 1) + { + ck::utils::FillUniformDistribution{0.f, 1.f}(q_host); + ck::utils::FillUniformDistribution{0.f, 1.f}(k_host); + ck::utils::FillUniformDistribution{-.5f, .5f}(v_host); + ck::utils::FillUniformDistribution{0.f, 1.f}(bias_host); + } + else if(init_method == 2) + { + ck::utils::FillTrigValue{}(q_host); + ck::utils::FillTrigValue{}(k_host); + ck::utils::FillTrigValue{}(v_host); + ck::utils::FillTrigValue{}(bias_host); + } + + DeviceMem q_buf(q_host.GetElementSpaceSizeInBytes()); + DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes()); + DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes()); + DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes()); + DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes()); + DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + + // clang-format off + auto layout_str = [&](int permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](int iperm_, int operm_) { + if (iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias + << ", mask:" << mask << ", v:" << std::string(VLayout::name)[0] << std::flush; + +#define INVOKE_FMHA_KERNEL(hdim_) \ + fmha_fwd_kernel_invoker{mode, use_bias, mask}(stream_config, \ + q_buf.GetDeviceBuffer(), \ + k_buf.GetDeviceBuffer(), \ + v_buf.GetDeviceBuffer(), \ + bias_buf.GetDeviceBuffer(), \ + o_buf.GetDeviceBuffer(), \ + seqstart_q.GetDeviceBuffer(), \ + seqstart_k.GetDeviceBuffer(), \ + nullptr, \ + batch, \ + nhead, \ + nhead_k, \ + shape_seqlen_q, \ + shape_seqlen_k, \ + hdim_q, \ + hdim_v, \ + max_seqlen_q, \ + scale, \ + i_perm, \ + o_perm, \ + mask.y, \ + mask.x) + + float ave_time = 0; + if(hdim_q == hdim_v && hdim_q == 64) + { + ave_time = INVOKE_FMHA_KERNEL(64); + } + else if(hdim_q == hdim_v && hdim_q == 128) + { + ave_time = INVOKE_FMHA_KERNEL(128); + } + else + { + std::cerr << "not support hdim, will not run" << std::endl; + return false; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::endl; + return true; + } + + o_buf.FromDevice(o_host.data()); + + bool pass = true; + + for(ck::index_t wb = 0; wb < batch; ++wb) + { + const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + const auto v_host_ref_lengths = std::array{nhead, hdim_v, real_seqlen_k}; + const auto v_host_ref_strides = + is_v_rowmajor ? std::array{hdim_v * real_seqlen_k, 1, hdim_v} + : std::array{hdim_v * real_seqlen_k, real_seqlen_k, 1}; + + Tensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + Tensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + Tensor v_host_ref(v_host_ref_lengths, v_host_ref_strides); + Tensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + Tensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + Tensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + + ck::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + if constexpr (is_v_rowmajor) { + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + } + else { + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); + } + + // reference + reference_batched_gemm( + q_host_ref, k_host_ref, s_host_ref, + ck::identity{}, ck::identity{}, + [&](SaccDataType x) { return scale * x; }); + + if(use_bias) + { + Tensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, real_seqlen_k] + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) { + reference_batched_masking(s_host_ref, FmhaMasks::NoMask{}); + } else if(mask.type == mask_enum::window_generic) { + reference_batched_masking(s_host_ref, + FmhaMasks::GenericMask{mask.y, mask.x, seqlen_q, seqlen_k}); + } else { + reference_batched_masking(s_host_ref, + FmhaMasks::CausalMask{mask.y, mask.x, seqlen_q, seqlen_k}); + } + reference_batched_softmax(s_host_ref, p_host_ref); + reference_batched_gemm(p_host_ref, v_host_ref, o_host_ref); + + Tensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool cur_pass = ck::utils::check_err( + o_host_result, o_host_ref, std::string("Error: Incorrect results!"), rtol, atol); + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/91_tile_program/fmha/fmha_fwd.hpp b/example/91_tile_program/fmha/fmha_fwd.hpp new file mode 100644 index 000000000..8a3abd301 --- /dev/null +++ b/example/91_tile_program/fmha/fmha_fwd.hpp @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +#include "ck/tile_program/block_tile/block_masking.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" +#include "ck/tile_program/tile/tile_fmha_shape.hpp" +#include "ck/tile_program/tile/tile_fmha_traits.hpp" + +#include "fmha_fwd_epilogue.hpp" +#include "fmha_fwd_kernel.hpp" +#include "fmha_fwd_tile_partitioner.hpp" + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck::half_t; + using KDataType = ck::half_t; + using VDataType = ck::half_t; + using BiasDataType = ck::half_t; + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck::bhalf_t; + using KDataType = ck::bhalf_t; + using VDataType = ck::bhalf_t; + using BiasDataType = ck::bhalf_t; + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::bhalf_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::bhalf_t; +}; + +// default settings for FmhaFwdKernelSelector<> type alias +using VLayout = ck::tensor_layout::gemm::RowMajor; // (bs, nhead) seqlen * hdim +// using VLayout = ck::tensor_layout::gemm::ColumnMajor; // (bs, nhead) hdim * seqlen + +struct FmhaMasks +{ + using NoMask = ck::tile_program::block::GenericAttentionMask; + using GenericMask = ck::tile_program::block::GenericAttentionMask; + using CausalMask = ck::tile_program::block::GenericAttentionMask; +}; + +inline constexpr bool kM0NeedPadding = false; +inline constexpr bool kN0K1NeedPadding = false; + +template +struct FmhaBlockTile; + +template <> +struct FmhaBlockTile : ck::Sequence<128, 64, 32, 64, 32, 64> +{ +}; +template <> +struct FmhaBlockTile : ck::Sequence<128, 128, 32, 128, 32, 128> +{ +}; +using FmhaBlockWarps = ck::Sequence<4, 1, 1>; +using FmhaWarpTile = ck::Sequence<32, 32, 16>; + +template +struct FmhaShape; + +template <> +struct FmhaShape : ck::tile_program::TileFmhaShape, + FmhaBlockWarps, + FmhaWarpTile, + FmhaBlockWarps, + FmhaWarpTile, + VLayout> +{ +}; + +template <> +struct FmhaShape + : ck::tile_program::TileFmhaShape, + FmhaBlockWarps, + FmhaWarpTile, + FmhaBlockWarps, + FmhaWarpTile, + VLayout> +{ +}; + +template +using FmhaTraits = ck::tile_program::TileFmhaTraits; + +template +using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + /* BlockSize = */ 256, + FmhaShape, + kIsGroupMode, + FmhaMask, + FmhaTraits>; + +template +using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync< + FmhaPipelineProblem>; + +template +using FmhaEpilogue = + FmhaFwdEpilogue::OaccDataType, + typename FmhaFwdTypeConfig::ODataType>>; + +template +using FmhaFwdKernelSelector = + FmhaFwdKernel>, + FmhaPipeline, + FmhaEpilogue>; + +// Kernel API +template +auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck::index_t batch, + ck::index_t nhead, + ck::index_t nhead_k, + ck::index_t seqlen_q, + ck::index_t seqlen_k, + ck::index_t hdim_q, + ck::index_t hdim_v, + ck::index_t max_seqlen_q, + float scale, + bool i_perm, + bool o_perm, + ck::index_t mask_y, + ck::index_t mask_x) +{ + constexpr bool is_v_rowmajor = + ck::is_same_v; + + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias' + /// are 0. + // setup stride_* arguments + const ck::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck::index_t stride_v = [&]() { + if constexpr(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? seqlen_k : nhead_k * seqlen_k; + }(); + const ck::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k); + const ck::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q); + const ck::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q); + const ck::index_t nhead_stride_v = [&]() { + if constexpr(is_v_rowmajor) + return i_perm ? seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_k : seqlen_k; + }(); + const ck::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k); + const ck::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck::index_t batch_stride_q = (nhead * seqlen_q * hdim_q); + const ck::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q); + const ck::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k); + const ck::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k); + const ck::index_t batch_stride_o = (nhead * seqlen_q * hdim_v); + + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + nhead / nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + mask_y, + mask_x); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + nhead / nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_o, + mask_y, + mask_x); + } + }(); + + dim3 grids = FmhaKernel::GridSize(batch, nhead, max_seqlen_q, hdim_v); + return ck::make_tuple(kargs, grids); +} + +// will instantiate this function across different source file +template +float fmha_fwd_run(const StreamConfig&, typename FmhaKernel::Kargs, dim3); + +#define FMHA_FWD_KERNEL_DEFINE(KERNEL_) \ + template <> \ + float fmha_fwd_run( \ + const StreamConfig& stream, typename KERNEL_::Kargs kargs, dim3 grids) \ + { \ + constexpr dim3 blocks = KERNEL_::BlockSize(); \ + constexpr ck::index_t kBlockPerCu = KERNEL_::kBlockPerCu; \ + return launch_kernel(stream, KERNEL_{}, grids, blocks, 0, kargs); \ + } diff --git a/example/91_tile_program/fmha_fwd_epilogue.hpp b/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp similarity index 100% rename from example/91_tile_program/fmha_fwd_epilogue.hpp rename to example/91_tile_program/fmha/fmha_fwd_epilogue.hpp diff --git a/example/91_tile_program/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp similarity index 92% rename from example/91_tile_program/fmha_fwd_kernel.hpp rename to example/91_tile_program/fmha/fmha_fwd_kernel.hpp index 390febd07..09865ecab 100644 --- a/example/91_tile_program/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp @@ -39,24 +39,20 @@ struct FmhaFwdKernel using FmhaMask = ck::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; - // using C0MatrixMask = ck::tile_program::block::C0MatrixMask_impl< - // ck::remove_cvref_t>; - - private: template // to avoid duplicated base class prblem, introduce an template arg - struct EmptyKargs + struct FmhaFwdEmptyKargs { }; // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. - struct CommonKargs + struct FmhaFwdCommonKargs { - const QDataType* q_ptr; - const KDataType* k_ptr; - const VDataType* v_ptr; - ODataType* o_ptr; + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; ck::index_t seqlen_q; ck::index_t seqlen_k; @@ -79,26 +75,27 @@ struct FmhaFwdKernel ck::index_t nhead_stride_o; }; - struct CommonBiasKargs + struct FmhaFwdCommonBiasKargs { - const BiasDataType* bias_ptr = nullptr; + const void* bias_ptr = nullptr; ck::index_t stride_bias = 0; ck::index_t nhead_stride_bias = 0; }; - struct BatchModeBiasKargs : CommonBiasKargs + struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs { ck::index_t batch_stride_bias = 0; }; - struct MaskKargs + struct FmhaFwdMaskKargs { ck::index_t mask_y, mask_x; }; - struct BatchModeKargs : CommonKargs, - std::conditional_t>, - std::conditional_t> + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t> { ck::index_t batch_stride_q; ck::index_t batch_stride_k; @@ -106,17 +103,17 @@ struct FmhaFwdKernel ck::index_t batch_stride_o; }; - struct GroupModeKargs : CommonKargs, - std::conditional_t>, - std::conditional_t> + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; }; - public: - using Kargs = std::conditional_t; + using Kargs = std::conditional_t; template __host__ static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -148,10 +145,10 @@ struct FmhaFwdKernel ck::index_t mask_y, ck::index_t mask_x) { - Kargs kargs{{reinterpret_cast(q_ptr), - reinterpret_cast(k_ptr), - reinterpret_cast(v_ptr), - reinterpret_cast(o_ptr), + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, seqlen_q, seqlen_k, hdim_q, @@ -179,7 +176,7 @@ struct FmhaFwdKernel if constexpr(kHasBias) { - kargs.bias_ptr = reinterpret_cast(bias_ptr); + kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; kargs.batch_stride_bias = batch_stride_bias; @@ -220,10 +217,10 @@ struct FmhaFwdKernel ck::index_t mask_y, ck::index_t mask_x) { - Kargs kargs{{reinterpret_cast(q_ptr), - reinterpret_cast(k_ptr), - reinterpret_cast(v_ptr), - reinterpret_cast(o_ptr), + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, -1, // seqlen will be updated by another pointer -1, // hdim_q, @@ -250,7 +247,7 @@ struct FmhaFwdKernel if constexpr(kHasBias) { - kargs.bias_ptr = reinterpret_cast(bias_ptr); + kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; } @@ -360,18 +357,19 @@ struct FmhaFwdKernel } // for simplicity, batch stride we just modify the pointer - const QDataType* q_ptr = kargs.q_ptr + + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + static_cast(i_nhead) * kargs.nhead_stride_q + batch_offset_q; const KDataType* k_ptr = - kargs.k_ptr + + reinterpret_cast(kargs.k_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + batch_offset_k; const VDataType* v_ptr = - kargs.v_ptr + + reinterpret_cast(kargs.v_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + batch_offset_v; - ODataType* o_ptr = kargs.o_ptr + static_cast(i_nhead) * kargs.nhead_stride_o + + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + batch_offset_o; // Q/K/V DRAM and DRAM window @@ -481,7 +479,8 @@ struct FmhaFwdKernel if constexpr(kHasBias) { const BiasDataType* bias_ptr = - kargs.bias_ptr + static_cast(i_nhead_) * kargs.nhead_stride_bias + + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + batch_offset_bias; const auto bias_dram = [&]() { diff --git a/example/91_tile_program/fmha_fwd_tile_partitioner.hpp b/example/91_tile_program/fmha/fmha_fwd_tile_partitioner.hpp similarity index 100% rename from example/91_tile_program/fmha_fwd_tile_partitioner.hpp rename to example/91_tile_program/fmha/fmha_fwd_tile_partitioner.hpp diff --git a/example/91_tile_program/fmha/generate_kernels.py b/example/91_tile_program/fmha/generate_kernels.py new file mode 100644 index 000000000..60556745a --- /dev/null +++ b/example/91_tile_program/fmha/generate_kernels.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +import itertools +from pathlib import Path +from typing import List, Optional + +DTYPE_MAP = { + "fp16": "ck::half_t", + "bf16": "ck::bhalf_t", +} + +MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +HDIMS = [64, 128] +MASKS = ["no", "causal", "generic"] +DIRECTIONS = ["fwd"] +GEN_DIR = "" + +KERNEL_IMPL_TEMPLATE_FWD = """// auto generated by generate_kernels.py +#include "fmha_fwd.hpp" + +using kernel_0 = FmhaFwdKernelSelector<{HDIM}, {DTYPE}, {MODE}, {MASK}, true>; +FMHA_FWD_KERNEL_DEFINE(kernel_0) +using kernel_1 = FmhaFwdKernelSelector<{HDIM}, {DTYPE}, {MODE}, {MASK}, false>; +FMHA_FWD_KERNEL_DEFINE(kernel_1) +""" + +class Kernel: + def __init__(self, + direction: str, + hdim: int, + dtype: str, + mode : str, + mask : str): + self.direction = direction + self.hdim = hdim + self.dtype = dtype + self.mode = mode + self.mask = mask + + @property + def template(self) -> str: + if self.direction == "fwd": + return KERNEL_IMPL_TEMPLATE_FWD.format( + HDIM=self.hdim, DTYPE=DTYPE_MAP[self.dtype], + MODE=MODE_MAP[self.mode], MASK=MASK_MAP[self.mask]) + + @property + def filename(self) -> str: + return f"fmha_{self.direction}_hdim{self.hdim}_{self.dtype}_{self.mode}_{self.mask}_mask.cpp" + +def get_all_kernels() -> List[Kernel]: + for direction, hdim, dtype, mode, mask in itertools.product(DIRECTIONS, HDIMS, DTYPE_MAP.keys(), MODE_MAP.keys(), MASK_MAP.keys()): + yield Kernel(direction=direction, hdim=hdim, dtype=dtype, mode=mode, mask=mask) + +def write_single_kernel(kernel: Kernel, autogen_dir: Path) -> None: + credit = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +""" + (autogen_dir / kernel.filename).write_text(credit + kernel.template) + +def write_kernels(output_dir: Optional[str]) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) / GEN_DIR + + output_dir.mkdir(parents=True, exist_ok=True) + for kernel in get_all_kernels(): + write_single_kernel(kernel, output_dir) + +def list_kernels(to_file: Optional[str]) -> None: + assert to_file is not None + file_path = Path(to_file) + with file_path.open('a') as f: + for kernel in get_all_kernels(): + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate_kernels", + description="gen kernels for CK fmha kernel instances", + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="Where to generate the kernels " + " will default to the current directory ", + ) + parser.add_argument( + "-l", + "--list_kernels", + required=False, + help="list all the kernels to a file" + ) + args = parser.parse_args() + if args.list_kernels is not None: + list_kernels(args.list_kernels) + else: + write_kernels(args.output_dir) diff --git a/example/91_tile_program/fmha/mask.hpp b/example/91_tile_program/fmha/mask.hpp new file mode 100644 index 000000000..7e3a3fedd --- /dev/null +++ b/example/91_tile_program/fmha/mask.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tile_program/block_tile/block_masking.hpp" + +enum class mask_enum +{ + no_mask = 0, + causal_top_left, + causal_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck::index_t y, x; + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::causal_top_left) + os << "tl"; + else if(type == mask_enum::causal_bottom_right) + os << "br"; + else + { + os << "g(" << y << "/" << x << ")"; + } + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); +}; + +std::ostream& operator<<(std::ostream& os, const mask_info& mi) +{ + mi.serialize(os); + return os; +} + +mask_info decode_mask_info(std::string str, ck::index_t seqlen_q, ck::index_t seqlen_k) +{ + ck::index_t x_total = seqlen_k; + ck::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.At(ck::Number<0>{}); + tmp.x = r.At(ck::Number<1>{}); + } + else if(t == "b") + { + auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.At(ck::Number<0>{}); + tmp.x = r.At(ck::Number<1>{}); + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + else + { + // should be 0, 1, 2 + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::causal_top_left) + { + tmp.y = seqlen_q; + tmp.x = 1; + } + else if(tmp.type == mask_enum::causal_bottom_right) + { + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + } + } + return tmp; +} diff --git a/example/91_tile_program/fmha/script/benchmark.sh b/example/91_tile_program/fmha/script/benchmark.sh new file mode 100644 index 000000000..b3b089652 --- /dev/null +++ b/example/91_tile_program/fmha/script/benchmark.sh @@ -0,0 +1,25 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/example_fmha_fwd +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do + +$EXE -prec=$prec -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -v=$VALID + +$EXE -prec=$prec -b=32 -h=32 -d=64 -s=512 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=16 -h=32 -d=64 -s=1024 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=8 -h=32 -d=64 -s=2048 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=4 -h=32 -d=64 -s=4096 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=2 -h=32 -d=64 -s=8192 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=32 -d=64 -s=16384 -iperm=$perm -operm=$perm -v=$VALID + +done +done diff --git a/example/91_tile_program/fmha/script/smoke_test.sh b/example/91_tile_program/fmha/script/smoke_test.sh new file mode 100644 index 000000000..863e1aeb5 --- /dev/null +++ b/example/91_tile_program/fmha/script/smoke_test.sh @@ -0,0 +1,23 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/example_fmha_fwd + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 128 64 ; do +for bias in 0 1 ; do + +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -iperm=$perm -operm=$perm -v=1 +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=256 -bias=$bias -iperm=$perm -operm=$perm -v=1 +$EXE -prec=$prec -b=2 -h=2 -h_k=1 -d=$hdim -s=512 -s_k=256 -bias=$bias -iperm=$perm -operm=$perm -v=1 +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=256 -s_k=512 -bias=$bias -iperm=$perm -operm=$perm -v=1 +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -iperm=$perm -operm=$perm -mask=1 -v=1 +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -iperm=$perm -operm=$perm -mask=2 -v=1 +$EXE -prec=$prec -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -iperm=$perm -operm=$perm -mask=g:128,32 -v=1 + + +done +done +done +done diff --git a/example/91_tile_program/fmha_utils.hpp b/example/91_tile_program/fmha/utils.hpp similarity index 62% rename from example/91_tile_program/fmha_utils.hpp rename to example/91_tile_program/fmha/utils.hpp index df885b7e3..960e59a8f 100644 --- a/example/91_tile_program/fmha_utils.hpp +++ b/example/91_tile_program/fmha/utils.hpp @@ -1,31 +1,29 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include -#include +#pragma once + +#include #include #include -#include #include #include #include #include "ck/utility/span.hpp" -#pragma once - -enum class Mode : unsigned +enum class mode_enum { - Batch, - Group + batch = 0, + group }; -inline std::ostream& operator<<(std::ostream& stream, Mode mode) +std::ostream& operator<<(std::ostream& stream, mode_enum mode) { - return stream << (mode == Mode::Batch ? "batch" : "group"); + return stream << (mode == mode_enum::batch ? "batch" : "group"); } -inline std::vector to_seqstarts(ck::span seqlens) +std::vector to_seqstarts(ck::span seqlens) { std::vector seqstarts = {0}; for(int32_t seqlen : seqlens) @@ -36,16 +34,16 @@ inline std::vector to_seqstarts(ck::span seqlens) return seqstarts; } -inline std::vector generate_seqlens_q(Mode mode, - unsigned count, - int32_t seqlens_q_sum, - std::optional seed = std::nullopt) +std::vector generate_seqlens_q(mode_enum mode, + unsigned count, + int32_t seqlens_q_sum, + std::optional seed = std::nullopt) { assert(0 < count); std::vector seqlens_q(count, seqlens_q_sum); - if(mode == Mode::Group && 1 < count) + if(mode == mode_enum::group && 1 < count) { using size_type = std::vector::size_type; @@ -75,26 +73,29 @@ inline std::vector generate_seqlens_q(Mode mode, return seqlens_q; } -inline std::tuple, std::vector> generate_seqlens_seqstarts_q( - Mode mode, unsigned count, int32_t seqlens_q_sum, std::optional seed = std::nullopt) +std::tuple, std::vector> +generate_seqlens_seqstarts_q(mode_enum mode, + unsigned count, + int32_t seqlens_q_sum, + std::optional seed = std::nullopt) { const std::vector seqlens_q = generate_seqlens_q(mode, count, seqlens_q_sum, seed); return std::make_tuple(seqlens_q, to_seqstarts(seqlens_q)); } -inline std::vector generate_seqlens_k(Mode mode, - unsigned count, - int32_t seqlens_k_sum, - ck::span seqlens_q, - int32_t seqlens_q_sum, - std::optional seed = std::nullopt) +std::vector generate_seqlens_k(mode_enum mode, + unsigned count, + int32_t seqlens_k_sum, + ck::span seqlens_q, + int32_t seqlens_q_sum, + std::optional seed = std::nullopt) { assert(0 < count); assert(seqlens_q.size() == count); std::vector seqlens_k(count, seqlens_k_sum); - if(mode == Mode::Group && 1 < count) + if(mode == mode_enum::group && 1 < count) { using size_type = std::vector::size_type; @@ -127,13 +128,22 @@ inline std::vector generate_seqlens_k(Mode mode, return seqlens_k; } -inline std::vector generate_seqstarts_k(Mode mode, - unsigned count, - int32_t seqlens_k_sum, - ck::span seqlens_q, - int32_t seqlens_q_sum, - std::optional seed = std::nullopt) +std::vector generate_seqstarts_k(mode_enum mode, + unsigned count, + int32_t seqlens_k_sum, + ck::span seqlens_q, + int32_t seqlens_q_sum, + std::optional seed = std::nullopt) { return to_seqstarts( generate_seqlens_k(mode, count, seqlens_k_sum, seqlens_q, seqlens_q_sum, seed)); } + +int env_get_int(const char* var_name, int default_int) +{ + char* v = getenv(var_name); + int r = default_int; + if(v) + r = atoi(v); + return r; +} diff --git a/example/91_tile_program/fmha_fwd.cpp b/example/91_tile_program/fmha_fwd.cpp deleted file mode 100644 index a8132027d..000000000 --- a/example/91_tile_program/fmha_fwd.cpp +++ /dev/null @@ -1,775 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/tensor/tensor_view.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" - -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" -#include "ck/tile_program/block_tile/block_masking.hpp" -#include "ck/tile_program/tile/tile_fmha_shape.hpp" -#include "ck/tile_program/tile/tile_fmha_traits.hpp" - -#include "reference_batched_elementwise.hpp" -#include "reference_batched_gemm.hpp" -#include "reference_batched_masking.hpp" -#include "reference_batched_softmax.hpp" -#include "fmha_fwd_kernel.hpp" -#include "fmha_fwd_tile_partitioner.hpp" -#include "fmha_fwd_epilogue.hpp" -#include "fmha_utils.hpp" -#include "arg_parser.hpp" - -#if 1 -using QDataType = ck::half_t; -using KDataType = ck::half_t; -using VDataType = ck::half_t; -using BiasDataType = ck::half_t; -using SaccDataType = float; // data type for first gemm accumulation -using SMPLComputeDataType = float; // data type for reduction, softmax -using PDataType = ck::half_t; // data type for A matrix of second gemm -using OaccDataType = float; // data type for second gemm accumulation -using ODataType = ck::half_t; -#else -using QDataType = ck::bhalf_t; -using KDataType = ck::bhalf_t; -using VDataType = ck::bhalf_t; -using BiasDataType = ck::bhalf_t; -using SaccDataType = float; // data type for first gemm accumulation -using SMPLComputeDataType = float; // data type for reduction, softmax -using PDataType = ck::bhalf_t; // data type for A matrix of second gemm -using OaccDataType = float; // data type for second gemm accumulation -using ODataType = ck::bhalf_t; -#endif - -// M0 N0 K0 N1 K1 K0L -// using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>; -// using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>; -using VLayout = ck::tensor_layout::gemm::RowMajor; // (bs, nhead) seqlen * hdim -// using VLayout = ck::tensor_layout::gemm::ColumnMajor; // (bs, nhead) hdim * seqlen - -template -struct FmhaBlockTile; - -template <> -struct FmhaBlockTile : ck::Sequence<128, 64, 32, 64, 32, 64> -{ -}; -template <> -struct FmhaBlockTile : ck::Sequence<128, 128, 32, 128, 32, 128> -{ -}; -using FmhaBlockWarps = ck::Sequence<4, 1, 1>; -using FmhaWarpTile = ck::Sequence<32, 32, 16>; - -template -struct FmhaShape; - -template <> -struct FmhaShape : ck::tile_program::TileFmhaShape, - FmhaBlockWarps, - FmhaWarpTile, - FmhaBlockWarps, - FmhaWarpTile, - VLayout> -{ -}; -template <> -struct FmhaShape - : ck::tile_program::TileFmhaShape, - FmhaBlockWarps, - FmhaWarpTile, - FmhaBlockWarps, - FmhaWarpTile, - VLayout> -{ -}; - -// using FmhaMask = ck::tile_program::block::MaskUpperTriangleFromTopLeftPredicate; -// using FmhaMask = ck::tile_program::block::MaskUpperTriangleFromBottomRightPredicate; -// using FmhaMask = ck::tile_program::block::MaskDisabledPredicate; - -inline constexpr bool kM0NeedPadding = false; -inline constexpr bool kN0K1NeedPadding = false; -template -using FmhaTraits = ck::tile_program::TileFmhaTraits; - -template -using FmhaTilePartitioner = FmhaFwdTilePartitioner>; - -template -using FmhaPipelineProblem = - ck::tile_program::block::BlockFmhaPipelineProblem, - kIsGroupMode, - FmhaMask, - FmhaTraits>; - -template -using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync< - FmhaPipelineProblem>; - -using FmhaEpilogue = FmhaFwdEpilogue>; - -enum class mask_enum -{ - no_mask = 0, - causal_top_left, - causal_bottom_right, - window_generic, -}; - -struct mask_info -{ - mask_enum type; - ck::index_t y, x; - void serialize(std::ostream& os) const - { - if(type == mask_enum::no_mask) - os << "n"; - else if(type == mask_enum::causal_top_left) - os << "tl"; - else if(type == mask_enum::causal_bottom_right) - os << "br"; - else - { - os << "g(" << y << "/" << x << ")"; - } - } - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); -}; - -std::ostream& operator<<(std::ostream& os, const mask_info& mi) -{ - mi.serialize(os); - return os; -} - -mask_info decode_mask_info(std::string str, ck::index_t seqlen_q, ck::index_t seqlen_k) -{ - ck::index_t x_total = seqlen_k; - ck::index_t y_total = seqlen_q; - mask_info tmp; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string t = str.substr(0, found_0); - std::string v = str.substr(found_0 + 1); - auto found_1 = v.find(","); - if(found_1 == std::string::npos) - { - printf("not supported value %s, %s\n", v.c_str(), str.c_str()); - assert(0); - } - tmp.type = mask_enum::window_generic; - ck::index_t v0 = atoi(v.substr(0, found_1).c_str()); - ck::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); - // TODO: some validation - if(t == "t") - { - auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, true); - tmp.y = r.At(ck::Number<0>{}); - tmp.x = r.At(ck::Number<1>{}); - } - else if(t == "b") - { - auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, false); - tmp.y = r.At(ck::Number<0>{}); - tmp.x = r.At(ck::Number<1>{}); - } - else if(t == "g") - { - tmp.y = v0; - tmp.x = v1; - } - else - { - printf("not supported type %s, %s\n", t.c_str(), str.c_str()); - assert(0); - } - } - else - { - // should be 0, 1, 2 - tmp.type = static_cast(atoi(str.c_str())); - if(tmp.type == mask_enum::causal_top_left) - { - tmp.y = seqlen_q; - tmp.x = 1; - } - else if(tmp.type == mask_enum::causal_bottom_right) - { - tmp.y = seqlen_q; - tmp.x = seqlen_k - seqlen_q + 1; - } - } - return tmp; -} - -template -float invoke_fmha_kernel(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* bias_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck::index_t batch, - ck::index_t nhead, - ck::index_t nhead_k, - ck::index_t seqlen_q, - ck::index_t seqlen_k, - ck::index_t hdim_q, - ck::index_t hdim_v, - ck::index_t max_seqlen_q, - float scale, - bool i_perm, - bool o_perm, - ck::index_t mask_y, - ck::index_t mask_x, - StreamConfig stream_config) -{ - constexpr bool is_v_rowmajor = - ck::is_same_v; - - assert(nhead % nhead_k == 0); - /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, - /// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias' - /// are 0. - // setup stride_* arguments - const ck::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck::index_t stride_v = [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return i_perm ? seqlen_k : nhead_k * seqlen_k; - }(); - const ck::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k); - const ck::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - // setup nhead_stride_* arguments - const ck::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q); - const ck::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q); - const ck::index_t nhead_stride_v = [&]() { - if constexpr(is_v_rowmajor) - return i_perm ? seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * seqlen_k : seqlen_k; - }(); - const ck::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k); - const ck::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v); - // setup batch_stride_* arguments - const ck::index_t batch_stride_q = (nhead * seqlen_q * hdim_q); - const ck::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q); - const ck::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k); - const ck::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k); - const ck::index_t batch_stride_o = (nhead * seqlen_q * hdim_v); - - const auto kargs = [&] { - // create group mode kernel arguments - if constexpr(FmhaKernel_::kIsGroupMode) - { - return FmhaKernel_::MakeKargs(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - nhead / nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - mask_y, - mask_x); - } - else - { // create batch mode kernel arguments - return FmhaKernel_::MakeKargs(q_ptr, - k_ptr, - v_ptr, - bias_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - nhead / nhead_k, - scale, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_o, - mask_y, - mask_x); - } - }(); - - const dim3 kGridSize = FmhaKernel_::GridSize(batch, nhead, max_seqlen_q, hdim_v); - constexpr dim3 kBlockSize = FmhaKernel_::BlockSize(); - - constexpr ck::index_t kBlockPerCu = FmhaKernel_::kBlockPerCu; - - return launch_kernel(stream_config, - FmhaKernel_{}, - kGridSize, - kBlockSize, - 0, - kargs); // BatchStrideO -} - -template -struct fmha_fwd_kernel_invoker -{ - static constexpr ck::index_t HDim = HDim_; - // these args are used to select kernel. - // args that may passed as karg shoule use operator() - Mode mode; - bool use_bias; - mask_info mask; - - fmha_fwd_kernel_invoker(Mode mode_, bool use_bias_, mask_info mask_) - : mode(mode_), use_bias(use_bias_), mask(mask_) - { - } - - template - float operator()(Args&&... args) - { - float ave_time; - BOOL_SWITCH_2(mode == Mode::Group, kIsGroupMode, use_bias, kHasBias, [&] { - if(mask.type == mask_enum::no_mask) - { - using FmhaMask = ck::tile_program::block::GenericAttentionMask; - using Kernel = FmhaFwdKernel, - FmhaPipeline, - FmhaEpilogue>; - ave_time = invoke_fmha_kernel(std::forward(args)...); - } - else - { - BOOL_SWITCH(mask.type == mask_enum::window_generic, kIsLocal, [&]() { - using FmhaMask = ck::tile_program::block::GenericAttentionMask; - using Kernel = - FmhaFwdKernel, - FmhaPipeline, - FmhaEpilogue>; - ave_time = invoke_fmha_kernel(std::forward(args)...); - }); - } - }); - return ave_time; - } -}; - -static inline int env_get_int(const char* var_name, int default_int) -{ - char* v = getenv(var_name); - int r = default_int; - if(v) - r = atoi(v); - return r; -} - -auto create_args(int argc, char* argv[]) -{ - ArgParser arg_parser; - arg_parser.insert("v", "1", "weather do cpu validation or not") - .insert("mode", "0", "kernel mode. 0:batch, 1:group") - .insert("b", "2", "batch size") - .insert("h", "8", "num of head, for q") - .insert("h_k", - "0", - "num of head, for k/v, 0 means equal to h\n" - "if not equal to h, then this is GQA/MQA case") - .insert("s", "3328", "seqlen_q") - .insert("s_k", "0", "seqlen_k, 0 means equal to s") - .insert("d", "128", "head dim for q, k") - .insert("d_v", "0", "head dim for v, 0 means equal to d") - .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(seqlen)") - .insert("iperm", - "1", - "permute input\n" - "if true, will be b*h*s*d, else b*s*h*d") - .insert("operm", "1", "permute output") - .insert("bias", "0", "add bias or not") - .insert("mask", - "0", - "0: no mask, 1: top-left, 2:bottom-right\n" - "'t:l,r', top-left local-attn with left right size\n" - "'b:l,r', bottom-r local-attn with left right size\n" - "'g:y,x', generic attention mask coordinate with y/x size\n") - .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - int do_validation = arg_parser.get_int("v"); - auto mode = static_cast(arg_parser.get_uint32("mode")); - ck::index_t batch = arg_parser.get_int("b"); - ck::index_t nhead = arg_parser.get_int("h"); - ck::index_t nhead_k = arg_parser.get_int("h_k"); - if(nhead_k == 0) - nhead_k = nhead; - - if(nhead % nhead_k != 0) - { - std::cout << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; - return -1; - } - - ck::index_t seqlen_q = arg_parser.get_int("s"); - ck::index_t seqlen_k = arg_parser.get_int("s_k"); - if(seqlen_k == 0) - seqlen_k = seqlen_q; - ck::index_t hdim_q = arg_parser.get_int("d"); - ck::index_t hdim_v = arg_parser.get_int("d_v"); - if(hdim_v == 0) - hdim_v = hdim_q; - - int i_perm = arg_parser.get_int("iperm"); // if true, will be batch * nhead * seqlen * hdim - int o_perm = arg_parser.get_int("operm"); // if false, will be batch * seqlen * nhead * hdim - - float scale = arg_parser.get_float("scale"); - if(scale == .0f) - scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); // TODO: q ? v ? - - bool use_bias = arg_parser.get_uint32("bias"); - - mask_info mask = decode_mask_info(arg_parser.get_str("mask"), seqlen_q, seqlen_k); - - int init_method = arg_parser.get_int("init"); - - int stream_warmup = env_get_int("CK_WARMUP", 5); - int stream_repeat = env_get_int("CK_REPEAT", 20); - - StreamConfig stream_config{nullptr, true, 0, stream_warmup, stream_repeat}; - - const auto [seqlens_q, seqstart_q_host] = generate_seqlens_seqstarts_q(mode, batch, seqlen_q); - const std::vector seqstart_k_host = - generate_seqstarts_k(mode, batch, seqlen_k, seqlens_q, seqlen_q); - - // accumulation numbers for performance evaluation - std::size_t flop = 0, num_byte = 0; - auto max_seqlen_q = - std::numeric_limits::min(); // we will use max seqlen to decide grid size - { - for(ck::index_t wb = 0; wb < batch; ++wb) - { - const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - - if(max_seqlen_q < real_seqlen_q) - { - max_seqlen_q = real_seqlen_q; - } - - using namespace ck::literals; - - flop += nhead * (2_uz * real_seqlen_q * real_seqlen_k * hdim_q + - 2_uz * real_seqlen_q * hdim_v * real_seqlen_k); - - num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + - sizeof(KDataType) * real_seqlen_k * hdim_q + - sizeof(VDataType) * hdim_v * real_seqlen_k + - sizeof(ODataType) * real_seqlen_q * hdim_v); - } - } - - auto get_lengths = [&](int permute, - ck::index_t b /*batch*/, - ck::index_t h /*nhead*/, - ck::index_t s /*seqlen*/, - ck::index_t d /*hdim*/) { - if(permute) - return std::array{b, h, s, d}; - else - return std::array{b, s, h, d}; - }; - - constexpr bool is_v_rowmajor = ck::is_same_v; - - // host memory for storing all the tensor elements - const ck::index_t shape_batch = (mode == Mode::Batch ? batch : 1); - const ck::index_t shape_seqlen_q = (mode == Mode::Batch ? seqlen_q : seqstart_q_host.back()); - const ck::index_t shape_seqlen_k = (mode == Mode::Batch ? seqlen_k : seqstart_k_host.back()); - - Tensor q_host(get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); - Tensor k_host(get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); - Tensor v_host( - is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) - : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); - // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host - // will not be used for verification at all (but will be copied to device anyway). - Tensor bias_host( - use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) - : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); - Tensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); - - if(init_method == 0) - { - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(bias_host); - } - else if(init_method == 1) - { - ck::utils::FillUniformDistribution{0.f, 1.f}(q_host); - ck::utils::FillUniformDistribution{0.f, 1.f}(k_host); - ck::utils::FillUniformDistribution{-.5f, .5f}(v_host); - ck::utils::FillUniformDistribution{0.f, 1.f}(bias_host); - } - else if(init_method == 2) - { - ck::utils::FillTrigValue{}(q_host); - ck::utils::FillTrigValue{}(k_host); - ck::utils::FillTrigValue{}(v_host); - ck::utils::FillTrigValue{}(bias_host); - } - - DeviceMem q_buf(q_host.GetElementSpaceSizeInBytes()); - DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes()); - DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes()); - DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes()); - DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes()); - DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); - DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - - q_buf.ToDevice(q_host.data()); - k_buf.ToDevice(k_host.data()); - v_buf.ToDevice(v_host.data()); - bias_buf.ToDevice(bias_host.data()); - seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqstart_k_host.data()); - - // clang-format off - auto layout_str = [&](int permute){ - if (permute) return std::string("bhsd"); - else return std::string("bshd"); - }; - // clang-format on - - std::cout << "[" << mode << "|" << layout_str(i_perm) << "|" << layout_str(o_perm) - << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" - << seqlen_k << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale - << ", bias:" << use_bias << ", mask:" << mask - << ", v:" << std::string(VLayout::name)[0] << std::flush; - -#define INVOKE_FMHA_KERNEL(hdim_) \ - fmha_fwd_kernel_invoker{mode, use_bias, mask}(q_buf.GetDeviceBuffer(), \ - k_buf.GetDeviceBuffer(), \ - v_buf.GetDeviceBuffer(), \ - bias_buf.GetDeviceBuffer(), \ - o_buf.GetDeviceBuffer(), \ - seqstart_q.GetDeviceBuffer(), \ - seqstart_k.GetDeviceBuffer(), \ - nullptr, \ - batch, \ - nhead, \ - nhead_k, \ - shape_seqlen_q, \ - shape_seqlen_k, \ - hdim_q, \ - hdim_v, \ - max_seqlen_q, \ - scale, \ - i_perm, \ - o_perm, \ - mask.y, \ - mask.x, \ - stream_config) - - float ave_time = 0; - if(hdim_q == hdim_v && hdim_q == 64) - { - ave_time = INVOKE_FMHA_KERNEL(64); - } - else if(hdim_q == hdim_v && hdim_q == 128) - { - ave_time = INVOKE_FMHA_KERNEL(128); - } - else - { - std::cerr << "not support hdim, will not run" << std::endl; - return -1; - } - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " - << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush << std::endl; - - if(do_validation) - { - o_buf.FromDevice(o_host.data()); - - for(ck::index_t wb = 0; wb < batch; ++wb) - { - const ck::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - - // adjust matrix index according to the mode - const ck::index_t b = (mode == Mode::Batch ? wb : 0); - const ck::index_t query_offset = (mode == Mode::Batch ? 0 : seqstart_q_host[wb]); - const ck::index_t key_offset = (mode == Mode::Batch ? 0 : seqstart_k_host[wb]); - - const auto v_host_ref_lengths = - std::array{nhead, hdim_v, real_seqlen_k}; - const auto v_host_ref_strides = - is_v_rowmajor - ? std::array{hdim_v * real_seqlen_k, 1, hdim_v} - : std::array{hdim_v * real_seqlen_k, real_seqlen_k, 1}; - - Tensor q_host_ref({nhead, real_seqlen_q, hdim_q}); - Tensor k_host_ref({nhead, real_seqlen_k, hdim_q}); - Tensor v_host_ref(v_host_ref_lengths, v_host_ref_strides); - Tensor o_host_ref({nhead, real_seqlen_q, hdim_v}); - - Tensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); - Tensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); - - ck::index_t nr = nhead / nhead_k; - - // clang-format off - // permute - if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); - else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); - - if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); - else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); - - if constexpr (is_v_rowmajor) { - // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); - // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); - } - else { - if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); - else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); - } - - // reference - reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, - ck::identity{}, ck::identity{}, - [&](SaccDataType x) { return scale * x; }); - - if(use_bias) - { - Tensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); - if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); - else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); - - // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, real_seqlen_k] - reference_batched_elementwise( - s_host_ref, bias_host_ref, s_host_ref); - } - - if(mask.type == mask_enum::no_mask) { - reference_batched_masking(s_host_ref, ck::tile_program::block::GenericAttentionMask{}); - } else if(mask.type == mask_enum::window_generic) { - reference_batched_masking(s_host_ref, - ck::tile_program::block::GenericAttentionMask{mask.y, mask.x, seqlen_q, seqlen_k}); - } else { - reference_batched_masking(s_host_ref, - ck::tile_program::block::GenericAttentionMask{mask.y, mask.x, seqlen_q, seqlen_k}); - } - reference_batched_softmax(s_host_ref, p_host_ref); - reference_batched_gemm(p_host_ref, v_host_ref, o_host_ref); - - Tensor o_host_result({nhead, real_seqlen_q, hdim_v}); - // permute - if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); - else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); - // clang-format on - - if(!ck::utils::check_err(o_host_result, o_host_ref)) - { - std::cerr << "mismatch found at batch: " << wb << std::endl - << "\tseqlen_q: " << real_seqlen_q << std::endl - << "\tseqlen_k: " << real_seqlen_k << std::endl - << "\tseqstart_q: " << seqstart_q_host << std::endl - << "\tseqstart_k: " << seqstart_k_host << std::endl; - - return -1; - } - } - } - else - { - return 0; - } -} diff --git a/example/91_tile_program/gemm/CMakeLists.txt b/example/91_tile_program/gemm/CMakeLists.txt new file mode 100644 index 000000000..a6e8f1ef5 --- /dev/null +++ b/example/91_tile_program/gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm gemm.cpp) diff --git a/example/91_tile_program/gemm.cpp b/example/91_tile_program/gemm/gemm.cpp similarity index 99% rename from example/91_tile_program/gemm.cpp rename to example/91_tile_program/gemm/gemm.cpp index 67e8479ea..7e8ad59c1 100644 --- a/example/91_tile_program/gemm.cpp +++ b/example/91_tile_program/gemm/gemm.cpp @@ -13,7 +13,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_gemm.hpp" +#include "reference/reference_gemm.hpp" #include "gemm.hpp" // elementwise lambda diff --git a/example/91_tile_program/gemm.hpp b/example/91_tile_program/gemm/gemm.hpp similarity index 100% rename from example/91_tile_program/gemm.hpp rename to example/91_tile_program/gemm/gemm.hpp diff --git a/example/91_tile_program/gemm_gemm/CMakeLists.txt b/example/91_tile_program/gemm_gemm/CMakeLists.txt new file mode 100644 index 000000000..0034ade28 --- /dev/null +++ b/example/91_tile_program/gemm_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_gemm gemm_gemm.cpp) diff --git a/example/91_tile_program/gemm_gemm.cpp b/example/91_tile_program/gemm_gemm/gemm_gemm.cpp similarity index 99% rename from example/91_tile_program/gemm_gemm.cpp rename to example/91_tile_program/gemm_gemm/gemm_gemm.cpp index ccbea2369..e65eaf40a 100644 --- a/example/91_tile_program/gemm_gemm.cpp +++ b/example/91_tile_program/gemm_gemm/gemm_gemm.cpp @@ -13,7 +13,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_gemm.hpp" +#include "reference/reference_gemm.hpp" #include "gemm_gemm.hpp" int main(int argc, char* argv[]) diff --git a/example/91_tile_program/gemm_gemm.hpp b/example/91_tile_program/gemm_gemm/gemm_gemm.hpp similarity index 100% rename from example/91_tile_program/gemm_gemm.hpp rename to example/91_tile_program/gemm_gemm/gemm_gemm.hpp diff --git a/example/91_tile_program/gemm_softmax_gemm/CMakeLists.txt b/example/91_tile_program/gemm_softmax_gemm/CMakeLists.txt new file mode 100644 index 000000000..8ce4b41fd --- /dev/null +++ b/example/91_tile_program/gemm_softmax_gemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_softmax_gemm gemm_softmax_gemm.cpp) diff --git a/example/91_tile_program/gemm_softmax_gemm.cpp b/example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.cpp similarity index 98% rename from example/91_tile_program/gemm_softmax_gemm.cpp rename to example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.cpp index b887b8ab9..4dddfaa75 100644 --- a/example/91_tile_program/gemm_softmax_gemm.cpp +++ b/example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.cpp @@ -13,8 +13,8 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_gemm.hpp" -#include "reference_softmax.hpp" +#include "reference/reference_gemm.hpp" +#include "reference/reference_softmax.hpp" #include "gemm_softmax_gemm.hpp" int main(int argc, char* argv[]) diff --git a/example/91_tile_program/gemm_softmax_gemm.hpp b/example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.hpp similarity index 100% rename from example/91_tile_program/gemm_softmax_gemm.hpp rename to example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm.hpp diff --git a/example/91_tile_program/gemm_softmax_gemm_impl.hpp b/example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm_impl.hpp similarity index 100% rename from example/91_tile_program/gemm_softmax_gemm_impl.hpp rename to example/91_tile_program/gemm_softmax_gemm/gemm_softmax_gemm_impl.hpp diff --git a/example/91_tile_program/im2col/CMakeLists.txt b/example/91_tile_program/im2col/CMakeLists.txt new file mode 100644 index 000000000..7a72732bc --- /dev/null +++ b/example/91_tile_program/im2col/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_im2col im2col.cpp) diff --git a/example/91_tile_program/im2col.cpp b/example/91_tile_program/im2col/im2col.cpp similarity index 88% rename from example/91_tile_program/im2col.cpp rename to example/91_tile_program/im2col/im2col.cpp index 83a8ba55f..d0744cfef 100644 --- a/example/91_tile_program/im2col.cpp +++ b/example/91_tile_program/im2col/im2col.cpp @@ -24,55 +24,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -template -void reference_im2col(Tensor& in_mtx_host_ref, - const Tensor& in_host, - int /*N*/, - int /*K*/, - int C, - int /*Y*/, - int X, - int Hi, - int Wi, - int Ho, - int Wo, - int ConvStrideH, - int ConvStrideW, - int ConvDilationH, - int ConvDilationW, - int InLeftPadH, - int InLeftPadW, - int /*InRightPadH*/, - int /*InRightPadW*/) -{ - int GemmM = in_mtx_host_ref.GetLengths()[0]; - int GemmK = in_mtx_host_ref.GetLengths()[1]; - - for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m) - { - int mtmp = gemm_m; - int n = mtmp / (Ho * Wo); - mtmp -= n * Ho * Wo; - int ho = mtmp / Wo; - int wo = mtmp - ho * Wo; - - for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k) - { - int ktmp = gemm_k; - int y = ktmp / (X * C); - ktmp -= y * X * C; - int x = ktmp / C; - int c = ktmp - x * C; - - int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH; - int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW; - - bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi); - - in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0; - } - } -} +#include "reference/reference_im2col.hpp" template -void reference_reduce(const Tensor& a_m_n, Tensor& b_m) -{ - auto f = [&](auto m) { - const int N = a_m_n.mDesc.GetLengths()[1]; - - AccDataType v_acc = 0; - - for(int n = 0; n < N; ++n) - { - const ADataType v_a = a_m_n(m, n); - - v_acc += v_a; - } - - b_m(m) = ck::type_convert(v_acc); - }; - - make_ParallelTensorFunctor(f, b_m.mDesc.GetLengths()[0])(std::thread::hardware_concurrency()); -} - int main(int argc, char* argv[]) { using ADataType = ck::half_t; diff --git a/example/91_tile_program/reduce.hpp b/example/91_tile_program/reduce/reduce.hpp similarity index 100% rename from example/91_tile_program/reduce.hpp rename to example/91_tile_program/reduce/reduce.hpp diff --git a/example/91_tile_program/reference_batched_elementwise.hpp b/example/91_tile_program/reference/reference_batched_elementwise.hpp similarity index 100% rename from example/91_tile_program/reference_batched_elementwise.hpp rename to example/91_tile_program/reference/reference_batched_elementwise.hpp diff --git a/example/91_tile_program/reference_batched_gemm.hpp b/example/91_tile_program/reference/reference_batched_gemm.hpp similarity index 100% rename from example/91_tile_program/reference_batched_gemm.hpp rename to example/91_tile_program/reference/reference_batched_gemm.hpp diff --git a/example/91_tile_program/reference_batched_masking.hpp b/example/91_tile_program/reference/reference_batched_masking.hpp similarity index 100% rename from example/91_tile_program/reference_batched_masking.hpp rename to example/91_tile_program/reference/reference_batched_masking.hpp diff --git a/example/91_tile_program/reference_batched_softmax.hpp b/example/91_tile_program/reference/reference_batched_softmax.hpp similarity index 100% rename from example/91_tile_program/reference_batched_softmax.hpp rename to example/91_tile_program/reference/reference_batched_softmax.hpp diff --git a/example/91_tile_program/reference_gemm.hpp b/example/91_tile_program/reference/reference_gemm.hpp similarity index 100% rename from example/91_tile_program/reference_gemm.hpp rename to example/91_tile_program/reference/reference_gemm.hpp diff --git a/example/91_tile_program/reference/reference_im2col.hpp b/example/91_tile_program/reference/reference_im2col.hpp new file mode 100644 index 000000000..44ecab29f --- /dev/null +++ b/example/91_tile_program/reference/reference_im2col.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template +void reference_im2col(Tensor& in_mtx_host_ref, + const Tensor& in_host, + int /*N*/, + int /*K*/, + int C, + int /*Y*/, + int X, + int Hi, + int Wi, + int Ho, + int Wo, + int ConvStrideH, + int ConvStrideW, + int ConvDilationH, + int ConvDilationW, + int InLeftPadH, + int InLeftPadW, + int /*InRightPadH*/, + int /*InRightPadW*/) +{ + int GemmM = in_mtx_host_ref.GetLengths()[0]; + int GemmK = in_mtx_host_ref.GetLengths()[1]; + + for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m) + { + int mtmp = gemm_m; + int n = mtmp / (Ho * Wo); + mtmp -= n * Ho * Wo; + int ho = mtmp / Wo; + int wo = mtmp - ho * Wo; + + for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k) + { + int ktmp = gemm_k; + int y = ktmp / (X * C); + ktmp -= y * X * C; + int x = ktmp / C; + int c = ktmp - x * C; + + int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH; + int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW; + + bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi); + + in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0; + } + } +} diff --git a/example/91_tile_program/reference/reference_reduce.hpp b/example/91_tile_program/reference/reference_reduce.hpp new file mode 100644 index 000000000..a4e0941f3 --- /dev/null +++ b/example/91_tile_program/reference/reference_reduce.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template +void reference_reduce(const Tensor& a_m_n, Tensor& b_m) +{ + auto f = [&](auto m) { + const int N = a_m_n.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + + for(int n = 0; n < N; ++n) + { + const ADataType v_a = a_m_n(m, n); + + v_acc += v_a; + } + + b_m(m) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f, b_m.mDesc.GetLengths()[0])(std::thread::hardware_concurrency()); +} diff --git a/example/91_tile_program/reference_softmax.hpp b/example/91_tile_program/reference/reference_softmax.hpp similarity index 100% rename from example/91_tile_program/reference_softmax.hpp rename to example/91_tile_program/reference/reference_softmax.hpp diff --git a/example/91_tile_program/softmax/CMakeLists.txt b/example/91_tile_program/softmax/CMakeLists.txt new file mode 100644 index 000000000..da580fbff --- /dev/null +++ b/example/91_tile_program/softmax/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_softmax softmax.cpp) diff --git a/example/91_tile_program/softmax.cpp b/example/91_tile_program/softmax/softmax.cpp similarity index 98% rename from example/91_tile_program/softmax.cpp rename to example/91_tile_program/softmax/softmax.cpp index f78d609f2..93d1279d8 100644 --- a/example/91_tile_program/softmax.cpp +++ b/example/91_tile_program/softmax/softmax.cpp @@ -13,7 +13,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "reference_softmax.hpp" +#include "reference/reference_softmax.hpp" #include "softmax.hpp" int main(int argc, char* argv[]) diff --git a/example/91_tile_program/softmax.hpp b/example/91_tile_program/softmax/softmax.hpp similarity index 100% rename from example/91_tile_program/softmax.hpp rename to example/91_tile_program/softmax/softmax.hpp diff --git a/include/ck/tile_program/block_tile/block_masking.hpp b/include/ck/tile_program/block_tile/block_masking.hpp index 238f63f8b..a9ed67bf6 100644 --- a/include/ck/tile_program/block_tile/block_masking.hpp +++ b/include/ck/tile_program/block_tile/block_masking.hpp @@ -3,6 +3,10 @@ #pragma once +#include "ck/ck.hpp" +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" + namespace ck { namespace tile_program { namespace block { diff --git a/include/ck/tile_program/tile/static_distributed_tensor.hpp b/include/ck/tile_program/tile/static_distributed_tensor.hpp index 3e675cb5f..b9b532eca 100644 --- a/include/ck/tile_program/tile/static_distributed_tensor.hpp +++ b/include/ck/tile_program/tile/static_distributed_tensor.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tile_program/tile/static_tile_distribution_helper.hpp" #include "ck/tile_program/tile/tile_distribution.hpp" namespace ck { diff --git a/include/ck/utility/buffer_view_declare.hpp b/include/ck/utility/buffer_view_declare.hpp index 747f1ab63..42b31954d 100644 --- a/include/ck/utility/buffer_view_declare.hpp +++ b/include/ck/utility/buffer_view_declare.hpp @@ -5,9 +5,11 @@ #pragma once #include "ck/ck.hpp" +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/amd_buffer_addressing.hpp" +#include "ck/utility/c_style_pointer_cast.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/enable_if.hpp" -#include "ck/utility/c_style_pointer_cast.hpp" namespace ck { From 0c1bf348e082a7fc4d9c6c8ebecfa32ad832ce24 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 5 Jan 2024 15:07:57 +0000 Subject: [PATCH 26/45] modify bench script --- .../91_tile_program/fmha/script/benchmark.sh | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/example/91_tile_program/fmha/script/benchmark.sh b/example/91_tile_program/fmha/script/benchmark.sh index b3b089652..41449ca3b 100644 --- a/example/91_tile_program/fmha/script/benchmark.sh +++ b/example/91_tile_program/fmha/script/benchmark.sh @@ -7,19 +7,19 @@ VALID=0 for prec in "fp16" "bf16" ; do for perm in 0 1 ; do -$EXE -prec=$prec -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=32 -h=32 -d=64 -s=512 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=16 -h=32 -d=64 -s=1024 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=8 -h=32 -d=64 -s=2048 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=4 -h=32 -d=64 -s=4096 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=2 -h=32 -d=64 -s=8192 -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=32 -d=64 -s=16384 -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=32 -h=32 -d=64 -s=512 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=32 -d=64 -s=1024 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=32 -d=64 -s=2048 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=32 -d=64 -s=4096 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=32 -d=64 -s=8192 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=32 -d=64 -s=16384 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 done done From 65c8f98270a78bb048ef9e44ffef7da32ba008ef Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Sat, 6 Jan 2024 10:46:56 +0800 Subject: [PATCH 27/45] Fix inconsistent mask creation logics (#63) --- example/91_tile_program/fmha/fmha_fwd.cpp | 6 +++--- example/91_tile_program/fmha/fmha_fwd_kernel.hpp | 2 +- include/ck/tile_program/block_tile/block_masking.hpp | 5 ++++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index a306d319c..7877de504 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -436,13 +436,13 @@ bool run(const ArgParser& arg_parser) } if(mask.type == mask_enum::no_mask) { - reference_batched_masking(s_host_ref, FmhaMasks::NoMask{}); + reference_batched_masking(s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); } else if(mask.type == mask_enum::window_generic) { reference_batched_masking(s_host_ref, - FmhaMasks::GenericMask{mask.y, mask.x, seqlen_q, seqlen_k}); + FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); } else { reference_batched_masking(s_host_ref, - FmhaMasks::CausalMask{mask.y, mask.x, seqlen_q, seqlen_k}); + FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); } reference_batched_softmax(s_host_ref, p_host_ref); reference_batched_gemm(p_host_ref, v_host_ref, o_host_ref); diff --git a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp index 09865ecab..3c7f8d270 100644 --- a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp @@ -508,7 +508,7 @@ struct FmhaFwdKernel if constexpr(kHasMask) return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k}; else - return FmhaMask{0, 0, kargs.seqlen_q, kargs.seqlen_k}; + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); auto o_acc_tile = diff --git a/include/ck/tile_program/block_tile/block_masking.hpp b/include/ck/tile_program/block_tile/block_masking.hpp index a9ed67bf6..c40c85bce 100644 --- a/include/ck/tile_program/block_tile/block_masking.hpp +++ b/include/ck/tile_program/block_tile/block_masking.hpp @@ -65,7 +65,10 @@ struct GenericAttentionMask static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask, // else only upper-right could have mask - __host__ __device__ GenericAttentionMask() : y(0), x(0), y_total(0), x_total(0) {} + __host__ __device__ GenericAttentionMask(index_t y_total_, index_t x_total_) + : GenericAttentionMask(0, 0, y_total_, x_total_) + { + } __host__ __device__ GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) From 539f9677e047da576f67810f7833dd983df3c1f8 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 6 Jan 2024 12:15:10 +0000 Subject: [PATCH 28/45] support non-broadcast in block reduce sync --- .../tile_program/block_tile/block_reduce.hpp | 60 ++++++++++--------- .../block_fmha_pipeline_qr_ks_vs.hpp | 4 +- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 4 +- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/include/ck/tile_program/block_tile/block_reduce.hpp b/include/ck/tile_program/block_tile/block_reduce.hpp index 1cba690c7..08a8f8a42 100644 --- a/include/ck/tile_program/block_tile/block_reduce.hpp +++ b/include/ck/tile_program/block_tile/block_reduce.hpp @@ -14,9 +14,10 @@ namespace tile_program { namespace block { // synchronize reduce result (cross lane reduction and broadcast on replicated dimension) -template +template __device__ void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, - const ReduceFunc& reduce_func) + const ReduceFunc& reduce_func, + bool_constant = {}) { using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using DstrEncode = typename Dstr::DstrEncode; @@ -67,40 +68,43 @@ __device__ void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, } }); - // cross-lane broadcast for replication - // only broadcast on R dimension correspond to lane - // (lane id maps to this R dimension) - static_for<0, NDimR, 1>{}([&](auto idim_r) { - // FIXME: nasty to use does_p_own_r_ - if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) - { - const index_t r_id = rs_idx[idim_r]; + if constexpr(WithBroadcast) + { + // cross-lane broadcast for replication + // only broadcast on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + const index_t r_id = rs_idx[idim_r]; - constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; - constexpr index_t lid_over_rid_derivative = - DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; - static_assert(math::is_power_of_two_integer(r_length), - "wrong! only support power of 2 reduction"); + static_assert(math::is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); - constexpr index_t nstage = math::integer_log2_floor(r_length); + constexpr index_t nstage = math::integer_log2_floor(r_length); - // broadcast sweep backward - static_for<0, nstage, 1>{}([&](auto istage) { - // do I hold reduced data? - const bool do_i_hold_reduced_data = r_id < (1 << istage); + // broadcast sweep backward + static_for<0, nstage, 1>{}([&](auto istage) { + // do I hold reduced data? + const bool do_i_hold_reduced_data = r_id < (1 << istage); - constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); + constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); - // pull data from remote lane - const auto v_remote = warp_shuffle_up(v_local, lid_delta); + // pull data from remote lane + const auto v_remote = warp_shuffle_up(v_local, lid_delta); - // decide whether to update local data with remote data - v_local = do_i_hold_reduced_data ? v_local : v_remote; - }); - } - }); + // decide whether to update local data with remote data + v_local = do_i_hold_reduced_data ? v_local : v_remote; + }); + } + }); + } acc_tensor.GetThreadBuffer()(i) = v_local; }); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 854b1631e..2cb074dcc 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -304,7 +304,7 @@ struct BlockFmhaPipelineQRKSVS Sequence<1>{}, f_max, NumericLimits::Lowest()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} tile_elementwise_inout( @@ -339,7 +339,7 @@ struct BlockFmhaPipelineQRKSVS auto rowsum_p = block_tile_reduce( p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - block_tile_reduce_sync(rowsum_p, f_sum); + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); // l{j}, Oacc{j} constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 695e2bdcb..9e3476531 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -338,7 +338,7 @@ struct BlockFmhaPipelineQRKSVSAsync Sequence<1>{}, f_max, NumericLimits::Lowest()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} tile_elementwise_inout( @@ -405,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto rowsum_p = block_tile_reduce( p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - block_tile_reduce_sync(rowsum_p, f_sum); + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); // l{j}, Oacc{j} constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { From f188b802dd23f92169573acbd589d8ab767bb81c Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Thu, 11 Jan 2024 03:49:13 -0500 Subject: [PATCH 29/45] Fix wrong data type used for bias tensor --- example/91_tile_program/fmha/fmha_fwd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index 7877de504..978f6f498 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -256,7 +256,7 @@ bool run(const ArgParser& arg_parser) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host // will not be used for verification at all (but will be copied to device anyway). - Tensor bias_host( + Tensor bias_host( use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); Tensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); From 1787c230afabc7ed7ffdb8a34454480bbc9da0fe Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Thu, 11 Jan 2024 21:20:49 +0800 Subject: [PATCH 30/45] Flexible head dimension (#66) * Support more head dim values * Change the head dim check logic according to config * Add missing include directive * Use fewer occupancy to prevent register spills * Rename closure object * Make expr more readable --- example/91_tile_program/fmha/fmha_fwd.cpp | 17 +++- example/91_tile_program/fmha/fmha_fwd.hpp | 18 +++- .../91_tile_program/fmha/fmha_fwd_kernel.hpp | 89 +++++++++++++------ .../fmha/fmha_fwd_tile_partitioner.hpp | 2 +- .../91_tile_program/fmha/generate_kernels.py | 2 +- .../91_tile_program/fmha/script/benchmark.sh | 7 ++ .../tensor_space_filling_curve.hpp | 4 + .../block_fmha_pipeline_problem.hpp | 5 ++ .../block_fmha_pipeline_qr_ks_vs.hpp | 4 + .../block_fmha_pipeline_qr_ks_vs_async.hpp | 3 + .../ck/tile_program/tile/tile_fmha_shape.hpp | 1 + .../ck/tile_program/tile/tile_fmha_traits.hpp | 2 + include/ck/utility/magic_division.hpp | 2 +- include/ck/utility/static_buffer.hpp | 2 + include/ck/utility/tuple.hpp | 4 +- 15 files changed, 128 insertions(+), 34 deletions(-) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index 978f6f498..4d27ddf3b 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -339,12 +340,22 @@ bool run(const ArgParser& arg_parser) mask.y, \ mask.x) - float ave_time = 0; - if(hdim_q == hdim_v && hdim_q == 64) + float ave_time = 0; + const auto check_hdims = [](ck::index_t hdim_q_, ck::index_t hdim_v_, ck::index_t threshold) { + const auto compare = + std::conditional_t, std::equal<>>{}; + return compare(hdim_q_, threshold) && compare(hdim_v_, threshold); + }; + + if(check_hdims(hdim_q, hdim_v, 32)) + { + ave_time = INVOKE_FMHA_KERNEL(32); + } + else if(check_hdims(hdim_q, hdim_v, 64)) { ave_time = INVOKE_FMHA_KERNEL(64); } - else if(hdim_q == hdim_v && hdim_q == 128) + else if(check_hdims(hdim_q, hdim_v, 128)) { ave_time = INVOKE_FMHA_KERNEL(128); } diff --git a/example/91_tile_program/fmha/fmha_fwd.hpp b/example/91_tile_program/fmha/fmha_fwd.hpp index 8a3abd301..68e1a9a12 100644 --- a/example/91_tile_program/fmha/fmha_fwd.hpp +++ b/example/91_tile_program/fmha/fmha_fwd.hpp @@ -63,10 +63,15 @@ struct FmhaMasks inline constexpr bool kM0NeedPadding = false; inline constexpr bool kN0K1NeedPadding = false; +inline constexpr bool kK0N1NeedPadding = false; template struct FmhaBlockTile; +template <> +struct FmhaBlockTile : ck::Sequence<128, 64, 16, 32, 32, 32> +{ +}; template <> struct FmhaBlockTile : ck::Sequence<128, 64, 32, 64, 32, 64> { @@ -81,6 +86,16 @@ using FmhaWarpTile = ck::Sequence<32, 32, 16>; template struct FmhaShape; +template <> +struct FmhaShape : ck::tile_program::TileFmhaShape, + ck::Sequence<2, 1, 1>, + FmhaWarpTile, + ck::Sequence<2, 1, 1>, + FmhaWarpTile, + VLayout> +{ +}; + template <> struct FmhaShape : ck::tile_program::TileFmhaShape, FmhaBlockWarps, @@ -105,6 +120,7 @@ struct FmhaShape template using FmhaTraits = ck::tile_program::TileFmhaTraits; @@ -119,7 +135,7 @@ using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, typename FmhaFwdTypeConfig::ODataType, - /* BlockSize = */ 256, + /* BlockSize = */ HDim == 32 ? 128 : 256, FmhaShape, kIsGroupMode, FmhaMask, diff --git a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp index 3c7f8d270..a326ca1aa 100644 --- a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp @@ -35,6 +35,7 @@ struct FmhaFwdKernel static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; static constexpr bool kM0NeedPadding = FmhaPipeline::kM0NeedPadding; static constexpr bool kN0K1NeedPadding = FmhaPipeline::kN0K1NeedPadding; + static constexpr bool kK0N1NeedPadding = FmhaPipeline::kK0N1NeedPadding; static constexpr bool kHasBias = FmhaPipeline::kHasBias; using FmhaMask = ck::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -380,10 +381,20 @@ struct FmhaFwdKernel make_tuple(kargs.stride_q, 1), Number<32>{}, Number<1>{}); - - return pad_tensor_view(q_dram_naive, - make_tuple(Number{}, Number<1>{}), - Sequence{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); + } }(); const auto k_dram = [&]() { const auto k_dram_naive = make_naive_tensor_view( @@ -393,9 +404,10 @@ struct FmhaFwdKernel Number<32>{}, Number<1>{}); - return pad_tensor_view(k_dram_naive, - make_tuple(Number{}, Number<1>{}), - Sequence{}); + return pad_tensor_view( + k_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); }(); const auto v_dram = [&]() { if constexpr(ck::is_same_v) @@ -418,19 +430,44 @@ struct FmhaFwdKernel /// same as /// v_dram_transposed.GetTensorDescriptor().GetLength(). Replace following /// if-clause by pad_tensor_view() call after fixing this issue. - if constexpr(kN0K1NeedPadding) + if constexpr(kK0N1NeedPadding || kN0K1NeedPadding) { - const index_t pad_length = - FmhaPipeline::kK1 * - ck::math::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kK1) - - kargs.seqlen_k; - - return transform_tensor_view( - v_dram_transposed, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_right_pad_transform(kargs.seqlen_k, pad_length)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto transform_n1 = [&] { + if constexpr(kK0N1NeedPadding) + { + const index_t n1_pad_length = + FmhaPipeline::kN1 * + ck::math::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1) - + kargs.hdim_v; + + return make_right_pad_transform(kargs.hdim_v, n1_pad_length); + } + else + { + return make_pass_through_transform(kargs.hdim_v); + } + }(); + + const auto transform_k1 = [&] { + if constexpr(kN0K1NeedPadding) + { + const index_t k1_pad_length = + FmhaPipeline::kK1 * ck::math::integer_divide_ceil( + kargs.seqlen_k, FmhaPipeline::kK1) - + kargs.seqlen_k; + + return make_right_pad_transform(kargs.seqlen_k, k1_pad_length); + } + else + { + return make_pass_through_transform(kargs.seqlen_k); + } + }(); + + return transform_tensor_view(v_dram_transposed, + make_tuple(transform_n1, transform_k1), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); } else { @@ -446,9 +483,10 @@ struct FmhaFwdKernel Number<32>{}, Number<1>{}); - return pad_tensor_view(v_dram_naive, - make_tuple(Number<1>{}, Number{}), - Sequence{}); + return pad_tensor_view( + v_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); } }(); @@ -531,9 +569,10 @@ struct FmhaFwdKernel Number<32>{}, Number<1>{}); - return pad_tensor_view(o_dram_naive, - make_tuple(Number{}, Number<1>{}), - Sequence{}); + return pad_tensor_view( + o_dram_naive, + make_tuple(Number{}, Number{}), + Sequence{}); }(); auto o_dram_window = diff --git a/example/91_tile_program/fmha/fmha_fwd_tile_partitioner.hpp b/example/91_tile_program/fmha/fmha_fwd_tile_partitioner.hpp index d60dce24a..b6194716f 100644 --- a/example/91_tile_program/fmha/fmha_fwd_tile_partitioner.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_tile_partitioner.hpp @@ -35,7 +35,7 @@ struct FmhaFwdTilePartitioner using namespace ck; // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = hdim_v / kN1; + const index_t num_tile_n1 = ck::math::integer_divide_ceil(hdim_v, kN1); const index_t i_block = blockIdx.x; const index_t i_nhead = blockIdx.y; diff --git a/example/91_tile_program/fmha/generate_kernels.py b/example/91_tile_program/fmha/generate_kernels.py index 60556745a..a7b8ebda5 100644 --- a/example/91_tile_program/fmha/generate_kernels.py +++ b/example/91_tile_program/fmha/generate_kernels.py @@ -23,7 +23,7 @@ "group" : "true" } -HDIMS = [64, 128] +HDIMS = [32, 64, 128] MASKS = ["no", "causal", "generic"] DIRECTIONS = ["fwd"] GEN_DIR = "" diff --git a/example/91_tile_program/fmha/script/benchmark.sh b/example/91_tile_program/fmha/script/benchmark.sh index 41449ca3b..f245691b3 100644 --- a/example/91_tile_program/fmha/script/benchmark.sh +++ b/example/91_tile_program/fmha/script/benchmark.sh @@ -21,5 +21,12 @@ $EXE -prec=$prec -b=4 -h=32 -d=64 -s=4096 -iperm=$perm -operm=$perm -v=$VALID $EXE -prec=$prec -b=2 -h=32 -d=64 -s=8192 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 $EXE -prec=$prec -b=1 -h=32 -d=64 -s=16384 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=32 -h=32 -d=32 -s=512 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=32 -d=32 -s=1024 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=32 -d=32 -s=2048 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=32 -d=32 -s=4096 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=32 -d=32 -s=8192 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=32 -d=32 -s=16384 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 + done done diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index 8fc0b03a5..d4e35ee82 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -20,6 +20,10 @@ template // # of scalars per access in each dimension struct SpaceFillingCurve { + static constexpr index_t TensorSize = + reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}); + static_assert(0 < TensorSize, "SpaceFillingCurve should be used to access a non-empty tensor"); + static constexpr index_t nDim = TensorLengths::Size(); using Index = MultiIndex; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp index ed42ae937..e0c00939f 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/ck.hpp" +#include "ck/utility/get_id.hpp" #include "ck/utility/type.hpp" namespace ck { @@ -39,12 +40,16 @@ struct BlockFmhaPipelineProblem using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; + static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0, + "kBlockSize should be divisible by get_warp_size()"); + static constexpr index_t kBlockSize = kBlockSize_; static constexpr bool kIsGroupMode = kIsGroupMode_; // attributes from traits static constexpr bool kM0NeedPadding = Traits::kM0NeedPadding; static constexpr bool kN0K1NeedPadding = Traits::kN0K1NeedPadding; + static constexpr bool kK0N1NeedPadding = Traits::kK0N1NeedPadding; static constexpr bool kHasBias = Traits::kHasBias; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 2cb074dcc..be5f1c590 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -55,6 +55,7 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; + static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; static constexpr bool kHasBias = Problem::kHasBias; __host__ __device__ static constexpr ck::index_t GetSmemSize() @@ -189,6 +190,9 @@ struct BlockFmhaPipelineQRKSVS index_t i_total_loops = 0; constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); do { // STAGE 1, QK gemm diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 9e3476531..f3c02e8f9 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -55,6 +55,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; + static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; static constexpr bool kHasBias = Problem::kHasBias; __host__ __device__ static constexpr ck::index_t GetSmemSize() @@ -233,6 +234,8 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr index_t k0_loops = kK0BlockLength / kK0; constexpr index_t k1_loops = kN0 / kK1; + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); // main loop do { diff --git a/include/ck/tile_program/tile/tile_fmha_shape.hpp b/include/ck/tile_program/tile/tile_fmha_shape.hpp index c5ab151e6..acedc2d07 100644 --- a/include/ck/tile_program/tile/tile_fmha_shape.hpp +++ b/include/ck/tile_program/tile/tile_fmha_shape.hpp @@ -39,6 +39,7 @@ struct TileFmhaShape static constexpr index_t kK0BlockLength = BlockTile::At(Number<5>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) + static_assert(kK0BlockLength % kK0 == 0, "kK0BlockLength should be divisible by kK0"); using VLayout = remove_cvref_t; // rowmajor : seqlen*hdim, colmajor : hdim*seqlen }; diff --git a/include/ck/tile_program/tile/tile_fmha_traits.hpp b/include/ck/tile_program/tile/tile_fmha_traits.hpp index ab52929ba..697c485d1 100644 --- a/include/ck/tile_program/tile/tile_fmha_traits.hpp +++ b/include/ck/tile_program/tile/tile_fmha_traits.hpp @@ -10,12 +10,14 @@ namespace tile_program { template struct TileFmhaTraits { static constexpr bool kM0NeedPadding = kM0NeedPadding_; static constexpr bool kN0K1NeedPadding = kN0K1NeedPadding_; + static constexpr bool kK0N1NeedPadding = kK0N1NeedPadding_; static constexpr bool kHasBias = kHasBias_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index b8e7380f0..e2f12216f 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -45,7 +45,7 @@ struct MagicDivision32BitRange } // integral_constant - template + template > __host__ __device__ static constexpr auto CalculateMagicNumbers(integral_constant) { diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index 0ccebd476..f79e774a8 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -14,6 +14,8 @@ template // TODO remove this bool, no longer needed struct StaticBuffer : public StaticallyIndexedArray, N> { + static_assert(0 < N, "StaticBuffer should not be empty"); + using S = remove_cvref_t; using type = S; using base = StaticallyIndexedArray; diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index eb2995872..cf8b0229d 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -171,14 +171,14 @@ struct Tuple : detail::TupleImpl + template > __host__ __device__ constexpr const auto& operator[](Number i) const { return At(i); } // write access - template + template > __host__ __device__ constexpr auto& operator()(Number i) { return At(i); From cd4c0600f37288f09736d910378efeb18a8c4142 Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Fri, 12 Jan 2024 01:31:45 -0500 Subject: [PATCH 31/45] Fix complation error --- example/91_tile_program/fmha/fmha_fwd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index 4d27ddf3b..6216ea939 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -343,7 +343,7 @@ bool run(const ArgParser& arg_parser) float ave_time = 0; const auto check_hdims = [](ck::index_t hdim_q_, ck::index_t hdim_v_, ck::index_t threshold) { const auto compare = - std::conditional_t, std::equal<>>{}; + std::conditional_t, std::equal_to<>>{}; return compare(hdim_q_, threshold) && compare(hdim_v_, threshold); }; From bf427ceb2afdb0e54958c918cf1c63ff07bac460 Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Fri, 12 Jan 2024 03:49:54 -0500 Subject: [PATCH 32/45] Extract distributed indices convertion logics as function Extract logic as new helper function: get_x_indices_from_distributed_indices() --- .../tile/static_distributed_tensor.hpp | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/include/ck/tile_program/tile/static_distributed_tensor.hpp b/include/ck/tile_program/tile/static_distributed_tensor.hpp index b9b532eca..d0a0355be 100644 --- a/include/ck/tile_program/tile/static_distributed_tensor.hpp +++ b/include/ck/tile_program/tile/static_distributed_tensor.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor_coordinate.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" @@ -177,30 +178,40 @@ __host__ __device__ constexpr auto make_static_distributed_tensor(const StaticTi remove_cvref_t>{}; } +// get X indices from tuple of TileDistributedIndex<> +template +__host__ __device__ constexpr auto +get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices) +{ + const auto partition_index = detail::get_partition_index(tile_distribution); + constexpr auto y_indices = + tile_distribution.GetYIndicesFromDistributedIndices(distributed_indices); + + const auto x_coord = make_tensor_adaptor_coordinate( + tile_distribution.GetPsYs2XsAdaptor(), + container_concat(partition_index, to_array(y_indices))); + + return x_coord.GetBottomIndex(); +} + template __host__ __device__ void set_tile_if(StaticDistributedTensor& out_tensor, DataType value, XIndicesPredicate predicate) { - - StaticTileDistribution tile_distribution; - const auto partition_index = detail::get_partition_index(tile_distribution); - constexpr auto out_spans = StaticDistributedTensor::GetDistributedSpans(); sweep_tile_span(out_spans[Number<0>{}], [&](auto idx0) { sweep_tile_span(out_spans[Number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto y_idx = tile_distribution.GetYIndicesFromDistributedIndices(i_j_idx); - - const auto coord = make_tensor_adaptor_coordinate( - tile_distribution.GetPsYs2XsAdaptor(), - container_concat(partition_index, to_array(y_idx))); + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{}, + distributed_indices); - if(predicate(coord.GetBottomIndex())) + if(predicate(x_indices)) { - out_tensor(i_j_idx) = value; + out_tensor(distributed_indices) = value; } }); }); From 6cbea7d6036d2d265db0a5669a9b1a8d88567bd1 Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 18 Jan 2024 21:24:14 +0800 Subject: [PATCH 33/45] Flash attention fwd store LSE (#65) * add lse parameters to kernel * add store lse in kernel * add lse host ref and check result * add parameter to control store lse or not * fix kernel template kStoreLSE value * move lse store to pipeline * fix output err info * fix output err info2 * change lse 4 dim to 3 dim * mask storing lse in example * remove divide in kernel * remove pointer in function reference_batched_softmax * set LSE is template parameter in FmhaFwdKernelSelector * remove parameter stride_lse * fix bug for using nullopt in function reference_batched_softmax --------- Co-authored-by: letaoqin --- example/91_tile_program/fmha/fmha_fwd.cpp | 155 ++++++++++++------ example/91_tile_program/fmha/fmha_fwd.hpp | 41 ++++- .../91_tile_program/fmha/fmha_fwd_kernel.hpp | 76 ++++++++- .../91_tile_program/fmha/generate_kernels.py | 23 ++- .../reference/reference_batched_softmax.hpp | 11 +- .../block_fmha_pipeline_problem.hpp | 3 + .../block_fmha_pipeline_qkvs.hpp | 2 +- .../block_fmha_pipeline_qr_ks_vs.hpp | 39 ++++- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 43 ++++- ...k_fmha_pipeline_qr_ks_vs_custom_policy.hpp | 2 +- .../ck/tile_program/tile/tile_fmha_traits.hpp | 2 + 11 files changed, 324 insertions(+), 73 deletions(-) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index 6216ea939..e41478f42 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -63,6 +63,7 @@ auto create_args(int argc, char* argv[]) "'t:l,r', top-left local-attn with left right size\n" "'b:l,r', bottom-r local-attn with left right size\n" "'g:y,x', generic attention mask coordinate with y/x size\n") + .insert("lse", "0", "0 not store lse, 1 store lse") .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float"); bool result = arg_parser.parse(argc, argv); @@ -79,9 +80,10 @@ struct fmha_fwd_kernel_invoker mode_enum mode; bool use_bias; mask_info mask; + bool store_lse; - fmha_fwd_kernel_invoker(mode_enum mode_, bool use_bias_, mask_info mask_) - : mode(mode_), use_bias(use_bias_), mask(mask_) + fmha_fwd_kernel_invoker(mode_enum mode_, bool use_bias_, mask_info mask_, bool store_lse_) + : mode(mode_), use_bias(use_bias_), mask(mask_), store_lse(store_lse_) { } @@ -89,30 +91,40 @@ struct fmha_fwd_kernel_invoker float operator()(const StreamConfig& stream, Args&&... args) { float ave_time; - BOOL_SWITCH_2(mode == mode_enum::group, kIsGroupMode, use_bias, kHasBias, [&] { - if(mask.type == mask_enum::no_mask) - { - using FmhaMask = FmhaMasks::NoMask; - using Kernel = - FmhaFwdKernelSelector; - - auto [kargs, grids] = - fmha_fwd_create_kargs_and_grids(std::forward(args)...); - ave_time = fmha_fwd_run(stream, kargs, grids); - } - else - { - BOOL_SWITCH(mask.type == mask_enum::window_generic, kIsLocal, [&]() { - using FmhaMask = ck::tile_program::block::GenericAttentionMask; - using Kernel = - FmhaFwdKernelSelector; + BOOL_SWITCH_3( + mode == mode_enum::group, kIsGroupMode, use_bias, kHasBias, store_lse, kStoreLSE, [&] { + if(mask.type == mask_enum::no_mask) + { + using FmhaMask = FmhaMasks::NoMask; + using Kernel = FmhaFwdKernelSelector; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(std::forward(args)...); ave_time = fmha_fwd_run(stream, kargs, grids); - }); - } - }); + } + else + { + BOOL_SWITCH(mask.type == mask_enum::window_generic, kIsLocal, [&]() { + using FmhaMask = + ck::tile_program::block::GenericAttentionMask; + using Kernel = FmhaFwdKernelSelector; + + auto [kargs, grids] = + fmha_fwd_create_kargs_and_grids(std::forward(args)...); + ave_time = fmha_fwd_run(stream, kargs, grids); + }); + } + }); return ave_time; } }; @@ -178,6 +190,8 @@ bool run(const ArgParser& arg_parser) bool use_bias = arg_parser.get_uint32("bias"); + bool store_lse = arg_parser.get_uint32("lse"); + mask_info mask = decode_mask_info(arg_parser.get_str("mask"), seqlen_q, seqlen_k); int init_method = arg_parser.get_int("init"); @@ -197,6 +211,7 @@ bool run(const ArgParser& arg_parser) using KDataType = typename TypeConfig::KDataType; using VDataType = typename TypeConfig::VDataType; using BiasDataType = typename TypeConfig::BiasDataType; + using LSEDataType = typename TypeConfig::LSEDataType; using SaccDataType = typename TypeConfig::SaccDataType; using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; using PDataType = typename TypeConfig::PDataType; @@ -260,6 +275,11 @@ bool run(const ArgParser& arg_parser) Tensor bias_host( use_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] + Tensor lse_host( + store_lse ? std::array{shape_batch, nhead, shape_seqlen_q} + : std::array{1, 1, 1} /* dummy shape for simplifying code */); + Tensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); if(init_method == 0) @@ -288,6 +308,7 @@ bool run(const ArgParser& arg_parser) DeviceMem k_buf(k_host.GetElementSpaceSizeInBytes()); DeviceMem v_buf(v_host.GetElementSpaceSizeInBytes()); DeviceMem bias_buf(bias_host.GetElementSpaceSizeInBytes()); + DeviceMem lse_buf(lse_host.GetElementSpaceSizeInBytes()); DeviceMem o_buf(o_host.GetElementSpaceSizeInBytes()); DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); @@ -314,31 +335,34 @@ bool run(const ArgParser& arg_parser) std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias - << ", mask:" << mask << ", v:" << std::string(VLayout::name)[0] << std::flush; - -#define INVOKE_FMHA_KERNEL(hdim_) \ - fmha_fwd_kernel_invoker{mode, use_bias, mask}(stream_config, \ - q_buf.GetDeviceBuffer(), \ - k_buf.GetDeviceBuffer(), \ - v_buf.GetDeviceBuffer(), \ - bias_buf.GetDeviceBuffer(), \ - o_buf.GetDeviceBuffer(), \ - seqstart_q.GetDeviceBuffer(), \ - seqstart_k.GetDeviceBuffer(), \ - nullptr, \ - batch, \ - nhead, \ - nhead_k, \ - shape_seqlen_q, \ - shape_seqlen_k, \ - hdim_q, \ - hdim_v, \ - max_seqlen_q, \ - scale, \ - i_perm, \ - o_perm, \ - mask.y, \ - mask.x) + << ", lse:" << store_lse << ", mask:" << mask + << ", v:" << std::string(VLayout::name)[0] << std::flush; + +#define INVOKE_FMHA_KERNEL(hdim_) \ + fmha_fwd_kernel_invoker{mode, use_bias, mask, store_lse}( \ + stream_config, \ + q_buf.GetDeviceBuffer(), \ + k_buf.GetDeviceBuffer(), \ + v_buf.GetDeviceBuffer(), \ + bias_buf.GetDeviceBuffer(), \ + lse_buf.GetDeviceBuffer(), \ + o_buf.GetDeviceBuffer(), \ + seqstart_q.GetDeviceBuffer(), \ + seqstart_k.GetDeviceBuffer(), \ + nullptr, \ + batch, \ + nhead, \ + nhead_k, \ + shape_seqlen_q, \ + shape_seqlen_k, \ + hdim_q, \ + hdim_v, \ + max_seqlen_q, \ + scale, \ + i_perm, \ + o_perm, \ + mask.y, \ + mask.x) float ave_time = 0; const auto check_hdims = [](ck::index_t hdim_q_, ck::index_t hdim_v_, ck::index_t threshold) { @@ -380,6 +404,7 @@ bool run(const ArgParser& arg_parser) } o_buf.FromDevice(o_host.data()); + lse_buf.FromDevice(lse_host.data()); bool pass = true; @@ -405,6 +430,7 @@ bool run(const ArgParser& arg_parser) Tensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); Tensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + Tensor lse_host_ref({nhead, real_seqlen_q}); ck::index_t nr = nhead / nhead_k; @@ -455,7 +481,13 @@ bool run(const ArgParser& arg_parser) reference_batched_masking(s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); } - reference_batched_softmax(s_host_ref, p_host_ref); + if(store_lse){ + reference_batched_softmax(s_host_ref, p_host_ref, lse_host_ref); + } + else{ + reference_batched_softmax(s_host_ref, p_host_ref); + } + reference_batched_gemm(p_host_ref, v_host_ref, o_host_ref); Tensor o_host_result({nhead, real_seqlen_q, hdim_v}); @@ -466,11 +498,11 @@ bool run(const ArgParser& arg_parser) auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck::utils::check_err( - o_host_result, o_host_ref, std::string("Error: Incorrect results!"), rtol, atol); + o_host_result, o_host_ref, std::string("O Error: Incorrect results!"), rtol, atol); pass &= cur_pass; if(!cur_pass) { - std::cerr << "mismatch found at batch: " << wb << std::endl + std::cerr << "O mismatch found at batch: " << wb << std::endl << "\tseqlen_q: " << real_seqlen_q << std::endl << "\tseqlen_k: " << real_seqlen_k << std::endl << "\tseqstart_q: " << seqstart_q_host << std::endl @@ -478,6 +510,29 @@ bool run(const ArgParser& arg_parser) break; } + + if(store_lse) + { + Tensor lse_host_result({nhead, real_seqlen_q}); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b, idx[0], idx[1] + query_offset); + }); + + bool lse_pass = ck::utils::check_err( + lse_host_result, lse_host_ref, "LSE Error: Incorrect results!", rtol, atol); + + pass &= lse_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } } std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; diff --git a/example/91_tile_program/fmha/fmha_fwd.hpp b/example/91_tile_program/fmha/fmha_fwd.hpp index 68e1a9a12..8021bbe56 100644 --- a/example/91_tile_program/fmha/fmha_fwd.hpp +++ b/example/91_tile_program/fmha/fmha_fwd.hpp @@ -29,6 +29,7 @@ struct FmhaFwdTypeConfig using KDataType = ck::half_t; using VDataType = ck::half_t; using BiasDataType = ck::half_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using SaccDataType = float; // data type for first gemm accumulation using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck::half_t; // data type for A matrix of second gemm @@ -43,6 +44,7 @@ struct FmhaFwdTypeConfig using KDataType = ck::bhalf_t; using VDataType = ck::bhalf_t; using BiasDataType = ck::bhalf_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using SaccDataType = float; // data type for first gemm accumulation using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck::bhalf_t; // data type for A matrix of second gemm @@ -117,14 +119,20 @@ struct FmhaShape { }; -template +template using FmhaTraits = ck::tile_program::TileFmhaTraits; -template +template using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::QDataType, typename FmhaFwdTypeConfig::KDataType, @@ -132,6 +140,7 @@ using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::SaccDataType, typename FmhaFwdTypeConfig::SMPLComputeDataType, typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, typename FmhaFwdTypeConfig::ODataType, @@ -139,21 +148,31 @@ using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem< FmhaShape, kIsGroupMode, FmhaMask, - FmhaTraits>; + FmhaTraits>; -template +template using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync< - FmhaPipelineProblem>; + FmhaPipelineProblem>; template using FmhaEpilogue = FmhaFwdEpilogue::OaccDataType, typename FmhaFwdTypeConfig::ODataType>>; -template +template using FmhaFwdKernelSelector = FmhaFwdKernel>, - FmhaPipeline, + FmhaPipeline, FmhaEpilogue>; // Kernel API @@ -162,6 +181,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, @@ -208,12 +228,14 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, return i_perm ? hdim_v * seqlen_k : seqlen_k; }(); const ck::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k); + const ck::index_t nhead_stride_lse = (seqlen_q * 1); const ck::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck::index_t batch_stride_q = (nhead * seqlen_q * hdim_q); const ck::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q); const ck::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k); const ck::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k); + const ck::index_t batch_stride_lse = (nhead * seqlen_q * 1); const ck::index_t batch_stride_o = (nhead * seqlen_q * hdim_v); auto kargs = [&] { @@ -224,6 +246,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, k_ptr, v_ptr, bias_ptr, + lse_ptr, o_ptr, seqstart_q_ptr, seqstart_k_ptr, @@ -241,6 +264,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, nhead_stride_k, nhead_stride_v, nhead_stride_bias, + nhead_stride_lse, nhead_stride_o, mask_y, mask_x); @@ -251,6 +275,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, k_ptr, v_ptr, bias_ptr, + lse_ptr, o_ptr, seqlen_q, seqlen_k, @@ -267,11 +292,13 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, nhead_stride_k, nhead_stride_v, nhead_stride_bias, + nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_bias, + batch_stride_lse, batch_stride_o, mask_y, mask_x); diff --git a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp index a326ca1aa..b49fde465 100644 --- a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp @@ -28,6 +28,7 @@ struct FmhaFwdKernel using KDataType = ck::remove_cvref_t; using VDataType = ck::remove_cvref_t; using BiasDataType = ck::remove_cvref_t; + using LSEDataType = ck::remove_cvref_t; using ODataType = ck::remove_cvref_t; using VLayout = ck::remove_cvref_t; @@ -37,6 +38,7 @@ struct FmhaFwdKernel static constexpr bool kN0K1NeedPadding = FmhaPipeline::kN0K1NeedPadding; static constexpr bool kK0N1NeedPadding = FmhaPipeline::kK0N1NeedPadding; static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; using FmhaMask = ck::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -93,10 +95,22 @@ struct FmhaFwdKernel ck::index_t mask_y, mask_x; }; + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck::index_t nhead_stride_lse = 0; + }; + + struct FmhaFwdBatchModeLSEKargs : FmhaFwdCommonLSEKargs + { + ck::index_t batch_stride_lse = 0; + }; + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck::index_t batch_stride_q; ck::index_t batch_stride_k; @@ -107,7 +121,8 @@ struct FmhaFwdKernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -121,6 +136,7 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* lse_ptr, void* o_ptr, ck::index_t seqlen_q, ck::index_t seqlen_k, @@ -137,11 +153,13 @@ struct FmhaFwdKernel ck::index_t nhead_stride_k, ck::index_t nhead_stride_v, ck::index_t nhead_stride_bias, + ck::index_t nhead_stride_lse, ck::index_t nhead_stride_o, ck::index_t batch_stride_q, ck::index_t batch_stride_k, ck::index_t batch_stride_v, ck::index_t batch_stride_bias, + ck::index_t batch_stride_lse, ck::index_t batch_stride_o, ck::index_t mask_y, ck::index_t mask_x) @@ -170,6 +188,7 @@ struct FmhaFwdKernel nhead_stride_o}, // args for common karg {}, // placeholder for bias {}, // placeholder for mask + {}, // placeholder for lse batch_stride_q, batch_stride_k, batch_stride_v, @@ -188,6 +207,12 @@ struct FmhaFwdKernel kargs.mask_y = mask_y; kargs.mask_x = mask_x; } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } return kargs; } @@ -197,6 +222,7 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, @@ -214,6 +240,7 @@ struct FmhaFwdKernel ck::index_t nhead_stride_k, ck::index_t nhead_stride_v, ck::index_t nhead_stride_bias, + ck::index_t nhead_stride_lse, ck::index_t nhead_stride_o, ck::index_t mask_y, ck::index_t mask_x) @@ -242,6 +269,7 @@ struct FmhaFwdKernel nhead_stride_o}, // args for common karg {}, // placeholder for bias {}, // placeholder for mask + {}, // placeholder for lse reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -257,6 +285,11 @@ struct FmhaFwdKernel kargs.mask_y = mask_y; kargs.mask_x = mask_x; } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } return kargs; } @@ -296,6 +329,7 @@ struct FmhaFwdKernel long_index_t batch_offset_k = 0; long_index_t batch_offset_v = 0; long_index_t batch_offset_bias = 0; + long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; if constexpr(kIsGroupMode) @@ -322,6 +356,10 @@ struct FmhaFwdKernel { batch_offset_bias = key_start; } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } batch_offset_o = query_start * kargs.stride_o; // get real # queries & # keys under group mode @@ -354,6 +392,10 @@ struct FmhaFwdKernel { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } @@ -542,6 +584,35 @@ struct FmhaFwdKernel } }(); + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(Number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = + make_naive_tensor_view(lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + Number<1>{}, + Number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, Sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + FmhaMask mask = [&]() { if constexpr(kHasMask) return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k}; @@ -554,6 +625,7 @@ struct FmhaFwdKernel k_dram_window, v_dram_window, bias_dram_window, + lse_dram_window, mask, kargs.scale, // ck::math::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0), diff --git a/example/91_tile_program/fmha/generate_kernels.py b/example/91_tile_program/fmha/generate_kernels.py index a7b8ebda5..eb71b2829 100644 --- a/example/91_tile_program/fmha/generate_kernels.py +++ b/example/91_tile_program/fmha/generate_kernels.py @@ -23,17 +23,23 @@ "group" : "true" } +LSE_MAP = { + "no" : "false", + "store" : "true" +} + HDIMS = [32, 64, 128] MASKS = ["no", "causal", "generic"] DIRECTIONS = ["fwd"] GEN_DIR = "" +LSES=["no","store"] KERNEL_IMPL_TEMPLATE_FWD = """// auto generated by generate_kernels.py #include "fmha_fwd.hpp" -using kernel_0 = FmhaFwdKernelSelector<{HDIM}, {DTYPE}, {MODE}, {MASK}, true>; +using kernel_0 = FmhaFwdKernelSelector<{HDIM}, {DTYPE}, {MODE}, {MASK}, true, {LSE}>; FMHA_FWD_KERNEL_DEFINE(kernel_0) -using kernel_1 = FmhaFwdKernelSelector<{HDIM}, {DTYPE}, {MODE}, {MASK}, false>; +using kernel_1 = FmhaFwdKernelSelector<{HDIM}, {DTYPE}, {MODE}, {MASK}, false, {LSE}>; FMHA_FWD_KERNEL_DEFINE(kernel_1) """ @@ -43,27 +49,30 @@ def __init__(self, hdim: int, dtype: str, mode : str, - mask : str): + mask : str, + lse:str): self.direction = direction self.hdim = hdim self.dtype = dtype self.mode = mode self.mask = mask + self.lse = lse @property def template(self) -> str: if self.direction == "fwd": return KERNEL_IMPL_TEMPLATE_FWD.format( HDIM=self.hdim, DTYPE=DTYPE_MAP[self.dtype], - MODE=MODE_MAP[self.mode], MASK=MASK_MAP[self.mask]) + MODE=MODE_MAP[self.mode], MASK=MASK_MAP[self.mask], + LSE=LSE_MAP[self.lse]) @property def filename(self) -> str: - return f"fmha_{self.direction}_hdim{self.hdim}_{self.dtype}_{self.mode}_{self.mask}_mask.cpp" + return f"fmha_{self.direction}_hdim{self.hdim}_{self.dtype}_{self.mode}_{self.mask}_mask_{self.lse}_lse.cpp" def get_all_kernels() -> List[Kernel]: - for direction, hdim, dtype, mode, mask in itertools.product(DIRECTIONS, HDIMS, DTYPE_MAP.keys(), MODE_MAP.keys(), MASK_MAP.keys()): - yield Kernel(direction=direction, hdim=hdim, dtype=dtype, mode=mode, mask=mask) + for direction, hdim, dtype, mode, mask, lse in itertools.product(DIRECTIONS, HDIMS, DTYPE_MAP.keys(), MODE_MAP.keys(), MASK_MAP.keys(), LSE_MAP.keys()): + yield Kernel(direction=direction, hdim=hdim, dtype=dtype, mode=mode, mask=mask, lse=lse) def write_single_kernel(kernel: Kernel, autogen_dir: Path) -> None: credit = """// SPDX-License-Identifier: MIT diff --git a/example/91_tile_program/reference/reference_batched_softmax.hpp b/example/91_tile_program/reference/reference_batched_softmax.hpp index e707db576..0f5447cff 100644 --- a/example/91_tile_program/reference/reference_batched_softmax.hpp +++ b/example/91_tile_program/reference/reference_batched_softmax.hpp @@ -3,11 +3,15 @@ #pragma once +#include #include "ck/utility/common_header.hpp" #include "ck/library/utility/host_tensor.hpp" template -void reference_batched_softmax(const Tensor& a_b_m_n, Tensor& b_b_m_n) +void reference_batched_softmax( + const Tensor& a_b_m_n, + Tensor& b_b_m_n, + std::optional>> lse_b_m = std::nullopt) { const int N = a_b_m_n.mDesc.GetLengths()[2]; @@ -43,6 +47,11 @@ void reference_batched_softmax(const Tensor& a_b_m_n, Tensor(ck::math::exp(v_a - v_max) * inv_sum); } + // lse + if(lse_b_m) + { + lse_b_m->get()(batch, m) = v_max + ck::math::log(v_exp_sum); + } }; make_ParallelTensorFunctor(f, b_b_m_n.mDesc.GetLengths()[0], b_b_m_n.mDesc.GetLengths()[1])( diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp index e0c00939f..1a379f32e 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp @@ -17,6 +17,7 @@ template ; using SMPLComputeDataType = remove_cvref_t; using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; @@ -51,6 +53,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kN0K1NeedPadding = Traits::kN0K1NeedPadding; static constexpr bool kK0N1NeedPadding = Traits::kK0N1NeedPadding; static constexpr bool kHasBias = Traits::kHasBias; + static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp index 934c5c90f..aa6642393 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp @@ -314,7 +314,7 @@ struct BlockFmhaPipelineQKVS }); }); - return o_acc; + return ck::make_tuple(o_acc, m, l); } template ; using SMPLComputeDataType = remove_cvref_t; using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; @@ -57,6 +58,7 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; __host__ __device__ static constexpr ck::index_t GetSmemSize() { @@ -67,10 +69,12 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, - typename BiasElementFunction> + typename BiasElementFunction, + typename LSEElementFunction> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -80,6 +84,8 @@ struct BlockFmhaPipelineQRKSVS const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, FmhaMask mask, float scale, void* smem_ptr) const @@ -432,6 +438,31 @@ struct BlockFmhaPipelineQRKSVS } } while(++i_total_loops < num_total_loop); + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + constexpr auto lse_spans = decltype(lse)::GetDistributedSpans(); + sweep_tile_span(lse_spans[Number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(is_null_tile_window(bias_dram_window)) + { + lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + // finally, O constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); @@ -457,12 +488,14 @@ struct BlockFmhaPipelineQRKSVS template + typename BiasDramBlockWindowTmp, + typename LSEDramBlockWindowTmp> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale, void* smem_ptr) const @@ -475,6 +508,8 @@ struct BlockFmhaPipelineQRKSVS identity{}, bias_dram_block_window_tmp, identity{}, + lse_dram_block_window_tmp, + identity{}, mask, scale, smem_ptr); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index f3c02e8f9..c527ab342 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -33,6 +33,7 @@ struct BlockFmhaPipelineQRKSVSAsync using SaccDataType = remove_cvref_t; using SMPLComputeDataType = remove_cvref_t; using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; @@ -57,6 +58,11 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + +#if CK_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / math::log2e_v; +#endif __host__ __device__ static constexpr ck::index_t GetSmemSize() { @@ -67,10 +73,12 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, - typename BiasElementFunction> + typename BiasElementFunction, + typename LSEElementFunction> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -80,6 +88,8 @@ struct BlockFmhaPipelineQRKSVSAsync const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, FmhaMask mask, float scale, void* smem_ptr) const @@ -509,6 +519,31 @@ struct BlockFmhaPipelineQRKSVSAsync } } while(i_total_loops < num_total_loop); + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + constexpr auto lse_spans = decltype(lse)::GetDistributedSpans(); + sweep_tile_span(lse_spans[Number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(is_null_tile_window(bias_dram_window)) + { + lse(i_idx) = m_[i_idx] * scale * R_LOG2E + math::log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * R_LOG2E + math::log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + // finally, O constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); @@ -534,12 +569,14 @@ struct BlockFmhaPipelineQRKSVSAsync template + typename BiasDramBlockWindowTmp, + typename LSEDramBlockWindowTmp> __host__ __device__ auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, float scale, void* smem_ptr) const @@ -552,6 +589,8 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, bias_dram_block_window_tmp, identity{}, + lse_dram_block_window_tmp, + identity{}, mask, scale, smem_ptr); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp index 0fb576590..04a3e8f1e 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp @@ -686,7 +686,7 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy TileGemmShape>; - + constexpr auto warp_gemm = []() { if constexpr(is_same_v && is_same_v && diff --git a/include/ck/tile_program/tile/tile_fmha_traits.hpp b/include/ck/tile_program/tile/tile_fmha_traits.hpp index 697c485d1..33fbcafee 100644 --- a/include/ck/tile_program/tile/tile_fmha_traits.hpp +++ b/include/ck/tile_program/tile/tile_fmha_traits.hpp @@ -12,6 +12,7 @@ template struct TileFmhaTraits { @@ -19,6 +20,7 @@ struct TileFmhaTraits static constexpr bool kN0K1NeedPadding = kN0K1NeedPadding_; static constexpr bool kK0N1NeedPadding = kK0N1NeedPadding_; static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLSE = kStoreLSE_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; From 73166db6920afac53189098acf4774f9fa929143 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 19 Jan 2024 19:29:57 +0800 Subject: [PATCH 34/45] Support head dim = 256 for fMHA (#70) * Add new fmha pipeline: BlockFmhaPipelineQSKSVS * Update tile size for hdim=256 * Use BlockFmhaPipelineQRKSVS for hdim=256 * Revert "Update tile size for hdim=256" This reverts commit 8cd70c2d1b683d5c7ecef10ae3f432574c584fa4. * Remove FMHA_FWD_SUPPORT_HDIM_256 option and run hdim=256 tests --- example/91_tile_program/fmha/CMakeLists.txt | 15 +- example/91_tile_program/fmha/fmha_fwd.cpp | 4 + example/91_tile_program/fmha/fmha_fwd.hpp | 16 + .../91_tile_program/fmha/generate_kernels.py | 2 +- .../91_tile_program/fmha/script/benchmark.sh | 28 +- .../91_tile_program/fmha/script/smoke_test.sh | 2 +- .../block_gemm_asmem_bsmem_creg_v1.hpp | 150 +---- ...gemm_asmem_bsmem_creg_v1_custom_policy.hpp | 22 +- .../block_fmha_pipeline_qkvs.hpp | 348 ------------ ...lock_fmha_pipeline_qkvs_default_policy.hpp | 264 --------- .../block_fmha_pipeline_qr_ks_vs.hpp | 1 + .../block_fmha_pipeline_qr_ks_vs_async.hpp | 1 + ...pipeline_qr_ks_vs_async_default_policy.hpp | 8 +- ..._fmha_pipeline_qr_ks_vs_default_policy.hpp | 8 +- .../block_fmha_pipeline_qs_ks_vs.hpp | 532 ++++++++++++++++++ ..._fmha_pipeline_qs_ks_vs_default_policy.hpp | 22 + ..._fmha_pipeline_qx_ks_vs_custom_policy.hpp} | 275 ++++++--- 17 files changed, 827 insertions(+), 871 deletions(-) delete mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp delete mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp create mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp create mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp rename include/ck/tile_program/block_tile_pipeline/{block_fmha_pipeline_qr_ks_vs_custom_policy.hpp => block_fmha_pipeline_qx_ks_vs_custom_policy.hpp} (86%) diff --git a/example/91_tile_program/fmha/CMakeLists.txt b/example/91_tile_program/fmha/CMakeLists.txt index c10947e64..d881d2f2e 100644 --- a/example/91_tile_program/fmha/CMakeLists.txt +++ b/example/91_tile_program/fmha/CMakeLists.txt @@ -14,9 +14,10 @@ add_custom_command( --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) -add_example_executable(example_fmha_fwd fmha_fwd.cpp) -target_include_directories(example_fmha_fwd PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(example_fmha_fwd PRIVATE ${FMHA_FWD_GEN_KERNELS}) +set(EXAMPLE_FMHA_FWD "example_fmha_fwd") +add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp) +target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_KERNELS}) # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this @@ -24,10 +25,14 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2) set(FMHA_FWD_FAST_EXP2 true) endif() +set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) -target_compile_options(example_fmha_fwd PRIVATE "-Wno-undefined-func-template;-DCK_FMHA_FWD_FAST_EXP2=1;-fgpu-flush-denormals-to-zero") +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() -target_compile_options(example_fmha_fwd PRIVATE "-Wno-undefined-func-template;-DCK_FMHA_FWD_FAST_EXP2=0") +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=0) endif() + +target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index e41478f42..f9813bcdf 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -383,6 +383,10 @@ bool run(const ArgParser& arg_parser) { ave_time = INVOKE_FMHA_KERNEL(128); } + else if(check_hdims(hdim_q, hdim_v, 256)) + { + ave_time = INVOKE_FMHA_KERNEL(256); + } else { std::cerr << "not support hdim, will not run" << std::endl; diff --git a/example/91_tile_program/fmha/fmha_fwd.hpp b/example/91_tile_program/fmha/fmha_fwd.hpp index 8021bbe56..43dac7967 100644 --- a/example/91_tile_program/fmha/fmha_fwd.hpp +++ b/example/91_tile_program/fmha/fmha_fwd.hpp @@ -12,6 +12,7 @@ #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck/tile_program/tile/tile_fmha_shape.hpp" #include "ck/tile_program/tile/tile_fmha_traits.hpp" @@ -82,6 +83,10 @@ template <> struct FmhaBlockTile : ck::Sequence<128, 128, 32, 128, 32, 128> { }; +template <> +struct FmhaBlockTile : ck::Sequence<128, 128, 32, 256, 32, 256> +{ +}; using FmhaBlockWarps = ck::Sequence<4, 1, 1>; using FmhaWarpTile = ck::Sequence<32, 32, 16>; @@ -119,6 +124,17 @@ struct FmhaShape { }; +template <> +struct FmhaShape + : ck::tile_program::TileFmhaShape, + FmhaBlockWarps, + FmhaWarpTile, + FmhaBlockWarps, + FmhaWarpTile, + VLayout> +{ +}; + template using FmhaTraits = ck::tile_program::TileFmhaTraits - __device__ auto operator()(const ABlockWindowTmp& a_block_window_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + __device__ constexpr auto MakeCBlockTile() const { - static_assert(is_same_v && - is_same_v, - "wrong!"); - - constexpr index_t MPerBlock = ABlockWindowTmp{}.GetWindowLengths()[Number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.GetWindowLengths()[Number<0>{}]; - constexpr index_t KPerBlock = ABlockWindowTmp{}.GetWindowLengths()[Number<1>{}]; - - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -203,86 +191,7 @@ struct BlockGemmASmemBSmemCRegV1 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - - constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; - constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() % NWarp; - - // construct A-warp-window - auto a_warp_window_tmp = make_tile_window( - a_block_window_tmp.GetBottomTensorView(), - make_tuple(Number{}, Number{}), - a_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - -#if 0 // FIXME: using Array will cause register spill - Array, MIterPerWarp> a_warp_windows{ - {a_warp_window_tmp}}; - - for(index_t mIter = 0; mIter < MIterPerWarp; mIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - StaticallyIndexedArray, - MIterPerWarp> - a_warp_windows; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window_tmp.GetBottomTensorView(), - make_tuple(Number{}, Number{}), - b_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iNWarp * WG::kN, 0}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - -#if 0 // FIXME: using Array will cause register spill - Array, NIterPerWarp> b_warp_windows{ - {b_warp_window_tmp}}; - - for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - StaticallyIndexedArray, - NIterPerWarp> - b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - static_assert(is_same_v, "wrong!"); - - // Construct C-Block-Tensor constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding< Sequence<>, Tuple, Sequence>, @@ -297,51 +206,16 @@ struct BlockGemmASmemBSmemCRegV1 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } - using CWarpDstr = typename WG::CWarpDstr; - using CWarpTensor = typename WG::CWarpTensor; - - constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - // warp GEMM - if constexpr(KIterPerWarp == 0) - { - // c = a * b - c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor); - } - else - { - // c += a * b - c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetYSlicedThreadData( - merge_sequences(Sequence{}, c_warp_y_index_zeros), - merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths)); - - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - } - - // write C warp tensor into C block tensor - c_block_tensor.SetYSlicedThreadData( - merge_sequences(Sequence{}, c_warp_y_index_zeros), - merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.GetThreadBuffer()); - }); - }); - }); - + // C = A * B + template + __device__ auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); return c_block_tensor; } }; diff --git a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index d22d702cb..dfe545bd3 100644 --- a/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -23,34 +23,26 @@ template + typename WarpGemm_> struct BlockGemmASmemBSmemCRegV1CustomPolicy { using AType = remove_cvref_t; using BType = remove_cvref_t; using CType = remove_cvref_t; - using BlockWarps = remove_cvref_t; - using WarpTile = remove_cvref_t; - static constexpr index_t BlockMWarps = BlockWarps::At(Number<0>{}); - static constexpr index_t BlockNWarps = BlockWarps::At(Number<1>{}); - static constexpr index_t BlockKWarps = BlockWarps::At(Number<2>{}); + using BlockWarps = remove_cvref_t; - static constexpr index_t MPerWarp = WarpTile::At(Number<0>{}); - static constexpr index_t NPerWarp = WarpTile::At(Number<1>{}); - static constexpr index_t KPerWarp = WarpTile::At(Number<2>{}); + static constexpr index_t kMWarps = BlockWarps::At(Number<0>{}); + static constexpr index_t kNWarps = BlockWarps::At(Number<1>{}); + static constexpr index_t kKWarps = BlockWarps::At(Number<2>{}); - static constexpr bool TranposeC = TranposeC_; - - using WarpGemm = ck::tile_program::warp:: - WarpGemmMfmaDispatcher; + using WarpGemm = remove_cvref_t; template __host__ __device__ static constexpr auto GetWarpGemmMWarpNWarp() { using namespace ck::tile_program::warp; - return make_tuple(WarpGemm{}, BlockMWarps, BlockNWarps); + return make_tuple(WarpGemm{}, kMWarps, kNWarps); } }; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp deleted file mode 100644 index aa6642393..000000000 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp +++ /dev/null @@ -1,348 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" - -#include "ck/tile_program/tile/tile_distribution.hpp" -#include "ck/tile_program/tile/load_tile.hpp" -#include "ck/tile_program/tile/store_tile.hpp" -#include "ck/tile_program/tile/tile_elementwise.hpp" -#include "ck/tile_program/tile/tile_gemm_shape.hpp" -#include "ck/tile_program/tile/slice_tile.hpp" -#include "ck/tile_program/warp_tile/warp_gemm.hpp" -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp" -#include "ck/tile_program/block_tile/block_reduce.hpp" - -namespace ck { -namespace tile_program { -namespace block { - -// This pipeline is qkv all located in LDS -template -struct BlockFmhaPipelineQKVS -{ - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - - using BlockFmhaShape = remove_cvref_t; - static constexpr bool kQLoadOnce = false; // if q load whole block length (hdim) at once - - static constexpr index_t kBlockSize = Problem::kBlockSize; - - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kN1 = BlockFmhaShape::kN1; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - - __host__ __device__ static constexpr ck::index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - __host__ __device__ auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, - float scale, - index_t num_total_loop, - index_t num_sub_loop_qk, - void* smem_ptr) const - { - static_assert( - is_same_v> && - is_same_v> && - is_same_v>, - "wrong!"); - - static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && - kK0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && - kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && - kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && - kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && - kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], - "wrong!"); - - // Q tile in LDS - auto q_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_window = - make_tile_window(q_lds, make_tuple(Number{}, Number{}), {0, 0}); - - // K tile in LDS - KDataType* k_lds_ptr = static_cast(static_cast( - static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); - auto k_lds = make_tensor_view( - k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); - auto k_lds_window = - make_tile_window(k_lds, make_tuple(Number{}, Number{}), {0, 0}); - - // V tile in LDS - auto v_lds = make_tensor_view( - reinterpret_cast(smem_ptr), - Policy::template MakeVLdsBlockDescriptor()); - auto v_lds_window = - make_tile_window(v_lds, make_tuple(Number{}, Number{}), {0, 0}); - - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - - auto s_acc = decltype(gemm_0(q_lds_window, k_lds_window)){}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = - decltype(tile_elementwise_in(type_convert, s_acc)); - - using PBlockTileType = - decltype(tile_elementwise_in(type_convert, s_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); - - using OaccBlockTileType = decltype(gemm_1( - get_slice_tile(PBlockTileType{}, Sequence<0, 0>{}, Sequence{}), - v_lds_window)); - - // init Oacc, M, L - auto o_acc = OaccBlockTileType{}; - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; - - tile_elementwise_inout([](auto& e) { e = 0; }, o_acc); - tile_elementwise_inout([](auto& e) { e = NumericLimits::Lowest(); }, - m); - tile_elementwise_inout([](auto& e) { e = 0; }, l); - - auto k_dram_block_window = k_dram_block_window_tmp; - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), - v_dram_block_window_tmp.GetWindowLengths(), - v_dram_block_window_tmp.GetWindowOrigin(), - Policy::template MakeVDramTileDistribution()); - - index_t i_total_loops = 0; - do - { - // STAGE 1, QK gemm - auto q_dram_window = make_tile_window( - q_dram_block_window_tmp.GetBottomTensorView(), - q_dram_block_window_tmp.GetWindowLengths(), - q_dram_block_window_tmp.GetWindowOrigin(), - Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for - // load - - auto k_dram_window = make_tile_window( - k_dram_block_window.GetBottomTensorView(), - k_dram_block_window.GetWindowLengths(), - k_dram_block_window.GetWindowOrigin(), - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load - - auto q_block_tile = load_tile(q_dram_window); // prefetch, global read 0 - auto k_block_tile = load_tile(k_dram_window); - { - move_tile_window(q_dram_window, {0, kK0}); // move to 1 - move_tile_window(k_dram_window, {0, kK0}); - - tile_elementwise_inout([](auto& c) { c = 0; }, s_acc); // Initialize C - - store_tile(q_lds_window, - tile_elementwise_in(q_element_func, q_block_tile)); // LDS write 0 - q_block_tile = load_tile(q_dram_window); // global read 1 - store_tile(k_lds_window, - tile_elementwise_in(k_element_func, k_block_tile)); // LDS write 0 - k_block_tile = load_tile(k_dram_window); // global read 1 - } - - index_t i_k0_loops = num_sub_loop_qk - 2; - do - { - block_sync_lds(); - gemm_0(s_acc, q_lds_window, k_lds_window); // GEMM i - block_sync_lds(); - - move_tile_window(q_dram_window, {0, kK0}); // move to i + 2 - move_tile_window(k_dram_window, {0, kK0}); - - store_tile(q_lds_window, - tile_elementwise_in(q_element_func, q_block_tile)); // LDS write i + 1 - q_block_tile = load_tile(q_dram_window); // global read i + 2 - store_tile(k_lds_window, - tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 - k_block_tile = load_tile(k_dram_window); // global read i + 2 - - i_k0_loops--; - } while(i_k0_loops > 0); - - const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile - - { // tail - block_sync_lds(); - gemm_0(s_acc, q_lds_window, k_lds_window); // GEMM num_loop - 2 - block_sync_lds(); - - store_tile( - q_lds_window, - tile_elementwise_in(q_element_func, q_block_tile)); // LDS write num_loop - 1 - store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); - block_sync_lds(); - - gemm_0(s_acc, q_lds_window, k_lds_window); // GEMM num_loop - 1 - } - - // STAGE 2, scale softmax - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); - - const auto s = - tile_elementwise_in(type_convert, s_acc); // S{j} - auto m_local = block_tile_reduce( - s, - Sequence<1>{}, - f_max, - NumericLimits::Lowest()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s.GetTileDistribution()); // Pcompute{j} - - constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); - sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync(rowsum_p, f_sum); - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); - sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - // FIXME: this use different equation from FA v2 paper, - // but produce correc result. - // Is the equation wrong? - o_acc(i_j_idx) *= tmp; - }); - }); - - block_sync_lds(); - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch - move_tile_window(v_dram_window, {0, kK1}); - - const auto p = - tile_elementwise_in(type_convert, p_compute); - - // STAGE 3, KV gemm - constexpr index_t k1_loops = kN0 / kK1; - if constexpr(k1_loops > 1) - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, Sequence<0, i_k1 * kK1>{}, Sequence{}), - v_lds_window); - block_sync_lds(); - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v)); // store next v - move_tile_window(v_dram_window, {0, kK1}); - }); - } - // tail - { - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), - v_lds_window); - block_sync_lds(); - } - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - - i_total_loops++; - } while(i_total_loops < num_total_loop); - - // finally, O - constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); - - sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = 1 / l[i_idx]; - sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - }); - - return ck::make_tuple(o_acc, m, l); - } - - template - __host__ __device__ auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - float scale, - index_t num_total_loop, - index_t num_sub_loop_qk, - void* smem_ptr) const - { - return operator()( - q_dram_block_window_tmp, - [](const QDataType& x) { return x; }, - k_dram_block_window_tmp, - [](const KDataType& x) { return x; }, - v_dram_block_window_tmp, - [](const VDataType& x) { return x; }, - scale, - num_total_loop, - num_sub_loop_qk, - smem_ptr); - } -}; - -} // namespace block -} // namespace tile_program -} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp deleted file mode 100644 index 833e0787d..000000000 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp +++ /dev/null @@ -1,264 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" - -#include "ck/tile_program/tile/tile_distribution.hpp" -#include "ck/tile_program/tile/tile_elementwise.hpp" -#include "ck/tile_program/tile/tile_gemm_shape.hpp" -#include "ck/tile_program/warp_tile/warp_gemm.hpp" -#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" -#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp" -#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp" - -namespace ck { -namespace tile_program { -namespace block { - -// This pipeline is qkv all located in LDS -struct BlockFmhaPipelineQKVSDefaultPolicy -{ - // 3d + padding - template - __host__ __device__ static constexpr auto MakeQLdsBlockDescriptor() - { - using namespace ck; - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number<8>{}), - make_tuple(Number<(kMPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), - Number<8>{}, - Number<1>{}); - - constexpr auto q_lds_block_desc = transform_tensor_descriptor( - q_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return q_lds_block_desc; - } - - // 3d + padding - template - __host__ __device__ static constexpr auto MakeKLdsBlockDescriptor() - { - using namespace ck; - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number<8>{}), - make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}), - Number<8>{}, - Number<1>{}); - - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return k_lds_block_desc; - } - - // 3d + padding - template - __host__ __device__ static constexpr auto MakeVLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t kPad = 1; - constexpr index_t kK1 = 8; - - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, Number{}), - make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number{}, Number<1>{}), - Number{}, - Number<1>{}); - - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(Number{}, Number{}))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return v_lds_block_desc; - } - - template - __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() - { - constexpr index_t lds_alignment = 16; // optional - constexpr index_t q_smem_size = - ck::math::integer_divide_ceil( - sizeof(typename Problem::QDataType) * - MakeQLdsBlockDescriptor().GetElementSpaceSize(), - lds_alignment) * - lds_alignment; - return q_smem_size; - } - - template - __host__ __device__ static constexpr ck::index_t GetSmemSize() - { - constexpr index_t smem_size_gemm_0 = - GetSmemSizeQ() + sizeof(typename Problem::KDataType) * - MakeKLdsBlockDescriptor().GetElementSpaceSize(); - constexpr index_t smem_size_gemm_1 = - MakeVLdsBlockDescriptor().GetElementSpaceSize() * - sizeof(typename Problem::VDataType); - - // TODO: consider shuffle requirement - return math::max(smem_size_gemm_0, smem_size_gemm_1); - } - - template - __host__ __device__ static constexpr auto MakeQDramTileDistribution() - { - using QDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t M0 = kBlockSize / get_warp_size(); - constexpr index_t M1 = kMPerBlock / (M2 * M0); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<1, 1>>{}); -#endif - } - - template - __host__ __device__ static constexpr auto MakeKDramTileDistribution() - { - using KDataType = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t K1 = 16 / sizeof(KDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t N0 = kBlockSize / get_warp_size(); - constexpr index_t N1 = kNPerBlock / (N2 * N0); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<1, 1>>{}); -#endif - } - - template - __device__ static constexpr auto MakeVDramTileDistribution() - { - using VDataType = remove_cvref_t; - ; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - constexpr index_t K1 = 16 / sizeof(VDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<1, 2>>, - Tuple, Sequence<2, 0>>, - Sequence<1, 2>, - Sequence<0, 1>>{}); - } - - template - __host__ __device__ static constexpr auto GetQKBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; - - return BlockGemmASmemBSmemCRegV1{}; - } - - template - __host__ __device__ static constexpr auto GetKVBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; - using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy; - - return BlockGemmARegBSmemCRegV1{}; - } -}; - -} // namespace block -} // namespace tile_program -} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 650964b2f..6c294d856 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -42,6 +42,7 @@ struct BlockFmhaPipelineQRKSVS using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index c527ab342..17c35a4c5 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -42,6 +42,7 @@ struct BlockFmhaPipelineQRKSVSAsync using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp index a22f16253..29bcde926 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp @@ -3,7 +3,7 @@ #pragma once -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" namespace ck { namespace tile_program { @@ -11,7 +11,11 @@ namespace block { // This pipeline is qkv all located in LDS using BlockFmhaPipelineQRKSVSAsyncDefaultPolicy = - BlockFmhaPipelineQRKSVSCustomPolicy; + BlockFmhaPipelineQXKSVSCustomPolicy; } // namespace block } // namespace tile_program diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp index df11a8e87..28d7ba2da 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -3,7 +3,7 @@ #pragma once -#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" namespace ck { namespace tile_program { @@ -11,7 +11,11 @@ namespace block { // This pipeline is qkv all located in LDS using BlockFmhaPipelineQRKSVSDefaultPolicy = - BlockFmhaPipelineQRKSVSCustomPolicy; + BlockFmhaPipelineQXKSVSCustomPolicy; } // namespace block } // namespace tile_program diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp new file mode 100644 index 000000000..5ce44dd2f --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -0,0 +1,532 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/load_tile.hpp" +#include "ck/tile_program/tile/store_tile.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/tile/slice_tile.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" +#include "ck/tile_program/block_tile/block_reduce.hpp" +#include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQSKSVS +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = false; + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; + static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; + static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() + { + return Policy::template GetSmemSizeQ(); + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + FmhaMask mask, + float scale, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + // Q tile in LDS + auto q_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().GetLengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, NumericLimits::Lowest()); + clear_tile(l); + + const auto q_origin = q_dram_block_window_tmp.GetWindowOrigin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.At(Number<0>{}), Number{}, Number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.GetBottomTensorView(), + k_dram_block_window_tmp.GetWindowLengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.GetWindowOrigin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.GetBottomTensorView(), + bias_dram_block_window_tmp.GetWindowLengths(), + {bias_origin.At(Number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), + v_dram_block_window_tmp.GetWindowLengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + do + { + // STAGE 1, QK gemm + auto q_dram_window = + make_tile_window(q_dram_block_window_tmp.GetBottomTensorView(), + q_dram_block_window_tmp.GetWindowLengths(), + q_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeQDramTileDistribution()); + + auto k_dram_window = + make_tile_window(k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); + + auto q_block_tile = load_tile(q_dram_window); + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {0, kK0}); + + clear_tile(s_acc); // Initialize C + + store_tile(q_lds_window, tile_elementwise_in(q_element_func, q_block_tile)); + q_block_tile = load_tile(q_dram_window); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(!is_null_tile_window(bias_dram_window)) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(!is_null_tile_window(bias_dram_window)) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto) { + block_sync_lds(); + gemm_0(s_acc, q_lds_window, k_lds_window); + block_sync_lds(); + + move_tile_window(q_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + q_lds_window, + tile_elementwise_in(q_element_func, q_block_tile)); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, q_lds_window, k_lds_window); + block_sync_lds(); + + store_tile(q_lds_window, tile_elementwise_in(q_element_func, q_block_tile)); + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, q_lds_window, k_lds_window); + } + + // STAGE 2, scale, add bias, mask, softmax + if constexpr(is_null_tile_window(bias_dram_window)) + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + else + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert(bias_element_func(y)); +#else + x = scale * x + math::log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kN0K1NeedPadding || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), + k_origin.At(Number<0>{}), + Number{}, + Number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + Sequence<1>{}, + f_max, + NumericLimits::Lowest()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.GetTileDistribution()); // Pcompute{j} + + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); + sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * m[i_idx]; +#endif + sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(is_null_tile_window(bias_dram_window)) + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } + else + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - m[i_idx]); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(is_null_tile_window(bias_dram_window)) + { + auto row_max = scale * m[i_idx]; + return math::exp2(scale * m_old[i_idx] - row_max); + } + else + { + return math::exp2(m_old[i_idx] - m[i_idx]); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, Sequence<0, i_k1 * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + constexpr auto lse_spans = decltype(lse)::GetDistributedSpans(); + sweep_tile_span(lse_spans[Number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(is_null_tile_window(bias_dram_window)) + { + lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale, + void* smem_ptr) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + mask, + scale, + smem_ptr); + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp new file mode 100644 index 000000000..8fdf2c0b1 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +using BlockFmhaPipelineQSKSVSDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp similarity index 86% rename from include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp rename to include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 04a3e8f1e..73c466c4c 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_custom_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -17,6 +17,8 @@ #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp" +#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" // TODO: remove this @@ -26,16 +28,208 @@ namespace ck { namespace tile_program { namespace block { +template +struct BlockFmhaPipelineQXCustomPolicy; + +template <> +struct BlockFmhaPipelineQXCustomPolicy +{ + static constexpr bool QLoadOnce = true; + + template + __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() + { + return 0; + } + + template + __host__ __device__ static constexpr auto MakeQDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template At<1>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; + + constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t K0 = kKPerBlock / (K1 * K2); + + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<2, 1>>, + Tuple, Sequence<1, 2>>, + Sequence<1, 2, 2>, + Sequence<0, 0, 2>>{}); + } + + template + __host__ __device__ static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + return BlockGemmARegBSmemCRegV2{}; + } +}; + +template <> +struct BlockFmhaPipelineQXCustomPolicy +{ + static constexpr bool QLoadOnce = false; + + template + __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() + { + constexpr index_t lds_alignment = 16; // optional + constexpr index_t q_smem_size = + ck::math::integer_divide_ceil( + sizeof(typename Problem::QDataType) * + MakeQLdsBlockDescriptor().GetElementSpaceSize(), + lds_alignment) * + lds_alignment; + return q_smem_size; + } + + template + __host__ __device__ static constexpr auto MakeQDramTileDistribution() + { + using QDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + StaticTileDistributionEncoding, + Tuple, Sequence>, + Tuple, Sequence<1, 2>>, + Tuple, Sequence<2, 0>>, + Sequence<1, 2>, + Sequence<0, 1>>{}); + } + + // 3d + padding + template + __host__ __device__ static constexpr auto MakeQLdsBlockDescriptor() + { + using QDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPack = 16 / sizeof(QDataType); + + constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number{}), + make_tuple(Number<(kMPerBlock + 1) * kKPack>{}, Number{}, Number<1>{}), + Number<8>{}, + Number<1>{}); + + constexpr auto q_lds_block_desc = transform_tensor_descriptor( + q_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return q_lds_block_desc; + } + + template + __host__ __device__ static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + + return BlockGemmASmemBSmemCRegV1{}; + } +}; + // This pipeline is qkv all located in LDS -template -struct BlockFmhaPipelineQRKSVSCustomPolicy +template +struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy { - static constexpr index_t AsyncCopyK = AsyncCopyK_; - static constexpr index_t AsyncCopyV = AsyncCopyV_; // TODO: this not supported yet + static constexpr bool AsyncCopyK = AsyncCopyK_; + static constexpr bool AsyncCopyV = AsyncCopyV_; // TODO: this not supported yet static constexpr index_t NumPrefetchK = NumPrefetchK_; static constexpr index_t NumPrefetchV = NumPrefetchK_; + using QXPolicy = BlockFmhaPipelineQXCustomPolicy; + template struct LdsBufferSequence { @@ -426,12 +620,6 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy return v_lds_block_desc; } - template - __host__ __device__ static constexpr ck::index_t GetSmemSizeQ() - { - return 0; - } - template __host__ __device__ static constexpr ck::index_t GetSmemSize() { @@ -440,34 +628,8 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy constexpr index_t single_smem_size = GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); - return single_smem_size * math::max(NumPrefetchK, NumPrefetchV); - } - - template - __host__ __device__ static constexpr auto MakeQDramTileDistribution() - { - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template At<1>(); - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength; - - constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane; - constexpr index_t K0 = kKPerBlock / (K1 * K2); - - constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t M1 = MWarp; - constexpr index_t M0 = kMPerBlock / (M2 * M1); - - return make_static_tile_distribution( - StaticTileDistributionEncoding, - Tuple, Sequence>, - Tuple, Sequence<2, 1>>, - Tuple, Sequence<1, 2>>, - Sequence<1, 2, 2>, - Sequence<0, 0, 2>>{}); + return QXPolicy::template GetSmemSizeQ() + + single_smem_size * math::max(NumPrefetchK, NumPrefetchV); } template @@ -675,43 +837,6 @@ struct BlockFmhaPipelineQRKSVSCustomPolicy } } - template - __host__ __device__ static constexpr auto GetQKBlockGemm() - { - using BlockGemmProblem = - BlockGemmPipelineProblem>; - - constexpr auto warp_gemm = []() { - if constexpr(is_same_v && - is_same_v && - is_same_v) - { - return warp::WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; - } - else if constexpr(is_same_v && - is_same_v && - is_same_v) - { - return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; - } - }(); - - using BlockGemmPolicy = - BlockGemmARegBSmemCRegV2CustomPolicy; - - return BlockGemmARegBSmemCRegV2{}; - } - template __host__ __device__ static constexpr auto GetKVBlockGemm() { From bcb6592123122816aa44161f58bfe21788b46a49 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 23 Jan 2024 20:55:01 +0800 Subject: [PATCH 35/45] F8 enablement (#71) * compute correct * improve perf, but seems pipeline has duplicated ISA * refactor generate_kernel * remove duplicated GetBlockQKGemm * finialize a more generic codegen * refactor into autogen API * fix some comment * Use occupancy=1 for hdim=256 * support hdim=256 * modify some comment * we no longer need to change target inside a file * update bench script * add readme * modify --------- Co-authored-by: Po Yen, Chen --- CMakeLists.txt | 2 +- example/91_tile_program/common/arg_parser.hpp | 11 + example/91_tile_program/fmha/CMakeLists.txt | 14 +- example/91_tile_program/fmha/README.md | 90 ++++ example/91_tile_program/fmha/fmha_fwd.cpp | 201 +++----- example/91_tile_program/fmha/fmha_fwd.hpp | 268 +++++----- .../fmha/fmha_fwd_epilogue.hpp | 3 +- .../91_tile_program/fmha/fmha_fwd_kernel.hpp | 84 +++- example/91_tile_program/fmha/generate.py | 398 +++++++++++++++ .../91_tile_program/fmha/generate_kernels.py | 122 ----- example/91_tile_program/fmha/mask.hpp | 127 +++-- example/91_tile_program/fmha/misc/gamc.png | Bin 0 -> 30073 bytes .../91_tile_program/fmha/script/benchmark.sh | 15 +- .../block_fmha_pipeline_problem.hpp | 14 +- .../block_fmha_pipeline_qr_ks_vs.hpp | 1 + .../block_fmha_pipeline_qr_ks_vs_async.hpp | 1 + .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 458 ++++++++++++++++++ ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 60 ++- .../ck/tile_program/tile/tile_elementwise.hpp | 59 ++- .../ck/tile_program/tile/tile_fmha_shape.hpp | 8 +- .../ck/tile_program/warp_tile/warp_gemm.hpp | 25 + .../warp_tile/warp_gemm_attribute_mfma.hpp | 93 +++- .../warp_gemm_attribute_mfma_impl.hpp | 82 ++++ .../warp_tile/warp_gemm_dispatcher.hpp | 11 + include/ck/utility/amd_buffer_addressing.hpp | 61 ++- script/cmake-ck-dev.sh | 8 +- script/cmake-ck-release.sh | 8 +- 27 files changed, 1685 insertions(+), 539 deletions(-) create mode 100644 example/91_tile_program/fmha/README.md create mode 100644 example/91_tile_program/fmha/generate.py delete mode 100644 example/91_tile_program/fmha/generate_kernels.py create mode 100644 example/91_tile_program/fmha/misc/gamc.png create mode 100644 include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index d132a3e4e..5eefd68b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,7 +21,7 @@ set(version 1.1.0) # Check support for CUDA/HIP in Cmake project(composable_kernel VERSION ${version}) -find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) +find_package(Python3 3.7 COMPONENTS Interpreter REQUIRED) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") diff --git a/example/91_tile_program/common/arg_parser.hpp b/example/91_tile_program/common/arg_parser.hpp index 501b92788..58155d078 100644 --- a/example/91_tile_program/common/arg_parser.hpp +++ b/example/91_tile_program/common/arg_parser.hpp @@ -153,6 +153,17 @@ class ArgParser return value; } + bool get_bool(const std::string& name) const + { + auto v = input_map.at(name).value; + if(v.compare("t") == 0 || v.compare("true") == 0) + return true; + if(v.compare("f") == 0 || v.compare("false") == 0) + return false; + int value = atoi(v.c_str()); + return value == 0 ? false : true; + } + float get_float(const std::string& name) const { double value = atof(input_map.at(name).value.c_str()); diff --git a/example/91_tile_program/fmha/CMakeLists.txt b/example/91_tile_program/fmha/CMakeLists.txt index d881d2f2e..aaa73fcfd 100644 --- a/example/91_tile_program/fmha/CMakeLists.txt +++ b/example/91_tile_program/fmha/CMakeLists.txt @@ -1,23 +1,23 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate_kernels.py - --list_kernels ${CMAKE_CURRENT_BINARY_DIR}/kernel_list.txt + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt ) -# NOTE: for cmake, the FMHA_FWD_GEN_KERNELS files must be in the same directory +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory # as current cmake list, otherwise will not figure out the dependency properly -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/kernel_list.txt FMHA_FWD_GEN_KERNELS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) add_custom_command( - OUTPUT ${FMHA_FWD_GEN_KERNELS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate_kernels.py + OUTPUT ${FMHA_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) set(EXAMPLE_FMHA_FWD "example_fmha_fwd") add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_KERNELS}) +target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this diff --git a/example/91_tile_program/fmha/README.md b/example/91_tile_program/fmha/README.md new file mode 100644 index 000000000..b5dde4043 --- /dev/null +++ b/example/91_tile_program/fmha/README.md @@ -0,0 +1,90 @@ +# fused multi-head attention + +This folder contains example for fmha(fused multi-head attention) using ck tile-programming implementation. It is a good example to demostrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. + +## build +``` +# in the root of ck +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make example_fmha_fwd -j +``` +This will result in an executable `build/bin/example_fmha_fwd` + +## kernel +The kernel template is `fmha_fwd_kernel.hpp`, this is the gridwise op in old ck's terminology. We put it here purposely, to demostrate one can construct a kernel by using various internal component from ck. We may still have an implementation under ck's include path (in the future) for the kernel template. + +There are 3 template parameters for this kernel template. +* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. +* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed we did a lot of optimization and trials to optimize the pipeline, and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benifit of different performant implementations (stay tuned for updated pipeline(s)). +* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage, but leave the room for furture possible support. + +## codegen +To speed up compile time, we instantiate the kernels into seperate file. In this way we can benifit from parallel building from cmake/make system. This is achieved by `generate.py` script. Beside, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. + +## executable +`example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/example_fmha_fwd -?` to list all supported args +``` +args: + -v weather do cpu validation or not (default:1) + -mode kernel mode. 0:batch, 1:group (default:0) + -b batch size (default:2) + -h num of head, for q (default:8) + -h_k num of head, for k/v, 0 means equal to h (default:0) + if not equal to h, then this is GQA/MQA case + -s seqlen_q (default:3328) + -s_k seqlen_k, 0 means equal to s (default:0) + -d head dim for q, k (default:128) + -d_v head dim for v, 0 means equal to d (default:0) + -scale scale factor. 0 means equal to 1/sqrt(seqlen) (default:0) + -iperm permute input (default:1) + if true, will be b*h*s*d, else b*s*h*d + -operm permute output (default:1) + -bias add bias or not (default:0) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -mask 0: no mask, 1: top-left, 2:bottom-right (default:0) + 't:l,r', top-left local-attn with left right size + 'b:l,r', bottom-r local-attn with left right size + 'g:y,x', generic attention mask coordinate with y/x size + -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) + -lse 0 not store lse, 1 store lse (default:0) + -init init method. 0:random int, 1:random float, 2:trig float (default:1) +``` +Example: `./bin/example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. + +## support features +Currently we are still in rapid development stage, so more features/optimizations will comming soon. + +### hdim +Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. We may consider optimize other hdim performance if have more request. We also have an experimental support for arbitrary hdim(even odd number), one can change the return value of `get_pad()` inside `generate.py` to achieve this. (Note: we may change the method or optimize arbitraty hdim support in the future) + +### group/batch mode +Currently we support both batch and group mode, by setting `-mode` = `0` or `1`, where in group mode we support each batch can have different seqlen + +### MQA/GQA +By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers. + +### input/output permute, and `b*s*3*h*d` +If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prividing arbitraty stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't privide a cmd line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trival to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`. + +### attention bias +Attention bias is supported with the layout of `b*h*s*s` and bias value in float number. + +### lse +For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` + +### vlayout +We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimention for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimention, it's more easy to support col-major V layout. However the performance of col-major is not necessarily fasater than row-major, there are many factors that may affect the overall performance. We still privide the `-vlayout=r/c` here to switch/test between different layout. + +### generic attention mask coordinate +We unify the mask expression into generic attention mask coordinate, providing an uniformed approach to describe causal top-left, causal bottom-right, local attention. +![](misc/gamc.png) + +(more description to be added) + +### dropout +TBD + +## FP8 experimental support +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `example_fmha_fwd`, on a gfx940/941/942 machine and rocm 6.0+. Currently if you not explicitely setting `-v=0`(which will disable cpu verification), it will printout an error as much as `0.05`. We are still WIP to tune the kernel performance as well as the precision, so stay tuned for the updated performance(pipeline) +Currently we only support `-vlayout=c` for fp8, which is `hdim*seqlen` for V matrix. row major for V matrix support will come later. diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index f9813bcdf..058e5d76b 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -50,19 +50,23 @@ auto create_args(int argc, char* argv[]) .insert("d", "128", "head dim for q, k") .insert("d_v", "0", "head dim for v, 0 means equal to d") .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(seqlen)") + .insert("descale_q", "1", "scale factor for fp8 quantization") + .insert("descale_k", "1", "scale factor for fp8 quantization") + .insert("descale_v", "1", "scale factor for fp8 quantization") .insert("iperm", "1", "permute input\n" "if true, will be b*h*s*d, else b*s*h*d") .insert("operm", "1", "permute output") .insert("bias", "0", "add bias or not") - .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("mask", "0", "0: no mask, 1: top-left, 2:bottom-right\n" "'t:l,r', top-left local-attn with left right size\n" "'b:l,r', bottom-r local-attn with left right size\n" "'g:y,x', generic attention mask coordinate with y/x size\n") + .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float"); @@ -70,65 +74,6 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template -struct fmha_fwd_kernel_invoker -{ - static constexpr ck::index_t HDim = HDim_; - using DataType = DataType_; - // these args are used to select kernel. - // args that may passed as karg shoule use operator() - mode_enum mode; - bool use_bias; - mask_info mask; - bool store_lse; - - fmha_fwd_kernel_invoker(mode_enum mode_, bool use_bias_, mask_info mask_, bool store_lse_) - : mode(mode_), use_bias(use_bias_), mask(mask_), store_lse(store_lse_) - { - } - - template - float operator()(const StreamConfig& stream, Args&&... args) - { - float ave_time; - BOOL_SWITCH_3( - mode == mode_enum::group, kIsGroupMode, use_bias, kHasBias, store_lse, kStoreLSE, [&] { - if(mask.type == mask_enum::no_mask) - { - using FmhaMask = FmhaMasks::NoMask; - using Kernel = FmhaFwdKernelSelector; - - auto [kargs, grids] = - fmha_fwd_create_kargs_and_grids(std::forward(args)...); - ave_time = fmha_fwd_run(stream, kargs, grids); - } - else - { - BOOL_SWITCH(mask.type == mask_enum::window_generic, kIsLocal, [&]() { - using FmhaMask = - ck::tile_program::block::GenericAttentionMask; - using Kernel = FmhaFwdKernelSelector; - - auto [kargs, grids] = - fmha_fwd_create_kargs_and_grids(std::forward(args)...); - ave_time = fmha_fwd_run(stream, kargs, grids); - }); - } - }); - return ave_time; - } -}; - // different threshold for different dtype template auto get_elimit(int /*init_method*/) @@ -158,11 +103,12 @@ auto get_elimit(int init_method) template bool run(const ArgParser& arg_parser) { - int do_validation = arg_parser.get_int("v"); - auto mode = static_cast(arg_parser.get_uint32("mode")); - ck::index_t batch = arg_parser.get_int("b"); - ck::index_t nhead = arg_parser.get_int("h"); - ck::index_t nhead_k = arg_parser.get_int("h_k"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck::index_t batch = arg_parser.get_int("b"); + ck::index_t nhead = arg_parser.get_int("h"); + ck::index_t nhead_k = arg_parser.get_int("h_k"); if(nhead_k == 0) nhead_k = nhead; @@ -181,18 +127,22 @@ bool run(const ArgParser& arg_parser) if(hdim_v == 0) hdim_v = hdim_q; - int i_perm = arg_parser.get_int("iperm"); // if true, will be batch * nhead * seqlen * hdim - int o_perm = arg_parser.get_int("operm"); // if false, will be batch * seqlen * nhead * hdim + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim float scale = arg_parser.get_float("scale"); if(scale == .0f) scale = 1.0 / ck::math::sqrt(static_cast(hdim_q)); // TODO: q ? v ? - bool use_bias = arg_parser.get_uint32("bias"); + float descale_q = arg_parser.get_float("descale_q"); + float descale_k = arg_parser.get_float("descale_k"); + float descale_v = arg_parser.get_float("descale_v"); - bool store_lse = arg_parser.get_uint32("lse"); + std::string vlayout = arg_parser.get_str("vlayout"); + bool use_bias = arg_parser.get_uint32("bias"); + bool lse = arg_parser.get_uint32("lse"); - mask_info mask = decode_mask_info(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); int init_method = arg_parser.get_int("init"); @@ -245,7 +195,7 @@ bool run(const ArgParser& arg_parser) } } - auto get_lengths = [&](int permute, + auto get_lengths = [&](bool permute, ck::index_t b /*batch*/, ck::index_t h /*nhead*/, ck::index_t s /*seqlen*/, @@ -256,7 +206,7 @@ bool run(const ArgParser& arg_parser) return std::array{b, s, h, d}; }; - constexpr bool is_v_rowmajor = ck::is_same_v; + bool is_v_rowmajor = vlayout == std::string("r"); // host memory for storing all the tensor elements const ck::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); @@ -277,8 +227,8 @@ bool run(const ArgParser& arg_parser) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] Tensor lse_host( - store_lse ? std::array{shape_batch, nhead, shape_seqlen_q} - : std::array{1, 1, 1} /* dummy shape for simplifying code */); + lse ? std::array{shape_batch, nhead, shape_seqlen_q} + : std::array{1, 1, 1} /* dummy shape for simplifying code */); Tensor o_host(get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); @@ -321,11 +271,11 @@ bool run(const ArgParser& arg_parser) seqstart_k.ToDevice(seqstart_k_host.data()); // clang-format off - auto layout_str = [&](int permute){ + auto layout_str = [&](bool permute){ if (permute) return std::string("bhsd"); else return std::string("bshd"); }; - auto io_layout = [&](int iperm_, int operm_) { + auto io_layout = [&](bool iperm_, bool operm_) { if (iperm_ == operm_) return layout_str(iperm_); else return layout_str(iperm_) + std::string("-") + layout_str(operm_); }; @@ -335,61 +285,40 @@ bool run(const ArgParser& arg_parser) std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias - << ", lse:" << store_lse << ", mask:" << mask - << ", v:" << std::string(VLayout::name)[0] << std::flush; - -#define INVOKE_FMHA_KERNEL(hdim_) \ - fmha_fwd_kernel_invoker{mode, use_bias, mask, store_lse}( \ - stream_config, \ - q_buf.GetDeviceBuffer(), \ - k_buf.GetDeviceBuffer(), \ - v_buf.GetDeviceBuffer(), \ - bias_buf.GetDeviceBuffer(), \ - lse_buf.GetDeviceBuffer(), \ - o_buf.GetDeviceBuffer(), \ - seqstart_q.GetDeviceBuffer(), \ - seqstart_k.GetDeviceBuffer(), \ - nullptr, \ - batch, \ - nhead, \ - nhead_k, \ - shape_seqlen_q, \ - shape_seqlen_k, \ - hdim_q, \ - hdim_v, \ - max_seqlen_q, \ - scale, \ - i_perm, \ - o_perm, \ - mask.y, \ - mask.x) - - float ave_time = 0; - const auto check_hdims = [](ck::index_t hdim_q_, ck::index_t hdim_v_, ck::index_t threshold) { - const auto compare = - std::conditional_t, std::equal_to<>>{}; - return compare(hdim_q_, threshold) && compare(hdim_v_, threshold); - }; - - if(check_hdims(hdim_q, hdim_v, 32)) + << ", lse:" << lse << ", mask:" << mask << ", v:" << vlayout << std::flush; + + auto fmha_traits = fmha_fwd_traits{ + hdim_q, data_type, mode == mode_enum::group, is_v_rowmajor, mask.type, use_bias, lse}; + auto fmha_args = fmha_fwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + batch, + nhead, + nhead_k, + shape_seqlen_q, + shape_seqlen_k, + hdim_q, + hdim_v, + max_seqlen_q, + scale, + descale_q * descale_k, + descale_v, + i_perm, + o_perm, + mask.y, + mask.x}; + + float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); + + if(ave_time < 0) { - ave_time = INVOKE_FMHA_KERNEL(32); - } - else if(check_hdims(hdim_q, hdim_v, 64)) - { - ave_time = INVOKE_FMHA_KERNEL(64); - } - else if(check_hdims(hdim_q, hdim_v, 128)) - { - ave_time = INVOKE_FMHA_KERNEL(128); - } - else if(check_hdims(hdim_q, hdim_v, 256)) - { - ave_time = INVOKE_FMHA_KERNEL(256); - } - else - { - std::cerr << "not support hdim, will not run" << std::endl; + std::cout << ", not supported yet" << std::flush << std::endl; return false; } @@ -403,7 +332,7 @@ bool run(const ArgParser& arg_parser) if(!do_validation) { - std::cout << std::endl; + std::cout << std::flush << std::endl; return true; } @@ -446,7 +375,7 @@ bool run(const ArgParser& arg_parser) if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); - if constexpr (is_v_rowmajor) { + if (is_v_rowmajor) { // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] @@ -485,7 +414,7 @@ bool run(const ArgParser& arg_parser) reference_batched_masking(s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); } - if(store_lse){ + if(lse){ reference_batched_softmax(s_host_ref, p_host_ref, lse_host_ref); } else{ @@ -515,7 +444,7 @@ bool run(const ArgParser& arg_parser) break; } - if(store_lse) + if(lse) { Tensor lse_host_result({nhead, real_seqlen_q}); lse_host_result.ForEach([&](auto& self, auto idx) { @@ -559,6 +488,10 @@ int main(int argc, char* argv[]) { return run(arg_parser) ? 0 : -2; } + else if(data_type == "fp8") + { + return run(arg_parser) ? 0 : -2; + } return -3; } diff --git a/example/91_tile_program/fmha/fmha_fwd.hpp b/example/91_tile_program/fmha/fmha_fwd.hpp index 43dac7967..a6db9439c 100644 --- a/example/91_tile_program/fmha/fmha_fwd.hpp +++ b/example/91_tile_program/fmha/fmha_fwd.hpp @@ -11,6 +11,7 @@ #include "ck/tile_program/block_tile/block_masking.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" #include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck/tile_program/tile/tile_fmha_shape.hpp" @@ -19,6 +20,7 @@ #include "fmha_fwd_epilogue.hpp" #include "fmha_fwd_kernel.hpp" #include "fmha_fwd_tile_partitioner.hpp" +#include "mask.hpp" template struct FmhaFwdTypeConfig; @@ -53,145 +55,44 @@ struct FmhaFwdTypeConfig using ODataType = ck::bhalf_t; }; -// default settings for FmhaFwdKernelSelector<> type alias -using VLayout = ck::tensor_layout::gemm::RowMajor; // (bs, nhead) seqlen * hdim -// using VLayout = ck::tensor_layout::gemm::ColumnMajor; // (bs, nhead) hdim * seqlen - -struct FmhaMasks -{ - using NoMask = ck::tile_program::block::GenericAttentionMask; - using GenericMask = ck::tile_program::block::GenericAttentionMask; - using CausalMask = ck::tile_program::block::GenericAttentionMask; -}; - -inline constexpr bool kM0NeedPadding = false; -inline constexpr bool kN0K1NeedPadding = false; -inline constexpr bool kK0N1NeedPadding = false; - -template -struct FmhaBlockTile; - -template <> -struct FmhaBlockTile : ck::Sequence<128, 64, 16, 32, 32, 32> -{ -}; -template <> -struct FmhaBlockTile : ck::Sequence<128, 64, 32, 64, 32, 64> -{ -}; -template <> -struct FmhaBlockTile : ck::Sequence<128, 128, 32, 128, 32, 128> -{ -}; -template <> -struct FmhaBlockTile : ck::Sequence<128, 128, 32, 256, 32, 256> -{ -}; -using FmhaBlockWarps = ck::Sequence<4, 1, 1>; -using FmhaWarpTile = ck::Sequence<32, 32, 16>; - -template -struct FmhaShape; - -template <> -struct FmhaShape : ck::tile_program::TileFmhaShape, - ck::Sequence<2, 1, 1>, - FmhaWarpTile, - ck::Sequence<2, 1, 1>, - FmhaWarpTile, - VLayout> -{ -}; - template <> -struct FmhaShape : ck::tile_program::TileFmhaShape, - FmhaBlockWarps, - FmhaWarpTile, - FmhaBlockWarps, - FmhaWarpTile, - VLayout> +struct FmhaFwdTypeConfig { + using QDataType = ck::f8_t; + using KDataType = ck::f8_t; + using VDataType = ck::f8_t; + using BiasDataType = float; // TODO: fix me + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::f8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::f8_t; }; template <> -struct FmhaShape - : ck::tile_program::TileFmhaShape, - FmhaBlockWarps, - FmhaWarpTile, - FmhaBlockWarps, - FmhaWarpTile, - VLayout> +struct FmhaFwdTypeConfig { + using QDataType = ck::bf8_t; + using KDataType = ck::bf8_t; + using VDataType = ck::bf8_t; + using BiasDataType = ck::bf8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck::bf8_t; }; -template <> -struct FmhaShape - : ck::tile_program::TileFmhaShape, - FmhaBlockWarps, - FmhaWarpTile, - FmhaBlockWarps, - FmhaWarpTile, - VLayout> +struct FmhaMasks { + using NoMask = ck::tile_program::block::GenericAttentionMask; + using GenericMask = ck::tile_program::block::GenericAttentionMask; + using CausalMask = ck::tile_program::block::GenericAttentionMask; }; -template -using FmhaTraits = ck::tile_program::TileFmhaTraits; - -template -using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - /* BlockSize = */ HDim == 32 ? 128 : 256, - FmhaShape, - kIsGroupMode, - FmhaMask, - FmhaTraits>; - -template -using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync< - FmhaPipelineProblem>; - -template -using FmhaEpilogue = - FmhaFwdEpilogue::OaccDataType, - typename FmhaFwdTypeConfig::ODataType>>; - -template -using FmhaFwdKernelSelector = - FmhaFwdKernel>, - FmhaPipeline, - FmhaEpilogue>; - -// Kernel API +// internal API, don't use this directly template auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, const void* k_ptr, @@ -211,6 +112,8 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, ck::index_t hdim_v, ck::index_t max_seqlen_q, float scale, + float descale_qk, + float descale_sv, bool i_perm, bool o_perm, ck::index_t mask_y, @@ -283,7 +186,9 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, nhead_stride_lse, nhead_stride_o, mask_y, - mask_x); + mask_x, + descale_qk, + descale_sv); } else { // create batch mode kernel arguments @@ -317,7 +222,9 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, batch_stride_lse, batch_stride_o, mask_y, - mask_x); + mask_x, + descale_qk, + descale_sv); } }(); @@ -325,16 +232,95 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, return ck::make_tuple(kargs, grids); } -// will instantiate this function across different source file +// This is the args from caller to underneath API, different from the kernel +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + void* lse_ptr; + void* o_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck::index_t batch; + ck::index_t nhead; + ck::index_t nhead_k; + ck::index_t seqlen_q; + ck::index_t seqlen_k; + ck::index_t hdim_q; + ck::index_t hdim_v; + ck::index_t max_seqlen_q; + float scale; + float descale_qk; + float descale_sv; + bool i_perm; + bool o_perm; + ck::index_t mask_y; + ck::index_t mask_x; +}; + template -float fmha_fwd_run(const StreamConfig&, typename FmhaKernel::Kargs, dim3); +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + return fmha_fwd_create_kargs_and_grids(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.batch, + args.nhead, + args.nhead_k, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.max_seqlen_q, + args.scale, + args.descale_qk, + args.descale_sv, + args.i_perm, + args.o_perm, + args.mask_y, + args.mask_x); +} -#define FMHA_FWD_KERNEL_DEFINE(KERNEL_) \ - template <> \ - float fmha_fwd_run( \ - const StreamConfig& stream, typename KERNEL_::Kargs kargs, dim3 grids) \ - { \ - constexpr dim3 blocks = KERNEL_::BlockSize(); \ - constexpr ck::index_t kBlockPerCu = KERNEL_::kBlockPerCu; \ - return launch_kernel(stream, KERNEL_{}, grids, blocks, 0, kargs); \ - } +// this is internal API, will be generated across different files to speedup compile +template +struct fmha_fwd_traits_ +{ + static constexpr ck::index_t HDim = HDim_; + using DataType = ck::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + using FmhaMask = ck::remove_cvref_t; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLse = kStoreLse_; +}; + +template +float fmha_fwd_(const StreamConfig&, fmha_fwd_args); + +// This is the public API, will be generated by script +struct fmha_fwd_traits +{ + int hdim; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bool has_bias; + bool has_lse; +}; +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const StreamConfig&); diff --git a/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp b/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp index 94f4b17e5..6c5e6e861 100644 --- a/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_epilogue.hpp @@ -29,7 +29,6 @@ struct FmhaFwdEpilogue using namespace ck; using namespace ck::tile_program; - const auto o = tile_elementwise_in(type_convert, o_acc_tile); - store_tile(o_dram_window_tmp, o); + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); } }; diff --git a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp index b49fde465..06f944b24 100644 --- a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp @@ -24,12 +24,13 @@ struct FmhaFwdKernel static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; - using QDataType = ck::remove_cvref_t; - using KDataType = ck::remove_cvref_t; - using VDataType = ck::remove_cvref_t; - using BiasDataType = ck::remove_cvref_t; - using LSEDataType = ck::remove_cvref_t; - using ODataType = ck::remove_cvref_t; + using QDataType = ck::remove_cvref_t; + using KDataType = ck::remove_cvref_t; + using VDataType = ck::remove_cvref_t; + using BiasDataType = ck::remove_cvref_t; + using LSEDataType = ck::remove_cvref_t; + using ODataType = ck::remove_cvref_t; + static constexpr bool kIsFp8 = FmhaPipeline::kIsFp8; using VLayout = ck::remove_cvref_t; @@ -95,6 +96,13 @@ struct FmhaFwdKernel ck::index_t mask_y, mask_x; }; + struct FmhaFwdFP8Kargs + { + float descale_qk; // q*k + float descale_sv; // s*v + // float * o_amax_ptr; + }; + struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -110,7 +118,8 @@ struct FmhaFwdKernel : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck::index_t batch_stride_q; ck::index_t batch_stride_k; @@ -122,7 +131,8 @@ struct FmhaFwdKernel : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -162,7 +172,9 @@ struct FmhaFwdKernel ck::index_t batch_stride_lse, ck::index_t batch_stride_o, ck::index_t mask_y, - ck::index_t mask_x) + ck::index_t mask_x, + float descale_qk, + float descale_sv) { Kargs kargs{{q_ptr, k_ptr, @@ -189,6 +201,7 @@ struct FmhaFwdKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for fp8 args batch_stride_q, batch_stride_k, batch_stride_v, @@ -201,7 +214,6 @@ struct FmhaFwdKernel kargs.nhead_stride_bias = nhead_stride_bias; kargs.batch_stride_bias = batch_stride_bias; } - if constexpr(kHasMask) { kargs.mask_y = mask_y; @@ -213,6 +225,11 @@ struct FmhaFwdKernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } + if constexpr(kIsFp8) + { + kargs.descale_qk = descale_qk; + kargs.descale_sv = descale_sv; + } return kargs; } @@ -243,7 +260,9 @@ struct FmhaFwdKernel ck::index_t nhead_stride_lse, ck::index_t nhead_stride_o, ck::index_t mask_y, - ck::index_t mask_x) + ck::index_t mask_x, + float descale_qk, + float descale_sv) { Kargs kargs{{q_ptr, k_ptr, @@ -270,6 +289,7 @@ struct FmhaFwdKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse + {}, // placeholder for fp8 args reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -290,6 +310,11 @@ struct FmhaFwdKernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } + if constexpr(kIsFp8) + { + kargs.descale_qk = descale_qk; + kargs.descale_sv = descale_sv; + } return kargs; } @@ -620,17 +645,32 @@ struct FmhaFwdKernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); - auto o_acc_tile = - FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - lse_dram_window, - mask, - kargs.scale, - // ck::math::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0), - // ck::math::integer_divide_ceil(kargs.hdim_q, FmhaPipeline::kK0), - smem_ptr); + auto o_acc_tile = [&]() { + if constexpr(kIsFp8) + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + kargs.scale, + kargs.descale_qk, + kargs.descale_sv, + smem_ptr); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + lse_dram_window, + mask, + kargs.scale, + smem_ptr); + } + }(); // O DRAM and O DRAM window auto o_dram = [&]() { diff --git a/example/91_tile_program/fmha/generate.py b/example/91_tile_program/fmha/generate.py new file mode 100644 index 000000000..d3e05ff12 --- /dev/null +++ b/example/91_tile_program/fmha/generate.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +import itertools +from pathlib import Path +from typing import List, Optional, Tuple +from dataclasses import dataclass +import copy + +DTYPE_MAP = { + "fp16": "ck::half_t", + "bf16": "ck::bhalf_t", + "fp8" : "ck::f8_t" +} + +MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck::tile_program::block::BlockFmhaPipelineQRKSVS", + "qr_fp8" : "ck::tile_program::block::BlockFmhaPipelineQRKSVSFp8", + "qr_async" : "ck::tile_program::block::BlockFmhaPipelineQRKSVSAsync", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false" +} + +MASKS = ["no", "causal", "generic"] +DIRECTIONS = ["fwd"] +GEN_DIR = "" # in Cmake, have to generate files in same folder + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck::Sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps_{F_idx} = ck::Sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile_{F_idx} = ck::Sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape_{F_idx} = ck::tile_program::TileFmhaShape; + +using fmha_trait_{F_idx} = ck::tile_program::TileFmhaTraits<{F_m0pad}, + {F_m0k1pad}, + {F_k0n1pad}, + {F_bias}, + {F_lse}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck::tile_program::block::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + FmhaFwdEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType>>; + +using fmha_kernel_{F_idx} = + FmhaFwdKernel, + fmha_pipeline_{F_idx}, + fmha_epilogue_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}>; + +template<> +float fmha_fwd_(const StreamConfig& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck::index_t kBlockPerCu = k_::kBlockPerCu; + return launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const StreamConfig& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ + switch (t.hdim){{ +{F_hdim_case} + default: + break; + }} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" case {F_hdim}: {{ +{F_inner_dispatch} + }} + break; +""" +MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::causal_top_left || t.mask_type == mask_enum::causal_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}>; + return fmha_fwd_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + vlayout : str + mask : str + bias : str # true/false + lse : str # + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.vlayout}-{self.mask}-{self.bias}-{self.lse}' + +class FmhaFwdApiPool: + def __init__(self): + self.pool = dict() + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for hdim in self.pool[dtype].keys(): + traits=self.pool[dtype][hdim] + inners=str() + for j, trait in enumerate(traits): + if0 = 'if' if j == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if0, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask], + F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_hdim=hdim, F_inner_dispatch=inners) + if1 = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if1, F_dtype=dtype, F_hdim_case=per_hdim_case) + + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along qk seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm : int # number of warps along q seqlen (block warps) + F_rn : int # number of warps along k seqlen(not used) + F_rk : int # number of warps along gemm-k(not used) + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn0}x{self.F_bk1}x{self.F_bk0blen}" +\ + f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" + +@dataclass +class FmhaFwdKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaFwdTileSize + F_vlayout : str # row/col + F_m0pad : str # true/false + F_m0k1pad : str # + F_k0n1pad : str # + F_bias : str # true/false + F_lse : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_pipeline : str # value from PIPIELINE_MAP + + @property + def template(self) -> str: + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_vlayout], + F_m0pad = BOOL_MAP[self.F_m0pad], + F_m0k1pad = BOOL_MAP[self.F_m0k1pad], + F_k0n1pad = BOOL_MAP[self.F_k0n1pad], + F_bias = BOOL_MAP[self.F_bias], + F_lse = BOOL_MAP[self.F_lse], + F_occupancy = self.F_tile.F_occupancy , + F_mask = MASK_MAP[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f"_v{self.F_vlayout[0]}" +\ + f"_p{BOOL_MAP[self.F_m0pad][0]}{BOOL_MAP[self.F_m0k1pad][0]}{BOOL_MAP[self.F_k0n1pad][0]}_{BOOL_MAP[self.F_bias][0]}" +\ + f"_m{self.F_mask[0]}_l{BOOL_MAP[self.F_lse][0]}_{self.F_pipeline}" + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait(hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + vlayout=self.F_vlayout, + mask=self.F_mask, + bias=self.F_bias, + lse=self.F_lse) + +# TODO: design a more practical way to do it +# this is current supported tile size. +def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: + if direction == 'fwd': + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, 2), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, 3), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, 2), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, 1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, 2) + } + else: + return None + else: + return None + +def get_blobs() -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_vlayout(dtype, hdim): + if dtype in ['fp16', 'bf16']: + return 'row' + elif dtype in ['fp8', 'bf8']: + return 'col' + else: + assert Fasle + def get_pipeline(dtype, hdim): + if dtype in ['fp16', 'bf16']: + if hdim == 256: + return 'qr' + else: + return 'qr_async' + elif dtype in ['fp8', 'bf8']: + return 'qr_fp8' + else: + assert Fasle + def get_pad(dtype, hdim): + return 'f' + + gen = list() + api_pool = FmhaFwdApiPool() + + for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): + d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) + if d == None: + continue + for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + tile = d[hdim_str] + hdim = int(hdim_str) + if dtype in ['fp8', 'bf8'] and lse == "t": + continue + k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, F_vlayout=get_vlayout(dtype, hdim), + F_m0pad=get_pad(dtype, hdim), F_m0k1pad=get_pad(dtype, hdim), F_k0n1pad=get_pad(dtype, hdim), + F_bias=bias, F_lse=lse, F_mask=mask, F_mode=mode, F_pipeline=get_pipeline(dtype, hdim)) + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir: Optional[str]) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) / GEN_DIR + + output_dir.mkdir(parents=True, exist_ok=True) + api_pool, kernels = get_blobs() + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_api(api_pool, output_dir) + +# list all the files that will be generated +def list_blobs(output_file: Optional[str]) -> None: + assert output_file is not None + file_path = Path(output_file) + with file_path.open('a') as f: + _, kernels = get_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen api for CK fmha kernel", + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="write all the blobs into a directory" + ) + parser.add_argument( + "-l", + "--list_blobs", + required=False, + help="list all the kernels to a file" + ) + args = parser.parse_args() + if args.list_blobs is not None: + list_blobs(args.list_blobs) + else: + write_blobs(args.output_dir) diff --git a/example/91_tile_program/fmha/generate_kernels.py b/example/91_tile_program/fmha/generate_kernels.py deleted file mode 100644 index 538589ce4..000000000 --- a/example/91_tile_program/fmha/generate_kernels.py +++ /dev/null @@ -1,122 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation - -import argparse -import itertools -from pathlib import Path -from typing import List, Optional - -DTYPE_MAP = { - "fp16": "ck::half_t", - "bf16": "ck::bhalf_t", -} - -MASK_MAP = { - "no" : "FmhaMasks::NoMask", - "causal" : "FmhaMasks::CausalMask", - "generic" : "FmhaMasks::GenericMask" -} - -MODE_MAP = { - "batch" : "false", - "group" : "true" -} - -LSE_MAP = { - "no" : "false", - "store" : "true" -} - -HDIMS = [32, 64, 128, 256] -MASKS = ["no", "causal", "generic"] -DIRECTIONS = ["fwd"] -GEN_DIR = "" -LSES=["no","store"] - -KERNEL_IMPL_TEMPLATE_FWD = """// auto generated by generate_kernels.py -#include "fmha_fwd.hpp" - -using kernel_0 = FmhaFwdKernelSelector<{HDIM}, {DTYPE}, {MODE}, {MASK}, true, {LSE}>; -FMHA_FWD_KERNEL_DEFINE(kernel_0) -using kernel_1 = FmhaFwdKernelSelector<{HDIM}, {DTYPE}, {MODE}, {MASK}, false, {LSE}>; -FMHA_FWD_KERNEL_DEFINE(kernel_1) -""" - -class Kernel: - def __init__(self, - direction: str, - hdim: int, - dtype: str, - mode : str, - mask : str, - lse:str): - self.direction = direction - self.hdim = hdim - self.dtype = dtype - self.mode = mode - self.mask = mask - self.lse = lse - - @property - def template(self) -> str: - if self.direction == "fwd": - return KERNEL_IMPL_TEMPLATE_FWD.format( - HDIM=self.hdim, DTYPE=DTYPE_MAP[self.dtype], - MODE=MODE_MAP[self.mode], MASK=MASK_MAP[self.mask], - LSE=LSE_MAP[self.lse]) - - @property - def filename(self) -> str: - return f"fmha_{self.direction}_hdim{self.hdim}_{self.dtype}_{self.mode}_{self.mask}_mask_{self.lse}_lse.cpp" - -def get_all_kernels() -> List[Kernel]: - for direction, hdim, dtype, mode, mask, lse in itertools.product(DIRECTIONS, HDIMS, DTYPE_MAP.keys(), MODE_MAP.keys(), MASK_MAP.keys(), LSE_MAP.keys()): - yield Kernel(direction=direction, hdim=hdim, dtype=dtype, mode=mode, mask=mask, lse=lse) - -def write_single_kernel(kernel: Kernel, autogen_dir: Path) -> None: - credit = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -""" - (autogen_dir / kernel.filename).write_text(credit + kernel.template) - -def write_kernels(output_dir: Optional[str]) -> None: - if output_dir is None: - output_dir = Path(__file__).parent - else: - output_dir = Path(output_dir) / GEN_DIR - - output_dir.mkdir(parents=True, exist_ok=True) - for kernel in get_all_kernels(): - write_single_kernel(kernel, output_dir) - -def list_kernels(to_file: Optional[str]) -> None: - assert to_file is not None - file_path = Path(to_file) - with file_path.open('a') as f: - for kernel in get_all_kernels(): - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog="generate_kernels", - description="gen kernels for CK fmha kernel instances", - ) - parser.add_argument( - "-o", - "--output_dir", - required=False, - help="Where to generate the kernels " - " will default to the current directory ", - ) - parser.add_argument( - "-l", - "--list_kernels", - required=False, - help="list all the kernels to a file" - ) - args = parser.parse_args() - if args.list_kernels is not None: - list_kernels(args.list_kernels) - else: - write_kernels(args.output_dir) diff --git a/example/91_tile_program/fmha/mask.hpp b/example/91_tile_program/fmha/mask.hpp index 7e3a3fedd..e64df0ba0 100644 --- a/example/91_tile_program/fmha/mask.hpp +++ b/example/91_tile_program/fmha/mask.hpp @@ -35,75 +35,74 @@ struct mask_info os << "g(" << y << "/" << x << ")"; } } - - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); -}; - -std::ostream& operator<<(std::ostream& os, const mask_info& mi) -{ - mi.serialize(os); - return os; -} - -mask_info decode_mask_info(std::string str, ck::index_t seqlen_q, ck::index_t seqlen_k) -{ - ck::index_t x_total = seqlen_k; - ck::index_t y_total = seqlen_q; - mask_info tmp; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) + static mask_info decode(std::string str, ck::index_t seqlen_q, ck::index_t seqlen_k) { - std::string t = str.substr(0, found_0); - std::string v = str.substr(found_0 + 1); - auto found_1 = v.find(","); - if(found_1 == std::string::npos) - { - printf("not supported value %s, %s\n", v.c_str(), str.c_str()); - assert(0); - } - tmp.type = mask_enum::window_generic; - ck::index_t v0 = atoi(v.substr(0, found_1).c_str()); - ck::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); - // TODO: some validation - if(t == "t") + ck::index_t x_total = seqlen_k; + ck::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) { - auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, true); - tmp.y = r.At(ck::Number<0>{}); - tmp.x = r.At(ck::Number<1>{}); - } - else if(t == "b") - { - auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, false); - tmp.y = r.At(ck::Number<0>{}); - tmp.x = r.At(ck::Number<1>{}); - } - else if(t == "g") - { - tmp.y = v0; - tmp.x = v1; + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.At(ck::Number<0>{}); + tmp.x = r.At(ck::Number<1>{}); + } + else if(t == "b") + { + auto r = ck::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.At(ck::Number<0>{}); + tmp.x = r.At(ck::Number<1>{}); + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } } else { - printf("not supported type %s, %s\n", t.c_str(), str.c_str()); - assert(0); - } - } - else - { - // should be 0, 1, 2 - tmp.type = static_cast(atoi(str.c_str())); - if(tmp.type == mask_enum::causal_top_left) - { - tmp.y = seqlen_q; - tmp.x = 1; - } - else if(tmp.type == mask_enum::causal_bottom_right) - { - tmp.y = seqlen_q; - tmp.x = seqlen_k - seqlen_q + 1; + // should be 0, 1, 2 + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::causal_top_left) + { + tmp.y = seqlen_q; + tmp.x = 1; + } + else if(tmp.type == mask_enum::causal_bottom_right) + { + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + } } + return tmp; } - return tmp; + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); +}; + +inline std::ostream& operator<<(std::ostream& os, const mask_info& mi) +{ + mi.serialize(os); + return os; } diff --git a/example/91_tile_program/fmha/misc/gamc.png b/example/91_tile_program/fmha/misc/gamc.png new file mode 100644 index 0000000000000000000000000000000000000000..2c96951f30f99ff0345706c4f3031ab2a623ce1a GIT binary patch literal 30073 zcmeFZ2UOI_w=QY|6&g@UO;$vb0u2ZV2q;L--6TPB5+o=R8w3?;6c8jQ0hQQ9$r4%y zn;;-4&}3;uk~B#|gGhSSfHTfH^Z(y-*E;XMyVg6i)~s1W*RRsvwX5p;zFkjrw3MkR zm?;h%I6$SUat(gqz~Srz2M!4xI|Tl6lpw`@;DE;!)oWMvya#7fv~P+I-P?{}aWh#u zbqu4KOqcZeq0M8_Wbp~S5~_y3tU|(t!i=nzuWR1k{-t70JnBMOWE{i58yi}M`sBXB zBNPw%9L(U=oni%jHgR9nPd}647PuY#tFjrs zD)@qX2C~m_7US8WxArRkrU)Tt?4_ezPI7vFO#RDh1c6|{7Ubwg>2Dyx*y37#VsesK z$l+;jH-d25?Bk2{)EENZW^QQuikQ$?BOT&p-8LFIfbhZ_2sxy0Xx=3D6F%!TR6=c(YwEcNzo zes6J;@!t-;_lErmx-x-6fta(dB*dl4lY3-H&}(AIY2?!1hW{-KuXru_!%;-G<6r{K zhGx$n9MtYO{XP?mT&XF-7lXMvy*{t`sY;1thPQA6eruBu>lm)*%5bk*y&$rVJ+q?_ z*Qj=yf@ew53?l1!`HwoPrLDv#kgYf8gS3bBwLzwQTaxQJiVh}FNG>(|BtUvwv zXV6q=qu||*>#rwh#0z`B6qRPOKR!f^%z4~*;Svpb znhUUxF^-Wq>oo6Msm0HWYmC#_A)ys9V>zRr+S7T2rqGp4E7)=U6N_6EeR^GzD-@~f z4Hd3&PuXt~C-KMfHj>41Wn1As@Q*eFZxfp%sDAgin^|UZ@JB^1+RB89QmLwAhYT)V zuT}`8mH0!T@VbzoOoX`Ltfkn`{6x*-TG_XCo@xp!bCYWC+Ia3@ClpV;j+bUhkI@(S zBJ|0a`tDELA8FUJWWsRkcepxGzGo9gKN*gx3yu4Xl|<&k`Hz1*|M-J&g0)Aw&@F8T zBA2>rL&c!t&sx0ZJ?MO*Qruj!$D6(vJ2rcZgRB+V+KcXE%USMfQkTmk#@=Owu)eKU zZ9gCQl+yQ@a+=t{`5Vb;4x6Lc_n#cp2*Wp2D{w65U9Vqh!VMYPdlBI1mRLtYyezpj zCgrZ{1D0i~7bVG838W%T@Q^mzg-L`Xf?!(udgtmYjz#SOC6|4BgI-L*Y`!%H zEXo4ore4L{H)@TteB_4KpRU@K|13S^_mdNE%P}*YfDC-f-ldNhU}JuYGswGuc^S8r zCPbt@<;&f+TCfrT*=zIx=Yv-PG{KGI#?5y|sDrPv8;6Ola^T*=*ZK%6`fiiSpKXl# z6;9myId7Vi^@NW8x|^>4o7bnw%W;7L(xu;q&{_k<`ax}!MFtqX=J?1L&tK;%2BR2` z{SgA|Tbr(0>7D3yGb2C*Gtpup=8%=FAhIMVm^oGGI ztc^Exr19garuB)Xk_|!zncQUo^3`6US*w$&wHc>gUW-nwoBSTr{{4!@cgk#^&taGz ze(86n2s2;*vaZ2~^ql{*Wj%M4f*l(Q{rZScMbXurmfiSFMiE1#8`b6_kG1%uOTeWZxP7m-cIPmR zAjhmMqxLc&&|sA#E*UV5Eh z^zcs#-jk+MtT2M;)oI(sO!3} zy{A@z_+ZnsNrlc#8;LNv-XVAdzO=$*F4*u<0yErw4K^TBB-|eSO1;GeDu-F+kzl0S zc@qOgPN{onb7Pig0*AgM$J&~#gwU1lt!vx!BLO?zJ3c?k=k)wOmJK<+VO{VFo|gZv zJgb7qLgb2Tz|J>m!nQPNeQj5b+WC5z?J7R6cMs^VWuS566YGl;gLls*${5^LoA8Ck ziY=?n6(=G1#Xnq_5oe4jf){KDO3++pFvUFY_CL0ja~e96@?5rgO|AHZ20v}ZgpV@E9_gr{dsIUe2Bgr#(|6q%bK(BA;gHIVVgs z&b2pl?T#=L^<)mtz@yPV5SpDAYufJShHp_1$gkY+!{1<=q0+sTf(%$1Z7!1i@olbh zggW3ysrja3`;dg(PMw4+zSet@)A>OUO~B7ZA?p$U<%j@#&%uoDIK`+Gr?*AewVU)(QD>AGXXoR9NX)Z=0D_rU_aIp(PFOVn?H;0{TK~i zs9n|?Je0@bKeN$Rnpotw*!V%v+tI1{PNoxq5oHNBec~|?WZCA&4^?^*Vq;>JFqk`DdJc8fmL)a>zA&=66|mXZ=U>ym zs*aBs$|#98)nGzJ)^aq%OyD2HAQcU;a(!W5q80R(9={r9WOKfIZF}j(f1$Jrh72o9 z?ED;daaMoAB^g5{p z>itiXv6w~g{LCbpk@Ze|WnJpv<>otf0UqalmM5R~F0D3(I_l-#QGwV(l^FX)(kM_T zb2ta$pjj3C>rPKBbq16gzZB#B&?8yi^(q?ZJ*R16riUOLEpy)M=|Y92A^>Nkw(Es`D|NN8;utYQ)=dOQG;h4L(F@tWu z%WA(`IjiPJpZrx(jK2qj4awh)Lj`QAdIVLhl<8@Wspc;-uj@6&zs0pZDgU+RTI|&! zfg!y6)^)i*?tS$!UpcIWKg#)urDL8vq1+D9dtK2Ing>=O%V|lo6IN;h8YTNt<{7eJ zAjZRB&utdSRl*q`CBP}8^Y>jY!|ay0ZHhh}s@OOPS936GjOdeKhWEY5utQ4;TnF zI0Hq0LL)5arYCCWz zw&|&9V@45~4bd6H6#-|tnkmmKgiV84CM_zZ3^Oo&3}jfNEO=wlz$PNI?Dt|KpZB`k z_!yT5gYYNOL9Cf*bBz zavlDyfN0o@VzXo=T;iw>f1@n6-)7IfKqoj6icXd}{o>gAEHiG&a+hTU)mdQ=jPlV*0GU#pQ&emXES z{h}L@Ey1WvURlRQ4`0}1L!HccXj#EaYMg0)?4#D_`wHjizO z_b6gNuf}J{6R7YOP*r{?cK~TP^}s#R772rSpkumch~%!YFfMvXbRDT*7Erx1J($KE zSW=le&yax|SccxAmg6}V3C zidht>tBUGmSlkWUfDS z>i5Sr9ir7__Z(*YzGk~J8TQ4g^=G*ASA%eQWG%^E+`EU8k`8`J3&b+$K+ZK;E~1*w z8KUFWrK~d5-BU?^sTy@Z&QA{*Ad*HEa|+C>@v~Y!XK)??$C8=+sKer79AJ!U;|lPi zBcMVCE?2c=P^#CdRVwel+Ux~xaK8js(_te?d&A+-^ro++8#Rlw3ILy5M8968)J2z@{hHIKA$EimYAH)qB4jy8OMR zUHALsJft~4aR0mLVp3=W2Hkg&%*Vy}rmx2X3Oq+@q`vsv^mlw@Esh-owQCqeiu8g; zR1tqR@+}+;GhOw^&}64uVVgTnQjR@05wYrfNlCqN?~{Q>BcZ*jU%bkEXh|r9Hxvzg zj?UgH^mVqd3~yPw<}jH)DSX$e7mLRwAAwju4v8@@wcK4fdrL%%KPba&Ve?z3?{Jcd zwSWDUa$^Bv&LwT*(4| zR(P$XHCe*7EUrj|sCHWg5+#wTxp(yXEhxpx>!~r#i189x3ar|VipI5+6q7>w-|0l6 z|B72xqO1_dx=zrVo@35mf8LLUCSLS#A%3ylpedsNXGD`H6x2|l8Dp7uL*H&$#Vk_I zXNbL~LAX#gs=vza#p7qlnL@A|WTHqhcCYMSFH|rdhG⪼0mN|Vz2CejK)`kn5K(< z`5$`>cO3=+EeEIbS3KYC8U?y?{{OyVRAic+rnT}CzQPTId^&_WX#5B43gseBnmxkH zG|G+DzVF<9Jr!U;sFKKU1%$DWp)&j@*vhy1I%i10xU{SrE!)OVo;CO z@ZdR*o@(`+2C@p?p0GbFOE-!}dLT5!ZH#VRbQ!;ZeWq3M_*LagEA@EmigGe1=0q)Z zj7h9xQ`VzBUZzC%f-0t4cWTr7N55lvN@;hY+-os5DrBFR^~41^H-#(N+10~a_;uoJ za_LZVXt}k;;)gGV%iAuq@S@KiDEE}W`Suo@a>z>=1Fw)e+YOqd+NsKgF0PZ**UL3O zn6^L38ZfIi_pw84TZyiHzL6jC@}MEH!#>p9G(~nqZo%uUb^C*php);JwY1WFM_0MF z%-`b$_aI6&2Re1Q>uqntt>8Rb*J~dr#d5c94>@Zg+b4P~7vRN?DY_r+)06hxf>Pjj)A;#-DiuEVek6nbYfpKyS>njiygMqZ?`vzhzd6 z{xq|S8to%n(~zl}UaR%-60aoOSjZjnn->#mE`z3}qDCIJb(n+^k+&SmhwQVndD_#J&Z@@a zZc&^YGjg?HX)nefDUC$#kAMwvK^?Qz((Y)*-PWE|;tQ zK_OIxYt0D5H%e9h8Rby%#M$6pDSZMmTBcT(=*ZHQ@i%81hTZnQ?58)^__FSQ?j=|#qG#G*L)5#=@MOgM)3--Y z)R^wER57e{;7r#&N_gjkYF@hP1OOVixEr&pREjwB_MCSzOX@%c{ofYujCI^u14fsN zxy&aShOEW1V!JSucIV$4w_l1+|IRkywc8FI&Gpo47FSGO=y_-?rFjk4N$K~s@n|IU zG4RG_;=df9*809}F~?Lvs_H##p7%rdX1n*B>3T5|(%l6~|4CEHqbW9>!a>8RDt@NE z$tN$;n}}b^Fr(#7zdB`>g7@ZmKmY0TmcN9p_c`y=4CqRH%q(ir#d30_*lA0!F@uCN z--IeHJdn1#$Ppr_h=r*?Vt*-!24i_)BcXAz7KJZp=v%jKwYYpq{`s`b!qKU*#j`Bv zzGG^)S4D@uI2PNUo%-Y`Ir-|Yo=64>{gy&4#&LLbC`7O`^f2(!>y)rGam+ZQo7V$h zo1e6DROuYX+O*X+ky<-fW4XC9nD+U`#mzCe)wdL(9+||Ijv^Wt?95XkmQYLvzKd^U zqb6WyOM7Q(w66lODAVndH|ka^x0>rRTV`wE4qrt}2RzM`kX`gtaF5WRgH;dX}bu)v3)=gPr>1qwJxDW*rW4kNovc zgpuOlm~W=hF>ry>TA%b9E=`oXi{I`p+Mb+}+nD~$zsOV*Ml@DK55krkl>-Ml-i`iD zzZ1j8)P&7)5Z>(ofvR-XP=$O2G!x6L)SgYO#xEXz-JIaQT1fM*iaViWI=uIS{iP2@ zR}=%Rq(g42V(2BCnm0DmHkRdn_%V!drSJ?(7Br789aWycC-pAoODfy`K-DBoW>`sB zX=?{0T?L~A&lUf$^mTR8dT)GiY=b9XC~el{Y-4v%YoVQ_pskGyV=)~iRl4XW6HO9m{0Y__ zQ;aua^=y8!W!})PeCr`@rkyeTlA!E!#0+kuXopxAutBr486@O)W^>dY{}Y!~DdfM^ zC%gF_>xK8>AtC8?#I1>e4~B1lSlPu3)ormDJ3rBQGU$y5)WTD%bj!~}4rUzzi4C=Q zD?V}NCYjIA&8cOYohrgU%yGtn-%C2a734=dzv{o4@SdG}GjKqqY_ZupaQNAp<)s<& zHN%yJ^L^=#DagE+WjjC9?6NgQi~PoD0&cBV*8XRFnHD0dueNM|MD_gjluBj5ZFKF_ zE1M%{0^TnAPyc7UJByhV84Ov{&M+Ilz3Uyo?2F_g@iDVp)8^beMprE*PnUvi5kTEA zNX!x`6p5i)9vEliF>1DS2+a66xCSNDTmi`mwH90eOx$%E>!$#( zC@0|+yb$YhlHr~q1wIaNBh7GS9L3t%>~*h64B6@e6E@qf}2U>&rIQ(tJrIm*(ywYpG7?PYgF8^A@es@ zu~R$#0XugO?KN9)MhPH?zmV=<^(ke(!kyo%$%xgeSE0lbL``%-$W?S@*krm@@E&~8 zIvf!V5MGu2h5#~Q_2A$CQY=6`lz5^h^q|Oty-kXA$$y-UU-Yh>iHaG!R3msWL;j(H z)gQ2^-wQA)eKIQ2JWKd+n>QA4;!Y)iM+NTnCpzPlJZgauBK75e6Kw_{JfKqq{;^UO zj2j6H28?@{CMn82hLh0nee}ez5JWV+kb9`&mN;Mo-M^5U0aXGUmDJkCaB_OnEYbbb zG`c%E^laeJE;PGMjXefXKa<~~*_vG#kvciP1Ir^(sIHWP8M_m(M0jQCvY`%e&i`JJ z8tw|PGW-`RR3yOBml%J8d@&)bMFR|iiyxNz-wCUNLhI6yvnpg*C*9u=@Ol10FF_V6 z>IfkYdm-bw3Q<J=(1g)$1c>$keN&R5)g*@`<~e@1xUQTT6C^3&6yED7kS z8W_a-QLFeel2T#*L5|nRJwrpqNv01oeMRMQ5)S#5M=YQuN<32>0M2J$M?1q?4nnNK ziDWlI!HWbtnX#m^3Saqpjs^5Vic_>2wFQ(!j*}R=XGo7Hy-+0*ZH^{E;m2lelJ(pU zV4>jDvKzy2dFo|NFcQJ1<}3>|0~S0(8BvF!O{ZgDAxOY|K}S>oL)FU?De0`Y07E`z z5Sf4k%n4M<7!-VH>14=vk%*-Efp}6`*QRqlg$>aSfL5El$Lb`=8gjM(^tT=+ro#s# zR!2xORs5te)L#2&T0nb_710JtX5<~)EeRNvJ;EU9H*h<9IP~mCmQ?dWnzZNm?8{$W zt!+I#!~Pdg_NWxL=EC<&%Wdttguns`0|WbJqhxR#Y~hz?nj!Io=zYeeT|WdmqmS21 zSj9S5qH!604Z?uHvknZ2A%Wd$Nx#5uGexg_FDYitcO&8ve%Q+>Ij2!5-1Pv&`gN;d z6KO60MYB%{#lBnUYwI}(!l>|ATEN*yTS@IlQCG>ZDl=f#VPO;dyN(K`TA*Za(p#N= zO*)9WhNA;cB%>KcB9{;Rzhl&0y?`+}YqPGt?VFOJnMy^&iS}d($w4dSh=~Izw9>gW z;3{O!S2G{JgK5aTSR zHy%hNA>R8EqXxu+_hr0?YU`cnrtt^hXF${fA?Wu4kh50#jQi?O01n5%`7+G*wXfi> zs|Xd$zH}6*hi9{&?E{J=DRIB{8KV3?O#b3Zahg0KXkYbz<$;wP$OPC|96vbJ=nu$? zKide|qJsOH6x3{q{7h8$`Uf~8n=EB*r!9$tv07>o9QH0rGyF}s`8@FwL@djM$ zJnL<`_6tUea$LI|4WaZ(gE{>Pm8X;_oE#y6duAiT&OF`$x0$8*6T1OU#$+e;Z6$Tm zwFx_#Zd?<6l7uhLS=O%P{sY9(qtOButJOb;ns`KA48{9q9a_5R_%ycdZ2R=SIn$K)wn9!6&G`Pc^YW zr0O)7!j39G7R5|$eet{OZ*vw5!HVPDCDYWoGdL^56HM;9omZvSLb?`poFSs9Ul?97 zJ`U;UE_bV(!|Z<-;R7CSqx*KvYiD~SZTlm($T|D#g3X|_*bTGOZp|JFH!sR>L?}zt zuSjUl#ts0CSnpQCfZLQdnDyu^cy#P@8OT(#WRSDo49p-J2?74{kHhsgm$KjxcapDiN6*c?O_LTYHpl zhS^yFZrb_W@m;=g`@bYmRRsaUrcqmV^GDB74$Ri-&?rs7YN&aE&&K4`&e~*qGNX`L zhQRBequ?%NJ3HjMy}0n9qts)teKnr$xXL!l6IwgAR`M^w1<+bZHf~u5UE~Ug^hB2P;%k;&(c&9 z;S#C1t4-}FwU>@}-5hmo-;OghqJBRMq{V;-fE>=e+#3FOU@z#9BE!GYrk=D(==}2C z%-Iih^X7WT=$6CmbR;QA+vwWa_w0zQH&amjUCbKkv=^|V=z%oVc;=|S)tKNpatPus z(4_9leMCP1Z3fsD4t%TM-xzMkKm1eh17xd!#%@2DXdR+s1K*^b;n?}nY|JlOyD?u= z&0jr6v8y-R#;oDn`WTA&t(2{7n7PeONLMQK3jq;*62zzXi6InReRR0s+mj1BV?u9N zhaB3}@r~f_%Mc8`OLbP#%mO;haNT z=;Fh2xX~q~ss>8B>ZcEJ-U?UUJ14)j0qvsDW|{ps79Y4N`so8JiN?>QrmsFGHQCIy zcVIU!0b3YWujo}GQaP{#BZaB@&i?f>^9oQIwPia$Vy~MOl-O|u!PKt<=xGQ!o&6bb zt&Nw^YyYm3Kuv*cF{r{;IVRq@_lO*sJ-Wb&~ zOLD93-3_9^#6q)>^jl2S{`1b2J-4A9wUIB9pCehiJHOa2e|=bN#+&@_zKp%RjKpYc zK-tQkc6`%R#gK`UhYZu*^%5p2`IIUZds(ANvfh|MzZO97Ll79Xx5aqrY~m$9Cm|Fv zyLf<_4=r(d$g8dCN^!E)yr2>ZocUJqKFKh-eX`n!&%bgQ5^ITwd2}5nEy6!M=4pxpMr=_a94lR$ukFNMlDY(`v^$I_dGD zw{#${5^6#tNw5z%fKMiHk;JgLNp&D)wMro9^LvZ4EF)juM0^+~x!fa?tKHRzA_}*Eb!fpI=V0gU7IgIiM2+m<&Vz$BCcR^H@l(pz7shkjbM0B# z&>UHSYk<`i?%Qm%3xz;_c|!sJPohFGPxLM6vK8^W1z9)1H%B&&tRK8XrGp z^jt}J2lFJ2bOfmJ1zo~pwtZ&9ym_~EW|p28V;0X~mC z93@kH+yU}(lu}`_4*DjKBvoKI6@uF){^iMlC&Sm8TjVSHu4e3GzgC;ZYn6E!gC z`*B-p^XV5aKY7>SBxa|E`1$HuF}?sysp8a1!)l?sn+;`$cUZKP zSBe z-d^4!PQX5MS4wo?q5}k?P36W%%)}R4t5+rvItjgNR`&>| zvWEuPD%^$M^6+PA=RWQ>jB`IKCp3H#-poWpOKz;+giFT%p(6m|+)GpgF7MP3`}H00 zHyl3JY)PWDrlXgWsxON}MnRQzz)1xoJ6U?1yJ*$-cu2CpQALLG(N|N}NfMoxZhSIw^$S;UBWM<$ zsjX-6r`lXWpeo0i6lT|{(Wo$f<(gB8v#>oVLI&U<8p>(ASlpccl+p&2Azzg3+8dZTGX)k*p7*bQ`WX#(y=fex)-)n@kog%ybus`zw$< zLtd@7U_2{s*57GOf$g{5-6{%bgoE~31V%c@zyA0iL;Mk*MsW7EF*ouyam$&tiHA#I zV}@JwSD~oRb0K13Z2V9DfHhhL^H*DA4mjPn<>{5w!{1vQo~4|Up18O5(?iU;@{3{Jtp!Z5Rsj6(qSMk#{fFb)X{AAbbcUm(=2F z+!Xr1!c5Bf?b|b7HsE2e#6b`>6JS6UPbOR=w0CYwu_$GM7+T@f+67-gO{F*ozE8Pz zx7d*VVDY&uk~FQDTOwljbQ@q4L~v+D~)HP2HB(mQjJwV`T{i$EXZEYkN_t zx!V$1^SlyN^2zR+#Wd^VUFtww4lWbXS9uIC;}Qh3YH< z>>wuMwpRg!%e+#!(4;15Ur$^2^}6HRM`OWp-C7;xhoml zqBjd-x2{mi%tKCzKF!PPp4tFcM|H2)ME`-XKo_0+O6B#mpN$Jvtp?gLFFDhKP=2O_ z-+_mSFtm-p`;UspNv6n$R;@Rxcla#z+Tp0!bho!nGbYHqb}51OIs)s3iiJr zV#Bw?)G!weqzqlR1ce?SjSz7mUD37sn(MpW6&~feELZi_hEYHMc_AX>##MTMPnXTG3FF+_~#{#&WBqP~W0^LZff3sf7M&4V;~v^#fai zd$Q2SsfgR*<+gQagj1ygreR30Xfxow6g7K zMZ4f+HpkAA$w=Cxx}SN`LID$@i#3I>|K$i%M!$qFvmHGt_B`{I;S%Dj*|&BzD>J34 zbqc*p>>HtW2vwsID*;;9Pezr4E2=dun3D!Knx@}}4EKuBO2fqYIIHnJ2a?2HEuEdw z>bZ@HiY8Q8Rc}FpzJ#>to9l@SuO}ulNPO7Gf>s%x0{e!umY!FZm)>gY{0eg zL()gAJrLk{16G6v!Q&{o(C`_RW~QTecw;i-&OQq5Po&dByfW>@CF|OH+*7@l;%LZo zH-GM_BJHVYVu?rV#`jUrohnA=j~&mEh)IZjj=;T^%yB;>zf3vh6zte&y4-0q9ONjy zOr_T4JcC(M+e+iYEsyt1?Xqub70l>^E~8_bAYs#viY^f;n%mG`2Q^YOUu^ z|Mw8a3C+#p6#IKYdZgZG6pp2?Tt05FS3olUwsAyf3-5XBsrrnFiR_^r_dnz+wWR;n zVje8aKcYu%Z_V66dr3BWh^Ys0_B%yR=na&yGKI&?FXb&GY2mKou{{n4wjpL>i5hy{ zq4Cu{9rmU820MoQnWrkDap!ej_I6|w@jv9AExX1%OF$l;zIlL{OXna=9 zu)u<1VNxLWttJO<(d~z~%14Vlj&~+*MYFOBjA2_Qx!@FPRn}$x8fB-idObh zRDi^1yU_^v;n=``2)28aPsn}TjP{N1W)(jnE6g5$*4JGdaRKG+h!jcXZS;?1|JbHjk4Nd1M9OWBk4JYGg;E3KEPRrZNC|#%$J`3Q3 z#gBmc#Vz8FM(&lyuh#IcENshu%Hl707niiQRVe}#T&t!lo+t^dMm zQ5e0)F3h%BLxU{Goz>t%F$6v!6!fq2y*n=JYXff0xas!B&I+Kd;acnWbgqXb>RpiK zcF46z(BS8GYIOujHQe;H;b_OXxtz; zcaW)uIG{ehlVfcV#Xk}!0W$CK`s^t}!M^!^Itc|X)6~Z7L^eG2g%(o1x ze5t22>IToHKub)sK%#R&IR{EmXRg&?R=yG$HX}Eof=>+aEjq248rf%iZZAB&#O9II z;tM6YDI`7nUsoKX!9)TT=ZRcBj%%*2241ke_=f;q6IB6b?S)m-ZW}f_!q$jMI4Oyd zY>nx_IqV%)+de_=;8wtymI&RjjX3@}(*hdX2D5bqwsd3U_S_wUJ0)!mm{#FNSQt@$ zKt_>qHN|L3$7gNWt5oRuLc>quOrPW+e0ACvkf8ccEN#Nq#t)F3_$X8<(HaUoVHn{K zuDg*)kNcJ}lFb?K^Z^^n3rRl=N&g@WK{K|_^Hp0g0?DG(AF$TvBuAIL57tTNL6q(! z6osR-U86I79jcpQ#%2Mv?I5`+>mOhZd;^u-g#eJdgr#*?ed`nbA*D2LzERiutnEQh z;aqNUa?<*Q?RmE!{nYGXsmEun z^x8qN958E)v4)@FhuZl9lGL)yBE%9HCO5kw?eW-K4#!!W{>B_6NOD2k58-Rh-~=vp z5D=2soAtXj$EYUmrwRc!HNB&5x-_kv127r-#*^uaLcNojmQ3HGbc`{pbMsSXE$1;a^k>A%lk>_9dU zt7BL*>!&%7lZj6~0}z>waMh*&|Ll zADz;~s5YAhG9M-jLDMyp7?54^H4_f3{o>e*Uqo$8)Gobx1f((T=hw9Cvag4Df6=^| zHoN5P!9Ph^xCn{64LR^9XUmsZBEcT9Wmq~ke{{L;z2_&U>^hT7O4fHy%@fC6wz`f#AiN_35(gxOwQ9=Q8)`m<6@6fZL!mT9&#Gk#{=;>4St`80G*8<~Vrx?3>H#O00t-C~;p z9ho-EiAOCogKpBF<8&Zdz^~pMYcjNxW-7nrc}-~=r1AcHj2+Bqx+v1A*Br`oOIo5@tLyd_?O2a%j_xP zpCfO}QRKzolH;DAO+GjAQ(c~mYU0!SY>;<`X-`N#4__3|BO=62dU&rd?b&_{3)@?3 za$9M^=0jSxviD!0+Rle`PhZqrL|uPQJGheP!jK6{`g zPC1GpTxat>8An^+6xBR$LgW9!WVXBQk+RAl)!iUUoUiB!8C}el^4F@V}s%U2>4s#sfV+4yQTFQvTtS z;GXs|InNq%CRY{hbtf=L@j}W}OYjiwgxM|MmH9WH{Hmfvwj?=2b$oqArmGabsb(MX ze@*S5KHI@y@RD9JlVSZLvx2E`>?^V3J)&vKtaQR^Klt-*SjP{Y3aO`qyC%kv>~Ri7Gh|0r|TMtSTv$cjOCi@5HEiytlmr4>W)@Exli#yiXFlvfEU5x1pYDyGW? zit-b=!#;H!tzxWWzFYTs!G}ZesPa@@Re}d@J+tnO#M!U9=Z0&X-i9q145f{nJG0p+ zA{cC)LlK&JQ}Y$^hcLcQ+Biecw2I9|otGRabZp_<#JcS@*PYR{y6-_kkK{H#n52Or zG#2d)26OoT=#XC6W#B8{*0Bn$0~PWIYL#D)Gr3ne*a%kJ6N2YCOOr{jK>aO;B^DyS zQ|~msDjeEU;3!MY+)-V>MK3eiO5`ntyJj96Df@8j{buDO&Lnw)SX$TiOu$c&-6JN= zeD6Qyv?N$0^c<^jkk{{0GjQosCzKx(*|B`_7%b!V?3IA;o$>geX|A>A$iI?Q_TOXb zH8&3t9FeS*dnzLng}X&#vyHuSQu7QkNqZf%MHGq-$ba#g^%+Sb0pvWrQ0}8DnAIzj zauB-CUE5*?I;d0{d89fF!Rreo<&!X$gx2o#%)Z&E-xNFWLhbvAA^-l~Z|&33yr&zs zLII^*{HUerd)707?~IDr=mNMVNiQ?Sg$Uw7nu{O*!rP5P7Zdm38j$*8stF+8$86Pk zrIt%XBQXFiS6{mYK|H2F8y>N-BLkZ)f;V7!ZD*0ud}FR`f`rxze~6PYKmj>4Blg>B zf)RLl;24xn=tT@zxGu^d2Q2OFTKrLmHWFHDl_21o4*^eqkl+^KVtx-540tiGK=$2T zM*WKs%q_NmeK?v>#{LWSklzD&v+(G?d(^2?b7oFzrN-Kp7s_gvrM<|@?|RbHRn0eb zkT+8DW`gsGTCaw!buNItzB2u`hA|E4s#DV33ewYefnIBsT|F~lW%8Zx{PN9D zH6sgx?4OF0REIOCD1Tt9Bp7o+y3z?BiuXBpzsauqSQDO3eh}eT?KJH`KfYWdB^Jo? z=gj=W0GpoSSi=0O>$QXISTwQ`ncSGk3uIc=?E_X%mClNDSe}I1A20K|C+W#9Pg53} z(L&d7)DWalw_~->v;_8Ng!Uf^E$hFfQ+ppy0hGgRLDpSJfC%lH2i&TamuH-y2yhDz z2{54w_CK(~!+TBDc`bQ7e|oBoltO*UWLE!Oe2YvjNyZm7TvD-FyjU3DxmZ0t9tGh9pzjTTC+r{t zz(>0__VeJ)l;rev5AiTbbRWy8J9u@j*>i>9-_M)RsIKOF0 zB}(+J$L3qu7T!LucXi1b2MJ-+QG5BXd`F2ed}!dIGb+}m%F=Z`HowZWNOkd$8Xw5} z^`r||ow*A19ww_sOhAyk%S3?W>h5V3vCrsY@S;i8J)9msVGh%JJ9;?kJb*J6v=u7N zhANt07%7t0**|$x2TnTRP>`EFgHx)!2{(U!!wc_&nRGy;Elts!K|K9yHNsuPC;J9q zjTs>GV74xa%}Xw!j{Jb$@$HR@hJB)c{z}S#YS0&JICX4smI=Pr$6C{b;vkBT!9POZ z8t_XGfPj$h6&%Q##9%j1Q&7ZE(KUe;1eG}zt&o`{d`fzdLn1J2mrET$A{tFuSs0~+QG zI8k4^93W+a|MbVbrFr6yS};!f>y}eSk}}oPQHwxo@Sdp?xa0g8Z91)(NWkqL58yXZ>Zr#6xuf5(+kq9(nVS2*9`<1m&`iid@_Y8-@Q-( zZ(y^4gU`PjBu|l+*-wRvg5I(8WMs`p{Km~UR8Pm0zJ9+zTxAQmY$P=8((TW?d+qs0 z(_IiJnq8DHWix0(N@SUN#Vt6j*~mtEz5Osg`Iw1X$5G`sZ%%I%;oXxF03Pr=uIBj! zOJefIZQoaBsT@M&{(&EdXewR)pW?1O9?G_DSC(X}>_W1nkTtuaQi?{N#=b<1E&IMq zmO_?fuZ)mHG?o~$8?x_P3{A{f%FfuDknOt$PtWsy@B6&ppWpn&?>F~7=Y3u0bzbLn z9>;my8Fc3PmBW~7cy*0#R=y}=kmlXw-y)38W`cz zcJD$ui%L1F5r_`=&3sIAqd6M?UuQ&c0c+of{H$J1!9+>%^pAh?U4TNC{2xyCU)>vW z2dG$m36NGK_=^Y79eU5OwdN3h(2a_8Sv=XkSG$1w%x(m>~*f zEu~3dEQi@s67!8&@e-}OS&vOBc5gZMJMe>2X60galHZswp9CPkb@j4@5G=B_@3SS6 z1%oCs2T+XEKN{p3jXfAaj;phD0fA;RbxX<=ORCt^Ik%u9=3mfE=k*V+3Fy$JCGw~ zw~mEiskHVdk7U}S3{_nE0N(_|Kmkwc%R=2X7d~6mb)nif!!Vfx&sU9s*Qee`M2eq1 zpr3>w&kzB7W^?<)Zn8{(&5uc5kw}f-`bMkET5IW>l`h|lma@9AO=-G$PtfowX?Tav za?Wi}P%@6@;Lf;YG`}* zX;VC7j?>-gjzEqjQ!z2kA|2!&cSHWU>rS@hQZxSz5cXgFWv`$zQI!FqSSR=TNke=> z3SKMBWi&)&LIj)Yt|?ID?&UnEsf@R_DC0cO4YY_dnz5w8v8NaFc+%{n8@&!IDiRxE4ibU+Ic z5Kd_32u%PYBx%P<+9N3g9AfEZdB@N4fjv56r7t>QBX3c&&U72onozUSxK8ITCm={l z;kjTeWNxboItN zt&4f0-4bdb>wG2vId;HlFyoIT%1mleR$bLtBMVp#9onA46DR&S%l}E^sxBB0pqg70 zO@y!A&=)PMQmOO3BK@`|u1!qmmfJE;{o+9P*LxZjC?)L-6El!#o2Yw>3%IbUx4S4Bm9VOy zTPdtxjTJaPtq(!r??=_Mz|eTZalc!$FH}Kneoq`SUyx15pnD-a5P^(9b|#yj7reHq)e@ov{F z=W=w^YT}7(j4kqJvH!Y^nOU(_)Ti^D9xylhxP-1HvT|vvHY1^sua#^ZQ4kiBpY^X^ z5PaOi7zbwD^eDYxmQ1B!X}l%vd0R^xH6tUt#$Q)BdJx$h zhTbpX-53WKo##Cgb;ZD0#Jh6UHLT`vHFTCa2}&| z(fwO6g*jh>;YkCgI4!3io+lsV;i)B6o;@L*tU&Vh6zt|wX6hh?01%-$DT-#YWyYcK zjhsOjSf_{uoYj$d{f?XIy|;|N!~sHzVsV466Gb9V(>2ig2m$ruzg95-Qc)x{OpCUj zF}zlxg;s(Eq%oGzEkEvQh#z$5qStI@p8&{A2Q{!VrzmQ{vyFP zH}NYS!Yryk_q(`m7D5InGbCmiOSvl9q@ttThae;2z+A$w9&bI611 zqttCO?UTfKxV%8CHc!L!?;G+EzlP1cYHm6fv`PNJ<14x)ans#Mfkb?Kr$6s7#^&J5 zSzH;cZhWOk1}X*wpM`?i{;HT~m~ik&9h|Owb1d_Fnss-qqo#|cO57=MA6?EXh!AZD zRsYSbCQz!{S2r~AT?#$wB>H*f(;iF}`NGXpnp?$PxuS;SrV>rSY_bEVW53~-950jf z=44h=wc6MgDd7npWrs2KP200$6q={8&*9b)IBtGrM1#|V-LcVDyGG1HRRSxWk*+um zZt7yWZq9<=K=O4wWkhbl5mMk5!{sLz8Su23+7Db67Nup;MXzGl^z^d2*J8U&XV_~r z1+u!)&Sd8Atk{>%5@!jbasmdvkk3FiF%TdlC;bwsX~AhQvZ-v89d=Gk5uJ77+o#T> zZP8cb39X=_>E}Gxgd9~^CGi^ZRoebdE4s+9$4iSeA9l7AyX6K9Qo4>y3Su3Zqnr1svS3Q^e+SN8c3*b zE*a^KU@-37;3QMNosBI#M_96(&r>&}=(gXZ;>`Z?&SZfDNf$HJQX<^7Z_(%d+SCy} z+~jJNwLsd7N;<@vcW#OOu2_Z%cX(w4bEpfXa33S)2e7j_{2~Qg)AxQ33*^~h{s})U zeRI0F*HxV{W=D9-lD75O`P;=qwsOYC$PXo>tuw0|%a_V2eeCJId_8LPG)!?_g(DL% z_j1qK)3D;MGZ%kdE!bWXdVW#w`><@*+L%WO(?FMVm@WgHb%%GMm|Cx9OMKX%EJWZwU@95C(|9(T@7hM zPox&Vk{d8mt*xbc?>bmcRSXgnNdwVuoiceP`Eip49|#h0vp;I6nw0ry&XEs>7FAjf zn6j?}F|qdA4Y`~*{Yk?YGw=)`ujfaeFn>DJ$+QnE!CYt2e}Ka$N`5PKuN6gYPKY`$&Tg+I!z1_v&Cw){7}x~HoD02gO2j@l>I8Q zRW^2mX6((}Kmp7oN#4+XrAW~JH1vAbZO3O(Zf!@Z~(}gQzmRpEkE%0Fx`_e95_BAa3ObqysgF0 z5H{E)d|poJve%cVT_Jf@`2!>6!u8Q)eX;sZKl*y8T>cnx)phyqASo;|BS0XSFD#R@ z3@N2q*CA!aH+b!D+76LJ^9ctA0oe;a&=btmn}GZz-jsnqQ9`foVgtd@iCnsko|-Kw z6KWIk4*jf{t0nhzk^T2|86u$LI!Z?zvU9ovw826EFw-2KLh_H)C?LtfnUq z(=<>H@$<{)1$vzqNz3OqQ%He(*+)R$oTth#!Qm*H`2n8+=tiqX0kw{|Vp@;Qo~fbR zbLz)vu#NZSeL2u5MD#|A}RV zX#L~Z6C5qK&_}#jIpc-2cp5B1m-Y!E^kbyW2rBkouz^w%n*RwrW=5Mfl?0Z@cb|T- zcMl*idUi4z5mVVEDbyeI<%;Soa~XIV?L=xG0mWbv*+CrQNi_mC*ciA&@J~huVbduY zyQC<}_{m0I@n*bL2J7CH$w)eJzgIItu?vOu2L zgxJvImx+mm=X|-heq6);paK`2J%9x2u?Cg&lUOQo%SnnFQ5$99S zUfc8tR%6fi%~Gk%u}}+N;b7auA9b3iR)w7QS|N9SwI1&wSia5*%IVvlBFO?4=0+lA zW2YNUFac!>E?(zRrPbA5b4xjWS#t(w?gUW2BFzuZ)vI^r@{{UIUwjRdYM+rj1*y<* zpRFFOoBi2kOM+_-->pXVnMMV}M?TiOh(L@pESWW(<|N1d4Et=?rInL!WZ!$!PzcbR zspi$5B)v}yzR+c2LncB(wdfd$;2&>rF9%o`&@6gZ)Ve!v9Hk8X%%dtv3VJKiuUEJqrZlM?Qs zW*sD+)ZQvFx1aU6xed*vy}&!<-(5M3-`@hHuc_2cWSTf#%_6duSHD*Bdb3`*zX{Nu zOGHUR+l-Y5=DJ<H26-_*CX=& z?m{)er#GZCl)XkflE`{yHflJHg6q`8l?P~?R}~vnUEQ+E{o{Pb>)O!Vt>73CkzV|~ z)t1s&m*bR*NVR>QJA2NyA(1iQNQo=qq*SM{6~1G%>gU+cx+J$bUA%8OZ7yEv%0*&R zScFfH?AYVUMk8o`Ro!(cO9E`aG#W`c$bI8I)wT=6-(4HXgaiIMZ1JfdfOtddzwcVkhH&^{_#2CD%7usUhEoZr4q`2io@qB{+ z%^{8tS6w05kQ-KEwc_T#ofZ)bS7~vkIYDQG8@rDJg{arfi%PnJQcfL*;?9tQ3oRa*25yv(+&I*!q;LXO5w9CYXOUq7bxP6s z-{JP%Ba)8mzYq~jkdb^`t@^9WuT+Vc;#@Yq~v zZ4Vcax!IEzDI78_IvjH-@h%d2a1v<+xj9L_o1_7eVKRmo8S@(S(8hB9<5#dfzT!{Z zE$pH?-fIg2F$4j?uCZXg_FF>&nr|NiFRGRK5BFCDQq6DN7L`8^20Y3KQTbN;mA>9L zUVgH=;N3)Pbu||V{_X*?QNAoY{qly?}<$g_%UY+s@@T&!_ zXxmrYaHCat0zMg1btylfF@(xQ9WC8(|dQ?Tfc7;iI{cVuhik6_t80IVr`*t%_~Im4@#6(2At?^zl-=9xv#uG z9bM#n@=7b4$I6I~AsgBo!%3y4EtoYer8^LvS!Oe|{K=_|77>2c+t7$(>1UCak1DZW zv-|b0EPQ`(xwyhDSGT^ID2jd2?QnYg>r%lxlSr}Sju@Pev&-(-1tv6L6(P3n+Z%^oOR^`HQdZ(Q~ExLXPOo^6e+Nd221r;Yf8f zZ?5dQ@noj*Ck@~sCConZ7t=0%FT`#SVYu$zkWVy*WIasD|D1)HSXy|AnN|4p>(oXQ zL1wTwqVPgQ==jpgLR3%$8RdZCrtwEFvLE83gs#yCSZr=mjCDi@F_km#>S+^x$7;u_ zn|EvK`!qe2?ukT9 zH!P)xf*yg7&HD2{UdJ;4pPSo^wA>zOdr3S!=GNn%>(L6k3Y!qx2R1&qK2w-M}FF zry1F9aSagt@K}U|8NcAYxs>iHOwNSQ+SFum*bnv^KLa<^R{FKKR;QgNlf+|Tb+@1j zi~?tx|J#y?(!T;BP4s*=0X?1X(6Tq6*aH(9`FD++Ce?Pu* uCMMxplu>>d__WlV9uVdQpWT3+L#|9QWlatdsU#6!wHtS?f4Fu(@P7cv2xV;m literal 0 HcmV?d00001 diff --git a/example/91_tile_program/fmha/script/benchmark.sh b/example/91_tile_program/fmha/script/benchmark.sh index 61daae483..a8f3a8202 100644 --- a/example/91_tile_program/fmha/script/benchmark.sh +++ b/example/91_tile_program/fmha/script/benchmark.sh @@ -6,14 +6,15 @@ VALID=0 for prec in "fp16" "bf16" ; do for perm in 0 1 ; do -for hdim in 128 64 256 ; do +for hdim in 64 128 256 ; do -$EXE -prec=$prec -b=32 -h=16 -d=$hdim -s=512 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=16 -h=16 -d=$hdim -s=1024 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=8 -h=16 -d=$hdim -s=2048 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=4 -h=16 -d=$hdim -s=4096 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=2 -h=16 -d=$hdim -s=8192 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 -$EXE -prec=$prec -b=1 -h=16 -d=$hdim -s=16384 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -v=$VALID ; sleep 3 done done diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp index 1a379f32e..56f2ba9d9 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp @@ -6,6 +6,7 @@ #include "ck/ck.hpp" #include "ck/utility/get_id.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/data_type.hpp" namespace ck { namespace tile_program { @@ -21,7 +22,6 @@ template ; using Traits = remove_cvref_t; - static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0, - "kBlockSize should be divisible by get_warp_size()"); - - static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; // attributes from traits @@ -55,6 +52,13 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasBias = Traits::kHasBias; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr bool kIsFp8 = + (is_same_v || is_same_v)&&( + is_same_v || + is_same_v)&&(is_same_v || + is_same_v)&&is_same_v && + is_same_v; }; } // namespace block diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 6c294d856..0458d07ec 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -43,6 +43,7 @@ struct BlockFmhaPipelineQRKSVS using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 17c35a4c5..e98c19649 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -43,6 +43,7 @@ struct BlockFmhaPipelineQRKSVSAsync using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp new file mode 100644 index 000000000..266b45198 --- /dev/null +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -0,0 +1,458 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#include "ck/tile_program/tile/tile_distribution.hpp" +#include "ck/tile_program/tile/load_tile.hpp" +#include "ck/tile_program/tile/store_tile.hpp" +#include "ck/tile_program/tile/tile_elementwise.hpp" +#include "ck/tile_program/tile/tile_gemm_shape.hpp" +#include "ck/tile_program/tile/slice_tile.hpp" +#include "ck/tile_program/warp_tile/warp_gemm.hpp" +#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck/tile_program/block_tile/block_reduce.hpp" +#include "ck/tile_program/tile/shuffle_distributed_tensor.hpp" + +namespace ck { +namespace tile_program { +namespace block { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQRKSVSFp8 +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static constexpr bool kIsFp8 = Problem::kIsFp8; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; + static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; + static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + __host__ __device__ static constexpr ck::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + __host__ __device__ auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported + FmhaMask mask, + float scale, + float descale_qk, + float descale_sv, + void* smem_ptr) const + { + static_assert( + is_same_v> && + is_same_v> && + is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK0 == KDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kN1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kK1 == VDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.GetWindowLengths()[Number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(Number{}, Number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().GetLengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.GetBottomTensorView(), + q_dram_block_window_tmp.GetWindowLengths(), + q_dram_block_window_tmp.GetWindowOrigin(), + Policy::template MakeQDramTileDistribution()); + + auto q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, NumericLimits::Lowest()); + clear_tile(l); + + const auto q_origin = q_dram_window.GetWindowOrigin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.At(Number<0>{}), Number{}, Number{}); + + const auto num_total_loop = math::integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.GetBottomTensorView(), + k_dram_block_window_tmp.GetWindowLengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.GetWindowOrigin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.GetBottomTensorView(), + bias_dram_block_window_tmp.GetWindowLengths(), + {bias_origin.At(Number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.GetBottomTensorView(), + v_dram_block_window_tmp.GetWindowLengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // auto q_tile = tile_elementwise_in(q_element_func, q); + auto q_tile = q; + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + + scale = scale * descale_qk; + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.GetBottomTensorView(), + k_dram_block_window.GetWindowLengths(), + k_dram_block_window.GetWindowOrigin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // Initialize C + store_tile(k_lds_window, k_block_tile); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(!is_null_tile_window(bias_dram_window)) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(!is_null_tile_window(bias_dram_window)) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, i_k0 * kK0>{}, + Sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile(k_lds_window, + k_block_tile); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, (k0_loops - 2) * kK0>{}, + Sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, k_block_tile); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + Sequence<0, (k0_loops - 1) * kK0>{}, + Sequence{}), + k_lds_window); + } + + // STAGE 2, scale, add bias, mask, softmax + if constexpr(is_null_tile_window(bias_dram_window)) + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } + else + { + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_FMHA_FWD_FAST_EXP2 + x = scale * x + type_convert((y)); +#else + x = scale * x + + math::log2e_v * type_convert((y)); +#endif + }, + s_acc, + bias_tile); + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kN0K1NeedPadding || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.GetWindowOrigin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), + k_origin.At(Number<0>{}), + Number{}, + Number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -NumericLimits::Infinity(), [&](auto tile_idx) { + const auto row = q_origin.At(Number<0>{}) + tile_idx.At(Number<0>{}); + const auto col = k_origin.At(Number<0>{}) + tile_idx.At(Number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + Sequence<1>{}, + f_max, + NumericLimits::Lowest()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.GetTileDistribution()); // Pcompute{j} + + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); + sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + auto row_max = scale * m[i_idx]; +#endif + sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_FMHA_FWD_FAST_EXP2 + if constexpr(is_null_tile_window(bias_dram_window)) + { + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + } + else + { + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - m[i_idx]); + } +#else + p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, Sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(is_null_tile_window(bias_dram_window)) + { + auto row_max = scale * m[i_idx]; + return math::exp2(scale * m_old[i_idx] - row_max); + } + else + { + return math::exp2(m_old[i_idx] - m[i_idx]); + } + }(); +#else + const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + block_sync_lds(); + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v_prefetch); + store_tile(v_lds_window, + v_shuffle_tmp); // store the prefetch + } + else + { + store_tile(v_lds_window, + v_prefetch); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = cast_tile(p_compute); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, Sequence<0, i_k1 * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(ck::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_distributed_tensor(v_shuffle_tmp, v); + store_tile(v_lds_window, v_shuffle_tmp); + } + else + { + store_tile(v_lds_window, v); + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, Sequence<0, (k1_loops - 1) * kK1>{}, Sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + // finally, O + constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans(); + + sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + tmp = tmp * descale_sv; + sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } +}; + +} // namespace block +} // namespace tile_program +} // namespace ck diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 73c466c4c..4bb59d79f 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -94,6 +94,17 @@ struct BlockFmhaPipelineQXCustomPolicy { return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; } + else if constexpr(Problem::kIsFp8) + { + constexpr index_t swizzle_factor = 4; // TODO: hard coded here + return warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< + warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< + typename Problem::QDataType, + typename Problem::KDataType>, + 2, + swizzle_factor>>{}; + } }(); using BlockGemmPolicy = @@ -201,6 +212,17 @@ struct BlockFmhaPipelineQXCustomPolicy { return warp::WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; } + else if constexpr(Problem::kIsFp8) + { + constexpr index_t swizzle_factor = 4; // TODO: hard coded here + return warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< + warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< + typename Problem::QDataType, + typename Problem::KDataType>, + 2, + swizzle_factor>>{}; + } }(); using BlockGemmPolicy = @@ -747,6 +769,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, @@ -849,14 +872,35 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy>; - using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<0>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<2>{}), - true>; + auto warp_gemm = [&]() { + if constexpr(Problem::kIsFp8) + { + return warp::WarpGemmImpl< + warp::WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< + warp::WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< + typename Problem::PDataType, + typename Problem::VDataType>, + 2>>{}; + // return + // warp::WarpGemmImpl>>{}; + } + else + { + return ck::tile_program::warp::WarpGemmMfmaDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::At(Number<2>{}), + true>{}; + } + }(); + + using WarpGemm = remove_cvref_t; + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy +__device__ auto cast_tile_pk_fp8x4(const InDstrTensors& in_dstr_tensors) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // This API is designed to use the _pk_ serious of function + constexpr auto in_tile_dstr = InDstrTensors::GetTileDistribution(); + + constexpr index_t thread_buffer_size = InDstrTensors::GetThreadBufferSize(); + static_assert(thread_buffer_size % 4 == 0); + constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4; + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wuninitialized" + // __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and + // will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA + // so we prepare an uninitialized variable purposely, and turn off the warning + int dummy_old; + static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) { + uint32_t x = + __builtin_amdgcn_cvt_pk_fp8_f32(in_dstr_tensors.GetThreadBuffer()[Number<4 * i + 0>{}], + in_dstr_tensors.GetThreadBuffer()[Number<4 * i + 1>{}], + dummy_old, + false); // false -> WORD0 + + uint32_t y = + __builtin_amdgcn_cvt_pk_fp8_f32(in_dstr_tensors.GetThreadBuffer()[Number<4 * i + 2>{}], + in_dstr_tensors.GetThreadBuffer()[Number<4 * i + 3>{}], + dummy_old, + false); // false -> WORD0 + + constexpr int32_t m0 = 0x05040100; + using vec_t = typename vector_type::type; + + vec_t d = bit_cast(__builtin_amdgcn_perm(y, x, m0)); + out_dstr_tensor.GetThreadBuffer().template SetAsType(Number<4 * i>{}, d); + }); +#pragma clang diagnostic pop + + return out_dstr_tensor; +#else + // fallback + return tile_elementwise_in(type_convert, + in_dstr_tensors); +#endif +} + template __device__ auto cast_tile(const SrcDstrTensors& src_tensor) { - return tile_elementwise_in(type_convert, - src_tensor); + if constexpr((ck::is_same_v || + ck::is_same_v)&&ck::is_same_v && + (SrcDstrTensors::GetThreadBufferSize() % 4 == 0)) + { + return cast_tile_pk_fp8x4(src_tensor); + } + else + return tile_elementwise_in(type_convert, + src_tensor); } // no-op function for NullTensor arguments diff --git a/include/ck/tile_program/tile/tile_fmha_shape.hpp b/include/ck/tile_program/tile/tile_fmha_shape.hpp index acedc2d07..88d3c0a2b 100644 --- a/include/ck/tile_program/tile/tile_fmha_shape.hpp +++ b/include/ck/tile_program/tile/tile_fmha_shape.hpp @@ -16,7 +16,7 @@ template + bool IsVLayoutRowMajor_> struct TileFmhaShape { using BlockTile = remove_cvref_t; @@ -41,7 +41,11 @@ struct TileFmhaShape // once (or repeately load Q as a whole tile) static_assert(kK0BlockLength % kK0 == 0, "kK0BlockLength should be divisible by kK0"); - using VLayout = remove_cvref_t; // rowmajor : seqlen*hdim, colmajor : hdim*seqlen + // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen + static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; + using VLayout = std::conditional_t; }; } // namespace tile_program diff --git a/include/ck/tile_program/warp_tile/warp_gemm.hpp b/include/ck/tile_program/warp_tile/warp_gemm.hpp index f08e24631..112b9d622 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm.hpp @@ -81,6 +81,31 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; +// fp8 +using WarpGemmMfma_f32_32x32x16_fp8_fp8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_fp8_bf8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_bf8_fp8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_bf8_bf8 = + WarpGemmImpl>; + +using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + +using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl< + WarpGemmAtrributeMfmaTransposedCDistribution>; + } // namespace warp } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp index 6e98e9115..85cabac37 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp @@ -207,6 +207,67 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution } }; +template +struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::BDataType; + using BDataType = typename Impl::ADataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::BVecType; + using BVecType = typename Impl::AVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK; + + using AWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, Sequence>, + Tuple>, + Tuple>, + Sequence<2>, + Sequence<1>>; + + using BWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2>, + Sequence<1>>; + + using CWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2, 2>, + Sequence<0, 2>>; + + // c_vec += a_vec * b_vec + __device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + // swap A and B + Impl{}(c_vec, b_vec, a_vec); + } + + // c_vec = a_vec * b_vec + __device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + // swap A and B + return Impl{}(b_vec, a_vec); + } +}; + template struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution { @@ -287,7 +348,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution } }; -template +template struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB { using Impl = remove_cvref_t; @@ -301,9 +362,10 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB using BVecType = typename vector_type_maker::type::type; using CVecType = typename Impl::CVecType; - static constexpr index_t kM = Impl::kN; - static constexpr index_t kN = Impl::kM; - static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together using AWarpDstrEncoding = StaticTileDistributionEncoding< Sequence<>, @@ -312,7 +374,7 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB Tuple>, Sequence<2>, Sequence<1>>; - +#if 0 using BWarpDstrEncoding = StaticTileDistributionEncoding< Sequence<>, Tuple>, Sequence<2, 2>, Sequence<0, 2>>; +#else + using BWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2>, + Sequence<1>>; + using CWarpDstrEncoding = StaticTileDistributionEncoding< + Sequence<>, + Tuple, + Sequence>, + Tuple>, + Tuple>, + Sequence<2, 2>, + Sequence<0, 2>>; +#endif // c_vec += a_vec * b_vec __device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { diff --git a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp index 72431c802..0a2badda6 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma_impl.hpp @@ -159,6 +159,88 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 } }; +// FP8 +template +struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base +{ + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = typename vector_type::type; + using BVecType = typename vector_type::type; + using CVecType = typename vector_type::type; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + __device__ void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); +#else + vector_type a_(a_vec); + vector_type b_(b_vec); + + static_for<0, 8, 1>{}([&](auto k) { + float a_f32 = type_convert(a_.template AsType()[Number{}]); + float b_f32 = type_convert(b_.template AsType()[Number{}]); + + c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); + }); +#endif + } + + // c_vec = a_vec * b_vec + __device__ CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + else if constexpr(is_same_v && is_same_v) + return __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0); + } +}; + +using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + } // namespace warp } // namespace tile_program } // namespace ck diff --git a/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp b/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp index 68f2255b5..2c9e3089d 100644 --- a/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp +++ b/include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp @@ -40,6 +40,17 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; + +// fp8 +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; + // clang-format on } // namespace impl diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index c184e9729..5220d8dd3 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -65,7 +65,10 @@ struct buffer_load<16> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 16); + // using dummy_vector = vector_type; + // using dummy_vector = StaticallyIndexedArray; asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" + // : "+v"(reinterpret_cast(value)) : "+v"(value) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory"); @@ -84,8 +87,10 @@ struct buffer_load<8> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 8); + // using dummy_vector = vector_type; + using dummy_vector = float __attribute__((ext_vector_type(2))); asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" - : "+v"(value) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory"); } @@ -103,8 +108,10 @@ struct buffer_load<4> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); + // using dummy_vector = vector_type; + using dummy_vector = float __attribute__((ext_vector_type(1))); asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" - : "+v"(value) + : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory"); } @@ -596,6 +603,36 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w return tmp.AsType()(Number<0>{}); } + else if constexpr(N == 16) + { + vector_type tmp; + + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + tmp.AsType()(Number<2>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(float), + static_cast(coherence)); + + tmp.AsType()(Number<3>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(float), + static_cast(coherence)); + + return tmp.AsType()(Number<0>{}); + } } else if constexpr(is_same::value) // fp16 { @@ -712,19 +749,13 @@ __device__ void amd_async_buffer_load_impl(T* smem, index_t src_wave_addr_offset, index_t src_immediate_addr_offset = 0) { - static_assert( - (is_same::value && (N == 1)) || (is_same::value && (N == 2)) || - (is_same::value && (N == 2)) || (is_same::value && (N == 1)) || - (is_same::value && (N == 4)), - "wrong! not implemented"); - if constexpr(sizeof(T) * N == 4) - { - async_buffer_load_dword(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset); - } + static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + + async_buffer_load_dword(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset); } template diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 51d6f7a30..b85ed239c 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -5,13 +5,19 @@ rm -rf CMakeFiles MY_PROJECT_SOURCE=$1 +if [ $# -ge 2 ] ; then + GPU_TARGETS=$2 +else + GPU_TARGETS="gfx908;gfx90a;gfx940" +fi + cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ --D GPU_TARGETS="gfx908;gfx90a;gfx940" \ +-D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE} diff --git a/script/cmake-ck-release.sh b/script/cmake-ck-release.sh index 787eabbf9..25ccb5c79 100755 --- a/script/cmake-ck-release.sh +++ b/script/cmake-ck-release.sh @@ -5,13 +5,19 @@ rm -rf CMakeFiles MY_PROJECT_SOURCE=$1 +if [ $# -ge 2 ] ; then + GPU_TARGETS=$2 +else + GPU_TARGETS="gfx908;gfx90a;gfx940" +fi + cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_FLAGS="-O3" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=OFF \ --D GPU_TARGETS="gfx908;gfx90a;gfx940" \ +-D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE} From 97997b678511c6ac96b769a457f7a095a361dea1 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Tue, 23 Jan 2024 18:56:43 +0000 Subject: [PATCH 36/45] Fix wrong arg order of transform_tensor_view() --- .../91_tile_program/fmha/fmha_fwd_kernel.hpp | 55 ++----------------- 1 file changed, 6 insertions(+), 49 deletions(-) diff --git a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp index b49fde465..f8a100f50 100644 --- a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp @@ -463,58 +463,15 @@ struct FmhaFwdKernel const auto v_dram_transposed = transform_tensor_view(v_dram_naive, - make_tuple(make_pass_through_transform(kargs.seqlen_k), - make_pass_through_transform(kargs.hdim_v)), + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - /// FIXME: The return value of v_dram_naive.GetTensorDescriptor().GetLength() is - /// same as - /// v_dram_transposed.GetTensorDescriptor().GetLength(). Replace following - /// if-clause by pad_tensor_view() call after fixing this issue. - if constexpr(kK0N1NeedPadding || kN0K1NeedPadding) - { - const auto transform_n1 = [&] { - if constexpr(kK0N1NeedPadding) - { - const index_t n1_pad_length = - FmhaPipeline::kN1 * - ck::math::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1) - - kargs.hdim_v; - - return make_right_pad_transform(kargs.hdim_v, n1_pad_length); - } - else - { - return make_pass_through_transform(kargs.hdim_v); - } - }(); - - const auto transform_k1 = [&] { - if constexpr(kN0K1NeedPadding) - { - const index_t k1_pad_length = - FmhaPipeline::kK1 * ck::math::integer_divide_ceil( - kargs.seqlen_k, FmhaPipeline::kK1) - - kargs.seqlen_k; - - return make_right_pad_transform(kargs.seqlen_k, k1_pad_length); - } - else - { - return make_pass_through_transform(kargs.seqlen_k); - } - }(); - - return transform_tensor_view(v_dram_transposed, - make_tuple(transform_n1, transform_k1), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - return v_dram_transposed; - } + return pad_tensor_view( + v_dram_transposed, + make_tuple(Number{}, Number{}), + Sequence{}); } else { From 5b6b5df6a8ec602b09dc78d4e5b90b9433a23529 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 29 Jan 2024 11:52:17 +0800 Subject: [PATCH 37/45] Validate m values before use them (#75) * Validate m values before use it * Ensure closure return type is same as param * Replace Lowest() to Infinity() calls * Fix format * Update all the pipelines * Only validate m if FmhaMask::IsMasking is true --- example/91_tile_program/fmha/CMakeLists.txt | 3 ++ .../block_fmha_pipeline_qr_ks_vs.hpp | 29 ++++++++++++----- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 31 ++++++++++++------ .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 32 +++++++++++++------ .../block_fmha_pipeline_qs_ks_vs.hpp | 30 ++++++++++++----- 5 files changed, 91 insertions(+), 34 deletions(-) diff --git a/example/91_tile_program/fmha/CMakeLists.txt b/example/91_tile_program/fmha/CMakeLists.txt index aaa73fcfd..a2255a025 100644 --- a/example/91_tile_program/fmha/CMakeLists.txt +++ b/example/91_tile_program/fmha/CMakeLists.txt @@ -35,4 +35,7 @@ else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_FMHA_FWD_FAST_EXP2=0) endif() +# Allow comparing floating points directly in order to check sentinel values +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) + target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 0458d07ec..ae4ae8ce5 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -155,7 +155,7 @@ struct BlockFmhaPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, NumericLimits::Lowest()); + set_tile(m, -NumericLimits::Infinity()); clear_tile(l); const auto q_origin = q_dram_window.GetWindowOrigin(); @@ -315,7 +315,7 @@ struct BlockFmhaPipelineQRKSVS s, Sequence<1>{}, f_max, - NumericLimits::Lowest()); // m_local = rowmax(S{j}) + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} @@ -325,11 +325,24 @@ struct BlockFmhaPipelineQRKSVS auto p_compute = make_static_distributed_tensor( s.GetTileDistribution()); // Pcompute{j} + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + if constexpr(FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 - auto row_max = scale * m[i_idx]; + auto row_max = scale * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -340,10 +353,10 @@ struct BlockFmhaPipelineQRKSVS } else { - p_compute(i_j_idx) = math::exp2(s[i_j_idx] - m[i_idx]); + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } #else - p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif }); }); @@ -360,16 +373,16 @@ struct BlockFmhaPipelineQRKSVS const auto tmp = [&]() { if constexpr(is_null_tile_window(bias_dram_window)) { - auto row_max = scale * m[i_idx]; + auto row_max = scale * get_validated_m(m[i_idx]); return math::exp2(scale * m_old[i_idx] - row_max); } else { - return math::exp2(m_old[i_idx] - m[i_idx]); + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } }(); #else - const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); #endif l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index e98c19649..ea85bce59 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -43,7 +43,7 @@ struct BlockFmhaPipelineQRKSVSAsync using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); - static constexpr bool kIsFp8 = Problem::kIsFp8; + static constexpr bool kIsFp8 = Problem::kIsFp8; static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -188,7 +188,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, NumericLimits::Lowest()); + set_tile(m, -NumericLimits::Infinity()); clear_tile(l); __builtin_amdgcn_sched_barrier(0); @@ -352,7 +352,7 @@ struct BlockFmhaPipelineQRKSVSAsync s, Sequence<1>{}, f_max, - NumericLimits::Lowest()); // m_local = rowmax(S{j}) + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} @@ -394,11 +394,24 @@ struct BlockFmhaPipelineQRKSVSAsync } __builtin_amdgcn_sched_barrier(0); + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + if constexpr(FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 - auto row_max = scale * m[i_idx]; + auto row_max = scale * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -409,10 +422,10 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - p_compute(i_j_idx) = math::exp2(s[i_j_idx] - m[i_idx]); + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } #else - p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif }); }); @@ -429,16 +442,16 @@ struct BlockFmhaPipelineQRKSVSAsync const auto tmp = [&]() { if constexpr(is_null_tile_window(bias_dram_window)) { - auto row_max = scale * m[i_idx]; + auto row_max = scale * get_validated_m(m[i_idx]); return math::exp2(scale * m_old[i_idx] - row_max); } else { - return math::exp2(m_old[i_idx] - m[i_idx]); + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } }(); #else - const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); #endif l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 266b45198..27ec89c17 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -42,7 +42,8 @@ struct BlockFmhaPipelineQRKSVSFp8 using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once - static constexpr bool kIsFp8 = Problem::kIsFp8; + static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -146,7 +147,7 @@ struct BlockFmhaPipelineQRKSVSFp8 auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, NumericLimits::Lowest()); + set_tile(m, -NumericLimits::Infinity()); clear_tile(l); const auto q_origin = q_dram_window.GetWindowOrigin(); @@ -308,7 +309,7 @@ struct BlockFmhaPipelineQRKSVSFp8 s, Sequence<1>{}, f_max, - NumericLimits::Lowest()); // m_local = rowmax(S{j}) + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} @@ -318,11 +319,24 @@ struct BlockFmhaPipelineQRKSVSFp8 auto p_compute = make_static_distributed_tensor( s.GetTileDistribution()); // Pcompute{j} + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + if constexpr(FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 - auto row_max = scale * m[i_idx]; + auto row_max = scale * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -333,10 +347,10 @@ struct BlockFmhaPipelineQRKSVSFp8 } else { - p_compute(i_j_idx) = math::exp2(s[i_j_idx] - m[i_idx]); + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } #else - p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif }); }); @@ -353,16 +367,16 @@ struct BlockFmhaPipelineQRKSVSFp8 const auto tmp = [&]() { if constexpr(is_null_tile_window(bias_dram_window)) { - auto row_max = scale * m[i_idx]; + auto row_max = scale * get_validated_m(m[i_idx]); return math::exp2(scale * m_old[i_idx] - row_max); } else { - return math::exp2(m_old[i_idx] - m[i_idx]); + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } }(); #else - const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); #endif l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 5ce44dd2f..9b84099c3 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -43,6 +43,7 @@ struct BlockFmhaPipelineQSKSVS using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = false; static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kIsFp8 = Problem::kIsFp8; static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -158,7 +159,7 @@ struct BlockFmhaPipelineQSKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, NumericLimits::Lowest()); + set_tile(m, -NumericLimits::Infinity()); clear_tile(l); const auto q_origin = q_dram_block_window_tmp.GetWindowOrigin(); @@ -324,7 +325,7 @@ struct BlockFmhaPipelineQSKSVS s, Sequence<1>{}, f_max, - NumericLimits::Lowest()); // m_local = rowmax(S{j}) + -NumericLimits::Infinity()); // m_local = rowmax(S{j}) block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} @@ -334,11 +335,24 @@ struct BlockFmhaPipelineQSKSVS auto p_compute = make_static_distributed_tensor( s.GetTileDistribution()); // Pcompute{j} + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + if constexpr(FmhaMask::IsMasking) + { + return raw_m == -NumericLimits::Infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + constexpr auto p_spans = decltype(p_compute)::GetDistributedSpans(); sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 - auto row_max = scale * m[i_idx]; + auto row_max = scale * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -349,10 +363,10 @@ struct BlockFmhaPipelineQSKSVS } else { - p_compute(i_j_idx) = math::exp2(s[i_j_idx] - m[i_idx]); + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } #else - p_compute(i_j_idx) = math::exp(s[i_j_idx] - m[i_idx]); + p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif }); }); @@ -369,16 +383,16 @@ struct BlockFmhaPipelineQSKSVS const auto tmp = [&]() { if constexpr(is_null_tile_window(bias_dram_window)) { - auto row_max = scale * m[i_idx]; + auto row_max = scale * get_validated_m(m[i_idx]); return math::exp2(scale * m_old[i_idx] - row_max); } else { - return math::exp2(m_old[i_idx] - m[i_idx]); + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } }(); #else - const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); + const auto tmp = math::exp(m_old[i_idx] - get_validated_m(m[i_idx])); #endif l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) { From 9a302e68c1df1c03441c0b00ba0bd5240142c746 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 29 Jan 2024 11:53:48 +0800 Subject: [PATCH 38/45] Rename & separate TileFmhaTraits<> padding flags for better comprehension (#77) * Rename TileFmhaTraits<> padding flags * Rename FmhaFwdKernel padding attributes * Separate kPadHeadDimQV into kPadHeadDimQ/V --- .../91_tile_program/fmha/fmha_fwd_kernel.hpp | 33 ++++++++++--------- example/91_tile_program/fmha/generate.py | 30 +++++++++-------- .../block_fmha_pipeline_problem.hpp | 13 ++++---- .../block_fmha_pipeline_qr_ks_vs.hpp | 15 +++++---- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 15 +++++---- .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 15 +++++---- .../block_fmha_pipeline_qs_ks_vs.hpp | 15 +++++---- .../ck/tile_program/tile/tile_fmha_traits.hpp | 20 ++++++----- 8 files changed, 84 insertions(+), 72 deletions(-) diff --git a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp index a0fac129b..4b2c6d09f 100644 --- a/example/91_tile_program/fmha/fmha_fwd_kernel.hpp +++ b/example/91_tile_program/fmha/fmha_fwd_kernel.hpp @@ -34,14 +34,15 @@ struct FmhaFwdKernel using VLayout = ck::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kM0NeedPadding = FmhaPipeline::kM0NeedPadding; - static constexpr bool kN0K1NeedPadding = FmhaPipeline::kN0K1NeedPadding; - static constexpr bool kK0N1NeedPadding = FmhaPipeline::kK0N1NeedPadding; - static constexpr bool kHasBias = FmhaPipeline::kHasBias; - static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; - using FmhaMask = ck::remove_cvref_t; - static constexpr bool kHasMask = FmhaMask::IsMasking; + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + using FmhaMask = ck::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; template // to avoid duplicated base class prblem, introduce an template arg struct FmhaFwdEmptyKargs @@ -453,14 +454,14 @@ struct FmhaFwdKernel return pad_tensor_view( q_dram_naive, make_tuple(Number{}, Number{}), - Sequence{}); + Sequence{}); } else { return pad_tensor_view( q_dram_naive, make_tuple(Number{}, Number{}), - Sequence{}); + Sequence{}); } }(); const auto k_dram = [&]() { @@ -474,7 +475,7 @@ struct FmhaFwdKernel return pad_tensor_view( k_dram_naive, make_tuple(Number{}, Number{}), - Sequence{}); + Sequence{}); }(); const auto v_dram = [&]() { if constexpr(ck::is_same_v) @@ -496,7 +497,7 @@ struct FmhaFwdKernel return pad_tensor_view( v_dram_transposed, make_tuple(Number{}, Number{}), - Sequence{}); + Sequence{}); } else { @@ -510,7 +511,7 @@ struct FmhaFwdKernel return pad_tensor_view( v_dram_naive, make_tuple(Number{}, Number{}), - Sequence{}); + Sequence{}); } }(); @@ -555,7 +556,7 @@ struct FmhaFwdKernel return pad_tensor_view(bias_dram_naive, bias_dram_window_lengths, - Sequence{}); + Sequence{}); }(); return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); @@ -584,7 +585,7 @@ struct FmhaFwdKernel Number<1>{}); return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, Sequence{}); + lse_dram_naive, lse_dram_window_lengths, Sequence{}); }(); return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); @@ -641,7 +642,7 @@ struct FmhaFwdKernel return pad_tensor_view( o_dram_naive, make_tuple(Number{}, Number{}), - Sequence{}); + Sequence{}); }(); auto o_dram_window = diff --git a/example/91_tile_program/fmha/generate.py b/example/91_tile_program/fmha/generate.py index d3e05ff12..f4594639b 100644 --- a/example/91_tile_program/fmha/generate.py +++ b/example/91_tile_program/fmha/generate.py @@ -66,9 +66,10 @@ fmha_warp_tile_{F_idx}, {F_vlayout}>; -using fmha_trait_{F_idx} = ck::tile_program::TileFmhaTraits<{F_m0pad}, - {F_m0k1pad}, - {F_k0n1pad}, +using fmha_trait_{F_idx} = ck::tile_program::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, {F_bias}, {F_lse}, {F_occupancy}>; @@ -224,9 +225,10 @@ class FmhaFwdKernel: F_dtype : str # data type F_tile : FmhaFwdTileSize F_vlayout : str # row/col - F_m0pad : str # true/false - F_m0k1pad : str # - F_k0n1pad : str # + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # F_bias : str # true/false F_lse : str # F_mask : str # value from MASK_MAP @@ -253,9 +255,10 @@ def template(self) -> str: F_wn = self.F_tile.F_wn, F_wk = self.F_tile.F_wk, F_vlayout = LAYOUT_MAP[self.F_vlayout], - F_m0pad = BOOL_MAP[self.F_m0pad], - F_m0k1pad = BOOL_MAP[self.F_m0k1pad], - F_k0n1pad = BOOL_MAP[self.F_k0n1pad], + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], F_bias = BOOL_MAP[self.F_bias], F_lse = BOOL_MAP[self.F_lse], F_occupancy = self.F_tile.F_occupancy , @@ -267,8 +270,8 @@ def template(self) -> str: def name(self) -> str: # TODO: we don't encode idx here return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f"_v{self.F_vlayout[0]}" +\ - f"_p{BOOL_MAP[self.F_m0pad][0]}{BOOL_MAP[self.F_m0k1pad][0]}{BOOL_MAP[self.F_k0n1pad][0]}_{BOOL_MAP[self.F_bias][0]}" +\ - f"_m{self.F_mask[0]}_l{BOOL_MAP[self.F_lse][0]}_{self.F_pipeline}" + f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\ + f"_{BOOL_MAP[self.F_bias][0]}_m{self.F_mask[0]}_l{BOOL_MAP[self.F_lse][0]}_{self.F_pipeline}" @property def filename(self) -> str: @@ -339,8 +342,9 @@ def get_pad(dtype, hdim): if dtype in ['fp8', 'bf8'] and lse == "t": continue k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, F_vlayout=get_vlayout(dtype, hdim), - F_m0pad=get_pad(dtype, hdim), F_m0k1pad=get_pad(dtype, hdim), F_k0n1pad=get_pad(dtype, hdim), - F_bias=bias, F_lse=lse, F_mask=mask, F_mode=mode, F_pipeline=get_pipeline(dtype, hdim)) + F_spad=get_pad(dtype, hdim), F_skpad=get_pad(dtype, hdim), F_dpad=get_pad(dtype, hdim), + F_dvpad=get_pad(dtype, hdim), F_bias=bias, F_lse=lse, F_mask=mask, F_mode=mode, + F_pipeline=get_pipeline(dtype, hdim)) api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp index 56f2ba9d9..a00163fe9 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp @@ -46,12 +46,13 @@ struct BlockFmhaPipelineProblem static constexpr bool kIsGroupMode = kIsGroupMode_; // attributes from traits - static constexpr bool kM0NeedPadding = Traits::kM0NeedPadding; - static constexpr bool kN0K1NeedPadding = Traits::kN0K1NeedPadding; - static constexpr bool kK0N1NeedPadding = Traits::kK0N1NeedPadding; - static constexpr bool kHasBias = Traits::kHasBias; - static constexpr bool kStoreLSE = Traits::kStoreLSE; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasBias = Traits::kHasBias; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr bool kIsFp8 = (is_same_v || is_same_v)&&( is_same_v || diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index ae4ae8ce5..2e6d6f7e5 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -55,12 +55,13 @@ struct BlockFmhaPipelineQRKSVS static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; - static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; - static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; - static constexpr bool kHasBias = Problem::kHasBias; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; __host__ __device__ static constexpr ck::index_t GetSmemSize() { @@ -292,7 +293,7 @@ struct BlockFmhaPipelineQRKSVS bias_tile); } move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kN0K1NeedPadding || FmhaMask::IsMasking) + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { const auto k_origin = k_dram_block_window.GetWindowOrigin(); bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index ea85bce59..cf46276fa 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -55,12 +55,13 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; - static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; - static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; - static constexpr bool kHasBias = Problem::kHasBias; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; #if CK_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / math::log2e_v; @@ -329,7 +330,7 @@ struct BlockFmhaPipelineQRKSVSAsync bias_tile); } move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kN0K1NeedPadding || FmhaMask::IsMasking) + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { const auto k_origin = k_dram_block_window.GetWindowOrigin(); bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 27ec89c17..715597dc1 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -55,12 +55,13 @@ struct BlockFmhaPipelineQRKSVSFp8 static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; - static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; - static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; - static constexpr bool kHasBias = Problem::kHasBias; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; __host__ __device__ static constexpr ck::index_t GetSmemSize() { @@ -286,7 +287,7 @@ struct BlockFmhaPipelineQRKSVSFp8 bias_tile); } move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kN0K1NeedPadding || FmhaMask::IsMasking) + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { const auto k_origin = k_dram_block_window.GetWindowOrigin(); bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 9b84099c3..c3ebf042d 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -55,12 +55,13 @@ struct BlockFmhaPipelineQSKSVS static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kM0NeedPadding = Problem::kM0NeedPadding; - static constexpr bool kN0K1NeedPadding = Problem::kN0K1NeedPadding; - static constexpr bool kK0N1NeedPadding = Problem::kK0N1NeedPadding; - static constexpr bool kHasBias = Problem::kHasBias; - static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasBias = Problem::kHasBias; + static constexpr bool kStoreLSE = Problem::kStoreLSE; __host__ __device__ static constexpr ck::index_t GetSmemSize() { @@ -302,7 +303,7 @@ struct BlockFmhaPipelineQSKSVS bias_tile); } move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kN0K1NeedPadding || FmhaMask::IsMasking) + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { const auto k_origin = k_dram_block_window.GetWindowOrigin(); bool need_perpixel_check = mask.IsEdgeTile(q_origin.At(Number<0>{}), diff --git a/include/ck/tile_program/tile/tile_fmha_traits.hpp b/include/ck/tile_program/tile/tile_fmha_traits.hpp index 33fbcafee..2bced1feb 100644 --- a/include/ck/tile_program/tile/tile_fmha_traits.hpp +++ b/include/ck/tile_program/tile/tile_fmha_traits.hpp @@ -8,20 +8,22 @@ namespace ck { namespace tile_program { -template struct TileFmhaTraits { - static constexpr bool kM0NeedPadding = kM0NeedPadding_; - static constexpr bool kN0K1NeedPadding = kN0K1NeedPadding_; - static constexpr bool kK0N1NeedPadding = kK0N1NeedPadding_; - static constexpr bool kHasBias = kHasBias_; - static constexpr bool kStoreLSE = kStoreLSE_; - static constexpr index_t kBlockPerCu = kBlockPerCu_; + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; }; } // namespace tile_program From 52b621ecf3533514031670dd99b6f2059832baaa Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 30 Jan 2024 14:52:27 +0800 Subject: [PATCH 39/45] Feature fMHA generic mask / bias issues (#80) * Allow setting random seed for uniform filler * Add seed option to fmha example * Add normal distribution tensor filler * Align tensor init method with xformers * Fix variable type * Validate m if we have bias tensor * Add comment to explain why we validate m under bias mode * Remove blank line * Support no seed scenario for fillers * Do not apply random seed if user set seed=0 --- example/91_tile_program/fmha/fmha_fwd.cpp | 29 +++++---- .../block_fmha_pipeline_qr_ks_vs.hpp | 42 ++++++------- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 38 ++++++------ .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 36 +++++------ .../block_fmha_pipeline_qs_ks_vs.hpp | 42 ++++++------- library/include/ck/library/utility/fill.hpp | 60 ++++++++++++++++++- 6 files changed, 160 insertions(+), 87 deletions(-) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index 058e5d76b..b22e310d8 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -68,7 +68,11 @@ auto create_args(int argc, char* argv[]) "'g:y,x', generic attention mask coordinate with y/x size\n") .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") - .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float"); + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") + .insert("seed", + "0", + "random seed used for initializing input tensors. 0 to use " + "non-deterministic random number as seed"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -144,7 +148,12 @@ bool run(const ArgParser& arg_parser) mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); - int init_method = arg_parser.get_int("init"); + int init_method = arg_parser.get_int("init"); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } int stream_warmup = env_get_int("CK_WARMUP", 5); int stream_repeat = env_get_int("CK_REPEAT", 20); @@ -234,17 +243,17 @@ bool run(const ArgParser& arg_parser) if(init_method == 0) { - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(q_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(k_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(v_host); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(bias_host); + ck::utils::FillNormalDistributionIntegerValue{0.f, 1.f, seed}(q_host); + ck::utils::FillNormalDistributionIntegerValue{0.f, 1.f, seed}(k_host); + ck::utils::FillNormalDistributionIntegerValue{0.f, 1.f, seed}(v_host); + ck::utils::FillNormalDistributionIntegerValue{0.f, 1.f, seed}(bias_host); } else if(init_method == 1) { - ck::utils::FillUniformDistribution{0.f, 1.f}(q_host); - ck::utils::FillUniformDistribution{0.f, 1.f}(k_host); - ck::utils::FillUniformDistribution{-.5f, .5f}(v_host); - ck::utils::FillUniformDistribution{0.f, 1.f}(bias_host); + ck::utils::FillNormalDistribution{0.f, 1.f, seed}(q_host); + ck::utils::FillNormalDistribution{0.f, 1.f, seed}(k_host); + ck::utils::FillNormalDistribution{0.f, 1.f, seed}(v_host); + ck::utils::FillNormalDistribution{0.f, 1.f, seed}(bias_host); } else if(init_method == 2) { diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 2e6d6f7e5..7e0819d3f 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -220,13 +220,13 @@ struct BlockFmhaPipelineQRKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(!is_null_tile_window(bias_dram_window)) + if constexpr(!kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(!is_null_tile_window(bias_dram_window)) + if constexpr(!kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -272,13 +272,7 @@ struct BlockFmhaPipelineQRKSVS } // STAGE 2, scale, add bias, mask, softmax - if constexpr(is_null_tile_window(bias_dram_window)) - { -#if !CK_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); -#endif - } - else + if constexpr(kHasBias) { tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -292,6 +286,12 @@ struct BlockFmhaPipelineQRKSVS s_acc, bias_tile); } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -327,7 +327,9 @@ struct BlockFmhaPipelineQRKSVS s.GetTileDistribution()); // Pcompute{j} static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - if constexpr(FmhaMask::IsMasking) + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) { return raw_m == -NumericLimits::Infinity() ? type_convert(0.f) @@ -348,13 +350,13 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_FMHA_FWD_FAST_EXP2 - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); } #else p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -372,14 +374,14 @@ struct BlockFmhaPipelineQRKSVS constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - auto row_max = scale * get_validated_m(m[i_idx]); - return math::exp2(scale * m_old[i_idx] - row_max); + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } else { - return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); } }(); #else @@ -463,13 +465,13 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(lse_spans[Number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); + lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); } else { - lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); + lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); } #else lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index cf46276fa..b8a488ab9 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -309,13 +309,7 @@ struct BlockFmhaPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(1); // STAGE 2, scale, add bias, mask, softmax - if constexpr(is_null_tile_window(bias_dram_window)) - { -#if !CK_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); -#endif - } - else + if constexpr(kHasBias) { tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -329,6 +323,12 @@ struct BlockFmhaPipelineQRKSVSAsync s_acc, bias_tile); } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -396,7 +396,9 @@ struct BlockFmhaPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(0); static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - if constexpr(FmhaMask::IsMasking) + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) { return raw_m == -NumericLimits::Infinity() ? type_convert(0.f) @@ -417,13 +419,13 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_FMHA_FWD_FAST_EXP2 - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); } #else p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -441,14 +443,14 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - auto row_max = scale * get_validated_m(m[i_idx]); - return math::exp2(scale * m_old[i_idx] - row_max); + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } else { - return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); } }(); #else @@ -544,13 +546,13 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(lse_spans[Number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - lse(i_idx) = m_[i_idx] * scale * R_LOG2E + math::log(l_[i_idx]); + lse(i_idx) = m_[i_idx] * R_LOG2E + math::log(l_[i_idx]); } else { - lse(i_idx) = m_[i_idx] * R_LOG2E + math::log(l_[i_idx]); + lse(i_idx) = m_[i_idx] * scale * R_LOG2E + math::log(l_[i_idx]); } #else lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 715597dc1..4eba3cbb0 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -215,13 +215,13 @@ struct BlockFmhaPipelineQRKSVSFp8 k_block_tile = load_tile(k_dram_window); } - if constexpr(!is_null_tile_window(bias_dram_window)) + if constexpr(!kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(!is_null_tile_window(bias_dram_window)) + if constexpr(!kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -266,13 +266,7 @@ struct BlockFmhaPipelineQRKSVSFp8 } // STAGE 2, scale, add bias, mask, softmax - if constexpr(is_null_tile_window(bias_dram_window)) - { -#if !CK_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); -#endif - } - else + if constexpr(kHasBias) { tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -286,6 +280,12 @@ struct BlockFmhaPipelineQRKSVSFp8 s_acc, bias_tile); } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -321,7 +321,9 @@ struct BlockFmhaPipelineQRKSVSFp8 s.GetTileDistribution()); // Pcompute{j} static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - if constexpr(FmhaMask::IsMasking) + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) { return raw_m == -NumericLimits::Infinity() ? type_convert(0.f) @@ -342,13 +344,13 @@ struct BlockFmhaPipelineQRKSVSFp8 sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_FMHA_FWD_FAST_EXP2 - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); } #else p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -366,14 +368,14 @@ struct BlockFmhaPipelineQRKSVSFp8 constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - auto row_max = scale * get_validated_m(m[i_idx]); - return math::exp2(scale * m_old[i_idx] - row_max); + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } else { - return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); } }(); #else diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp index c3ebf042d..37caa1c9e 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -234,13 +234,13 @@ struct BlockFmhaPipelineQSKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(!is_null_tile_window(bias_dram_window)) + if constexpr(!kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(!is_null_tile_window(bias_dram_window)) + if constexpr(!kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -282,13 +282,7 @@ struct BlockFmhaPipelineQSKSVS } // STAGE 2, scale, add bias, mask, softmax - if constexpr(is_null_tile_window(bias_dram_window)) - { -#if !CK_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); -#endif - } - else + if constexpr(kHasBias) { tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -302,6 +296,12 @@ struct BlockFmhaPipelineQSKSVS s_acc, bias_tile); } + else + { +#if !CK_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc); +#endif + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -337,7 +337,9 @@ struct BlockFmhaPipelineQSKSVS s.GetTileDistribution()); // Pcompute{j} static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - if constexpr(FmhaMask::IsMasking) + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(kHasBias || FmhaMask::IsMasking) { return raw_m == -NumericLimits::Infinity() ? type_convert(0.f) @@ -358,13 +360,13 @@ struct BlockFmhaPipelineQSKSVS sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_FMHA_FWD_FAST_EXP2 - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); + p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - p_compute(i_j_idx) = math::exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = math::exp2(scale * s[i_j_idx] - row_max); } #else p_compute(i_j_idx) = math::exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -382,14 +384,14 @@ struct BlockFmhaPipelineQSKSVS constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - auto row_max = scale * get_validated_m(m[i_idx]); - return math::exp2(scale * m_old[i_idx] - row_max); + return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } else { - return math::exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + auto row_max = scale * get_validated_m(m[i_idx]); + return math::exp2(scale * m_old[i_idx] - row_max); } }(); #else @@ -473,13 +475,13 @@ struct BlockFmhaPipelineQSKSVS sweep_tile_span(lse_spans[Number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_FMHA_FWD_FAST_EXP2 - if constexpr(is_null_tile_window(bias_dram_window)) + if constexpr(kHasBias) { - lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); + lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); } else { - lse(i_idx) = m_[i_idx] / C_LOG2E + math::log(l_[i_idx]); + lse(i_idx) = m_[i_idx] * scale / C_LOG2E + math::log(l_[i_idx]); } #else lse(i_idx) = m_[i_idx] + math::log(l_[i_idx]); diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index fef974810..d852e9c91 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -20,11 +21,12 @@ struct FillUniformDistribution { float a_{-5.f}; float b_{5.f}; + std::optional seed_{std::nullopt}; template void operator()(ForwardIter first, ForwardIter last) const { - std::mt19937 gen(11939); + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); } @@ -40,6 +42,32 @@ struct FillUniformDistribution } }; +template +struct FillNormalDistribution +{ + float mean_{0.f}; + float variance_{1.f}; + std::optional seed_{std::nullopt}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::normal_distribution dis(mean_, std::sqrt(variance_)); + std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + // Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. // However this produces segfaults in std::mt19937 which look like inifite loop. // template @@ -64,11 +92,12 @@ struct FillUniformDistributionIntegerValue { float a_{-5.f}; float b_{5.f}; + std::optional seed_{std::nullopt}; template void operator()(ForwardIter first, ForwardIter last) const { - std::mt19937 gen(11939); + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate( first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); @@ -85,6 +114,33 @@ struct FillUniformDistributionIntegerValue } }; +template +struct FillNormalDistributionIntegerValue +{ + float mean_{0.f}; + float variance_{1.f}; + std::optional seed_{std::nullopt}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::normal_distribution dis(mean_, std::sqrt(variance_)); + std::generate( + first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + template struct FillMonotonicSeq { From 1bed0e7d2247bba49578e23160e1f79cc31e468d Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Tue, 30 Jan 2024 07:19:20 +0000 Subject: [PATCH 40/45] Fallback changes for init=0 --- example/91_tile_program/fmha/fmha_fwd.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index b22e310d8..4798bd454 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -243,10 +243,10 @@ bool run(const ArgParser& arg_parser) if(init_method == 0) { - ck::utils::FillNormalDistributionIntegerValue{0.f, 1.f, seed}(q_host); - ck::utils::FillNormalDistributionIntegerValue{0.f, 1.f, seed}(k_host); - ck::utils::FillNormalDistributionIntegerValue{0.f, 1.f, seed}(v_host); - ck::utils::FillNormalDistributionIntegerValue{0.f, 1.f, seed}(bias_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); } else if(init_method == 1) { From 0d231a7809ea39f96c80e0d3aed73164ad321913 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Tue, 30 Jan 2024 13:53:25 +0000 Subject: [PATCH 41/45] Add back sched_barrier() in pipeline --- .../block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 4 ++-- .../block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp | 4 ++-- .../block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 7e0819d3f..c1c48a58d 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -220,13 +220,13 @@ struct BlockFmhaPipelineQRKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(!kHasBias) + if constexpr(kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(!kHasBias) + if constexpr(kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 4eba3cbb0..3e74e4058 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -215,13 +215,13 @@ struct BlockFmhaPipelineQRKSVSFp8 k_block_tile = load_tile(k_dram_window); } - if constexpr(!kHasBias) + if constexpr(kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(!kHasBias) + if constexpr(kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 37caa1c9e..3c94597aa 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -234,13 +234,13 @@ struct BlockFmhaPipelineQSKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(!kHasBias) + if constexpr(kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(!kHasBias) + if constexpr(kHasBias) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads From e914fa299651d2e669d382adfc7427b30b627631 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 30 Jan 2024 22:25:28 +0800 Subject: [PATCH 42/45] Fix README.md wording (#78) --- example/91_tile_program/fmha/README.md | 22 +++++++++++----------- example/91_tile_program/fmha/fmha_fwd.cpp | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/example/91_tile_program/fmha/README.md b/example/91_tile_program/fmha/README.md index b5dde4043..8b0d2521a 100644 --- a/example/91_tile_program/fmha/README.md +++ b/example/91_tile_program/fmha/README.md @@ -1,6 +1,6 @@ # fused multi-head attention -This folder contains example for fmha(fused multi-head attention) using ck tile-programming implementation. It is a good example to demostrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. +This folder contains example for fmha(fused multi-head attention) using ck tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. ## build ``` @@ -12,21 +12,21 @@ make example_fmha_fwd -j This will result in an executable `build/bin/example_fmha_fwd` ## kernel -The kernel template is `fmha_fwd_kernel.hpp`, this is the gridwise op in old ck's terminology. We put it here purposely, to demostrate one can construct a kernel by using various internal component from ck. We may still have an implementation under ck's include path (in the future) for the kernel template. +The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck. We may still have an implementation under ck's include path (in the future) for the kernel template. There are 3 template parameters for this kernel template. * `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. -* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed we did a lot of optimization and trials to optimize the pipeline, and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benifit of different performant implementations (stay tuned for updated pipeline(s)). -* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage, but leave the room for furture possible support. +* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). +* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. ## codegen -To speed up compile time, we instantiate the kernels into seperate file. In this way we can benifit from parallel building from cmake/make system. This is achieved by `generate.py` script. Beside, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. +To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. ## executable `example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/example_fmha_fwd -?` to list all supported args ``` args: - -v weather do cpu validation or not (default:1) + -v weather do CPU validation or not (default:1) -mode kernel mode. 0:batch, 1:group (default:0) -b batch size (default:2) -h num of head, for q (default:8) @@ -53,10 +53,10 @@ args: Example: `./bin/example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. ## support features -Currently we are still in rapid development stage, so more features/optimizations will comming soon. +Currently we are still in rapid development stage, so more features/optimizations will be coming soon. ### hdim -Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. We may consider optimize other hdim performance if have more request. We also have an experimental support for arbitrary hdim(even odd number), one can change the return value of `get_pad()` inside `generate.py` to achieve this. (Note: we may change the method or optimize arbitraty hdim support in the future) +Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. We may consider optimize other hdim performance if have more request. We also have an experimental support for arbitrary hdim(even odd number), one can change the return value of `get_pad()` inside `generate.py` to achieve this. (Note: we may change the method or optimize arbitrary hdim support in the future) ### group/batch mode Currently we support both batch and group mode, by setting `-mode` = `0` or `1`, where in group mode we support each batch can have different seqlen @@ -65,7 +65,7 @@ Currently we support both batch and group mode, by setting `-mode` = `0` or `1`, By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers. ### input/output permute, and `b*s*3*h*d` -If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prividing arbitraty stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't privide a cmd line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trival to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`. +If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`. ### attention bias Attention bias is supported with the layout of `b*h*s*s` and bias value in float number. @@ -74,7 +74,7 @@ Attention bias is supported with the layout of `b*h*s*s` and bias value in float For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` ### vlayout -We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimention for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimention, it's more easy to support col-major V layout. However the performance of col-major is not necessarily fasater than row-major, there are many factors that may affect the overall performance. We still privide the `-vlayout=r/c` here to switch/test between different layout. +We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts. ### generic attention mask coordinate We unify the mask expression into generic attention mask coordinate, providing an uniformed approach to describe causal top-left, causal bottom-right, local attention. @@ -86,5 +86,5 @@ We unify the mask expression into generic attention mask coordinate, providing a TBD ## FP8 experimental support -As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `example_fmha_fwd`, on a gfx940/941/942 machine and rocm 6.0+. Currently if you not explicitely setting `-v=0`(which will disable cpu verification), it will printout an error as much as `0.05`. We are still WIP to tune the kernel performance as well as the precision, so stay tuned for the updated performance(pipeline) +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `example_fmha_fwd`, on a gfx940/941/942 machine and ROCm 6.0+. Currently if you not explicitly setting `-v=0`(which will disable CPU verification), it will printout an error as much as `0.05`. We are still WIP to tune the kernel performance as well as the precision, so stay tuned for the updated performance(pipeline) Currently we only support `-vlayout=c` for fp8, which is `hdim*seqlen` for V matrix. row major for V matrix support will come later. diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index 4798bd454..d5334d78f 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -37,7 +37,7 @@ auto create_args(int argc, char* argv[]) { ArgParser arg_parser; - arg_parser.insert("v", "1", "weather do cpu validation or not") + arg_parser.insert("v", "1", "weather do CPU validation or not") .insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("b", "2", "batch size") .insert("h", "8", "num of head, for q") From d1adca3aced43ce8fb6f3b2fc8b852e71cf23f94 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Wed, 31 Jan 2024 10:27:44 +0000 Subject: [PATCH 43/45] restore init dist --- example/91_tile_program/fmha/fmha_fwd.cpp | 10 +++++----- library/include/ck/library/utility/fill.hpp | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index d5334d78f..8582691d6 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -70,7 +70,7 @@ auto create_args(int argc, char* argv[]) .insert("lse", "0", "0 not store lse, 1 store lse") .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") .insert("seed", - "0", + "11939", "random seed used for initializing input tensors. 0 to use " "non-deterministic random number as seed"); @@ -250,10 +250,10 @@ bool run(const ArgParser& arg_parser) } else if(init_method == 1) { - ck::utils::FillNormalDistribution{0.f, 1.f, seed}(q_host); - ck::utils::FillNormalDistribution{0.f, 1.f, seed}(k_host); - ck::utils::FillNormalDistribution{0.f, 1.f, seed}(v_host); - ck::utils::FillNormalDistribution{0.f, 1.f, seed}(bias_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck::utils::FillUniformDistribution{0.f, 1.f, seed}(bias_host); } else if(init_method == 2) { diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index d852e9c91..e6666fa51 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -21,7 +21,7 @@ struct FillUniformDistribution { float a_{-5.f}; float b_{5.f}; - std::optional seed_{std::nullopt}; + std::optional seed_{11939}; template void operator()(ForwardIter first, ForwardIter last) const @@ -47,7 +47,7 @@ struct FillNormalDistribution { float mean_{0.f}; float variance_{1.f}; - std::optional seed_{std::nullopt}; + std::optional seed_{11939}; template void operator()(ForwardIter first, ForwardIter last) const @@ -92,7 +92,7 @@ struct FillUniformDistributionIntegerValue { float a_{-5.f}; float b_{5.f}; - std::optional seed_{std::nullopt}; + std::optional seed_{11939}; template void operator()(ForwardIter first, ForwardIter last) const @@ -119,7 +119,7 @@ struct FillNormalDistributionIntegerValue { float mean_{0.f}; float variance_{1.f}; - std::optional seed_{std::nullopt}; + std::optional seed_{11939}; template void operator()(ForwardIter first, ForwardIter last) const From eb53e235c76e3da0374214221e94c45419b90bec Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Thu, 1 Feb 2024 13:25:14 +0800 Subject: [PATCH 44/45] Check padding boundary in GenericAttentionMask<>::IsEdgeTile() (#81) * Check padding edge in IsEdgeTile() * Rename variables in IsEdgeTile() * Rename template parameter & fix compilation error * Only check right boundary for now * Add "i_" prefix for indicating index variables --- .../tile_program/block_tile/block_masking.hpp | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/include/ck/tile_program/block_tile/block_masking.hpp b/include/ck/tile_program/block_tile/block_masking.hpp index c40c85bce..1e01310d8 100644 --- a/include/ck/tile_program/block_tile/block_masking.hpp +++ b/include/ck/tile_program/block_tile/block_masking.hpp @@ -149,21 +149,30 @@ struct GenericAttentionMask // otherwise no need to check per-pixel // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() // can be used as a fast-path to decide if do per-pixel check or not - template + template __host__ __device__ constexpr auto - IsEdgeTile(index_t i_y, index_t i_x, Number, Number) const + IsEdgeTile(index_t i_tile_top, index_t i_tile_left, Number, Number) const { if constexpr(IsLocal) { // check top-right corner > x or left-borrom corner < x - bool top_right_edge = (i_x + XTile) > (x + i_y); - bool bottom_left_edge = (i_y + YTile) > (y + i_x); - return top_right_edge || bottom_left_edge; + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = i_tile_top + TileHeight; + index_t x_end = math::min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > (i_tile_top + x); + bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); + bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge || is_partial_out_of_bound; } else { // only need to check top-right corner > x - bool top_right_edge = (i_x + XTile) > math::min(x + i_y, x_total); + index_t i_tile_right = i_tile_left + TileWidth; + index_t x_end = math::min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > x_end; return top_right_edge; } } From 3bda955fe6ca92cdd29691783ebb772ac13c857c Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Sun, 4 Feb 2024 14:59:58 +0800 Subject: [PATCH 45/45] Allow infinity reference value while checking LSE (#83) * Validate v_max before use in reference_batched_softmax() * Dump LSE even if nothing to do in pipeline * Allow infinity reference value for check_err() * Allow infinity reference value while checking LSE * Rename variable * Generate independent seqlens_k in group mode seqlens_k is no longer greater than corresponding element in seqlens_q * Use std::isinf() to check -INF value * Remove check for NaN * Do not clear buffer before use --- example/91_tile_program/fmha/fmha_fwd.cpp | 77 +++++++++++------ example/91_tile_program/fmha/utils.hpp | 86 +++---------------- .../reference/reference_batched_softmax.hpp | 12 ++- .../block_fmha_pipeline_qr_ks_vs.hpp | 9 ++ .../block_fmha_pipeline_qr_ks_vs_async.hpp | 9 ++ .../block_fmha_pipeline_qs_ks_vs.hpp | 9 ++ .../include/ck/library/utility/check_err.hpp | 80 ++++++++++++----- 7 files changed, 162 insertions(+), 120 deletions(-) diff --git a/example/91_tile_program/fmha/fmha_fwd.cpp b/example/91_tile_program/fmha/fmha_fwd.cpp index 8582691d6..1de17cfa8 100644 --- a/example/91_tile_program/fmha/fmha_fwd.cpp +++ b/example/91_tile_program/fmha/fmha_fwd.cpp @@ -143,8 +143,8 @@ bool run(const ArgParser& arg_parser) float descale_v = arg_parser.get_float("descale_v"); std::string vlayout = arg_parser.get_str("vlayout"); - bool use_bias = arg_parser.get_uint32("bias"); - bool lse = arg_parser.get_uint32("lse"); + bool use_bias = arg_parser.get_bool("bias"); + bool lse = arg_parser.get_bool("lse"); mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); @@ -160,9 +160,8 @@ bool run(const ArgParser& arg_parser) StreamConfig stream_config{nullptr, true, 0, stream_warmup, stream_repeat}; - const auto [seqlens_q, seqstart_q_host] = generate_seqlens_seqstarts_q(mode, batch, seqlen_q); - const std::vector seqstart_k_host = - generate_seqstarts_k(mode, batch, seqlen_k, seqlens_q, seqlen_q); + const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); using TypeConfig = FmhaFwdTypeConfig; @@ -394,45 +393,67 @@ bool run(const ArgParser& arg_parser) if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); } + // clang-format on // reference reference_batched_gemm( - q_host_ref, k_host_ref, s_host_ref, - ck::identity{}, ck::identity{}, + q_host_ref, + k_host_ref, + s_host_ref, + ck::identity{}, + ck::identity{}, [&](SaccDataType x) { return scale * x; }); if(use_bias) { Tensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off if(i_perm) bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); else bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); - - // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, real_seqlen_k] - reference_batched_elementwise( + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + reference_batched_elementwise( s_host_ref, bias_host_ref, s_host_ref); } - if(mask.type == mask_enum::no_mask) { - reference_batched_masking(s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); - } else if(mask.type == mask_enum::window_generic) { - reference_batched_masking(s_host_ref, - FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); - } else { - reference_batched_masking(s_host_ref, - FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + if(mask.type == mask_enum::no_mask) + { + reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + reference_batched_masking( + s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); + } + else + { + reference_batched_masking( + s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k}); } - if(lse){ - reference_batched_softmax(s_host_ref, p_host_ref, lse_host_ref); + if(lse) + { + reference_batched_softmax( + s_host_ref, p_host_ref, lse_host_ref); } - else{ - reference_batched_softmax(s_host_ref, p_host_ref); + else + { + reference_batched_softmax( + s_host_ref, p_host_ref); } - - reference_batched_gemm(p_host_ref, v_host_ref, o_host_ref); + + reference_batched_gemm( + p_host_ref, v_host_ref, o_host_ref); Tensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // clang-format off // permute if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); @@ -460,8 +481,12 @@ bool run(const ArgParser& arg_parser) self(idx) = lse_host(b, idx[0], idx[1] + query_offset); }); - bool lse_pass = ck::utils::check_err( - lse_host_result, lse_host_ref, "LSE Error: Incorrect results!", rtol, atol); + bool lse_pass = ck::utils::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); pass &= lse_pass; if(!cur_pass) diff --git a/example/91_tile_program/fmha/utils.hpp b/example/91_tile_program/fmha/utils.hpp index 960e59a8f..5a8ef1042 100644 --- a/example/91_tile_program/fmha/utils.hpp +++ b/example/91_tile_program/fmha/utils.hpp @@ -34,14 +34,14 @@ std::vector to_seqstarts(ck::span seqlens) return seqstarts; } -std::vector generate_seqlens_q(mode_enum mode, - unsigned count, - int32_t seqlens_q_sum, - std::optional seed = std::nullopt) +std::vector generate_seqlens(mode_enum mode, + unsigned count, + int32_t seqlens_sum, + std::optional seed = std::nullopt) { assert(0 < count); - std::vector seqlens_q(count, seqlens_q_sum); + std::vector seqlens(count, seqlens_sum); if(mode == mode_enum::group && 1 < count) { @@ -54,89 +54,31 @@ std::vector generate_seqlens_q(mode_enum mode, std::uniform_int_distribution step_dist(1, count - 1); auto next_step = std::bind(step_dist, std::ref(random_engine)); - for(unsigned repeat = seqlens_q_sum * (count / 2); 0 < repeat; --repeat) + for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) { const size_type to_decrease = next_idx(); - // make sure each elements of seqlens_q is always greater than 0 - if(seqlens_q[to_decrease] == 1) + // make sure each elements of seqlens is always greater than 0 + if(seqlens[to_decrease] == 1) { continue; } const size_type to_increase = (to_decrease + next_step()) % count; - --seqlens_q[to_decrease]; - ++seqlens_q[to_increase]; + --seqlens[to_decrease]; + ++seqlens[to_increase]; } } - return seqlens_q; -} - -std::tuple, std::vector> -generate_seqlens_seqstarts_q(mode_enum mode, - unsigned count, - int32_t seqlens_q_sum, - std::optional seed = std::nullopt) -{ - const std::vector seqlens_q = generate_seqlens_q(mode, count, seqlens_q_sum, seed); - return std::make_tuple(seqlens_q, to_seqstarts(seqlens_q)); + return seqlens; } -std::vector generate_seqlens_k(mode_enum mode, +std::vector generate_seqstarts(mode_enum mode, unsigned count, - int32_t seqlens_k_sum, - ck::span seqlens_q, - int32_t seqlens_q_sum, + int32_t seqlens_sum, std::optional seed = std::nullopt) { - assert(0 < count); - assert(seqlens_q.size() == count); - - std::vector seqlens_k(count, seqlens_k_sum); - - if(mode == mode_enum::group && 1 < count) - { - using size_type = std::vector::size_type; - - std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); - std::uniform_int_distribution idx_dist(0, count - 1); - auto next_idx = std::bind(idx_dist, std::ref(random_engine)); - - std::uniform_int_distribution step_dist(1, count - 1); - auto next_step = std::bind(step_dist, std::ref(random_engine)); - - for(unsigned repeat = seqlens_k_sum * (count / 2); 0 < repeat; --repeat) - { - const size_type to_decrease = next_idx(); - // make sure each elements of seqlens_k is always greater than 0 & greater than - // corresponding elements in seqlens_q - if(seqlens_k[to_decrease] == 1 || - (seqlens_q_sum < seqlens_k_sum && - seqlens_k[to_decrease] <= seqlens_q[to_decrease] + 1)) - { - continue; - } - - const size_type to_increase = (to_decrease + next_step()) % count; - - --seqlens_k[to_decrease]; - ++seqlens_k[to_increase]; - } - } - - return seqlens_k; -} - -std::vector generate_seqstarts_k(mode_enum mode, - unsigned count, - int32_t seqlens_k_sum, - ck::span seqlens_q, - int32_t seqlens_q_sum, - std::optional seed = std::nullopt) -{ - return to_seqstarts( - generate_seqlens_k(mode, count, seqlens_k_sum, seqlens_q, seqlens_q_sum, seed)); + return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); } int env_get_int(const char* var_name, int default_int) diff --git a/example/91_tile_program/reference/reference_batched_softmax.hpp b/example/91_tile_program/reference/reference_batched_softmax.hpp index 0f5447cff..ae6c861a4 100644 --- a/example/91_tile_program/reference/reference_batched_softmax.hpp +++ b/example/91_tile_program/reference/reference_batched_softmax.hpp @@ -3,7 +3,10 @@ #pragma once +#include +#include #include + #include "ck/utility/common_header.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -16,7 +19,7 @@ void reference_batched_softmax( const int N = a_b_m_n.mDesc.GetLengths()[2]; auto f = [&](auto batch, auto m) { - CompDataType v_max = ck::NumericLimits::Lowest(); + CompDataType v_max = -ck::NumericLimits::Infinity(); // max for(int n = 0; n < N; ++n) @@ -27,6 +30,11 @@ void reference_batched_softmax( } CompDataType v_exp_sum = 0; + // validate v_max if all the elements within a row are -INF + if(std::isinf(v_max) && v_max < 0) + { + v_max = ck::type_convert(0.f); + } // sum for(int n = 0; n < N; ++n) @@ -37,7 +45,7 @@ void reference_batched_softmax( } // if sum is zero(masked), or nan/inf(other computation error), don't do divide - CompDataType inv_sum = (v_exp_sum == 0.f || v_exp_sum != v_exp_sum) ? 1.f : 1.f / v_exp_sum; + CompDataType inv_sum = (v_exp_sum == 0.f ? 1.f : 1.f / v_exp_sum); // elementwise for(int n = 0; n < N; ++n) diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp index c1c48a58d..1400a2b20 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -170,6 +170,15 @@ struct BlockFmhaPipelineQRKSVS { if(num_total_loop <= 0) { + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + set_tile(lse, -NumericLimits::Infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + // Note: here occ are all cleard, return it // Note: q loaded but no fence, ignore it. return o_acc; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index b8a488ab9..eb33aea53 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -204,6 +204,15 @@ struct BlockFmhaPipelineQRKSVSAsync { if(num_total_loop <= 0) { + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + set_tile(lse, -NumericLimits::Infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + // Note: here occ are all cleard, return it // Note: q loaded but no fence, ignore it. return o_acc; diff --git a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 3c94597aa..87332d5fc 100644 --- a/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -174,6 +174,15 @@ struct BlockFmhaPipelineQSKSVS { if(num_total_loop <= 0) { + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.GetTileDistribution()); + + set_tile(lse, -NumericLimits::Infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + // Note: here occ are all cleard, return it // Note: q loaded but no fence, ignore it. return o_acc; diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index a3df884ee..33bc06ea0 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -31,9 +31,10 @@ typename std::enable_if< bool>::type check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-5, - double atol = 3e-6) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 3e-6, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -42,6 +43,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -51,7 +59,7 @@ check_err(const Range& out, const double o = *std::next(std::begin(out), i); const double r = *std::next(std::begin(ref), i); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; @@ -81,9 +89,10 @@ typename std::enable_if< bool>::type check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -92,6 +101,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -102,7 +118,7 @@ check_err(const Range& out, const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; @@ -132,9 +148,10 @@ typename std::enable_if< bool>::type check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -143,6 +160,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -152,7 +176,7 @@ check_err(const Range& out, const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; @@ -236,9 +260,10 @@ std::enable_if_t<(std::is_same_v, ranges::range_val bool> check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -247,6 +272,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -256,7 +288,7 @@ check_err(const Range& out, const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++; @@ -281,9 +313,10 @@ std::enable_if_t<(std::is_same_v, ranges::range_val bool> check_err(const Range& out, const RefRange& ref, - const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3, + bool allow_infinity_ref = false) { if(out.size() != ref.size()) { @@ -292,6 +325,13 @@ check_err(const Range& out, return false; } + const auto is_infinity_error = [=](auto o, auto r) { + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + + return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); + }; + bool res{true}; int err_count = 0; double err = 0; @@ -301,7 +341,7 @@ check_err(const Range& out, const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r)) { max_err = err > max_err ? err : max_err; err_count++;