Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
001f77c
Initial plan
Copilot Feb 27, 2026
b90aee5
Revert "fix: update 1 file reformatted."
Copilot Feb 27, 2026
b6da9af
Merge pull request #5 from nathon-lee/copilot/git-revert-ff886701
nathon-lee Feb 27, 2026
bb7f64f
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 6, 2026
cbc816c
Initial plan
Copilot Mar 6, 2026
5fcc9a7
Reapply "fix: update 1 file reformatted."
Copilot Mar 6, 2026
f7c5d75
Merge pull request #6 from nathon-lee/copilot/remove-commits-from-master
nathon-lee Mar 6, 2026
18efbcc
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 25, 2026
e2ac74d
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 27, 2026
da07382
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 30, 2026
5d8875c
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 31, 2026
316b6dd
Merge branch 'deepspeedai:master' into master
nathon-lee Apr 1, 2026
2020543
Merge branch 'deepspeedai:master' into master
nathon-lee Apr 2, 2026
1a8694c
Merge branch 'deepspeedai:master' into master
nathon-lee Apr 16, 2026
d6725be
Merge branch 'deepspeedai:master' into master
nathon-lee Apr 23, 2026
a06c548
Merge branch 'deepspeedai:master' into master
nathon-lee May 1, 2026
6959eb4
Merge branch 'deepspeedai:master' into master
nathon-lee May 5, 2026
e88eb3e
Merge branch 'deepspeedai:master' into master
nathon-lee May 7, 2026
683bd0b
Merge branch 'deepspeedai:master' into master
nathon-lee Jun 3, 2026
b41bb4c
zero3: fix retained-graph second backward (#7352)
nathon-lee Jun 3, 2026
5c75f99
zero3: cover manual scale().backward() for retain_graph release deferral
nathon-lee Jun 3, 2026
7c5f269
refactor(zero3): replace offload optimizer back-reference with retain…
nathon-lee Jun 28, 2026
3cfd153
fix: fix some format errs.
nathon-lee Jun 28, 2026
f53eaf0
Merge pull request #21 from nathon-lee/test/zero-multi-loss-separate-…
nathon-lee Jun 30, 2026
563acb4
fix(engine): resolve ZeRO manual backward scale merge and preserve re…
nathon-lee Jul 1, 2026
4849fd9
Merge branch 'master' into test/zero-multi-loss-separate-backward-7352
nathon-lee Jul 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class ZeROOptimizer(DeepSpeedOptimizer):

def __init__(self):
self._backward_hook_state = BackwardHookStateManager()
self.retain_graph_on_current_backward = False

# Delegate backward hook state management to the manager.
# These properties provide backward compatibility with code that accesses
Expand Down Expand Up @@ -419,10 +420,14 @@ def backward(self, loss, **kwargs):

scaled_loss = self.backward_prologue(loss)
retain_graph = kwargs.pop('retain_graph', False)
self.retain_graph_on_current_backward = retain_graph
self.enter_backward()
scaled_loss.backward(retain_graph=retain_graph)
self.backward_epilogue()
self.exit_backward()
try:
scaled_loss.backward(retain_graph=retain_graph)
self.backward_epilogue()
finally:
self.exit_backward()
self.retain_graph_on_current_backward = False

def register_grad_acc_post_hook(self, hook):
"""Register a callback to run when all gradient hooks have fired."""
Expand Down
10 changes: 9 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2754,6 +2754,7 @@ def _backward_epilogue(self):
if not bf16_optimizer:
self.optimizer.backward_epilogue()
self.optimizer.exit_backward()
self.optimizer.retain_graph_on_current_backward = False

if self.is_deepcompile_active():
deepcompile_backward_epilogue()
Expand Down Expand Up @@ -2915,7 +2916,7 @@ def _flush_coalesced_reduction_zero3(self, optimizer):
optimizer.reduce_ready_partitions_and_remove_grads(param)
optimizer.independent_gradient_partition_epilogue()

def scale(self, loss):
def scale(self, loss, retain_graph=False):
r"""Apply loss scaler for manual backward pass.

Use this method when calling loss.backward() directly instead of engine.backward().
Expand All @@ -2933,6 +2934,8 @@ def scale(self, loss):

Arguments:
loss: Scalar loss tensor to be scaled
retain_graph: bool, default: false
forward on user defined choice of retain_graph

Returns:
Scaled loss tensor ready for .backward() call
Expand All @@ -2958,6 +2961,8 @@ def scale(self, loss):
# Apply loss scaler based on optimizer type
scaled_loss = loss
if isinstance(self.optimizer, ZeROOptimizer):
scaled_loss = self.optimizer.scale_if_loss(loss)
self.optimizer.retain_graph_on_current_backward = retain_graph
scaled_loss = self.optimizer.scale_if_loss(scaled_loss)
elif self.torch_autocast_z0_gradscaler:
scaled_loss = self.torch_autocast_z0_gradscaler.scale(scaled_loss)
Expand Down Expand Up @@ -2998,6 +3003,7 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
# TODO: handle these scaling with direct calls to loss.backward()
if isinstance(self.optimizer, ZeROOptimizer):
loss = self.optimizer.scale_if_loss(loss)
self.optimizer.retain_graph_on_current_backward = retain_graph
elif self.torch_autocast_z0_gradscaler:
loss = self.torch_autocast_z0_gradscaler.scale(loss)

Expand All @@ -3015,6 +3021,8 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
self._backward_epilogue()

self._running_engine_backward = False
if isinstance(self.optimizer, ZeROOptimizer):
self.optimizer.retain_graph_on_current_backward = False

return gas_scaled_loss

Expand Down
8 changes: 7 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
log_trace_cache_warnings=False,
retain_graph_checker=None,
):

see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=False)
Expand All @@ -153,6 +154,7 @@ def __init__(
self.zero_quantized_weights = zero_quantized_weights
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
self.log_trace_cache_warnings = log_trace_cache_warnings
self.retain_graph_checker = retain_graph_checker

if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
self.offload_device = offload_param_config.device
Expand Down Expand Up @@ -563,7 +565,11 @@ def post_sub_module_backward_function(self, sub_module):
for param in params_to_fetch:
param.data = param.data.t() if len(param.ds_shape) != 1 else param.data

self.get_param_coordinator().release_sub_module(sub_module, forward=False)
# Keep gathered params alive when the current backward retains the graph,
# so a second backward over the same forward can reuse valid saved tensors.
retain_graph_backward = bool(self.retain_graph_checker()) if self.retain_graph_checker is not None else False
if not retain_graph_backward:
self.get_param_coordinator().release_sub_module(sub_module, forward=False)

see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold,
log_trace_cache_warnings=log_trace_cache_warnings,
retain_graph_checker=lambda: self.retain_graph_on_current_backward,
)

self.persistent_parameters = self.parameter_offload.persistent_parameters
Expand Down Expand Up @@ -571,6 +572,7 @@ def initialize_ds_offload(
zero_quantized_nontrainable_weights,
zero_module_granularity_threshold,
log_trace_cache_warnings,
retain_graph_checker=None,
):
return DeepSpeedZeRoOffload(module=module,
timers=timers,
Expand All @@ -589,7 +591,8 @@ def initialize_ds_offload(
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold,
log_trace_cache_warnings=log_trace_cache_warnings)
log_trace_cache_warnings=log_trace_cache_warnings,
retain_graph_checker=retain_graph_checker)

def _get_param_partition_group(self, param):
return getattr(param, "ds_process_group", self.dp_process_group)
Expand Down
122 changes: 122 additions & 0 deletions tests/unit/v1/zero/test_zero_user_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,128 @@ def test_separate_loss_function(self, zero_stage):

model_engine.destroy()

def test_two_losses_separate_backward_gas1(self, zero_stage):
"""Regression test for https://github.com/deepspeedai/DeepSpeed/issues/7352

A single forward followed by two separate backward passes, with
zero_grad() in between, must produce independent gradients for each
loss when gradient_accumulation_steps == 1. Previously the first
backward was treated as the accumulation boundary, which froze loss1's
gradients so that zero_grad() had no effect and loss2's gradients were
doubled (grad2 == grad1 + grad2). Each DeepSpeed gradient set is
compared against an equivalent PyTorch DDP baseline.
"""
hidden_dim = 4
batch_size = 2

# Default gradient_accumulation_steps=1, which is the failing case.
model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines(model_class=SimpleOutputModel,
zero_stage=zero_stage,
hidden_dim=hidden_dim)

loss_fn = torch.nn.CrossEntropyLoss()

# Two different targets so loss1 and loss2 yield distinct gradients;
# this makes accidental accumulation (grad1 + grad2) detectable.
torch.manual_seed(456)
x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype)
y1 = torch.randint(0, hidden_dim, (batch_size, ), device=device)
y2 = torch.randint(0, hidden_dim, (batch_size, ), device=device)

# DDP baseline: separate backward for each loss with zero_grad in between.
output_ddp = model_ddp(x)
loss1_ddp = loss_fn(output_ddp, y1)
loss2_ddp = loss_fn(output_ddp, y2)

optimizer_ddp.zero_grad()
loss1_ddp.backward(retain_graph=True)
grads1_ddp = collect_ddp_gradients(model_ddp)

optimizer_ddp.zero_grad()
loss2_ddp.backward()
grads2_ddp = collect_ddp_gradients(model_ddp)

# DeepSpeed: identical sequence.
output_ds = model_engine(x)
loss1_ds = loss_fn(output_ds, y1)
loss2_ds = loss_fn(output_ds, y2)

model_engine.zero_grad()
model_engine.backward(loss1_ds, retain_graph=True)
grads1_ds = collect_gradients_safe(model_engine)

model_engine.zero_grad()
model_engine.backward(loss2_ds)
grads2_ds = collect_gradients_safe(model_engine)

# The second backward must NOT accumulate loss1's gradients on top of
# loss2; both gradient sets must match their DDP counterparts.
compare_gradients(grads1_ddp, grads1_ds, "loss1")
compare_gradients(grads2_ddp, grads2_ds, "loss2")

model_engine.destroy()

def test_two_losses_separate_manual_backward_gas1(self, zero_stage):
"""Regression test for https://github.com/deepspeedai/DeepSpeed/issues/7352

Same scenario as test_two_losses_separate_backward_gas1, but using the
torch-style manual path engine.scale(loss, retain_graph=True).backward(
retain_graph=True) instead of engine.backward(loss, retain_graph=True).
The manual path bypasses engine.backward(), so retain_graph must be
propagated through scale(). For ZeRO-3 this defers parameter release so
the retained graph's saved tensors stay valid for the second backward
over the same forward.
"""
hidden_dim = 4
batch_size = 2

# Default gradient_accumulation_steps=1, which is the failing case.
model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines(model_class=SimpleOutputModel,
zero_stage=zero_stage,
hidden_dim=hidden_dim)

loss_fn = torch.nn.CrossEntropyLoss()

# Two different targets so loss1 and loss2 yield distinct gradients;
# this makes accidental accumulation (grad1 + grad2) detectable.
torch.manual_seed(456)
x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype)
y1 = torch.randint(0, hidden_dim, (batch_size, ), device=device)
y2 = torch.randint(0, hidden_dim, (batch_size, ), device=device)

# DDP baseline: separate backward for each loss with zero_grad in between.
output_ddp = model_ddp(x)
loss1_ddp = loss_fn(output_ddp, y1)
loss2_ddp = loss_fn(output_ddp, y2)

optimizer_ddp.zero_grad()
loss1_ddp.backward(retain_graph=True)
grads1_ddp = collect_ddp_gradients(model_ddp)

optimizer_ddp.zero_grad()
loss2_ddp.backward()
grads2_ddp = collect_ddp_gradients(model_ddp)

# DeepSpeed: identical sequence via the manual scale().backward() path.
output_ds = model_engine(x)
loss1_ds = loss_fn(output_ds, y1)
loss2_ds = loss_fn(output_ds, y2)

model_engine.zero_grad()
model_engine.scale(loss1_ds, retain_graph=True).backward(retain_graph=True)
grads1_ds = collect_gradients_safe(model_engine)

model_engine.zero_grad()
model_engine.scale(loss2_ds).backward()
grads2_ds = collect_gradients_safe(model_engine)

# The second backward must NOT accumulate loss1's gradients on top of
# loss2; both gradient sets must match their DDP counterparts.
compare_gradients(grads1_ddp, grads1_ds, "loss1")
compare_gradients(grads2_ddp, grads2_ds, "loss2")

model_engine.destroy()


class LeafModuleModel(torch.nn.Module):
"""Model with ModuleList that uses all parameters - for testing leaf module compatibility"""
Expand Down
Loading