Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0
2.2.0.post0
29 changes: 29 additions & 0 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols

# Only run FP8 tests on supported devices.
Expand Down Expand Up @@ -1265,3 +1266,31 @@ def test_fp8_model_init_high_precision_init_val():
assert not hasattr(
weight, "._high_precision_init_val"
), "clear_high_precision_init_val() not work"


def test_sanity_checkpointing_on_callables():
"""Test that TE checkpointing works correctly on callable modules."""

# torch.autograf.function
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp

@staticmethod
def backward(ctx, grad_output):
return grad_output

module = MyFunction.apply
inp = torch.randn(10, 10, device="cuda", requires_grad=True)

out_checkpoint = checkpoint(module, inp)
out_checkpoint.sum().backward()
grad_checkpoint = inp.grad

out_standard = module(inp)
out_standard.sum().backward()
grad_standard = inp.grad

# Assert that gradients are the same
torch.testing.assert_close(grad_checkpoint, grad_standard)
11 changes: 7 additions & 4 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,13 @@ def checkpoint(
**kwargs,
)

# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False)
setattr(function, "fsdp_group", None)
from .module.base import TransformerEngineBaseModule

if isinstance(function, TransformerEngineBaseModule):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False)
setattr(function, "fsdp_group", None)

# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -44,7 +44,7 @@
from ..triton_kernels.cast import te_quantize_triton

from ..utils import non_tn_fp8_gemm_supported
from ..tensor.float8_tensor import Float8Quantizer
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer

__all__ = ["initialize_ub", "destroy_ub"]

Expand Down Expand Up @@ -1046,7 +1046,7 @@ def get_weight_workspace(
else:
current_quantizer = quantizer

assert isinstance(current_quantizer, Float8Quantizer), "`create_tranpose_buffer=False` only availabe in `Float8Quantizer`."
assert isinstance(current_quantizer, Float8Quantizer) or isinstance(current_quantizer, Float8CurrentScalingQuantizer), f"`create_tranpose_buffer=False` only availabe in `Float8Quantizer`. Not available in {current_quantizer.__class__.__name__}."

# NOTE: Not create transpose buffer internally.
current_quantizer.columnwise_usage = False
Expand Down