-
Notifications
You must be signed in to change notification settings - Fork 270
[CK_TILE] MX GEMM, non-preshuffled and RCR layout #3709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop_deprecated
Are you sure you want to change the base?
Conversation
…eline from flatmm
ThomasNing
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @samremes Could I know when will be the estimation we could have the mxfp4 and mxfp8 working?
| (std::is_same<T, pk_fp6x16_t>::value && (N == 1)), | ||
| (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) || | ||
| (std::is_same<T, pk_fp4_t>::value && | ||
| (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The largest granularity is b128, so I don't think we need the N = 32.
| merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); | ||
|
|
||
| // get B scale for this N-K tile using get_y_sliced_thread_data | ||
| auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the packed data of scale, it is better that we could use the load_tile_offset + get_thread_buffer() method
| // warp GEMM with MX scaling | ||
| // Cast e8m0_t to int32_t, use OpSel=0 (least significant byte) | ||
| constexpr index_t kOpSel = 0; // Always use OpSel=0 | ||
| WarpGemm{}.template operator()<kOpSel, kOpSel>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could load 32 bits together and select based on the iterations.
| }); | ||
| } | ||
|
|
||
| // C += A * B with MX scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could migrate that to the gemm_mx block folder.
| // C Distributed tensor: register | ||
| // MX scaling support with OpSel | ||
| template <typename Problem> | ||
| struct BaseMXGemmPipelineAgBgCrCompAsync |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we need, we could first limit it without async feature :)
Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered