From 70c6e4033624db14a6bc0d978aae89fa9a60133e Mon Sep 17 00:00:00 2001 From: sugovind Date: Tue, 27 Jan 2026 04:13:19 +0000 Subject: [PATCH 1/4] Fix quantizer assertion to support Float8CurrentScalingQuantizer in base.py --- transformer_engine/pytorch/module/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 1d88355ca..b5bc947d0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -44,7 +44,7 @@ from ..triton_kernels.cast import te_quantize_triton from ..utils import non_tn_fp8_gemm_supported -from ..tensor.float8_tensor import Float8Quantizer +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer __all__ = ["initialize_ub", "destroy_ub"] @@ -1046,7 +1046,7 @@ def get_weight_workspace( else: current_quantizer = quantizer - assert isinstance(current_quantizer, Float8Quantizer), "`create_tranpose_buffer=False` only availabe in `Float8Quantizer`." + assert isinstance(current_quantizer, Float8Quantizer) or isinstance(current_quantizer, Float8CurrentScalingQuantizer), f"`create_tranpose_buffer=False` only availabe in `Float8Quantizer`. Not available in {current_quantizer.__class__.__name__}." # NOTE: Not create transpose buffer internally. current_quantizer.columnwise_usage = False From 1814c78be3a4bd7edbbce717a56ac35d0d7e7f53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Tue, 15 Apr 2025 09:00:46 +0200 Subject: [PATCH 2/4] [PyTorch] Fix for checkpointing for callables. (#1679) * fix Signed-off-by: Pawel Gadzinski * added test Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test change Signed-off-by: Pawel Gadzinski * changed the test Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_sanity.py | 29 +++++++++++++++++++++++ transformer_engine/pytorch/distributed.py | 11 +++++---- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 403164205..ece123490 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -46,6 +46,7 @@ Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.utils import replace_raw_data +from transformer_engine.pytorch.distributed import checkpoint from test_numerics import reset_rng_states, dtype_tols # Only run FP8 tests on supported devices. @@ -1265,3 +1266,31 @@ def test_fp8_model_init_high_precision_init_val(): assert not hasattr( weight, "._high_precision_init_val" ), "clear_high_precision_init_val() not work" + + +def test_sanity_checkpointing_on_callables(): + """Test that TE checkpointing works correctly on callable modules.""" + + # torch.autograf.function + class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp): + return inp + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + module = MyFunction.apply + inp = torch.randn(10, 10, device="cuda", requires_grad=True) + + out_checkpoint = checkpoint(module, inp) + out_checkpoint.sum().backward() + grad_checkpoint = inp.grad + + out_standard = module(inp) + out_standard.sum().backward() + grad_standard = inp.grad + + # Assert that gradients are the same + torch.testing.assert_close(grad_checkpoint, grad_standard) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 52ef6cbff..4304f8d3e 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -662,10 +662,13 @@ def checkpoint( **kwargs, ) - # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need - # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + from .module.base import TransformerEngineBaseModule + + if isinstance(function, TransformerEngineBaseModule): + # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need + # to scatter/gather activations that we will recompute anyway. + setattr(function, "fsdp_wrapped", False) + setattr(function, "fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing From c49fe342963a4d0efbce1244af99ffe716a53fdf Mon Sep 17 00:00:00 2001 From: sugovind Date: Wed, 11 Feb 2026 19:33:09 +0000 Subject: [PATCH 3/4] Update copyright year in base.py to reflect 2026 --- transformer_engine/pytorch/module/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b5bc947d0..669869595 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From 30aebedb39a2ceb5b9bd2c98488677a01cb52365 Mon Sep 17 00:00:00 2001 From: sugovind Date: Thu, 12 Feb 2026 16:44:14 +0000 Subject: [PATCH 4/4] Update VERSION.txt to version 2.2.0.post0 --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index ccbccc3dc..a8d39b550 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.2.0 +2.2.0.post0