From 013df2e4060934aac58efc45427c6d2bd4bfe8c6 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Wed, 19 Feb 2025 12:53:55 +0000 Subject: [PATCH] Add support for gemm pipeline v5 to grouped gemm tile loop --- .../blockwise_gemm_pipeline_xdlops_v5.hpp | 4 +- ...gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 37 +++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp index b6a4f05502a..ea8e15fc0cf 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -149,12 +149,12 @@ struct BlockwiseGemmXdlops_pipeline_v5 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { if(num_loop % HotloopUnroll == 1) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 61058dec2b2..06d7d573e72 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -371,6 +371,43 @@ __global__ void b2c_tile_map); } } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } } else {