-
Notifications
You must be signed in to change notification settings - Fork 270
Update/add to qr_ks_vs_whole_k_prefetch pipeline #3485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
5fada1c
98f9b4a
c32949b
25521a7
8b85919
5722f8a
044f554
2ea8d83
12c8873
c3d3487
409ec3b
370d386
d281c51
384f470
eb598a9
57abd10
3f6d26e
1ef76a6
db5c12d
57cf989
b77fdbf
e7e6ebc
6c91b0c
489e255
f5b4d5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,52 @@ | |
|
|
||
| namespace ck_tile { | ||
|
|
||
| namespace detail { | ||
|
|
||
| template <typename DataType, index_t ElemPerThread> | ||
| CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize() | ||
| { | ||
| if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>) | ||
| { | ||
| // ToDo: need support in ck_tile for using buffer_load_dwordx3 | ||
| // if constexpr(ElemPerThread % 6 == 0) | ||
| // return 6; | ||
| if constexpr(ElemPerThread % 8 == 0) | ||
| return 8; | ||
| else if constexpr(ElemPerThread % 4 == 0) | ||
| return 4; | ||
| else if constexpr(ElemPerThread % 2 == 0) | ||
| return 2; | ||
| return 1; | ||
| } | ||
| else if constexpr(std::is_same_v<DataType, float>) | ||
| { | ||
| // ToDo: need support in ck_tile for using buffer_load_dwordx3 | ||
| // if constexpr(ElemPerThread % 3 == 0) | ||
| // return 3; | ||
| if constexpr(ElemPerThread % 4 == 0) | ||
| return 4; | ||
| else if constexpr(ElemPerThread % 2 == 0) | ||
| return 2; | ||
| return 1; | ||
| } | ||
| else | ||
| static_assert(false, "The data type is not supported!"); | ||
|
Comment on lines
+40
to
+41
|
||
| }; | ||
|
|
||
| template <typename DataType, | ||
| index_t kThreadBlockSize, | ||
| index_t kHigherDimSize, | ||
| index_t kLowerDimSize> | ||
| CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize() | ||
| { | ||
| constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize; | ||
|
|
||
| return GetMaxVectorSize<DataType, ElemPerThread>(); | ||
| } | ||
|
|
||
| }; // namespace detail | ||
|
|
||
| template <typename QDataType_, | ||
| typename KDataType_, | ||
| typename VDataType_, | ||
|
|
@@ -62,6 +108,33 @@ struct BlockFmhaPipelineProblem | |
| static constexpr bool kHasDropout = Traits::kHasDropout; | ||
| static constexpr auto QScaleEnum = Traits::QScaleEnum; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
|
|
||
| CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize() | ||
| { | ||
| constexpr index_t kMPerBlock = BlockFmhaShape::kM0; | ||
| constexpr index_t kKPerBlock = BlockFmhaShape::kQKHeaddim; | ||
|
|
||
| return detail:: | ||
| GetDramTileAccessMaxVectorSize<QDataType, kBlockSize, kMPerBlock, kKPerBlock>(); | ||
| } | ||
|
|
||
| CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize() | ||
| { | ||
| constexpr index_t kNPerBlock = BlockFmhaShape::kN0; | ||
| constexpr index_t kKPerBlock = BlockFmhaShape::kK0; | ||
|
|
||
| return detail:: | ||
| GetDramTileAccessMaxVectorSize<KDataType, kBlockSize, kNPerBlock, kKPerBlock>(); | ||
| } | ||
|
|
||
| CK_TILE_HOST_DEVICE static constexpr auto GetVDramTileAccessMaxVectorSize() | ||
| { | ||
| constexpr index_t kNPerBlock = BlockFmhaShape::kN1; | ||
| constexpr index_t kKPerBlock = BlockFmhaShape::kK1; | ||
|
|
||
| return detail:: | ||
| GetDramTileAccessMaxVectorSize<VDataType, kBlockSize, kNPerBlock, kKPerBlock>(); | ||
| }; | ||
| }; | ||
|
|
||
| template <typename QDataType_, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'detechting' to 'detecting'.