-
Notifications
You must be signed in to change notification settings - Fork 23
MXFP4 Cast Transpose Triton [WIP] #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You
| import numpy as np | ||
| import os | ||
|
|
||
| os.environ["USE_TRITON_FUSED_CAST_TRANSPOSE"] = "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We previously already defined env NVTE_USE_CAST_TRANSPOSE_TRITON.
| def test_quantize_mxfp4(shape, in_dtype, rowwise, columnwise, shuffle_B_matrix): | ||
| """Test MXFP4 quantization for rowwise/columnwise modes with/without FP4 shuffle. | ||
|
|
||
| Note: FP4 data shuffle (shuffle_B_matrix_for_aiter) is not yet supported in Triton kernel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If FP4 data shuffle is not yet supported in Triton kernel, why do we need to add it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kept to ensure API consistency between Triton and the upcoming hip kernel for which I'll create a separate PR. In the hip kernel we were able to fuse the shuffle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hip vs triton flow
Input: BF16 [M, N]
↓
MXFP4Quantizer.update_quantized()
↓
tex.cast_transpose_mxfp4_fused_shuffle() [Single HIP kernel]
↓
├─→ Rowwise FP4 [M, K/2] (MFMA shuffled)
├─→ Rowwise Scale [M_pad, K/32_pad] (shuffled)
├─→ Colwise FP4 [N, M/2] (MFMA shuffled)
└─→ Colwise Scale [N_pad, M/32_pad] (shuffled)
↓
AITER gemm_a4w4 (zero-copy)
vs
Input: BF16 [M, N]
↓
MXFP4Quantizer.update_quantized()
↓
te_cast_transpose_mxfp4_triton() [Triton JIT kernel]
↓
├─→ Rowwise FP4 [M, K/2] (linear layout)
├─→ Rowwise Scale [M_pad, K/32_pad] (shuffled)
├─→ Colwise FP4 [N, M/2] (linear layout)
└─→ Colwise Scale [N_pad, M/32_pad] (shuffled)
↓
aiter.ops.shuffle.shuffle_weight() [External call]
↓
FP4 data → MFMA layout
↓
AITER gemm_a4w4
| (32768, 160), | ||
| (4096, 1632), | ||
| (8, 32, 1024), | ||
| (16, 8, 4, 512), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some prime numbers like
TransformerEngine/tests/cpp/operator/test_cast_transpose.cu
Lines 90 to 92 in 9d6b0e5
| {1, 3221}, // Prime 456 | |
| {2333, 1}, // Prime 345 | |
| {1481, 677}}; // Primes 234, 123 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MXFP4 requires dimensions divisible by 32 for per-block scaling compatibility with AITER gemm_a4w4. I have added the shapes that should throw a valid and expected assertion error.
| data_atol = 20.0 if in_dtype != torch.float32 else 16.0 | ||
| scale_atol = 2.0 if in_dtype != torch.float32 else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Data tol seems to be quite large. You can follow our mxfp8 scale and data adjustment scheme:
TransformerEngine/tests/cpp/test_common.cu
Line 730 in 9d6b0e5
| void adjust_ref_for_e8m0_scale_error(const std::string &name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wangye805 closely followed the example and updated the pytest.
| use_torch_semantics=True | ||
| ) | ||
|
|
||
| # Compare only valid (non-padded) region - no shuffle extraction needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is fp4 shuffle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp4 shuffle basically rearranges [M, K/2] linear layout → MFMA instruction layout (16×16).
The currently flow training workflow if TE MXFP4 Quantization Kernel is used is as follows
TE Triton Kernel → Linear FP4 [N, K/2] → aiter.ops.shuffle_weight() → MFMA FP4 → aiter.gemm_a4w4()
You can find the shuffle code in aiter/aiter/ops/shuffle.py
| .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ | ||
| .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ | ||
| .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ | ||
| .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are going to enable kFloat4E2M1, there are other related changes needed. Search for https://github.com/search?q=repo%3AROCm%2FTransformerEngine%20kFloat4E2M1&type=code for more details:
| - Data: [M, K/2] uint8 tensor (2 FP4 values packed per byte) | ||
| - Scale: [M, K/32] uint8 tensor (E8M0 format, one scale per 32-element block) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there alignment/padding requirements for M and K?
| if inp.ndim < 2: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TE currently supported 2D matrices from flatted high-dimensional tensors:
TransformerEngine/transformer_engine/common/common.h
Lines 238 to 262 in 9d6b0e5
| size_t flat_first_dim() const { | |
| const auto &full_shape = shape(); | |
| size_t ret = 1; | |
| if (!full_shape.empty()) { | |
| for (size_t i = 0; i < full_shape.size() - 1; i++) { | |
| ret *= full_shape[i]; | |
| } | |
| } | |
| return ret; | |
| } | |
| /*! Matrix width after tensor is flattened to 2D | |
| * | |
| * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted | |
| * as a (D1*D2*...*D(n-1), Dn) matrix. | |
| */ | |
| size_t flat_last_dim() const { | |
| const auto &full_shape = shape(); | |
| if (full_shape.empty()) { | |
| return 1; | |
| } else { | |
| return full_shape.back(); | |
| } | |
| } | |
| }; |
|
|
||
| # Allocate PADDED scale tensors for shuffle compatibility | ||
| rowwise_scale_N = K // MXFP4_BLOCK_SCALING_SIZE | ||
| rowwise_scale_M_pad = cdiv(M, 256) * 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I presume this 256 is from some alignment/padding requirement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 256 alignment is required by AITER's CK-based MXFP4 GEMM kernels for scale tensor swizzle/shuffle layout.
Required for scale swizzle layout: 256 = ScaleBlockSize(32) × 8 waves.
See aiter/aiter/utility/fp4_utils.py:398 and gemm_a4w4_blockscale_common.cuh:66
| @@ -0,0 +1,178 @@ | |||
| # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will need to add this pytest into our ci script (somewhere near
TransformerEngine/ci/pytorch.sh
Line 74 in 9d6b0e5
| run_default_fa 1 triton_kernels/test_norms.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Description
Implements the MXFP4
rowwiseandcolumnwiseFP32/BF16 -> MXFP4 fused quantization + cast kernelVerify Tolerances and functional Unit Tests
The triton
te_cast_transpose_mxfp4_tritoncurrently outputs FP4 data in linear layout [M, N/2] with contiguous byte packing. AITER'sgemm_a4w4requires the B matrix in MFMA shuffle layout for tensor cores. This layout shuffle can be fused into the triton kernel in future.