-
Notifications
You must be signed in to change notification settings - Fork 23
Grouped GEMM with ck_tile #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
This reverts commit 86fbbac.
tests/pytorch/test_numerics.py
Outdated
| delay_wgrad_compute, | ||
| ): | ||
| os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" | ||
| if IS_HIP_EXTENSION: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is our CK grouped gemm a drop-in replacement with NV upstream CUTLASS grouped gemm? If so, we can share the same env. It's like cublaslt vs hipblaslt...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It mostly is a drop-in replacement for upstream, so I changed the envs to the upstream versions in 259645c
| struct TileCfg_basic { | ||
| static constexpr ck_tile::index_t M_Tile = 256; | ||
| static constexpr ck_tile::index_t N_Tile = 128; | ||
| static constexpr ck_tile::index_t K_Tile = 64; | ||
|
|
||
| static constexpr ck_tile::index_t M_Warp = 2; | ||
| static constexpr ck_tile::index_t N_Warp = 2; | ||
| static constexpr ck_tile::index_t K_Warp = 1; | ||
|
|
||
| static constexpr ck_tile::index_t M_Warp_Tile = 32; | ||
| static constexpr ck_tile::index_t N_Warp_Tile = 32; | ||
| static constexpr ck_tile::index_t K_Warp_Tile = 16; | ||
|
|
||
| static constexpr bool kPadM = true; | ||
| static constexpr bool kPadN = true; | ||
| static constexpr bool kPadK = true; | ||
|
|
||
| static constexpr bool DoubleSmemBuffer = false; | ||
|
|
||
| static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; | ||
| static constexpr ck_tile::index_t TilePartitionerM01 = 1; | ||
| }; | ||
|
|
||
| template <typename AType, typename BType, typename CType, | ||
| typename ALayout, typename BLayout, typename CLayout, | ||
| typename TileCfg, ck_tile::memory_operation_enum MemOp, | ||
| typename AccType = float> | ||
| class Runner{ | ||
| public: | ||
| using GemmShape = ck_tile::TileGemmShape< | ||
| ck_tile::sequence<TileCfg::M_Tile, TileCfg::N_Tile, TileCfg::K_Tile>, | ||
| ck_tile::sequence<TileCfg::M_Warp, TileCfg::N_Warp, TileCfg::K_Warp>, | ||
| ck_tile::sequence<TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile>>; | ||
|
|
||
| using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< | ||
| GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; | ||
|
|
||
| using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< | ||
| TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, | ||
| TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; | ||
|
|
||
| static constexpr ck_tile::GemmPipelineScheduler Scheduler = | ||
| ck_tile::GemmPipelineScheduler::Intrawave; | ||
|
|
||
| using Problem = ck_tile::UniversalGemmPipelineProblem< | ||
| AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; | ||
|
|
||
| using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>; | ||
|
|
||
| using Epilogue = ck_tile::CShuffleEpilogue< | ||
| ck_tile::CShuffleEpilogueProblem< | ||
| AType, BType, ck_tile::tuple<>, AccType, | ||
| CType, ck_tile::tuple<>, CLayout, | ||
| ck_tile::element_wise::PassThrough, | ||
| Partitioner::MPerBlock, Partitioner::NPerBlock, | ||
| TileCfg::M_Warp, TileCfg::N_Warp, | ||
| TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, | ||
| Problem::TransposeC, MemOp>>; | ||
|
|
||
| using Kernel = ck_tile::GroupedGemmKernel<Partitioner, Pipeline, Epilogue>; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these codes from CK repo? If so, can you add a comment to point to the reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a reference in fac7c11.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see the comment with the reference to CK repo, so I am resolving this.
| std::vector<ck_tile::GroupedGemmHostArgs<0>> descs; | ||
| descs.reserve(group_num); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not put group_num inside the desc vector definition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used reserve() here instead of std::vector<ck_tile::GroupedGemmHostArgs<0>> descs(group_num); to avoid default-constructing GroupedGemmHostArgs objects that are immediately overwritten, to reduce construction overhead.
| using R = Runner<T, T, T, ALayout, BLayout, CLayout, TileCfg_basic, MemOp>; | ||
| using Kernel = typename R::Kernel; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This R is not used anywhere else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged R into the next line in fac7c11.
| if (a.shape.size() != 2 || b.shape.size() != 2 || d.shape.size() != 2) { | ||
| NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does grouped gemm support generalized matrices from high-dimensional tensors? Regular gemm supports that. And TE treat the last dim as col with other dimensions as row:
TransformerEngine/transformer_engine/common/common.h
Lines 238 to 262 in 9d6b0e5
| size_t flat_first_dim() const { | |
| const auto &full_shape = shape(); | |
| size_t ret = 1; | |
| if (!full_shape.empty()) { | |
| for (size_t i = 0; i < full_shape.size() - 1; i++) { | |
| ret *= full_shape[i]; | |
| } | |
| } | |
| return ret; | |
| } | |
| /*! Matrix width after tensor is flattened to 2D | |
| * | |
| * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted | |
| * as a (D1*D2*...*D(n-1), Dn) matrix. | |
| */ | |
| size_t flat_last_dim() const { | |
| const auto &full_shape = shape(); | |
| if (full_shape.empty()) { | |
| return 1; | |
| } else { | |
| return full_shape.back(); | |
| } | |
| } | |
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added (untested) support for higher-dim tensors in dd3ed2f
| } | ||
| } | ||
|
|
||
| bool grouped_gemm_ck_tile(const NVTETensor* A, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we overload this function? In cublaslt_gemm.cu, it's only called by this signature. Perhaps we can rename the grouped_gemm_ck_tile in line 255
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simplified this in 259645c so that there is no more overload (only this signature remains).
| transformer_engine::getenv<bool>("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", false); | ||
|
|
||
| auto is_supported_dtype = [&]() -> bool { | ||
| auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible that num_group=0 so A[0] access not valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) | ||
|
|
||
| target_include_directories(transformer_engine | ||
| BEFORE PRIVATE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why using keyword BEFORE in this target_include_directories? Is it because cmake will not be able to find the correct header files without prioritizing the ck include dirs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed BEFORE in 259645c, compilation still seems to work fine.
| target_include_directories(transformer_engine PUBLIC | ||
| "${CMAKE_CURRENT_SOURCE_DIR}/include") | ||
|
|
||
| set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CMAKE_SOURCE_DIR --> CMAKE_CURRENT_SOURCE_DIR? Not sure whether other upstream libs will depend on us but let's make it future proof
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to in CMAKE_CURRENT_SOURCE_DIR in 259645c.
| #include "common/util/cuda_runtime.h" | ||
| #include "common/util/system.h" | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include "cutlass_grouped_gemm.cuh" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NV upstream made another .cu file for their cutlass_grouped_gemm and compiled it separately. Maybe we can follow their structure for better isolation (avoid CK defining some macros contaminating our cublaslt_gemm.cu)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I restructured this to a cpp file and a header file in 259645c.
2095d3f to
ebc005f
Compare
d1ab38e to
0b16287
Compare
Description
See https://github.com/ROCm/frameworks-internal/issues/13792 for context.
Primus-Turbo implementation: https://github.com/AMD-AGI/Primus-Turbo/blob/5bcd13785ef380fec0eec0911b7d6db5e606143e/csrc/kernels/grouped_gemm
TODOs:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: