Skip to content

Conversation

@samremes
Copy link
Contributor

@samremes samremes commented Feb 6, 2026

Proposed changes

Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

Copy link
Contributor

@ThomasNing ThomasNing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @samremes Could I know when will be the estimation we could have the mxfp4 and mxfp8 working?

(std::is_same<T, pk_fp6x16_t>::value && (N == 1)),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The largest granularity is b128, so I don't think we need the N = 32.

merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));

// get B scale for this N-K tile using get_y_sliced_thread_data
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the packed data of scale, it is better that we could use the load_tile_offset + get_thread_buffer() method

// warp GEMM with MX scaling
// Cast e8m0_t to int32_t, use OpSel=0 (least significant byte)
constexpr index_t kOpSel = 0; // Always use OpSel=0
WarpGemm{}.template operator()<kOpSel, kOpSel>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could load 32 bits together and select based on the iterations.

});
}

// C += A * B with MX scaling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could migrate that to the gemm_mx block folder.

// C Distributed tensor: register
// MX scaling support with OpSel
template <typename Problem>
struct BaseMXGemmPipelineAgBgCrCompAsync
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need, we could first limit it without async feature :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants