From 1a24ff22254711b9fc59b62e9c68d82579b3dee8 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 3 Feb 2026 10:18:31 -0600 Subject: [PATCH 01/10] Remove padding from scales for hipBLASlt calls --- tests/cpp/operator/test_cublaslt_gemm.cu | 16 ++++++-- tests/cpp/test_common.cu | 20 ++++++++++ tests/cpp/test_common.h | 7 ++++ .../blockwise_quantizer_reference.py | 5 +++ .../comm_gemm_overlap/comm_gemm_overlap.cpp | 2 + transformer_engine/common/common.h | 8 +++- .../common/transformer_engine.cpp | 5 +++ .../common/util/cast_kernels.cuh | 1 - .../common/util/rocm_cast_kernels.cuh | 8 +++- transformer_engine/jax/csrc/extensions/misc.h | 4 ++ .../jax/quantize/scaling_modes.py | 9 ++++- transformer_engine/pytorch/csrc/quantizer.cpp | 24 ++++++++++- transformer_engine/pytorch/module/base.py | 2 +- .../pytorch/tensor/float8_blockwise_tensor.py | 18 +++++++-- .../pytorch/tensor/mxfp8_tensor.py | 40 +++++++++++++------ 15 files changed, 141 insertions(+), 28 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 801b52712..2fa882bf4 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -29,6 +29,8 @@ std::vector> test_case_sizes = { }; std::vector> test_case_sizes_mxfp8 = { + {16, 128, 16}, + {32, 128, 32}, {768, 3072, 4096}, }; @@ -345,8 +347,11 @@ void performTest(const TestParams& params) { if (!has_fp8) { GTEST_SKIP() << "MXFP8 scaling mode requires Float8 types"; } - if (params.m % 32 != 0 || params.n % 32 != 0 || params.k % 32 != 0) { - GTEST_SKIP() << "MXFP8 requires M, N, K to be multiples of 32"; + if (params.m % 16 || params.n % 16) { + GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; + } + if (params.k % 128) { + GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; } } @@ -560,8 +565,11 @@ void performDqTest(const TestParams ¶ms) { GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected"; GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected"; - if (params.m % 32 != 0 || params.n % 32 != 0 || params.k % 32 != 0) { - GTEST_SKIP() << "MXFP8 requires M, N, K to be multiples of 32"; + if (params.m % 16 || params.n % 16) { + GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; + } + if (params.k % 128) { + GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; } cudaDeviceProp prop; diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 7a89148fd..9d023cdce 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -147,7 +147,11 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; +#ifndef __HIP_PLATFORM_AMD__ auto block_alignment = std::vector{128ul, 4ul}; +#else + auto block_alignment = std::vector{1ul, 1ul}; +#endif { auto alignment = block_alignment[0]; auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; @@ -181,12 +185,20 @@ std::pair get_scales(const NVTEShape& shape, { auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = DIVUP(last_dim, static_cast(128)); +#else auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; +#endif ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = DIVUP(first_dim, static_cast(128)); +#else auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; +#endif ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -207,12 +219,20 @@ std::pair get_scales(const NVTEShape& shape, { auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = first_dim; +#else auto scale_dim_1 = DIVUP(first_dim, 4) * 4; +#endif ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); +#ifdef __HIP_PLATFORM_AMD__ + auto scale_dim_1 = last_dim; +#else auto scale_dim_1 = DIVUP(last_dim, 4) * 4; +#endif ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b824f8d4d..9e3c0f2a4 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -330,10 +330,17 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t scale_tensor_alignment_X_rowwise = 1; +constexpr size_t scale_tensor_alignment_Y_rowwise = 1; +constexpr size_t scale_tensor_alignment_X_colwise = 1; +constexpr size_t scale_tensor_alignment_Y_colwise = 1; +#else constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +#endif inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index 1ce7d3e42..9ffbf9452 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,6 +10,7 @@ from typing import Optional, Protocol, Tuple from references.quantize_scale_calc import scale_from_amax_tensor +from torch.utils.cpp_extension import IS_HIP_EXTENSION @dataclasses.dataclass() class QuantizeResult: @@ -36,6 +39,8 @@ def munge_scale_shapes_for_backend( def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: if transpose: s = s.transpose(-1, -2).contiguous() + if IS_HIP_EXTENSION: # HIP does not use scale padding + return s M, K = s.shape if K % 4 == 0: return s diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index ec29e6e12..b8167fad5 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -189,6 +189,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk"); NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width, "Attempted to get out-of-bounds tensor chunk"); +#ifndef __HIP_PLATFORM_AMD__ if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { // MXFP8 scale-inverses are padded to a 2D matrix with dims that // are divisible by 128. UB doesn't handle this padding yet. @@ -197,6 +198,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0, "Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128"); } +#endif #undef NVTE_DIM_CHECK // Construct tensor chunk diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ce510334b..5c646e45d 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -715,16 +715,22 @@ template <> struct is_fp4 : std::true_type {}; #endif +#ifndef __HIP_PLATFORM_AMD__ // [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; -#ifndef __HIP_PLATFORM_AMD__ // Alignment requirements for the Tensor Memory Accelerator (TMA) constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment +#else +// HIP does not use scale padding +constexpr size_t scale_tensor_alignment_X_rowwise = 1; +constexpr size_t scale_tensor_alignment_Y_rowwise = 1; +constexpr size_t scale_tensor_alignment_X_colwise = 1; +constexpr size_t scale_tensor_alignment_Y_colwise = 1; #endif inline bool is_aligned_ptr(const void *ptr, size_t alignment) { diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 68d1f0ec5..f6fbec2c7 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -98,8 +98,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { } else { if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { +#ifndef __HIP_PLATFORM_AMD__ // Need (4, 128) alignment even for e8 scaling factor auto block_alignment = std::vector{128ul, 4ul}; +#else + // HIP does not use scale padding + auto block_alignment = std::vector{1ul, 1ul}; +#endif size_t expected_x, expected_y, alignment; const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; const size_t block_size_colwise = 32; diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index b7c4cf837..37d90ae3b 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1061,7 +1061,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - e8m0_t *const scales_rowwise_ptr = use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; e8m0_t *const scales_colwise_ptr = diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index e39e0a4a7..eb0c9b94d 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -233,7 +233,8 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const e8m0_t biased_exponent = ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); // Only single thread writes the computed scaling factor - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const bool col_out_of_bounds = dbias_rowwise_offset_X >= cols; + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && !(row_out_of_bounds || col_out_of_bounds)) { const int global_scales_offset_Y = iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; const int global_scales_offset_X = @@ -297,7 +298,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + const bool row_out_of_bounds = row_base >= rows; + if (!(row_out_of_bounds || col_out_of_bounds)) { + scales_colwise[scale_idx] = biased_exponent; + } const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index af7f54feb..596312582 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -83,7 +83,11 @@ constexpr struct BlockSize { constexpr struct Alignment { size_t x; size_t y; +#ifndef __HIP_PLATFORM_AMD__ } MXFP8_ALIGNMENT{128, 4}; +#else +} MXFP8_ALIGNMENT{1, 1}; +#endif std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index e81a614f0..deb2320eb 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -23,6 +25,7 @@ from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout from .device_utils import is_fp8_gemm_with_all_layouts_supported +from ..util import is_hip_extension __all__ = [ @@ -366,7 +369,11 @@ def __init__(self, block_dims: Tuple[int]): block_dims: Dimensions of the scaling blocks """ self._block_dims = block_dims - self._block_alignment = (128, 4) + if is_hip_extension(): + self._block_alignment = (1, 1) + else: + self._block_alignment = (128, 4) + def get_scale_dtype(self) -> jnp.dtype: """Get the data type for scale tensors in block scaling. diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 37c13362c..9935f3c90 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -842,13 +842,21 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector Float8BlockQuantizer::get_scale_shape(const std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); - +#ifdef __HIP_PLATFORM_AMD__ + return !columnwise + ? std::vector{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE} + : std::vector{numel / last_dim / MXFP8_BLOCK_SIZE, last_dim}; +#else std::vector scale_shape; bool rowwise_usage = !columnwise; - if (rowwise_usage) { // rowwise scaling factor shape size_t sinv0 = roundup(numel / last_dim, 128); @@ -1124,6 +1143,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s scale_shape = {sinv0, sinv1}; } return scale_shape; +#endif } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b49e38544..add07bb07 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -602,7 +602,7 @@ def fill_userbuffers_buffer_for_all_gather( comm.copy_into_buffer(local_data, local_chunk=True) # Gather scaling-inverses - if math.prod(local_shape[:-1]) % 128 != 0: + if math.prod(local_shape[:-1]) % 128 != 0 and not IS_HIP_EXTENSION: raise ValueError( "Userbuffers requires MXFP8 tensor dims that are divisible by 128, " f"but got MXFP8 tensor with shape={tuple(local_shape)}" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 0e41fc9c5..e51da3223 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -17,6 +19,8 @@ from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple +from torch.utils.cpp_extension import IS_HIP_EXTENSION + aten = torch.ops.aten @@ -137,11 +141,17 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, if self.block_scaling_dim == 2: if columnwise: outer = math.ceil(K / self.block_len) - inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) + if IS_HIP_EXTENSION: + inner = math.ceil(M / self.block_len) + else: + inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) return (outer, inner) # rowwise outer = math.ceil(M / self.block_len) - inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) + if IS_HIP_EXTENSION: + inner = math.ceil(K / self.block_len) + else: + inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) return (outer, inner) # 1D 1x128 quantization block scaling # CuBLAS requries 1x128 scaling factor to be padded and transposed @@ -149,7 +159,7 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, if columnwise: columnwise_compact = self.all_gather_usage outer = math.ceil(M / self.block_len) - inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K + inner = round_up_to_nearest_multiple(K, 4) if not IS_HIP_EXTENSION or not columnwise_compact else K # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS # for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner] # so no need to swap inner outer here @@ -157,7 +167,7 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, # rowwise rowwise_compact = self.all_gather_usage outer = math.ceil(K / self.block_len) - inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M + inner = round_up_to_nearest_multiple(M, 4) if not IS_HIP_EXTENSION or not rowwise_compact else M # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need # for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here return (outer, inner) if not rowwise_compact else (inner, outer) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 16b1568cb..485787bc5 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -113,12 +113,20 @@ def make_empty( # Allocate FP8 data data = torch.empty(shape, dtype=torch.uint8, device=device) # ROCm TE does not implement fuse padding zeros so use zero tensor here - scale_inv = torch.zeros( - round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), - dtype=torch.uint8, - device=device, - ) + if IS_HIP_EXTENSION: + scale_inv = torch.zeros( + math.prod(shape[:-1]), + math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), + dtype=torch.uint8, + device=device, + ) + else: + scale_inv = torch.empty( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + dtype=torch.uint8, + device=device, + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -126,12 +134,20 @@ def make_empty( if self.columnwise_usage: columnwise_data = torch.empty_like(data) # ROCm TE does not implement fuse padding zeros so use zero tensor here - columnwise_scale_inv = torch.zeros( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), - round_up_to_nearest_multiple(shape[-1], 128), - dtype=torch.uint8, - device=device, - ) + if IS_HIP_EXTENSION: + columnwise_scale_inv = torch.zeros( + math.ceil(math.prod(shape[:-1]) / MXFP8_BLOCK_SCALING_SIZE), + shape[-1], + dtype=torch.uint8, + device=device, + ) + else: + columnwise_scale_inv = torch.empty( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + dtype=torch.uint8, + device=device, + ) # Construct FP8 tensor return MXFP8Tensor( From 5bbfb4b99e1f441f527218b0f10775ff14e02c5a Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 3 Feb 2026 20:44:02 -0600 Subject: [PATCH 02/10] Unpadding for checkpoints --- transformer_engine/common/gemm/rocm_gemm.cu | 58 ++++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index c2c4c502a..51a011411 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -197,9 +197,59 @@ struct GemmParam { int ldb = 0; // B column strides }; +__global__ void unpad_mxfp8_scales_kernel(float* data, const size_t rows, const size_t cols, + const size_t padded_stride) { + int total_elements = rows * cols; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid < total_elements) { + int r = tid / cols; + int c = tid % cols; + + int src_idx = r * padded_stride + c; + + float val = data[src_idx]; + __threadfence(); + data[tid] = val; + } +} + +// Checks if scales have been padded and unpads. +// Necessary for NV checkpoint load with extra cuBLASlt padding +void unpad_mxfp8_checkpoint(const transformer_engine::Tensor &M_const, bool transM, + const int m, const int n, const int k, hipStream_t stream) { + + auto &M_tensor = const_cast(M_const); + auto &scale_tensor = transM ? M_tensor.scale_inv : M_tensor.columnwise_scale_inv; + + size_t unpadded_rows = transM ? k : m; + size_t unpadded_cols = ((transM ? m : k) + 31) / 32; // Change if MXFP8 block size changes + size_t padded_rows = scale_tensor.shape[0]; + size_t padded_cols = scale_tensor.shape.size() > 1 ? scale_tensor.shape[1] : 1; + + bool is_padded = ( padded_rows == (unpadded_rows + 127) / 128*128 && + padded_cols == (unpadded_cols + 3) / 4*4 && + (padded_rows > unpadded_rows || padded_cols > unpadded_cols)); + + if (is_padded) { + float *scale_dptr = (float*)(scale_tensor.dptr); + + const size_t total_elements = unpadded_rows * unpadded_cols; + const size_t threads = 256; + const size_t blocks = (total_elements + threads - 1) / threads; + + unpad_mxfp8_scales_kernel<<>> + (scale_dptr, unpadded_rows, unpadded_cols, padded_cols); + + NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); + + scale_tensor.shape = {unpadded_rows, unpadded_cols}; + } +} + GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - const int m, const int n, const int k) { + const int m, const int n, const int k, hipStream_t stream) { using namespace transformer_engine; NVTE_CHECK(A.scaling_mode == B.scaling_mode, "Inputs A and B to GEMM need to have the same scaling mode!"); @@ -235,6 +285,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). + + unpad_mxfp8_checkpoint(A, is_A_transposed, m, n, k, stream); + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { @@ -273,6 +326,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). + unpad_mxfp8_checkpoint(B, is_B_transposed, m, n, k, stream); if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { @@ -951,7 +1005,7 @@ void hipblaslt_gemm(const Tensor *inputA, } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k, stream); bool nvte_log_gemm_config = false; if (const char* env_p = std::getenv("NVTE_LOG_GEMM_CONFIG") ) { From cbbd027ce40e8415e1897117fab675511b2ef176 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 3 Feb 2026 20:49:53 -0600 Subject: [PATCH 03/10] copyrights --- transformer_engine/common/common.h | 2 +- transformer_engine/common/gemm/rocm_gemm.cu | 2 +- transformer_engine/common/transformer_engine.cpp | 2 +- transformer_engine/common/util/cast_kernels.cuh | 1 + transformer_engine/jax/csrc/extensions/misc.h | 2 ++ transformer_engine/pytorch/module/base.py | 2 +- 6 files changed, 7 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 5c646e45d..0015e9155 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 51a011411..d9dc5f755 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index f6fbec2c7..78af6061f 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 37d90ae3b..b7c4cf837 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1061,6 +1061,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + e8m0_t *const scales_rowwise_ptr = use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; e8m0_t *const scales_colwise_ptr = diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 596312582..c71bb1306 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index add07bb07..265109d6e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From 9f8f6114cde9c803011a293da4402f0e650a8043 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Fri, 6 Feb 2026 17:37:11 -0600 Subject: [PATCH 04/10] Updated unpadding func --- tests/pytorch/test_sanity.py | 74 ++++++++++++++++++- transformer_engine/common/gemm/rocm_gemm.cu | 49 ------------ .../pytorch/cpp_extensions/gemm.py | 48 ++++++++++++ 3 files changed, 121 insertions(+), 50 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index a7d762c3d..3d28083a2 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -41,7 +41,7 @@ Float8Quantizer, Float8Tensor, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint from utils import ModelConfig @@ -913,6 +913,78 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ) torch.cuda.synchronize() +@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) +@pytest.mark.parametrize("N", [32]) +@pytest.mark.parametrize("K", [128]) +@pytest.mark.parametrize("M", [32]) +@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) +def test_sanity_mxfp8_gemm_with_padding(N, K, M, datatype): + """Test the unpadding functionality in rocm""" + dtype = tex.DType.kFloat8E4M3 + quantizer = MXFP8Quantizer(dtype) + + input_dtype = torch.randn(M, K, device="cuda", dtype=datatype) + weight_dtype = torch.randn(N, K, device="cuda", dtype=datatype) + + input_data = quantizer.make_empty((M, K), device="cuda") + weight_data = quantizer.make_empty((N, K), device="cuda") + + quantizer.update_quantized(input_dtype, input_data) + quantizer.update_quantized(weight_dtype, weight_data) + + out_ref = general_gemm( + weight_data, + input_data, + get_workspace(), + datatype, + bias=None, + use_split_accumulator=False, + ) + torch.cuda.synchronize() + + row_scale_inv = input_data._rowwise_scale_inv + rows, cols = row_scale_inv.shape + row_padded_scale_inv = torch.zeros((128, 4), dtype=row_scale_inv.dtype, device="cuda") + row_padded_scale_inv[:rows, :cols] = row_scale_inv + + col_scale_inv = input_data._columnwise_scale_inv + rows, cols = col_scale_inv.shape + col_padded_scale_inv = torch.zeros((4, 128), dtype=col_scale_inv.dtype, device="cuda") + col_padded_scale_inv[:rows, :cols] = col_scale_inv + + + input_padded = MXFP8Tensor( + shape=input_data.shape, + rowwise_data=input_data._rowwise_data.clone(), + rowwise_scale_inv=row_padded_scale_inv, + columnwise_data=input_data._columnwise_data.clone(), + columnwise_scale_inv=col_padded_scale_inv, + fp8_dtype=tex.DType.kFloat8E4M3, + quantizer=quantizer, + dtype=datatype + ) + + out_pass1 = general_gemm( + weight_data, + input_padded, + get_workspace(), + datatype, + bias=None, + use_split_accumulator=False + ) + torch.cuda.synchronize() + + assert row_scale_inv.shape == input_padded._rowwise_scale_inv.shape, \ + ("Shape mismatch in rowwise scales") + assert col_scale_inv.shape == input_padded._columnwise_scale_inv.shape, \ + ("Shape mismatch in colwise scales") + torch.testing.assert_close(row_scale_inv, input_padded._rowwise_scale_inv, + rtol=1e-7, atol=1e-7, msg="rowwise scale mismatch") + torch.testing.assert_close(col_scale_inv, input_padded._columnwise_scale_inv, + rtol=1e-7, atol=1e-7, msg="colwise scale mismatch") + torch.testing.assert_close(out_pass1[0], out_ref[0], + rtol=1e-2, atol=1e-2, msg="GEMM output mismatch") + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_replace_raw_data_for_float8tensor(): diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index d9dc5f755..91f2b2733 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -197,55 +197,6 @@ struct GemmParam { int ldb = 0; // B column strides }; -__global__ void unpad_mxfp8_scales_kernel(float* data, const size_t rows, const size_t cols, - const size_t padded_stride) { - int total_elements = rows * cols; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - if (tid < total_elements) { - int r = tid / cols; - int c = tid % cols; - - int src_idx = r * padded_stride + c; - - float val = data[src_idx]; - __threadfence(); - data[tid] = val; - } -} - -// Checks if scales have been padded and unpads. -// Necessary for NV checkpoint load with extra cuBLASlt padding -void unpad_mxfp8_checkpoint(const transformer_engine::Tensor &M_const, bool transM, - const int m, const int n, const int k, hipStream_t stream) { - - auto &M_tensor = const_cast(M_const); - auto &scale_tensor = transM ? M_tensor.scale_inv : M_tensor.columnwise_scale_inv; - - size_t unpadded_rows = transM ? k : m; - size_t unpadded_cols = ((transM ? m : k) + 31) / 32; // Change if MXFP8 block size changes - size_t padded_rows = scale_tensor.shape[0]; - size_t padded_cols = scale_tensor.shape.size() > 1 ? scale_tensor.shape[1] : 1; - - bool is_padded = ( padded_rows == (unpadded_rows + 127) / 128*128 && - padded_cols == (unpadded_cols + 3) / 4*4 && - (padded_rows > unpadded_rows || padded_cols > unpadded_cols)); - - if (is_padded) { - float *scale_dptr = (float*)(scale_tensor.dptr); - - const size_t total_elements = unpadded_rows * unpadded_cols; - const size_t threads = 256; - const size_t blocks = (total_elements + threads - 1) / threads; - - unpad_mxfp8_scales_kernel<<>> - (scale_dptr, unpadded_rows, unpadded_cols, padded_cols); - - NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); - - scale_tensor.shape = {unpadded_rows, unpadded_cols}; - } -} GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index e4f4e619f..a64ef7106 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -5,13 +7,17 @@ """Python interface for GEMM extensions""" from typing import Iterable, Optional, Tuple, Union, List +import math import os import torch import transformer_engine_torch as tex from ..constants import TE_DType from ..utils import get_sm_count, _empty_tensor +from torch.utils.cpp_extension import IS_HIP_EXTENSION + from ..tensor.quantized_tensor import Quantizer +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -30,6 +36,42 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 +def unpad_scales(tensor: torch.Tensor, transpose: bool) -> torch.Tensor: + """Removes padding from scales in MXFP8 Tensors if present""" + if isinstance(tensor, MXFP8TensorBase): + block_size = 32 + elif isinstance(tensor, Float8BlockwiseQTensorBase): + block_size = 128 + else: + raise ValueError("Only MXFP8 and FP8 Block scaling can be unpadded") + + if tensor._rowwise_scale_inv is not None: + if transpose: + rows, cols = tensor._rowwise_data.shape[1], tensor._rowwise_data.shape[0] + else: + rows, cols = tensor._rowwise_data.shape[0], tensor._rowwise_data.shape[1] + + actual_scale_shape = tensor._rowwise_scale_inv.shape + expected_scale_shape = (rows, math.ceil(cols / block_size)) + + if actual_scale_shape != expected_scale_shape: + tensor._rowwise_scale_inv = tensor._rowwise_scale_inv[:expected_scale_shape[0], :expected_scale_shape[1]] + + if tensor._columnwise_scale_inv is not None: + if transpose: + rows, cols = tensor._columnwise_data.shape[1], tensor._columnwise_data.shape[0] + else: + rows, cols = tensor._columnwise_data.shape[0], tensor._columnwise_data.shape[1] + + actual_scale_shape = tensor._columnwise_scale_inv.shape + expected_scale_shape = (math.ceil(rows / block_size), cols) + + if actual_scale_shape != expected_scale_shape: + tensor._columnwise_scale_inv = tensor._columnwise_scale_inv[:expected_scale_shape[0], :expected_scale_shape[1]] + + return tensor + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -61,6 +103,12 @@ def general_gemm( alpha = validate_gemm_scale(alpha, True) beta = validate_gemm_scale(beta, accumulate) + if IS_HIP_EXTENSION: + if isinstance(A, (MXFP8TensorBase, Float8BlockwiseQTensorBase)): + A = unpad_scales(A, transa) + if isinstance(B, (MXFP8TensorBase, Float8BlockwiseQTensorBase)): + A = unpad_scales(B, transb) + if ub_type is not None: assert ub is not None, ( f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" From 9a5980fa1efc867d8a1b6e4ec39d8715d9966068 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Fri, 6 Feb 2026 17:40:22 -0600 Subject: [PATCH 05/10] rocm_gemm.cu revert --- transformer_engine/common/gemm/rocm_gemm.cu | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 91f2b2733..c2c4c502a 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -197,10 +197,9 @@ struct GemmParam { int ldb = 0; // B column strides }; - GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, - const int m, const int n, const int k, hipStream_t stream) { + const int m, const int n, const int k) { using namespace transformer_engine; NVTE_CHECK(A.scaling_mode == B.scaling_mode, "Inputs A and B to GEMM need to have the same scaling mode!"); @@ -236,9 +235,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - - unpad_mxfp8_checkpoint(A, is_A_transposed, m, n, k, stream); - if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { @@ -277,7 +273,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). - unpad_mxfp8_checkpoint(B, is_B_transposed, m, n, k, stream); if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { @@ -956,7 +951,7 @@ void hipblaslt_gemm(const Tensor *inputA, } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k, stream); + const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); bool nvte_log_gemm_config = false; if (const char* env_p = std::getenv("NVTE_LOG_GEMM_CONFIG") ) { From ff7214232c0d142549d2333422245c9418e226d0 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 10 Feb 2026 10:11:48 -0600 Subject: [PATCH 06/10] minor fixes --- tests/cpp/test_common.cu | 6 +++--- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 9d023cdce..3ddd9047d 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -147,10 +147,10 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; -#ifndef __HIP_PLATFORM_AMD__ - auto block_alignment = std::vector{128ul, 4ul}; -#else +#ifdef __HIP_PLATFORM_AMD__ auto block_alignment = std::vector{1ul, 1ul}; +#else + auto block_alignment = std::vector{128ul, 4ul}; #endif { auto alignment = block_alignment[0]; diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index b8167fad5..ec29e6e12 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -189,7 +189,6 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk"); NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width, "Attempted to get out-of-bounds tensor chunk"); -#ifndef __HIP_PLATFORM_AMD__ if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { // MXFP8 scale-inverses are padded to a 2D matrix with dims that // are divisible by 128. UB doesn't handle this padding yet. @@ -198,7 +197,6 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0, "Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128"); } -#endif #undef NVTE_DIM_CHECK // Construct tensor chunk From b16611da133da075e7b24a2bf178d278f32d1cbb Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 10 Feb 2026 15:35:27 -0600 Subject: [PATCH 07/10] Moved unpadding check to load_from_state_dict --- tests/cpp/operator/test_cublaslt_gemm.cu | 3 +- tests/pytorch/test_sanity.py | 72 ------------------- .../pytorch/cpp_extensions/gemm.py | 44 ------------ transformer_engine/pytorch/fp8.py | 33 ++++++++- transformer_engine/pytorch/module/base.py | 6 ++ transformer_engine/pytorch/ops/op.py | 10 +++ 6 files changed, 48 insertions(+), 120 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 2fa882bf4..cd0e124e8 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -29,8 +29,7 @@ std::vector> test_case_sizes = { }; std::vector> test_case_sizes_mxfp8 = { - {16, 128, 16}, - {32, 128, 32}, + {32, 128, 16}, {768, 3072, 4096}, }; diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 3d28083a2..0c6b65329 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -913,78 +913,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ) torch.cuda.synchronize() -@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) -@pytest.mark.parametrize("N", [32]) -@pytest.mark.parametrize("K", [128]) -@pytest.mark.parametrize("M", [32]) -@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) -def test_sanity_mxfp8_gemm_with_padding(N, K, M, datatype): - """Test the unpadding functionality in rocm""" - dtype = tex.DType.kFloat8E4M3 - quantizer = MXFP8Quantizer(dtype) - - input_dtype = torch.randn(M, K, device="cuda", dtype=datatype) - weight_dtype = torch.randn(N, K, device="cuda", dtype=datatype) - - input_data = quantizer.make_empty((M, K), device="cuda") - weight_data = quantizer.make_empty((N, K), device="cuda") - - quantizer.update_quantized(input_dtype, input_data) - quantizer.update_quantized(weight_dtype, weight_data) - - out_ref = general_gemm( - weight_data, - input_data, - get_workspace(), - datatype, - bias=None, - use_split_accumulator=False, - ) - torch.cuda.synchronize() - - row_scale_inv = input_data._rowwise_scale_inv - rows, cols = row_scale_inv.shape - row_padded_scale_inv = torch.zeros((128, 4), dtype=row_scale_inv.dtype, device="cuda") - row_padded_scale_inv[:rows, :cols] = row_scale_inv - - col_scale_inv = input_data._columnwise_scale_inv - rows, cols = col_scale_inv.shape - col_padded_scale_inv = torch.zeros((4, 128), dtype=col_scale_inv.dtype, device="cuda") - col_padded_scale_inv[:rows, :cols] = col_scale_inv - - - input_padded = MXFP8Tensor( - shape=input_data.shape, - rowwise_data=input_data._rowwise_data.clone(), - rowwise_scale_inv=row_padded_scale_inv, - columnwise_data=input_data._columnwise_data.clone(), - columnwise_scale_inv=col_padded_scale_inv, - fp8_dtype=tex.DType.kFloat8E4M3, - quantizer=quantizer, - dtype=datatype - ) - - out_pass1 = general_gemm( - weight_data, - input_padded, - get_workspace(), - datatype, - bias=None, - use_split_accumulator=False - ) - torch.cuda.synchronize() - - assert row_scale_inv.shape == input_padded._rowwise_scale_inv.shape, \ - ("Shape mismatch in rowwise scales") - assert col_scale_inv.shape == input_padded._columnwise_scale_inv.shape, \ - ("Shape mismatch in colwise scales") - torch.testing.assert_close(row_scale_inv, input_padded._rowwise_scale_inv, - rtol=1e-7, atol=1e-7, msg="rowwise scale mismatch") - torch.testing.assert_close(col_scale_inv, input_padded._columnwise_scale_inv, - rtol=1e-7, atol=1e-7, msg="colwise scale mismatch") - torch.testing.assert_close(out_pass1[0], out_ref[0], - rtol=1e-2, atol=1e-2, msg="GEMM output mismatch") - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_replace_raw_data_for_float8tensor(): diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index a64ef7106..58e60f19f 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -17,7 +15,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION from ..tensor.quantized_tensor import Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -36,42 +33,6 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 -def unpad_scales(tensor: torch.Tensor, transpose: bool) -> torch.Tensor: - """Removes padding from scales in MXFP8 Tensors if present""" - if isinstance(tensor, MXFP8TensorBase): - block_size = 32 - elif isinstance(tensor, Float8BlockwiseQTensorBase): - block_size = 128 - else: - raise ValueError("Only MXFP8 and FP8 Block scaling can be unpadded") - - if tensor._rowwise_scale_inv is not None: - if transpose: - rows, cols = tensor._rowwise_data.shape[1], tensor._rowwise_data.shape[0] - else: - rows, cols = tensor._rowwise_data.shape[0], tensor._rowwise_data.shape[1] - - actual_scale_shape = tensor._rowwise_scale_inv.shape - expected_scale_shape = (rows, math.ceil(cols / block_size)) - - if actual_scale_shape != expected_scale_shape: - tensor._rowwise_scale_inv = tensor._rowwise_scale_inv[:expected_scale_shape[0], :expected_scale_shape[1]] - - if tensor._columnwise_scale_inv is not None: - if transpose: - rows, cols = tensor._columnwise_data.shape[1], tensor._columnwise_data.shape[0] - else: - rows, cols = tensor._columnwise_data.shape[0], tensor._columnwise_data.shape[1] - - actual_scale_shape = tensor._columnwise_scale_inv.shape - expected_scale_shape = (math.ceil(rows / block_size), cols) - - if actual_scale_shape != expected_scale_shape: - tensor._columnwise_scale_inv = tensor._columnwise_scale_inv[:expected_scale_shape[0], :expected_scale_shape[1]] - - return tensor - - def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -103,11 +64,6 @@ def general_gemm( alpha = validate_gemm_scale(alpha, True) beta = validate_gemm_scale(beta, accumulate) - if IS_HIP_EXTENSION: - if isinstance(A, (MXFP8TensorBase, Float8BlockwiseQTensorBase)): - A = unpad_scales(A, transa) - if isinstance(B, (MXFP8TensorBase, Float8BlockwiseQTensorBase)): - A = unpad_scales(B, transb) if ub_type is not None: assert ub is not None, ( diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 15cb88b00..ed6cfe4c4 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -128,7 +128,7 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType return tex.DType.kFloat8E5M2 -def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: +def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> texs.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -137,6 +137,35 @@ def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: return Format.E5M2.value.max_fwd +def unpad_scales(tensor: torch.Tensor, transpose: bool, block_size: int) -> torch.Tensor: + """Removes padding from scales in Tensors if present""" + if tensor._rowwise_scale_inv is not None: + if transpose: + rows, cols = tensor._rowwise_data.shape[1], tensor._rowwise_data.shape[0] + else: + rows, cols = tensor._rowwise_data.shape[0], tensor._rowwise_data.shape[1] + + actual_scale_shape = tensor._rowwise_scale_inv.shape + expected_scale_shape = (rows, math.ceil(cols / block_size)) + + if actual_scale_shape != expected_scale_shape: + tensor._rowwise_scale_inv = tensor._rowwise_scale_inv[:expected_scale_shape[0], :expected_scale_shape[1]] + + if tensor._columnwise_scale_inv is not None: + if transpose: + rows, cols = tensor._columnwise_data.shape[1], tensor._columnwise_data.shape[0] + else: + rows, cols = tensor._columnwise_data.shape[0], tensor._columnwise_data.shape[1] + + actual_scale_shape = tensor._columnwise_scale_inv.shape + expected_scale_shape = (math.ceil(rows / block_size), cols) + + if actual_scale_shape != expected_scale_shape: + tensor._columnwise_scale_inv = tensor._columnwise_scale_inv[:expected_scale_shape[0], :expected_scale_shape[1]] + + return tensor + + class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 265109d6e..fadcd213a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -32,6 +32,7 @@ Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, + unpad_scales, ) from ..distributed import ( gather_along_first_dim, @@ -1496,6 +1497,11 @@ def _load_from_state_dict( super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) + for name, param in self.named_parameters(recurse=False): + if isinstance(param, MXFP8TensorBase): + unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=32) + elif isinstance(param, Float8BlockwiseQTensorBase): + unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=128) def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook): """ diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 903bc49d5..a56561fe4 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -18,7 +18,12 @@ FP8GlobalStateManager, RecipeState, fp8_autocast, + unpad_scales, ) + +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase + from ..tensor import Quantizer @@ -665,6 +670,11 @@ def _load_from_state_dict(self, *args, **kwargs) -> None: if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) super()._load_from_state_dict(*args, **kwargs) + for name, param in self.named_parameters(recurse=False): + if isinstance(param, MXFP8TensorBase): + unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=32) + elif isinstance(param, Float8BlockwiseQTensorBase): + unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=128) class FusedOperation(FusibleOperation): From 1cfd235c1c9862d4e1bd4604f4430e50505b6b96 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 10 Feb 2026 15:48:27 -0600 Subject: [PATCH 08/10] gemm.py revert --- transformer_engine/pytorch/cpp_extensions/gemm.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 58e60f19f..e4f4e619f 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -5,15 +5,12 @@ """Python interface for GEMM extensions""" from typing import Iterable, Optional, Tuple, Union, List -import math import os import torch import transformer_engine_torch as tex from ..constants import TE_DType from ..utils import get_sm_count, _empty_tensor -from torch.utils.cpp_extension import IS_HIP_EXTENSION - from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -64,7 +61,6 @@ def general_gemm( alpha = validate_gemm_scale(alpha, True) beta = validate_gemm_scale(beta, accumulate) - if ub_type is not None: assert ub is not None, ( f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" From 04a84095310fee0a3d911928bf9cb0889c1fefac Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 10 Feb 2026 16:15:22 -0600 Subject: [PATCH 09/10] HIP guards --- transformer_engine/pytorch/fp8.py | 2 +- transformer_engine/pytorch/module/base.py | 11 ++++++----- transformer_engine/pytorch/ops/op.py | 15 ++++++++++----- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index ed6cfe4c4..d340b9c47 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -128,7 +128,7 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType return tex.DType.kFloat8E5M2 -def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> texs.DType: +def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index fadcd213a..d86e6df87 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1497,11 +1497,12 @@ def _load_from_state_dict( super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) - for name, param in self.named_parameters(recurse=False): - if isinstance(param, MXFP8TensorBase): - unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=32) - elif isinstance(param, Float8BlockwiseQTensorBase): - unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=128) + if IS_HIP_EXTENSION: + for name, param in self.named_parameters(recurse=False): + if isinstance(param, MXFP8TensorBase): + unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=32) + elif isinstance(param, Float8BlockwiseQTensorBase): + unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=128) def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook): """ diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index a56561fe4..c417b8f56 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -13,6 +15,8 @@ import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION + from transformer_engine.common.recipe import Recipe from ..fp8 import ( FP8GlobalStateManager, @@ -670,11 +674,12 @@ def _load_from_state_dict(self, *args, **kwargs) -> None: if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) super()._load_from_state_dict(*args, **kwargs) - for name, param in self.named_parameters(recurse=False): - if isinstance(param, MXFP8TensorBase): - unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=32) - elif isinstance(param, Float8BlockwiseQTensorBase): - unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=128) + if IS_HIP_EXTENSION: + for name, param in self.named_parameters(recurse=False): + if isinstance(param, MXFP8TensorBase): + unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=32) + elif isinstance(param, Float8BlockwiseQTensorBase): + unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=128) class FusedOperation(FusibleOperation): From 8ff5c15f0325045055a8bfde740642030bc5c4ef Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Wed, 11 Feb 2026 11:27:41 -0600 Subject: [PATCH 10/10] Revert checkpointing logic --- transformer_engine/pytorch/fp8.py | 31 +---------------------- transformer_engine/pytorch/module/base.py | 7 ----- transformer_engine/pytorch/ops/op.py | 15 ----------- 3 files changed, 1 insertion(+), 52 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index d340b9c47..15cb88b00 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -137,35 +137,6 @@ def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: return Format.E5M2.value.max_fwd -def unpad_scales(tensor: torch.Tensor, transpose: bool, block_size: int) -> torch.Tensor: - """Removes padding from scales in Tensors if present""" - if tensor._rowwise_scale_inv is not None: - if transpose: - rows, cols = tensor._rowwise_data.shape[1], tensor._rowwise_data.shape[0] - else: - rows, cols = tensor._rowwise_data.shape[0], tensor._rowwise_data.shape[1] - - actual_scale_shape = tensor._rowwise_scale_inv.shape - expected_scale_shape = (rows, math.ceil(cols / block_size)) - - if actual_scale_shape != expected_scale_shape: - tensor._rowwise_scale_inv = tensor._rowwise_scale_inv[:expected_scale_shape[0], :expected_scale_shape[1]] - - if tensor._columnwise_scale_inv is not None: - if transpose: - rows, cols = tensor._columnwise_data.shape[1], tensor._columnwise_data.shape[0] - else: - rows, cols = tensor._columnwise_data.shape[0], tensor._columnwise_data.shape[1] - - actual_scale_shape = tensor._columnwise_scale_inv.shape - expected_scale_shape = (math.ceil(rows / block_size), cols) - - if actual_scale_shape != expected_scale_shape: - tensor._columnwise_scale_inv = tensor._columnwise_scale_inv[:expected_scale_shape[0], :expected_scale_shape[1]] - - return tensor - - class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d86e6df87..265109d6e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -32,7 +32,6 @@ Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, - unpad_scales, ) from ..distributed import ( gather_along_first_dim, @@ -1497,12 +1496,6 @@ def _load_from_state_dict( super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) - if IS_HIP_EXTENSION: - for name, param in self.named_parameters(recurse=False): - if isinstance(param, MXFP8TensorBase): - unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=32) - elif isinstance(param, Float8BlockwiseQTensorBase): - unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=128) def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook): """ diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index c417b8f56..903bc49d5 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -15,19 +13,12 @@ import torch -from torch.utils.cpp_extension import IS_HIP_EXTENSION - from transformer_engine.common.recipe import Recipe from ..fp8 import ( FP8GlobalStateManager, RecipeState, fp8_autocast, - unpad_scales, ) - -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase - from ..tensor import Quantizer @@ -674,12 +665,6 @@ def _load_from_state_dict(self, *args, **kwargs) -> None: if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) super()._load_from_state_dict(*args, **kwargs) - if IS_HIP_EXTENSION: - for name, param in self.named_parameters(recurse=False): - if isinstance(param, MXFP8TensorBase): - unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=32) - elif isinstance(param, Float8BlockwiseQTensorBase): - unpad_scales(param, transpose=getattr(self, "layout", "N")=="T", block_size=128) class FusedOperation(FusibleOperation):