diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 801b52712..cd0e124e8 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -29,6 +29,7 @@ std::vector> test_case_sizes = { }; std::vector> test_case_sizes_mxfp8 = { + {32, 128, 16}, {768, 3072, 4096}, }; @@ -345,8 +346,11 @@ void performTest(const TestParams& params) { if (!has_fp8) { GTEST_SKIP() << "MXFP8 scaling mode requires Float8 types"; } - if (params.m % 32 != 0 || params.n % 32 != 0 || params.k % 32 != 0) { - GTEST_SKIP() << "MXFP8 requires M, N, K to be multiples of 32"; + if (params.m % 16 || params.n % 16) { + GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; + } + if (params.k % 128) { + GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; } } @@ -560,8 +564,11 @@ void performDqTest(const TestParams ¶ms) { GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected"; GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected"; - if (params.m % 32 != 0 || params.n % 32 != 0 || params.k % 32 != 0) { - GTEST_SKIP() << "MXFP8 requires M, N, K to be multiples of 32"; + if (params.m % 16 || params.n % 16) { + GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; + } + if (params.k % 128) { + GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; } cudaDeviceProp prop; diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 7a89148fd..3ddd9047d 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -147,7 +147,11 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; +#ifdef __HIP_PLATFORM_AMD__ + auto block_alignment = std::vector{1ul, 1ul}; +#else auto block_alignment = std::vector{128ul, 4ul}; +#endif { auto alignment = block_alignment[0]; auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; @@ -181,12 +185,20 @@ std::pair get_scales(const NVTEShape& shape, { auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = DIVUP(last_dim, static_cast(128)); +#else auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; +#endif ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = DIVUP(first_dim, static_cast(128)); +#else auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; +#endif ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -207,12 +219,20 @@ std::pair get_scales(const NVTEShape& shape, { auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = first_dim; +#else auto scale_dim_1 = DIVUP(first_dim, 4) * 4; +#endif ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = last_dim; +#else auto scale_dim_1 = DIVUP(last_dim, 4) * 4; +#endif ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b824f8d4d..9e3c0f2a4 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -330,10 +330,17 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t scale_tensor_alignment_X_rowwise = 1; +constexpr size_t scale_tensor_alignment_Y_rowwise = 1; +constexpr size_t scale_tensor_alignment_X_colwise = 1; +constexpr size_t scale_tensor_alignment_Y_colwise = 1; +#else constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +#endif inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index 1ce7d3e42..9ffbf9452 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,6 +10,7 @@ from typing import Optional, Protocol, Tuple from references.quantize_scale_calc import scale_from_amax_tensor +from torch.utils.cpp_extension import IS_HIP_EXTENSION @dataclasses.dataclass() class QuantizeResult: @@ -36,6 +39,8 @@ def munge_scale_shapes_for_backend( def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: if transpose: s = s.transpose(-1, -2).contiguous() + if IS_HIP_EXTENSION: # HIP does not use scale padding + return s M, K = s.shape if K % 4 == 0: return s diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index a7d762c3d..0c6b65329 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -41,7 +41,7 @@ Float8Quantizer, Float8Tensor, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint from utils import ModelConfig diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ce510334b..0015e9155 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -715,16 +715,22 @@ template <> struct is_fp4 : std::true_type {}; #endif +#ifndef __HIP_PLATFORM_AMD__ // [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; -#ifndef __HIP_PLATFORM_AMD__ // Alignment requirements for the Tensor Memory Accelerator (TMA) constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment +#else +// HIP does not use scale padding +constexpr size_t scale_tensor_alignment_X_rowwise = 1; +constexpr size_t scale_tensor_alignment_Y_rowwise = 1; +constexpr size_t scale_tensor_alignment_X_colwise = 1; +constexpr size_t scale_tensor_alignment_Y_colwise = 1; #endif inline bool is_aligned_ptr(const void *ptr, size_t alignment) { diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 68d1f0ec5..78af6061f 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -98,8 +98,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { } else { if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { +#ifndef __HIP_PLATFORM_AMD__ // Need (4, 128) alignment even for e8 scaling factor auto block_alignment = std::vector{128ul, 4ul}; +#else + // HIP does not use scale padding + auto block_alignment = std::vector{1ul, 1ul}; +#endif size_t expected_x, expected_y, alignment; const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; const size_t block_size_colwise = 32; diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index e39e0a4a7..eb0c9b94d 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -233,7 +233,8 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const e8m0_t biased_exponent = ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); // Only single thread writes the computed scaling factor - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const bool col_out_of_bounds = dbias_rowwise_offset_X >= cols; + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && !(row_out_of_bounds || col_out_of_bounds)) { const int global_scales_offset_Y = iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; const int global_scales_offset_X = @@ -297,7 +298,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + const bool row_out_of_bounds = row_base >= rows; + if (!(row_out_of_bounds || col_out_of_bounds)) { + scales_colwise[scale_idx] = biased_exponent; + } const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index af7f54feb..c71bb1306 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -83,7 +85,11 @@ constexpr struct BlockSize { constexpr struct Alignment { size_t x; size_t y; +#ifndef __HIP_PLATFORM_AMD__ } MXFP8_ALIGNMENT{128, 4}; +#else +} MXFP8_ALIGNMENT{1, 1}; +#endif std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index e81a614f0..deb2320eb 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -23,6 +25,7 @@ from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout from .device_utils import is_fp8_gemm_with_all_layouts_supported +from ..util import is_hip_extension __all__ = [ @@ -366,7 +369,11 @@ def __init__(self, block_dims: Tuple[int]): block_dims: Dimensions of the scaling blocks """ self._block_dims = block_dims - self._block_alignment = (128, 4) + if is_hip_extension(): + self._block_alignment = (1, 1) + else: + self._block_alignment = (128, 4) + def get_scale_dtype(self) -> jnp.dtype: """Get the data type for scale tensors in block scaling. diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 37c13362c..9935f3c90 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -842,13 +842,21 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector Float8BlockQuantizer::get_scale_shape(const std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); - +#ifdef __HIP_PLATFORM_AMD__ + return !columnwise + ? std::vector{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE} + : std::vector{numel / last_dim / MXFP8_BLOCK_SIZE, last_dim}; +#else std::vector scale_shape; bool rowwise_usage = !columnwise; - if (rowwise_usage) { // rowwise scaling factor shape size_t sinv0 = roundup(numel / last_dim, 128); @@ -1124,6 +1143,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s scale_shape = {sinv0, sinv1}; } return scale_shape; +#endif } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b49e38544..265109d6e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -602,7 +602,7 @@ def fill_userbuffers_buffer_for_all_gather( comm.copy_into_buffer(local_data, local_chunk=True) # Gather scaling-inverses - if math.prod(local_shape[:-1]) % 128 != 0: + if math.prod(local_shape[:-1]) % 128 != 0 and not IS_HIP_EXTENSION: raise ValueError( "Userbuffers requires MXFP8 tensor dims that are divisible by 128, " f"but got MXFP8 tensor with shape={tuple(local_shape)}" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 0e41fc9c5..e51da3223 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -17,6 +19,8 @@ from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple +from torch.utils.cpp_extension import IS_HIP_EXTENSION + aten = torch.ops.aten @@ -137,11 +141,17 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, if self.block_scaling_dim == 2: if columnwise: outer = math.ceil(K / self.block_len) - inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) + if IS_HIP_EXTENSION: + inner = math.ceil(M / self.block_len) + else: + inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) return (outer, inner) # rowwise outer = math.ceil(M / self.block_len) - inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) + if IS_HIP_EXTENSION: + inner = math.ceil(K / self.block_len) + else: + inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) return (outer, inner) # 1D 1x128 quantization block scaling # CuBLAS requries 1x128 scaling factor to be padded and transposed @@ -149,7 +159,7 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, if columnwise: columnwise_compact = self.all_gather_usage outer = math.ceil(M / self.block_len) - inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K + inner = round_up_to_nearest_multiple(K, 4) if not IS_HIP_EXTENSION or not columnwise_compact else K # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS # for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner] # so no need to swap inner outer here @@ -157,7 +167,7 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, # rowwise rowwise_compact = self.all_gather_usage outer = math.ceil(K / self.block_len) - inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M + inner = round_up_to_nearest_multiple(M, 4) if not IS_HIP_EXTENSION or not rowwise_compact else M # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need # for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here return (outer, inner) if not rowwise_compact else (inner, outer) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 16b1568cb..485787bc5 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -113,12 +113,20 @@ def make_empty( # Allocate FP8 data data = torch.empty(shape, dtype=torch.uint8, device=device) # ROCm TE does not implement fuse padding zeros so use zero tensor here - scale_inv = torch.zeros( - round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), - dtype=torch.uint8, - device=device, - ) + if IS_HIP_EXTENSION: + scale_inv = torch.zeros( + math.prod(shape[:-1]), + math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), + dtype=torch.uint8, + device=device, + ) + else: + scale_inv = torch.empty( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + dtype=torch.uint8, + device=device, + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -126,12 +134,20 @@ def make_empty( if self.columnwise_usage: columnwise_data = torch.empty_like(data) # ROCm TE does not implement fuse padding zeros so use zero tensor here - columnwise_scale_inv = torch.zeros( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), - round_up_to_nearest_multiple(shape[-1], 128), - dtype=torch.uint8, - device=device, - ) + if IS_HIP_EXTENSION: + columnwise_scale_inv = torch.zeros( + math.ceil(math.prod(shape[:-1]) / MXFP8_BLOCK_SCALING_SIZE), + shape[-1], + dtype=torch.uint8, + device=device, + ) + else: + columnwise_scale_inv = torch.empty( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + dtype=torch.uint8, + device=device, + ) # Construct FP8 tensor return MXFP8Tensor(