From 001f77c363710e3f62e05c5aacbed4b2ff7c8c97 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:30:00 +0000 Subject: [PATCH 1/9] Initial plan From b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:36:18 +0000 Subject: [PATCH 2/9] Revert "fix: update 1 file reformatted." This reverts commit ff886701c392ab03863c227de14fbe1d671d4173. Co-authored-by: nathon-lee <248585198+nathon-lee@users.noreply.github.com> --- deepspeed/runtime/zero/stage_1_and_2.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 107e47a44042..183fd077f8a9 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -283,11 +283,18 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 + # Check for Muon optimizer usage + self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) + if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + + # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) + if self.reduce_scatter and self.uses_muon: + assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1187,7 +1194,9 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - if not self.reduce_scatter: + # Check if current configuration requires full all-reduce + if not self.reduce_scatter or any(self.group_uses_muon): + # Force full all-reduce for Muon parameters or when reduce_scatter is disabled self.gradient_reduction_w_predivide(tensor, communication_data_type) return From cbc816c90f4bd6e10ab5b67f4d471002ade8cba7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:40:53 +0000 Subject: [PATCH 3/9] Initial plan From 5fcc9a7e4bf58b1d935dcfeab53143d3cf9dbdf7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:43:32 +0000 Subject: [PATCH 4/9] Reapply "fix: update 1 file reformatted." This reverts commit b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35. --- deepspeed/runtime/zero/stage_1_and_2.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 1efea00bcbbd..12f97348a21f 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -284,18 +284,11 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 - # Check for Muon optimizer usage - self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) - if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" - - # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) - if self.reduce_scatter and self.uses_muon: - assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1224,9 +1217,7 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - # Check if current configuration requires full all-reduce - if not self.reduce_scatter or any(self.group_uses_muon): - # Force full all-reduce for Muon parameters or when reduce_scatter is disabled + if not self.reduce_scatter: self.gradient_reduction_w_predivide(tensor, communication_data_type) return From b41bb4cc1b2c364ccf32a0d6477cca4dcbbc1a5a Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Wed, 3 Jun 2026 13:30:24 +0800 Subject: [PATCH 5/9] zero3: fix retained-graph second backward (#7352) Signed-off-by: nathon-lee fix: add ZeRO-3 second backward after retain_graph=True fails with tensor size mismatch Signed-off-by: nathon-lee fix: Stage 3 Temporarily change the exemption from xfail to skip (for this test case only) Signed-off-by: nathon-lee fix: Fix ZeRO-3 behavior for two separate backward passes on the same forward graph. Signed-off-by: nathon-lee --- deepspeed/runtime/base_optimizer.py | 11 +++- deepspeed/runtime/engine.py | 32 +++++----- deepspeed/runtime/zero/parameter_offload.py | 8 ++- deepspeed/runtime/zero/stage3.py | 1 + tests/unit/v1/zero/test_zero_user_backward.py | 60 +++++++++++++++++++ 5 files changed, 94 insertions(+), 18 deletions(-) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index c9dbfd0a4e81..2aa101477577 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -237,6 +237,7 @@ class ZeROOptimizer(DeepSpeedOptimizer): def __init__(self): self._backward_hook_state = BackwardHookStateManager() + self.retain_graph_on_current_backward = False # Delegate backward hook state management to the manager. # These properties provide backward compatibility with code that accesses @@ -399,10 +400,14 @@ def backward(self, loss, **kwargs): scaled_loss = self.backward_prologue(loss) retain_graph = kwargs.pop('retain_graph', False) + self.retain_graph_on_current_backward = retain_graph self.enter_backward() - scaled_loss.backward(retain_graph=retain_graph) - self.backward_epilogue() - self.exit_backward() + try: + scaled_loss.backward(retain_graph=retain_graph) + self.backward_epilogue() + finally: + self.exit_backward() + self.retain_graph_on_current_backward = False def register_grad_acc_post_hook(self, hook): """Register a callback to run when all gradient hooks have fired.""" diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 82a7592f14cb..6c324703e93e 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2745,23 +2745,27 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): # TODO: handle these scaling with direct calls to loss.backward() if isinstance(self.optimizer, ZeROOptimizer): loss = self.optimizer.scale_if_loss(loss) + self.optimizer.retain_graph_on_current_backward = retain_graph elif self.torch_autocast_z0_gradscaler: loss = self.torch_autocast_z0_gradscaler.scale(loss) - with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): - if self.zero_optimization() or not self.amp_enabled(): - loss.backward(**backward_kwargs) - elif self.amp_enabled(): - # AMP requires delaying unscale when inside gradient accumulation boundaries - # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations - delay_unscale = not self.is_gradient_accumulation_boundary() - with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: - scaled_loss.backward(**backward_kwargs) - - # backward_epilogue is not called in a hook when self._support_torch_style_backward is False - self._backward_epilogue() - - self._running_engine_backward = False + try: + with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): + if self.zero_optimization() or not self.amp_enabled(): + loss.backward(**backward_kwargs) + elif self.amp_enabled(): + # AMP requires delaying unscale when inside gradient accumulation boundaries + # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations + delay_unscale = not self.is_gradient_accumulation_boundary() + with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: + scaled_loss.backward(**backward_kwargs) + + # backward_epilogue is not called in a hook when self._support_torch_style_backward is False + self._backward_epilogue() + finally: + self._running_engine_backward = False + if isinstance(self.optimizer, ZeROOptimizer): + self.optimizer.retain_graph_on_current_backward = False return gas_scaled_loss diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index aba0cde6266d..722d43420a6b 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -552,7 +552,13 @@ def post_sub_module_backward_function(self, sub_module): for param in params_to_fetch: param.data = param.data.t() if len(param.ds_shape) != 1 else param.data - self.get_param_coordinator().release_sub_module(sub_module, forward=False) + # Keep gathered params alive when the current backward retains the graph, + # so a second backward over the same forward can reuse valid saved tensors. + zero_optimizer = getattr(self, "zero_optimizer", None) + retain_graph_backward = bool(zero_optimizer is not None + and getattr(zero_optimizer, "retain_graph_on_current_backward", False)) + if not retain_graph_backward: + self.get_param_coordinator().release_sub_module(sub_module, forward=False) see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release", diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 7ebb42905456..fc25c904f181 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -274,6 +274,7 @@ def __init__( zero_module_granularity_threshold=zero_module_granularity_threshold, log_trace_cache_warnings=log_trace_cache_warnings, ) + self.parameter_offload.zero_optimizer = self self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py index b094c22c2fb5..d166efca06b6 100644 --- a/tests/unit/v1/zero/test_zero_user_backward.py +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -539,6 +539,66 @@ def test_separate_loss_function(self, zero_stage): model_engine.destroy() + def test_two_losses_separate_backward_gas1(self, zero_stage): + """Regression test for https://github.com/deepspeedai/DeepSpeed/issues/7352 + + A single forward followed by two separate backward passes, with + zero_grad() in between, must produce independent gradients for each + loss when gradient_accumulation_steps == 1. Previously the first + backward was treated as the accumulation boundary, which froze loss1's + gradients so that zero_grad() had no effect and loss2's gradients were + doubled (grad2 == grad1 + grad2). Each DeepSpeed gradient set is + compared against an equivalent PyTorch DDP baseline. + """ + hidden_dim = 4 + batch_size = 2 + + # Default gradient_accumulation_steps=1, which is the failing case. + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines( + model_class=SimpleOutputModel, zero_stage=zero_stage, hidden_dim=hidden_dim) + + loss_fn = torch.nn.CrossEntropyLoss() + + # Two different targets so loss1 and loss2 yield distinct gradients; + # this makes accidental accumulation (grad1 + grad2) detectable. + torch.manual_seed(456) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + y1 = torch.randint(0, hidden_dim, (batch_size, ), device=device) + y2 = torch.randint(0, hidden_dim, (batch_size, ), device=device) + + # DDP baseline: separate backward for each loss with zero_grad in between. + output_ddp = model_ddp(x) + loss1_ddp = loss_fn(output_ddp, y1) + loss2_ddp = loss_fn(output_ddp, y2) + + optimizer_ddp.zero_grad() + loss1_ddp.backward(retain_graph=True) + grads1_ddp = collect_ddp_gradients(model_ddp) + + optimizer_ddp.zero_grad() + loss2_ddp.backward() + grads2_ddp = collect_ddp_gradients(model_ddp) + + # DeepSpeed: identical sequence. + output_ds = model_engine(x) + loss1_ds = loss_fn(output_ds, y1) + loss2_ds = loss_fn(output_ds, y2) + + model_engine.zero_grad() + model_engine.backward(loss1_ds, retain_graph=True) + grads1_ds = collect_gradients_safe(model_engine) + + model_engine.zero_grad() + model_engine.backward(loss2_ds) + grads2_ds = collect_gradients_safe(model_engine) + + # The second backward must NOT accumulate loss1's gradients on top of + # loss2; both gradient sets must match their DDP counterparts. + compare_gradients(grads1_ddp, grads1_ds, "loss1") + compare_gradients(grads2_ddp, grads2_ds, "loss2") + + model_engine.destroy() + class LeafModuleModel(torch.nn.Module): """Model with ModuleList that uses all parameters - for testing leaf module compatibility""" From 5c75f999c4832a9f4e0be9f6fbd185e9b937eb6e Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Wed, 3 Jun 2026 15:28:08 +0800 Subject: [PATCH 6/9] zero3: cover manual scale().backward() for retain_graph release deferral Signed-off-by: nathon-lee --- deepspeed/runtime/engine.py | 16 ++++- tests/unit/v1/zero/test_zero_user_backward.py | 60 +++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6c324703e93e..4abb54794d16 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2504,6 +2504,11 @@ def _backward_epilogue(self): if not bf16_optimizer: self.optimizer.backward_epilogue() self.optimizer.exit_backward() + # Clear the retained-backward flag here so it is reset for both the + # engine.backward() path and the torch-style manual path + # (engine.scale(loss).backward(retain_graph=True)), which both reach + # this epilogue after the gradient hooks have run. + self.optimizer.retain_graph_on_current_backward = False see_memory_usage("Engine after backward", force=self.memory_breakdown()) self._stop_timers(self.engine_timers.backward_reduce_timers) @@ -2662,7 +2667,7 @@ def _flush_coalesced_reduction_zero3(self, optimizer): optimizer.reduce_ready_partitions_and_remove_grads(param) optimizer.independent_gradient_partition_epilogue() - def scale(self, loss): + def scale(self, loss, retain_graph=False): r"""Apply loss scaler for manual backward pass. Use this method when calling loss.backward() directly instead of engine.backward(). @@ -2680,6 +2685,12 @@ def scale(self, loss): Arguments: loss: Scalar loss tensor to be scaled + retain_graph: bool, default: false + Set to true when the upcoming manual backward keeps the graph alive + (``scaled_loss.backward(retain_graph=True)``) so a second backward can + run over the same forward. For ZeRO-3 this defers parameter release so + the retained graph's saved tensors stay valid, matching the behavior of + ``engine.backward(loss, retain_graph=True)``. Returns: Scaled loss tensor ready for .backward() call @@ -2706,6 +2717,9 @@ def scale(self, loss): scaled_loss = loss if isinstance(self.optimizer, ZeROOptimizer): scaled_loss = self.optimizer.scale_if_loss(loss) + # The manual path bypasses engine.backward(), so propagate retain_graph here. + # Cleared in _backward_epilogue() once the gradient hooks have run. + self.optimizer.retain_graph_on_current_backward = retain_graph elif self.torch_autocast_z0_gradscaler: scaled_loss = self.torch_autocast_z0_gradscaler.scale(loss) diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py index d166efca06b6..333c74a146fe 100644 --- a/tests/unit/v1/zero/test_zero_user_backward.py +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -599,6 +599,66 @@ def test_two_losses_separate_backward_gas1(self, zero_stage): model_engine.destroy() + def test_two_losses_separate_manual_backward_gas1(self, zero_stage): + """Regression test for https://github.com/deepspeedai/DeepSpeed/issues/7352 + + Same scenario as test_two_losses_separate_backward_gas1, but using the + torch-style manual path engine.scale(loss, retain_graph=True).backward( + retain_graph=True) instead of engine.backward(loss, retain_graph=True). + The manual path bypasses engine.backward(), so retain_graph must be + propagated through scale(). For ZeRO-3 this defers parameter release so + the retained graph's saved tensors stay valid for the second backward + over the same forward. + """ + hidden_dim = 4 + batch_size = 2 + + # Default gradient_accumulation_steps=1, which is the failing case. + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines( + model_class=SimpleOutputModel, zero_stage=zero_stage, hidden_dim=hidden_dim) + + loss_fn = torch.nn.CrossEntropyLoss() + + # Two different targets so loss1 and loss2 yield distinct gradients; + # this makes accidental accumulation (grad1 + grad2) detectable. + torch.manual_seed(456) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + y1 = torch.randint(0, hidden_dim, (batch_size, ), device=device) + y2 = torch.randint(0, hidden_dim, (batch_size, ), device=device) + + # DDP baseline: separate backward for each loss with zero_grad in between. + output_ddp = model_ddp(x) + loss1_ddp = loss_fn(output_ddp, y1) + loss2_ddp = loss_fn(output_ddp, y2) + + optimizer_ddp.zero_grad() + loss1_ddp.backward(retain_graph=True) + grads1_ddp = collect_ddp_gradients(model_ddp) + + optimizer_ddp.zero_grad() + loss2_ddp.backward() + grads2_ddp = collect_ddp_gradients(model_ddp) + + # DeepSpeed: identical sequence via the manual scale().backward() path. + output_ds = model_engine(x) + loss1_ds = loss_fn(output_ds, y1) + loss2_ds = loss_fn(output_ds, y2) + + model_engine.zero_grad() + model_engine.scale(loss1_ds, retain_graph=True).backward(retain_graph=True) + grads1_ds = collect_gradients_safe(model_engine) + + model_engine.zero_grad() + model_engine.scale(loss2_ds).backward() + grads2_ds = collect_gradients_safe(model_engine) + + # The second backward must NOT accumulate loss1's gradients on top of + # loss2; both gradient sets must match their DDP counterparts. + compare_gradients(grads1_ddp, grads1_ds, "loss1") + compare_gradients(grads2_ddp, grads2_ds, "loss2") + + model_engine.destroy() + class LeafModuleModel(torch.nn.Module): """Model with ModuleList that uses all parameters - for testing leaf module compatibility""" From 7c5f2693c45936ec1f561ff1fe68818244401863 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Sun, 28 Jun 2026 22:07:07 +0800 Subject: [PATCH 7/9] refactor(zero3): replace offload optimizer back-reference with retain_graph checker callback Signed-off-by: nathon-lee --- deepspeed/runtime/zero/parameter_offload.py | 6 +++--- deepspeed/runtime/zero/stage3.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 722d43420a6b..b500b564778a 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -131,6 +131,7 @@ def __init__( zero_quantized_nontrainable_weights=False, zero_module_granularity_threshold=0, log_trace_cache_warnings=False, + retain_graph_checker=None, ): see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=False) @@ -148,6 +149,7 @@ def __init__( self.zero_quantized_weights = zero_quantized_weights self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights self.log_trace_cache_warnings = log_trace_cache_warnings + self.retain_graph_checker = retain_graph_checker if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none: self.offload_device = offload_param_config.device @@ -554,9 +556,7 @@ def post_sub_module_backward_function(self, sub_module): # Keep gathered params alive when the current backward retains the graph, # so a second backward over the same forward can reuse valid saved tensors. - zero_optimizer = getattr(self, "zero_optimizer", None) - retain_graph_backward = bool(zero_optimizer is not None - and getattr(zero_optimizer, "retain_graph_on_current_backward", False)) + retain_graph_backward = bool(self.retain_graph_checker()) if self.retain_graph_checker is not None else False if not retain_graph_backward: self.get_param_coordinator().release_sub_module(sub_module, forward=False) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index fc25c904f181..fde721d87ecd 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -273,8 +273,8 @@ def __init__( zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, zero_module_granularity_threshold=zero_module_granularity_threshold, log_trace_cache_warnings=log_trace_cache_warnings, + retain_graph_checker=lambda: self.retain_graph_on_current_backward, ) - self.parameter_offload.zero_optimizer = self self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) @@ -558,6 +558,7 @@ def initialize_ds_offload( zero_quantized_nontrainable_weights, zero_module_granularity_threshold, log_trace_cache_warnings, + retain_graph_checker=None, ): return DeepSpeedZeRoOffload(module=module, timers=timers, @@ -576,7 +577,8 @@ def initialize_ds_offload( zero_quantized_weights=zero_quantized_weights, zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, zero_module_granularity_threshold=zero_module_granularity_threshold, - log_trace_cache_warnings=log_trace_cache_warnings) + log_trace_cache_warnings=log_trace_cache_warnings, + retain_graph_checker=retain_graph_checker) def _get_trainable_parameter_groups(self): param_groups = [] From 3cfd15304d109f289176ebb9a441347d9f1217dd Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Sun, 28 Jun 2026 14:27:06 +0000 Subject: [PATCH 8/9] fix: fix some format errs. Signed-off-by: nathon-lee --- tests/unit/v1/zero/test_zero_user_backward.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py index 333c74a146fe..55f2bcb946ad 100644 --- a/tests/unit/v1/zero/test_zero_user_backward.py +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -554,8 +554,9 @@ def test_two_losses_separate_backward_gas1(self, zero_stage): batch_size = 2 # Default gradient_accumulation_steps=1, which is the failing case. - model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines( - model_class=SimpleOutputModel, zero_stage=zero_stage, hidden_dim=hidden_dim) + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines(model_class=SimpleOutputModel, + zero_stage=zero_stage, + hidden_dim=hidden_dim) loss_fn = torch.nn.CrossEntropyLoss() @@ -614,8 +615,9 @@ def test_two_losses_separate_manual_backward_gas1(self, zero_stage): batch_size = 2 # Default gradient_accumulation_steps=1, which is the failing case. - model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines( - model_class=SimpleOutputModel, zero_stage=zero_stage, hidden_dim=hidden_dim) + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines(model_class=SimpleOutputModel, + zero_stage=zero_stage, + hidden_dim=hidden_dim) loss_fn = torch.nn.CrossEntropyLoss() From 563acb4de33d1840a5431b5a0f4753ba47c96246 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Wed, 1 Jul 2026 16:45:16 +0800 Subject: [PATCH 9/9] fix(engine): resolve ZeRO manual backward scale merge and preserve retain_graph flag - resolve merge conflict in engine.scale for ZeRO and autocast grad scaler paths - keep retain_graph propagation on manual backward when using ZeRO optimizer - use scaled_loss consistently in subsequent scaling calls Signed-off-by: nathon-lee fix(engine): resolve ZeRO manual backward scale merge and preserve retain_graph flag Signed-off-by: nathon-lee --- deepspeed/runtime/engine.py | 45 ++++++++++++++----------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 4abb54794d16..863ce30736da 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2504,10 +2504,6 @@ def _backward_epilogue(self): if not bf16_optimizer: self.optimizer.backward_epilogue() self.optimizer.exit_backward() - # Clear the retained-backward flag here so it is reset for both the - # engine.backward() path and the torch-style manual path - # (engine.scale(loss).backward(retain_graph=True)), which both reach - # this epilogue after the gradient hooks have run. self.optimizer.retain_graph_on_current_backward = False see_memory_usage("Engine after backward", force=self.memory_breakdown()) @@ -2686,11 +2682,7 @@ def scale(self, loss, retain_graph=False): Arguments: loss: Scalar loss tensor to be scaled retain_graph: bool, default: false - Set to true when the upcoming manual backward keeps the graph alive - (``scaled_loss.backward(retain_graph=True)``) so a second backward can - run over the same forward. For ZeRO-3 this defers parameter release so - the retained graph's saved tensors stay valid, matching the behavior of - ``engine.backward(loss, retain_graph=True)``. + forward on user defined choice of retain_graph Returns: Scaled loss tensor ready for .backward() call @@ -2717,8 +2709,6 @@ def scale(self, loss, retain_graph=False): scaled_loss = loss if isinstance(self.optimizer, ZeROOptimizer): scaled_loss = self.optimizer.scale_if_loss(loss) - # The manual path bypasses engine.backward(), so propagate retain_graph here. - # Cleared in _backward_epilogue() once the gradient hooks have run. self.optimizer.retain_graph_on_current_backward = retain_graph elif self.torch_autocast_z0_gradscaler: scaled_loss = self.torch_autocast_z0_gradscaler.scale(loss) @@ -2763,23 +2753,22 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): elif self.torch_autocast_z0_gradscaler: loss = self.torch_autocast_z0_gradscaler.scale(loss) - try: - with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): - if self.zero_optimization() or not self.amp_enabled(): - loss.backward(**backward_kwargs) - elif self.amp_enabled(): - # AMP requires delaying unscale when inside gradient accumulation boundaries - # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations - delay_unscale = not self.is_gradient_accumulation_boundary() - with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: - scaled_loss.backward(**backward_kwargs) - - # backward_epilogue is not called in a hook when self._support_torch_style_backward is False - self._backward_epilogue() - finally: - self._running_engine_backward = False - if isinstance(self.optimizer, ZeROOptimizer): - self.optimizer.retain_graph_on_current_backward = False + with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): + if self.zero_optimization() or not self.amp_enabled(): + loss.backward(**backward_kwargs) + elif self.amp_enabled(): + # AMP requires delaying unscale when inside gradient accumulation boundaries + # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations + delay_unscale = not self.is_gradient_accumulation_boundary() + with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: + scaled_loss.backward(**backward_kwargs) + + # backward_epilogue is not called in a hook when self._support_torch_style_backward is False + self._backward_epilogue() + + self._running_engine_backward = False + if isinstance(self.optimizer, ZeROOptimizer): + self.optimizer.retain_graph_on_current_backward = False return gas_scaled_loss