diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index ccbccc3dc..a8d39b550 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.2.0 +2.2.0.post0 diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 403164205..ece123490 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -46,6 +46,7 @@ Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.utils import replace_raw_data +from transformer_engine.pytorch.distributed import checkpoint from test_numerics import reset_rng_states, dtype_tols # Only run FP8 tests on supported devices. @@ -1265,3 +1266,31 @@ def test_fp8_model_init_high_precision_init_val(): assert not hasattr( weight, "._high_precision_init_val" ), "clear_high_precision_init_val() not work" + + +def test_sanity_checkpointing_on_callables(): + """Test that TE checkpointing works correctly on callable modules.""" + + # torch.autograf.function + class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp): + return inp + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + module = MyFunction.apply + inp = torch.randn(10, 10, device="cuda", requires_grad=True) + + out_checkpoint = checkpoint(module, inp) + out_checkpoint.sum().backward() + grad_checkpoint = inp.grad + + out_standard = module(inp) + out_standard.sum().backward() + grad_standard = inp.grad + + # Assert that gradients are the same + torch.testing.assert_close(grad_checkpoint, grad_standard) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 52ef6cbff..4304f8d3e 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -662,10 +662,13 @@ def checkpoint( **kwargs, ) - # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need - # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + from .module.base import TransformerEngineBaseModule + + if isinstance(function, TransformerEngineBaseModule): + # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need + # to scatter/gather activations that we will recompute anyway. + setattr(function, "fsdp_wrapped", False) + setattr(function, "fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 1d88355ca..669869595 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -44,7 +44,7 @@ from ..triton_kernels.cast import te_quantize_triton from ..utils import non_tn_fp8_gemm_supported -from ..tensor.float8_tensor import Float8Quantizer +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer __all__ = ["initialize_ub", "destroy_ub"] @@ -1046,7 +1046,7 @@ def get_weight_workspace( else: current_quantizer = quantizer - assert isinstance(current_quantizer, Float8Quantizer), "`create_tranpose_buffer=False` only availabe in `Float8Quantizer`." + assert isinstance(current_quantizer, Float8Quantizer) or isinstance(current_quantizer, Float8CurrentScalingQuantizer), f"`create_tranpose_buffer=False` only availabe in `Float8Quantizer`. Not available in {current_quantizer.__class__.__name__}." # NOTE: Not create transpose buffer internally. current_quantizer.columnwise_usage = False