diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 9583ac8a3f1..217f8f1301a 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -661,39 +661,150 @@ struct UniversalGemmKernel return AsTensorIsValid && BsTensorIsValid && DTensorIsValid; } + template CK_TILE_DEVICE static auto - MakeABlockWindows(const std::array& as_ptr, - const KernelArgs& kargs, - const index_t k_size, - const index_t i_m) + MakeDefaultATensorDescriptor(const index_t M, const index_t stride, const index_t k_size) { - // Step 1: Create tensor views for A tensors (from MakeGemmTensorViews) - const auto& as_tensor_view = generate_tuple( - [&](auto i) { - using AiLayout = remove_cvref_t>; - using AiDataType = remove_cvref_t>; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, k_size), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_descriptor(make_tuple(k_size, M), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + } + + template + CK_TILE_DEVICE static auto MakeDefaultBTensorDescriptor(const index_t N, + const index_t K, + const index_t stride, + const index_t k_size) + { + if constexpr(std::is_same_v) + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1), + make_tuple(N * K1, K1, I1), + number{}, + number<1>{}); + return transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + return make_naive_tensor_descriptor(make_tuple(k_size, N), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1), + make_tuple(N * K1, K1, I1), + number{}, + number<1>{}); + return transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + } + else + { + if constexpr(GemmPipeline::Preshuffle) { - return make_naive_tensor_view( - static_cast(as_ptr[i]), - make_tuple(kargs.M, k_size), - make_tuple(kargs.stride_As[i], 1), - number{}, - number<1>{}); + index_t kFlatK = + GemmPipeline::BlockGemmShape::flatKPerWarp * + (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = N * K / kFlatK; + + return make_naive_tensor_descriptor(make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); } else { - return make_naive_tensor_view( - static_cast(as_ptr[i]), - make_tuple(k_size, kargs.M), - make_tuple(kargs.stride_As[i], 1), - number{}, - number<1>{}); + return make_naive_tensor_descriptor(make_tuple(N, k_size), + make_tuple(stride, 1), + number{}, + number<1>{}); } + } + } + } + + template + CK_TILE_DEVICE static auto + MakeDefaultDTensorDescriptor(const index_t M, const index_t N, const index_t stride) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(stride, 1), number{}, number<1>{}); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, M), make_tuple(stride, 1), number{}, number<1>{}); + } + } + + CK_TILE_DEVICE static auto + MakeDefaultETensorDescriptor(const index_t M, const index_t N, const index_t stride) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, N), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(1, stride), number<1>{}, number<1>{}); + } + } + + template + CK_TILE_DEVICE static auto + MakeABlockWindows(const std::array& as_ptr, + const AsTensorDesc& as_desc, + const index_t i_m) + { + // Step 1: Create tensor views + const auto& as_tensor_view = generate_tuple( + [&](auto i) { + using AiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(as_ptr[i]), as_desc[i]); }, number{}); - // Step 2: Create padded views (from MakeGemmPadViews) + // Step 2: Create padded views const auto& as_pad_view = generate_tuple( [&](auto i) { using AiLayout = remove_cvref_t>; @@ -714,7 +825,7 @@ struct UniversalGemmKernel }, number{}); - // Step 3: Create tile windows (from MakeGemmTileWindows) + // Step 3: Create tile windows const auto& as_block_window = generate_tuple( [&](auto i) { using AiLayout = remove_cvref_t>; @@ -739,101 +850,38 @@ struct UniversalGemmKernel } CK_TILE_DEVICE static auto - MakeBBlockWindows(const std::array& bs_ptr, + MakeABlockWindows(const std::array& as_ptr, const KernelArgs& kargs, const index_t k_size, + const index_t i_m) + { + // Step 1: Create tensor descriptors for A tensors + const auto& as_tensor_desc = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + return MakeDefaultATensorDescriptor(kargs.M, kargs.stride_As[i], k_size); + }, + number{}); + + return MakeABlockWindows(as_ptr, as_tensor_desc, i_m); + } + + template + CK_TILE_DEVICE static auto + MakeBBlockWindows(const std::array& bs_ptr, + const BsTensorDesc& bs_desc, const index_t i_n) { - // Step 1: Create tensor views for B tensors (from MakeGemmTensorViews) + // Step 1: Create tensor views const auto& bs_tensor_view = generate_tuple( [&](auto i) { - using BiLayout = remove_cvref_t>; using BiDataType = remove_cvref_t>; - if constexpr(std::is_same_v) - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = k_size / K1; - constexpr index_t VectorSizeB = - std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view( - static_cast(bs_ptr[i]), b_n_k_desc); - } - else - { - return make_naive_tensor_view( - bs_ptr[i], - make_tuple(k_size, kargs.N), - make_tuple(kargs.stride_Bs[i], 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = k_size / K1; - constexpr index_t VectorSizeB = - std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view( - static_cast(bs_ptr[i]), b_n_k_desc); - } - else - { - if constexpr(GemmPipeline::Preshuffle) - { - index_t kFlatK = - GemmPipeline::BlockGemmShape::flatKPerWarp * - (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - - return make_naive_tensor_view( - bs_ptr[i], - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - bs_ptr[i], - make_tuple(kargs.N, k_size), - make_tuple(kargs.stride_Bs[i], 1), - number{}, - number<1>{}); - } - } - } + return make_tensor_view( + static_cast(bs_ptr[i]), bs_desc[i]); }, number{}); - // Step 2: Create padded views (from MakeGemmPadViews) + // Step 2: Create padded views const auto& bs_pad_view = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -854,7 +902,7 @@ struct UniversalGemmKernel }, number{}); - // Step 3: Create tile windows (from MakeGemmTileWindows) + // Step 3: Create tile windows const auto& bs_block_window = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -890,38 +938,39 @@ struct UniversalGemmKernel return bs_block_window; } + CK_TILE_DEVICE static auto + MakeBBlockWindows(const std::array& bs_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_n) + { + const auto& bs_tensor_desc = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + return MakeDefaultBTensorDescriptor( + kargs.N, kargs.K, kargs.stride_Bs[i], k_size); + }, + number{}); + + return MakeBBlockWindows(bs_ptr, bs_tensor_desc, i_n); + } + + template CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, - const KernelArgs& kargs, + const DsTensorDesc& ds_desc, const index_t i_m, const index_t i_n) { - // Step 1: Create tensor views for D tensors (from MakeGemmTensorViews) + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { - using DiLayout = remove_cvref_t>; using DDataType_ = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.N, kargs.M), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } + return make_tensor_view( + static_cast(ds_ptr[i]), ds_desc[i]); }, number{}); - // Step 2: Create padded views (from MakeGemmPadViews) + // Step 2: Create padded views const auto& ds_pad_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -942,7 +991,7 @@ struct UniversalGemmKernel }, number{}); - // Step 3: Create tile windows (from MakeGemmTileWindows) + // Step 3: Create tile windows const auto& ds_block_window = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -966,35 +1015,34 @@ struct UniversalGemmKernel return ds_block_window; } - template - CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr, + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, const KernelArgs& kargs, const index_t i_m, const index_t i_n) { - // Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews) - const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_E), - number<1>{}, - number<1>{}); - } - }(); + const auto& ds_tensor_desc = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + return MakeDefaultDTensorDescriptor( + kargs.M, kargs.N, kargs.stride_Ds[i]); + }, + number{}); - // Step 2: Create padded view (from MakeGemmPadViews) + return MakeDBlockWindows(ds_ptr, ds_tensor_desc, i_m, i_n); + } + + template + CK_TILE_DEVICE static auto MakeCBlockWindows( + EDataType* e_ptr, + const index_t i_m, + const index_t i_n, + const ETensorDesc& e_desc) // Argument order differs from A,B,D to disambiguate overloads + { + // Step 1: Create tensor view for E/C tensor + const auto& e_tensor_view = + make_tensor_view(e_ptr, e_desc); + + // Step 2: Create padded view const auto& e_pad_view = [&]() { if constexpr(std::is_same_v) { @@ -1012,7 +1060,7 @@ struct UniversalGemmKernel } }(); - // Step 3: Create tile window (from MakeGemmTileWindows) + // Step 3: Create tile window auto e_block_window = make_tile_window( e_pad_view, make_tuple(number{}, number{}), @@ -1021,6 +1069,17 @@ struct UniversalGemmKernel return e_block_window; } + template + CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + + const auto& e_tensor_desc = MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E); + return MakeCBlockWindows(e_ptr, i_m, i_n, e_tensor_desc); + } + /** * @brief Runs single GEMM problem cooperatively by whole workgroup. *