From 96820bf5a8157ae1ae3ec04d778f844530e3a318 Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Tue, 16 Dec 2025 12:12:52 +0000 Subject: [PATCH 1/4] Implement RunGemmDesc that allows directly passing descriptors --- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 5f7e78fac28..e4a2908a531 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -989,6 +989,78 @@ struct UniversalGemmKernel } } + // Version of RunGemm using descriptors + template + CK_TILE_DEVICE static void RunGemmDesc(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n, + const std::array& as_desc, + const std::array& bs_desc, + const std::array& ds_desc, + const EGridDesc& e_desc) + { + // Create tensor views from descriptors (supports arbitrary stride patterns) + const auto& as_tensor_view = generate_tuple( + [&](auto i) { + using AiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(as_ptr[i]), as_desc[i]); + }, + number{}); + + const auto& bs_tensor_view = generate_tuple( + [&](auto i) { + using BiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(bs_ptr[i]), bs_desc[i]); + }, + number{}); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(ds_ptr[i]), ds_desc[i]); + }, + number{}); + + auto e_tensor_view = + make_tensor_view(static_cast(e_ptr), e_desc); + + const auto& gemm_tensors_views_tuple = + make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensors_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * From ccf4558f9a5c055ce415db44fe810cb68aef4919 Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Wed, 17 Dec 2025 13:36:35 +0000 Subject: [PATCH 2/4] Use RunGemmDesc instead of custom RunGemm in BatchedContractionKernel --- .../kernel/batched_contraction_kernel.hpp | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index 968d5d6ac2b..f0f40828d33 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -671,18 +671,28 @@ struct BatchedContractionKernel i_splitk); // Apply K-split offsets and run descriptor-based RunGemm - const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; - - RunGemm(a_ptr_split, - b_ptr_split, - ds_batch_ptr, - e_ptr, - smem_ptr, - kargs, - splitk_batch_offset.splitted_k, - i_m, - i_n); + const std::array{}> a_ptr_split = { + a_ptr + splitk_batch_offset.as_k_split_offset[0]}; + const std::array{}> b_ptr_split = { + b_ptr + splitk_batch_offset.bs_k_split_offset[0]}; + + const std::array{}> a_grid_desc = { + kargs.a_grid_desc_m_k}; + const std::array{}> b_grid_desc = { + kargs.b_grid_desc_n_k}; + + UniversalGemmKernel::RunGemmDesc(a_ptr_split, + b_ptr_split, + ds_batch_ptr, + e_ptr, + smem_ptr, + splitk_batch_offset, + i_m, + i_n, + a_grid_desc, + b_grid_desc, + kargs.ds_grid_desc_m_n, + kargs.e_grid_desc_m_n); } }; From 26cdb3e65f18a2c202354d10ec32e76d4882e79d Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Wed, 17 Dec 2025 13:38:48 +0000 Subject: [PATCH 3/4] Remove custom RunGemm implementation --- .../kernel/batched_contraction_kernel.hpp | 99 ------------------- 1 file changed, 99 deletions(-) diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index f0f40828d33..04c8f92f2db 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -71,7 +71,6 @@ * * **Architecture:** * - Uses TensorDescriptorUtils for stride-aware descriptor creation - * - Custom RunGemm implementation with descriptor-based tensor views * - Reuses GemmPipeline and EpiloguePipeline for computation * - Split-K support via UniversalGemmKernel utilities */ @@ -375,104 +374,6 @@ struct BatchedContractionKernel TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch); } - /// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride - /// support - /// - /// @details This function performs the core GEMM computation using tensor descriptors to handle - /// arbitrary multi-dimensional stride patterns. It creates tensor views from - /// pre-computed descriptors (stored in kargs), applies padding, creates tile windows, - /// and executes the GemmPipeline and EpiloguePipeline. - /// - /// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied) - /// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied) - /// @param ds_ptr Array of pointers to auxiliary D tensor data - /// @param e_ptr Pointer to output tensor E data (after batch offset applied) - /// @param smem_ptr Pointer to shared memory for tile operations - /// @param kargs Kernel arguments containing tensor descriptors and dimension information - /// @param k_size Size of K dimension for this split (for split-K support) - /// @param i_m Starting M index for this block's tile - /// @param i_n Starting N index for this block's tile - CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr, - const KernelArgs& kargs, - const index_t k_size, - const index_t i_m, - const index_t i_n) - { - // Create tensor views from descriptors (supports arbitrary stride patterns) - auto a_tensor_view = - make_tensor_view(a_ptr, kargs.a_grid_desc_m_k); - auto b_tensor_view = - make_tensor_view(b_ptr, kargs.b_grid_desc_n_k); - auto e_tensor_view = - make_tensor_view(e_ptr, kargs.e_grid_desc_m_n); - - // Pad views for boundary handling and optimization (like UniversalGemmKernel) - auto a_pad_view = pad_tensor_view( - a_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - - auto b_pad_view = pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - - auto e_pad_view = pad_tensor_view( - e_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - - // Create tile windows from PADDED views - auto a_block_window = make_tile_window( - a_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - - auto b_block_window = make_tile_window( - b_pad_view, - make_tuple(number{}, number{}), - {i_n, 0}); - - auto e_block_window = make_tile_window( - e_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - // Calculate number of K loops - const index_t num_loop = - __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size)); - - // Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows) - using AElementWise = remove_cvref_t; - using BElementWise = remove_cvref_t; - - const auto& c_block_tile = GemmPipeline{}( - a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr); - - // Create D windows from descriptors (for each D tensor) - auto ds_block_windows = generate_tuple( - [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* d_ptr = static_cast(ds_ptr[i]); - - auto d_tensor_view = - make_tensor_view(d_ptr, kargs.ds_grid_desc_m_n[i]); - - return make_tile_window(d_tensor_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - }, - number{}); - - // Run Epilogue Pipeline with descriptor-based D windows - EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr); - } - CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const BatchedContractionHostArgs& host_args) { From ee9ba8cb5684eae6539550ab83bd933452260e6c Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Thu, 18 Dec 2025 13:31:29 +0000 Subject: [PATCH 4/4] [WIP] Partial attempt at implementing RunGemm using RunGemmDesc --- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 134 ++++++++++-------- 1 file changed, 73 insertions(+), 61 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index e4a2908a531..1b1e5d11dcd 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -936,75 +936,28 @@ struct UniversalGemmKernel return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); } - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @param as_ptr input As pointer - * @param bs_ptr input Bs pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - template - CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr_0, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = - amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(I0); - const auto& bs_block_window = gemm_tile_windows.at(I1); - const auto& ds_block_window = gemm_tile_windows.at(I2); - - const auto& c_block_tile = GemmPipeline{}.template operator()( - as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); - - if(UseDefaultScheduler || (get_warp_id() == 0)) - { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); - } - } - // Version of RunGemm using descriptors - template - CK_TILE_DEVICE static void RunGemmDesc(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, + CK_TILE_DEVICE static void RunGemmDesc(const AsList& as_ptr, + const BsList& bs_ptr, + const DsList& ds_ptr, EDataType* e_ptr, void* smem_ptr_0, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n, - const std::array& as_desc, - const std::array& bs_desc, - const std::array& ds_desc, + const AGridDescs& as_desc, + const BGridDescs& bs_desc, + const DGridDescs& ds_desc, const EGridDesc& e_desc) { // Create tensor views from descriptors (supports arbitrary stride patterns) @@ -1061,6 +1014,65 @@ struct UniversalGemmKernel } } + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param as_ptr input As pointer + * @param bs_ptr input Bs pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + template + CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); + + // FIXME: Refactor to generate descriptors and views separately, then rework signatures + // FIXME: pointers need to be extracted as well + // FIXME: Fails (at least) 1024x1024x256_splitk2 and 1024x1024x256_splitk4 in + // test_gemm_tile_engine_fp16_rcr_quick_coverage_config_compv3_cshuffle_intrawave_False_False_False_False_32x64x16_2x2x1_16x16x16 + + auto as_desc = generate_tuple( + [&](auto i) { return gemm_tensor_views_tuple.at(I0)[i].get_tensor_descriptor(); }, + number{}); + auto bs_desc = generate_tuple( + [&](auto i) { return gemm_tensor_views_tuple.at(I1)[i].get_tensor_descriptor(); }, + number{}); + auto ds_desc = generate_tuple( + [&](auto i) { return gemm_tensor_views_tuple.at(I2)[i].get_tensor_descriptor(); }, + number{}); + auto e_desc = gemm_tensor_views_tuple.at(I3).get_tensor_descriptor(); + + RunGemmDesc(as_ptr, + bs_ptr, + ds_ptr, + e_ptr, + smem_ptr_0, + splitk_batch_offset, + block_idx_m, + block_idx_n, + as_desc, + bs_desc, + ds_desc, + e_desc); + } + /** * @brief Runs single GEMM problem cooperatively by whole workgroup. *