diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index 968d5d6ac2b..04c8f92f2db 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -71,7 +71,6 @@ * * **Architecture:** * - Uses TensorDescriptorUtils for stride-aware descriptor creation - * - Custom RunGemm implementation with descriptor-based tensor views * - Reuses GemmPipeline and EpiloguePipeline for computation * - Split-K support via UniversalGemmKernel utilities */ @@ -375,104 +374,6 @@ struct BatchedContractionKernel TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch); } - /// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride - /// support - /// - /// @details This function performs the core GEMM computation using tensor descriptors to handle - /// arbitrary multi-dimensional stride patterns. It creates tensor views from - /// pre-computed descriptors (stored in kargs), applies padding, creates tile windows, - /// and executes the GemmPipeline and EpiloguePipeline. - /// - /// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied) - /// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied) - /// @param ds_ptr Array of pointers to auxiliary D tensor data - /// @param e_ptr Pointer to output tensor E data (after batch offset applied) - /// @param smem_ptr Pointer to shared memory for tile operations - /// @param kargs Kernel arguments containing tensor descriptors and dimension information - /// @param k_size Size of K dimension for this split (for split-K support) - /// @param i_m Starting M index for this block's tile - /// @param i_n Starting N index for this block's tile - CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr, - const KernelArgs& kargs, - const index_t k_size, - const index_t i_m, - const index_t i_n) - { - // Create tensor views from descriptors (supports arbitrary stride patterns) - auto a_tensor_view = - make_tensor_view(a_ptr, kargs.a_grid_desc_m_k); - auto b_tensor_view = - make_tensor_view(b_ptr, kargs.b_grid_desc_n_k); - auto e_tensor_view = - make_tensor_view(e_ptr, kargs.e_grid_desc_m_n); - - // Pad views for boundary handling and optimization (like UniversalGemmKernel) - auto a_pad_view = pad_tensor_view( - a_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - - auto b_pad_view = pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - - auto e_pad_view = pad_tensor_view( - e_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - - // Create tile windows from PADDED views - auto a_block_window = make_tile_window( - a_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - - auto b_block_window = make_tile_window( - b_pad_view, - make_tuple(number{}, number{}), - {i_n, 0}); - - auto e_block_window = make_tile_window( - e_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - // Calculate number of K loops - const index_t num_loop = - __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size)); - - // Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows) - using AElementWise = remove_cvref_t; - using BElementWise = remove_cvref_t; - - const auto& c_block_tile = GemmPipeline{}( - a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr); - - // Create D windows from descriptors (for each D tensor) - auto ds_block_windows = generate_tuple( - [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* d_ptr = static_cast(ds_ptr[i]); - - auto d_tensor_view = - make_tensor_view(d_ptr, kargs.ds_grid_desc_m_n[i]); - - return make_tile_window(d_tensor_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - }, - number{}); - - // Run Epilogue Pipeline with descriptor-based D windows - EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr); - } - CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const BatchedContractionHostArgs& host_args) { @@ -671,18 +572,28 @@ struct BatchedContractionKernel i_splitk); // Apply K-split offsets and run descriptor-based RunGemm - const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; - - RunGemm(a_ptr_split, - b_ptr_split, - ds_batch_ptr, - e_ptr, - smem_ptr, - kargs, - splitk_batch_offset.splitted_k, - i_m, - i_n); + const std::array{}> a_ptr_split = { + a_ptr + splitk_batch_offset.as_k_split_offset[0]}; + const std::array{}> b_ptr_split = { + b_ptr + splitk_batch_offset.bs_k_split_offset[0]}; + + const std::array{}> a_grid_desc = { + kargs.a_grid_desc_m_k}; + const std::array{}> b_grid_desc = { + kargs.b_grid_desc_n_k}; + + UniversalGemmKernel::RunGemmDesc(a_ptr_split, + b_ptr_split, + ds_batch_ptr, + e_ptr, + smem_ptr, + splitk_batch_offset, + i_m, + i_n, + a_grid_desc, + b_grid_desc, + kargs.ds_grid_desc_m_n, + kargs.e_grid_desc_m_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 5f7e78fac28..1b1e5d11dcd 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -936,6 +936,84 @@ struct UniversalGemmKernel return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); } + // Version of RunGemm using descriptors + // FIXME: Currently Templated to XsList to allow both arrays and tuples for convenience, which + // doesn't enforce same size nor matching types (as with arrays) + template + CK_TILE_DEVICE static void RunGemmDesc(const AsList& as_ptr, + const BsList& bs_ptr, + const DsList& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n, + const AGridDescs& as_desc, + const BGridDescs& bs_desc, + const DGridDescs& ds_desc, + const EGridDesc& e_desc) + { + // Create tensor views from descriptors (supports arbitrary stride patterns) + const auto& as_tensor_view = generate_tuple( + [&](auto i) { + using AiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(as_ptr[i]), as_desc[i]); + }, + number{}); + + const auto& bs_tensor_view = generate_tuple( + [&](auto i) { + using BiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(bs_ptr[i]), bs_desc[i]); + }, + number{}); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiDataType = remove_cvref_t>; + return make_tensor_view( + static_cast(ds_ptr[i]), ds_desc[i]); + }, + number{}); + + auto e_tensor_view = + make_tensor_view(static_cast(e_ptr), e_desc); + + const auto& gemm_tensors_views_tuple = + make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensors_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * @@ -961,32 +1039,38 @@ struct UniversalGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = - amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(I0); - const auto& bs_block_window = gemm_tile_windows.at(I1); - const auto& ds_block_window = gemm_tile_windows.at(I2); + // FIXME: Refactor to generate descriptors and views separately, then rework signatures + // FIXME: pointers need to be extracted as well + // FIXME: Fails (at least) 1024x1024x256_splitk2 and 1024x1024x256_splitk4 in + // test_gemm_tile_engine_fp16_rcr_quick_coverage_config_compv3_cshuffle_intrawave_False_False_False_False_32x64x16_2x2x1_16x16x16 - const auto& c_block_tile = GemmPipeline{}.template operator()( - as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); - - if(UseDefaultScheduler || (get_warp_id() == 0)) - { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + auto as_desc = generate_tuple( + [&](auto i) { return gemm_tensor_views_tuple.at(I0)[i].get_tensor_descriptor(); }, + number{}); + auto bs_desc = generate_tuple( + [&](auto i) { return gemm_tensor_views_tuple.at(I1)[i].get_tensor_descriptor(); }, + number{}); + auto ds_desc = generate_tuple( + [&](auto i) { return gemm_tensor_views_tuple.at(I2)[i].get_tensor_descriptor(); }, + number{}); + auto e_desc = gemm_tensor_views_tuple.at(I3).get_tensor_descriptor(); - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); - } + RunGemmDesc(as_ptr, + bs_ptr, + ds_ptr, + e_ptr, + smem_ptr_0, + splitk_batch_offset, + block_idx_m, + block_idx_n, + as_desc, + bs_desc, + ds_desc, + e_desc); } /**