From 4e0fd5241ae68487f8d40fb4453bda8f57c603e1 Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Thu, 15 Jan 2026 11:32:50 +0000 Subject: [PATCH 1/3] Separate tensor descriptor creation from the tensor view creation This adds utility functions to construct default tensor descriptors for A, B, C and D tensors and refactors the Make{A,B,C,D}BlockWindows to call make_tensor_view using the utility functions instead of directly calling make_naive_tensor_view, allowing for further refactors later. --- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 281 +++++++++--------- 1 file changed, 144 insertions(+), 137 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 9583ac8a3f1..a028989c825 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -661,6 +661,136 @@ struct UniversalGemmKernel return AsTensorIsValid && BsTensorIsValid && DTensorIsValid; } + template + CK_TILE_DEVICE static auto + MakeDefaultATensorDescriptor(const index_t M, const index_t stride, const index_t k_size) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, k_size), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_descriptor(make_tuple(k_size, M), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + } + + template + CK_TILE_DEVICE static auto MakeDefaultBTensorDescriptor(const index_t N, + const index_t K, + const index_t stride, + const index_t k_size) + { + if constexpr(std::is_same_v) + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1), + make_tuple(N * K1, K1, I1), + number{}, + number<1>{}); + return transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + else + { + return make_naive_tensor_descriptor(make_tuple(k_size, N), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = make_naive_tensor_descriptor(make_tuple(K0, N, K1), + make_tuple(N * K1, K1, I1), + number{}, + number<1>{}); + return transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + } + else + { + if constexpr(GemmPipeline::Preshuffle) + { + index_t kFlatK = + GemmPipeline::BlockGemmShape::flatKPerWarp * + (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = N * K / kFlatK; + + return make_naive_tensor_descriptor(make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N, k_size), + make_tuple(stride, 1), + number{}, + number<1>{}); + } + } + } + } + + template + CK_TILE_DEVICE static auto + MakeDefaultDTensorDescriptor(const index_t M, const index_t N, const index_t stride) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(stride, 1), number{}, number<1>{}); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, M), make_tuple(stride, 1), number{}, number<1>{}); + } + } + + CK_TILE_DEVICE static auto + MakeDefaultETensorDescriptor(const index_t M, const index_t N, const index_t stride) + { + // TODO: enable vector write for C in ColMajor + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(M, N), // arguments not matching with flatmm. + make_tuple(stride, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(1, stride), number<1>{}, number<1>{}); + } + } + CK_TILE_DEVICE static auto MakeABlockWindows(const std::array& as_ptr, const KernelArgs& kargs, @@ -672,24 +802,10 @@ struct UniversalGemmKernel [&](auto i) { using AiLayout = remove_cvref_t>; using AiDataType = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(as_ptr[i]), - make_tuple(kargs.M, k_size), - make_tuple(kargs.stride_As[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(as_ptr[i]), - make_tuple(k_size, kargs.M), - make_tuple(kargs.stride_As[i], 1), - number{}, - number<1>{}); - } + + return make_tensor_view( + static_cast(as_ptr[i]), + MakeDefaultATensorDescriptor(kargs.M, kargs.stride_As[i], k_size)); }, number{}); @@ -749,87 +865,10 @@ struct UniversalGemmKernel [&](auto i) { using BiLayout = remove_cvref_t>; using BiDataType = remove_cvref_t>; - if constexpr(std::is_same_v) - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = k_size / K1; - constexpr index_t VectorSizeB = - std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view( - static_cast(bs_ptr[i]), b_n_k_desc); - } - else - { - return make_naive_tensor_view( - bs_ptr[i], - make_tuple(k_size, kargs.N), - make_tuple(kargs.stride_Bs[i], 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = k_size / K1; - constexpr index_t VectorSizeB = - std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view( - static_cast(bs_ptr[i]), b_n_k_desc); - } - else - { - if constexpr(GemmPipeline::Preshuffle) - { - index_t kFlatK = - GemmPipeline::BlockGemmShape::flatKPerWarp * - (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - - return make_naive_tensor_view( - bs_ptr[i], - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - bs_ptr[i], - make_tuple(kargs.N, k_size), - make_tuple(kargs.stride_Bs[i], 1), - number{}, - number<1>{}); - } - } - } + return make_tensor_view( + static_cast(bs_ptr[i]), + MakeDefaultBTensorDescriptor( + kargs.N, kargs.K, kargs.stride_Bs[i], k_size)); }, number{}); @@ -900,24 +939,10 @@ struct UniversalGemmKernel [&](auto i) { using DiLayout = remove_cvref_t>; using DDataType_ = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.N, kargs.M), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } + return make_tensor_view( + static_cast(ds_ptr[i]), + MakeDefaultDTensorDescriptor( + kargs.M, kargs.N, kargs.stride_Ds[i])); }, number{}); @@ -973,26 +998,8 @@ struct UniversalGemmKernel const index_t i_n) { // Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews) - const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_E), - number<1>{}, - number<1>{}); - } - }(); + const auto& e_tensor_view = make_tensor_view( + e_ptr, MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E)); // Step 2: Create padded view (from MakeGemmPadViews) const auto& e_pad_view = [&]() { From fc3835528e4a0cbba37f33d867c8988ce8bad9c6 Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Thu, 15 Jan 2026 13:48:44 +0000 Subject: [PATCH 2/3] Generate the descriptors explicitly as separate tuples --- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 72 ++++++++++++------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index a028989c825..23767a935c1 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -797,19 +797,24 @@ struct UniversalGemmKernel const index_t k_size, const index_t i_m) { - // Step 1: Create tensor views for A tensors (from MakeGemmTensorViews) + // Step 1: Create tensor descriptors for A tensors + const auto& as_tensor_desc = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + return MakeDefaultATensorDescriptor(kargs.M, kargs.stride_As[i], k_size); + }, + number{}); + + // Step 1: Create tensor views const auto& as_tensor_view = generate_tuple( [&](auto i) { - using AiLayout = remove_cvref_t>; using AiDataType = remove_cvref_t>; - return make_tensor_view( - static_cast(as_ptr[i]), - MakeDefaultATensorDescriptor(kargs.M, kargs.stride_As[i], k_size)); + static_cast(as_ptr[i]), as_tensor_desc[i]); }, number{}); - // Step 2: Create padded views (from MakeGemmPadViews) + // Step 2: Create padded views const auto& as_pad_view = generate_tuple( [&](auto i) { using AiLayout = remove_cvref_t>; @@ -830,7 +835,7 @@ struct UniversalGemmKernel }, number{}); - // Step 3: Create tile windows (from MakeGemmTileWindows) + // Step 3: Create tile windows const auto& as_block_window = generate_tuple( [&](auto i) { using AiLayout = remove_cvref_t>; @@ -860,19 +865,25 @@ struct UniversalGemmKernel const index_t k_size, const index_t i_n) { - // Step 1: Create tensor views for B tensors (from MakeGemmTensorViews) + // Step 1: Create tensor descriptors for B tensors + const auto& bs_tensor_desc = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + return MakeDefaultBTensorDescriptor( + kargs.N, kargs.K, kargs.stride_Bs[i], k_size); + }, + number{}); + + // Step 2: Create tensor views const auto& bs_tensor_view = generate_tuple( [&](auto i) { - using BiLayout = remove_cvref_t>; using BiDataType = remove_cvref_t>; return make_tensor_view( - static_cast(bs_ptr[i]), - MakeDefaultBTensorDescriptor( - kargs.N, kargs.K, kargs.stride_Bs[i], k_size)); + static_cast(bs_ptr[i]), bs_tensor_desc[i]) }, number{}); - // Step 2: Create padded views (from MakeGemmPadViews) + // Step 3: Create padded views const auto& bs_pad_view = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -893,7 +904,7 @@ struct UniversalGemmKernel }, number{}); - // Step 3: Create tile windows (from MakeGemmTileWindows) + // Step 4: Create tile windows const auto& bs_block_window = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -934,19 +945,25 @@ struct UniversalGemmKernel const index_t i_m, const index_t i_n) { - // Step 1: Create tensor views for D tensors (from MakeGemmTensorViews) + // Step 1: Create tensor descriptors for D tensors + const auto& ds_tensor_desc = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + return MakeDefaultDTensorDescriptor( + kargs.M, kargs.N, kargs.stride_Ds[i]); + }, + number{}); + + // Step 2: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { - using DiLayout = remove_cvref_t>; using DDataType_ = remove_cvref_t>; return make_tensor_view( - static_cast(ds_ptr[i]), - MakeDefaultDTensorDescriptor( - kargs.M, kargs.N, kargs.stride_Ds[i])); + static_cast(ds_ptr[i]), ds_tensor_desc[i]); }, number{}); - // Step 2: Create padded views (from MakeGemmPadViews) + // Step 3: Create padded views const auto& ds_pad_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -967,7 +984,7 @@ struct UniversalGemmKernel }, number{}); - // Step 3: Create tile windows (from MakeGemmTileWindows) + // Step 4: Create tile windows const auto& ds_block_window = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -997,11 +1014,14 @@ struct UniversalGemmKernel const index_t i_m, const index_t i_n) { - // Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews) - const auto& e_tensor_view = make_tensor_view( - e_ptr, MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E)); + // Step 1: Create tensor descriptor for E/C tensor + const auto& e_tensor_desc = MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E); + + // Step 1: Create tensor view + const auto& e_tensor_view = + make_tensor_view(e_ptr, e_tensor_desc); - // Step 2: Create padded view (from MakeGemmPadViews) + // Step 2: Create padded view const auto& e_pad_view = [&]() { if constexpr(std::is_same_v) { @@ -1019,7 +1039,7 @@ struct UniversalGemmKernel } }(); - // Step 3: Create tile window (from MakeGemmTileWindows) + // Step 3: Create tile window auto e_block_window = make_tile_window( e_pad_view, make_tuple(number{}, number{}), From 767530856adafa6612bb1232318baa546b21b25d Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Thu, 15 Jan 2026 14:19:27 +0000 Subject: [PATCH 3/3] Add overloads of MakeA/B/C/DBlockWindows that accept descriptors This adds overloaded versions of the block window creation functions that allow the caller to specify explicit descriptors instead of the default ones, and reimplements the existing definitions by calling the new ones using default descriptors. --- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 130 +++++++++++------- 1 file changed, 81 insertions(+), 49 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 23767a935c1..217f8f1301a 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -775,14 +775,12 @@ struct UniversalGemmKernel CK_TILE_DEVICE static auto MakeDefaultETensorDescriptor(const index_t M, const index_t N, const index_t stride) { - // TODO: enable vector write for C in ColMajor if constexpr(std::is_same_v) { - return make_naive_tensor_descriptor( - make_tuple(M, N), // arguments not matching with flatmm. - make_tuple(stride, 1), - number{}, - number<1>{}); + return make_naive_tensor_descriptor(make_tuple(M, N), + make_tuple(stride, 1), + number{}, + number<1>{}); } else { @@ -791,26 +789,18 @@ struct UniversalGemmKernel } } + template CK_TILE_DEVICE static auto MakeABlockWindows(const std::array& as_ptr, - const KernelArgs& kargs, - const index_t k_size, + const AsTensorDesc& as_desc, const index_t i_m) { - // Step 1: Create tensor descriptors for A tensors - const auto& as_tensor_desc = generate_tuple( - [&](auto i) { - using AiLayout = remove_cvref_t>; - return MakeDefaultATensorDescriptor(kargs.M, kargs.stride_As[i], k_size); - }, - number{}); - // Step 1: Create tensor views const auto& as_tensor_view = generate_tuple( [&](auto i) { using AiDataType = remove_cvref_t>; return make_tensor_view( - static_cast(as_ptr[i]), as_tensor_desc[i]); + static_cast(as_ptr[i]), as_desc[i]); }, number{}); @@ -860,30 +850,38 @@ struct UniversalGemmKernel } CK_TILE_DEVICE static auto - MakeBBlockWindows(const std::array& bs_ptr, + MakeABlockWindows(const std::array& as_ptr, const KernelArgs& kargs, const index_t k_size, - const index_t i_n) + const index_t i_m) { - // Step 1: Create tensor descriptors for B tensors - const auto& bs_tensor_desc = generate_tuple( + // Step 1: Create tensor descriptors for A tensors + const auto& as_tensor_desc = generate_tuple( [&](auto i) { - using BiLayout = remove_cvref_t>; - return MakeDefaultBTensorDescriptor( - kargs.N, kargs.K, kargs.stride_Bs[i], k_size); + using AiLayout = remove_cvref_t>; + return MakeDefaultATensorDescriptor(kargs.M, kargs.stride_As[i], k_size); }, - number{}); + number{}); - // Step 2: Create tensor views + return MakeABlockWindows(as_ptr, as_tensor_desc, i_m); + } + + template + CK_TILE_DEVICE static auto + MakeBBlockWindows(const std::array& bs_ptr, + const BsTensorDesc& bs_desc, + const index_t i_n) + { + // Step 1: Create tensor views const auto& bs_tensor_view = generate_tuple( [&](auto i) { using BiDataType = remove_cvref_t>; return make_tensor_view( - static_cast(bs_ptr[i]), bs_tensor_desc[i]) + static_cast(bs_ptr[i]), bs_desc[i]); }, number{}); - // Step 3: Create padded views + // Step 2: Create padded views const auto& bs_pad_view = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -904,7 +902,7 @@ struct UniversalGemmKernel }, number{}); - // Step 4: Create tile windows + // Step 3: Create tile windows const auto& bs_block_window = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -940,30 +938,39 @@ struct UniversalGemmKernel return bs_block_window; } - CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, - const KernelArgs& kargs, - const index_t i_m, - const index_t i_n) + CK_TILE_DEVICE static auto + MakeBBlockWindows(const std::array& bs_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_n) { - // Step 1: Create tensor descriptors for D tensors - const auto& ds_tensor_desc = generate_tuple( + const auto& bs_tensor_desc = generate_tuple( [&](auto i) { - using DiLayout = remove_cvref_t>; - return MakeDefaultDTensorDescriptor( - kargs.M, kargs.N, kargs.stride_Ds[i]); + using BiLayout = remove_cvref_t>; + return MakeDefaultBTensorDescriptor( + kargs.N, kargs.K, kargs.stride_Bs[i], k_size); }, - number{}); + number{}); - // Step 2: Create tensor views + return MakeBBlockWindows(bs_ptr, bs_tensor_desc, i_n); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const DsTensorDesc& ds_desc, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DDataType_ = remove_cvref_t>; return make_tensor_view( - static_cast(ds_ptr[i]), ds_tensor_desc[i]); + static_cast(ds_ptr[i]), ds_desc[i]); }, number{}); - // Step 3: Create padded views + // Step 2: Create padded views const auto& ds_pad_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -984,7 +991,7 @@ struct UniversalGemmKernel }, number{}); - // Step 4: Create tile windows + // Step 3: Create tile windows const auto& ds_block_window = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -1008,18 +1015,32 @@ struct UniversalGemmKernel return ds_block_window; } - template - CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr, + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, const KernelArgs& kargs, const index_t i_m, const index_t i_n) { - // Step 1: Create tensor descriptor for E/C tensor - const auto& e_tensor_desc = MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E); + const auto& ds_tensor_desc = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + return MakeDefaultDTensorDescriptor( + kargs.M, kargs.N, kargs.stride_Ds[i]); + }, + number{}); + + return MakeDBlockWindows(ds_ptr, ds_tensor_desc, i_m, i_n); + } - // Step 1: Create tensor view + template + CK_TILE_DEVICE static auto MakeCBlockWindows( + EDataType* e_ptr, + const index_t i_m, + const index_t i_n, + const ETensorDesc& e_desc) // Argument order differs from A,B,D to disambiguate overloads + { + // Step 1: Create tensor view for E/C tensor const auto& e_tensor_view = - make_tensor_view(e_ptr, e_tensor_desc); + make_tensor_view(e_ptr, e_desc); // Step 2: Create padded view const auto& e_pad_view = [&]() { @@ -1048,6 +1069,17 @@ struct UniversalGemmKernel return e_block_window; } + template + CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + + const auto& e_tensor_desc = MakeDefaultETensorDescriptor(kargs.M, kargs.N, kargs.stride_E); + return MakeCBlockWindows(e_ptr, i_m, i_n, e_tensor_desc); + } + /** * @brief Runs single GEMM problem cooperatively by whole workgroup. *